diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index b59f300a..dadda83a 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -6,7 +6,7 @@ import os import threading import uuid from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import requests from dotenv import load_dotenv @@ -200,7 +200,7 @@ class EmbedChain(JSONSerializable): data_formatter = DataFormatter(data_type, config) self.user_asks.append([source, data_type.value, metadata]) - documents, metadatas, _ids, new_chunks = self.load_and_embed_v2( + documents, metadatas, _ids, new_chunks = self.load_and_embed( data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run ) if data_type in {DataType.DOCS_SITE}: @@ -255,92 +255,6 @@ class EmbedChain(JSONSerializable): ) return self.add(source=source, data_type=data_type, metadata=metadata, config=config) - def load_and_embed( - self, - loader: BaseLoader, - chunker: BaseChunker, - src: Any, - metadata: Optional[Dict[str, Any]] = None, - source_id: Optional[str] = None, - dry_run=False, - ) -> Tuple[List[str], Dict[str, Any], List[str], int]: - """The loader to use to load the data. - - :param loader: The loader to use to load the data. - :type loader: BaseLoader - :param chunker: The chunker to use to chunk the data. - :type chunker: BaseChunker - :param src: The data to be handled by the loader. - Can be a URL for remote sources or local content for local loaders. - :type src: Any - :param metadata: Metadata associated with the data source., defaults to None - :type metadata: Dict[str, Any], optional - :param source_id: Hexadecimal hash of the source., defaults to None - :type source_id: str, optional - :param dry_run: Optional. A dry run returns chunks and doesn't update DB. - :type dry_run: bool, defaults to False - :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks - :rtype: Tuple[List[str], Dict[str, Any], List[str], int] - """ - embeddings_data = chunker.create_chunks(loader, src) - - # spread chunking results - documents = embeddings_data["documents"] - metadatas = embeddings_data["metadatas"] - ids = embeddings_data["ids"] - - # get existing ids, and discard doc if any common id exist. - where = {"app_id": self.config.id} if self.config.id is not None else {} - # where={"url": src} - db_result = self.db.get( - ids=ids, - where=where, # optional filter - ) - existing_ids = set(db_result["ids"]) - - if len(existing_ids): - data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} - data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids} - - if not data_dict: - src_copy = src - if len(src_copy) > 50: - src_copy = src[:50] + "..." - print(f"All data from {src_copy} already exists in the database.") - # Make sure to return a matching return type - return [], [], [], 0 - - ids = list(data_dict.keys()) - documents, metadatas = zip(*data_dict.values()) - - # Loop though all metadatas and add extras. - new_metadatas = [] - for m in metadatas: - # Add app id in metadatas so that they can be queried on later - if self.config.id: - m["app_id"] = self.config.id - - # Add hashed source - m["hash"] = source_id - - # Note: Metadata is the function argument - if metadata: - # Spread whatever is in metadata into the new object. - m.update(metadata) - - new_metadatas.append(m) - metadatas = new_metadatas - - if dry_run: - return list(documents), metadatas, ids, 0 - - # Count before, to calculate a delta in the end. - chunks_before_addition = self.db.count() - - self.db.add(documents=documents, metadatas=metadatas, ids=ids) - count_new_chunks = self.db.count() - chunks_before_addition - print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) - return list(documents), metadatas, ids, count_new_chunks def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): """ @@ -392,7 +306,7 @@ class EmbedChain(JSONSerializable): "When it should be DirectDataType, IndirectDataType or SpecialDataType." ) - def load_and_embed_v2( + def load_and_embed( self, loader: BaseLoader, chunker: BaseChunker,