diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 7130d30b..76b38f38 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -41,7 +41,6 @@ class BaseChunker(JSONSerializable): url = meta_data["url"] chunks = self.get_chunks(content) - for chunk in chunks: chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest() chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index fc985efd..1bff1e62 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -1,5 +1,5 @@ from importlib import import_module -from typing import Any, Dict +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig @@ -16,7 +16,13 @@ class DataFormatter(JSONSerializable): .add or .add_local method call """ - def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]): + def __init__( + self, + data_type: DataType, + config: AddConfig, + loader: Optional[BaseLoader] = None, + chunker: Optional[BaseChunker] = None, + ): """ Initialize a dataformatter, set data type and chunker based on datatype. @@ -25,15 +31,15 @@ class DataFormatter(JSONSerializable): :param config: AddConfig instance with nested loader and chunker config attributes. :type config: AddConfig """ - self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs) - self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs) + self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader) + self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker) def _lazy_load(self, module_path: str): module_path, class_name = module_path.rsplit(".", 1) module = import_module(module_path) return getattr(module, class_name) - def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader: + def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader: """ Returns the appropriate data loader for the given data type. @@ -68,8 +74,8 @@ class DataFormatter(JSONSerializable): DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader", } - if data_type == DataType.CUSTOM or ("loader" in kwargs): - loader_class: type = kwargs.get("loader", None) + if data_type == DataType.CUSTOM or loader is not None: + loader_class: type = loader if loader_class: return loader_class elif data_type in loaders: @@ -82,7 +88,7 @@ class DataFormatter(JSONSerializable): check `https://docs.embedchain.ai/data-sources/overview`." ) - def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker: + def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker: """Returns the appropriate chunker for the given data type (updated for lazy loading).""" chunker_classes = { DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker", @@ -108,12 +114,8 @@ class DataFormatter(JSONSerializable): DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker", } - if "chunker" in kwargs: - chunker_class = kwargs.get("chunker", None) - if chunker_class: - chunker = chunker_class(config) - chunker.set_data_type(data_type) - return chunker + if chunker is not None: + return chunker elif data_type in chunker_classes: chunker_class = self._lazy_load(chunker_classes[data_type]) chunker = chunker_class(config) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 76cdea79..813bcf5c 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable): metadata: Optional[Dict[str, Any]] = None, config: Optional[AddConfig] = None, dry_run=False, - **kwargs: Dict[str, Any], + loader: Optional[BaseLoader] = None, + chunker: Optional[BaseChunker] = None, + **kwargs: Optional[Dict[str, Any]], ): """ Adds the data from the given URL to the vector db. @@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable): self.user_asks.append([source, data_type.value, metadata]) - data_formatter = DataFormatter(data_type, config, kwargs) + data_formatter = DataFormatter(data_type, config, loader, chunker) documents, metadatas, _ids, new_chunks = self._load_and_embed( - data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run + data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs ) if data_type in {DataType.DOCS_SITE}: self.is_docs_site_instance = True @@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable): data_type: Optional[DataType] = None, metadata: Optional[Dict[str, Any]] = None, config: Optional[AddConfig] = None, - **kwargs: Dict[str, Any], + **kwargs: Optional[Dict[str, Any]], ): """ Adds the data from the given URL to the vector db. @@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable): data_type=data_type, metadata=metadata, config=config, - kwargs=kwargs, + **kwargs, ) def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): @@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable): metadata: Optional[Dict[str, Any]] = None, source_hash: Optional[str] = None, dry_run=False, + **kwargs: Optional[Dict[str, Any]], ): """ Loads the data from the given URL, chunks it, and adds it to database. @@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable): metadatas=metadatas, ids=ids, skip_embedding=(chunker.data_type == DataType.IMAGES), + **kwargs, ) count_new_chunks = self.db.count() - chunks_before_addition @@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable): ] def _retrieve_from_database( - self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False + self, + input_query: str, + config: Optional[BaseLlmConfig] = None, + where=None, + citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ Queries the vector database based on the given input query. @@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable): where=where, skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), citations=citations, + **kwargs, ) return contexts @@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable): or the dry run result :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ - citations = kwargs.get("citations", False) + if "citations" in kwargs: + citations = kwargs.pop("citations") + else: + citations = False + contexts = self._retrieve_from_database( - input_query=input_query, config=config, where=where, citations=citations + input_query=input_query, config=config, where=where, citations=citations, **kwargs ) if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) @@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable): or the dry run result :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ - citations = kwargs.get("citations", False) + if "citations" in kwargs: + citations = kwargs.pop("citations") + else: + citations = False + contexts = self._retrieve_from_database( - input_query=input_query, config=config, where=where, citations=citations + input_query=input_query, config=config, where=where, citations=citations, **kwargs ) if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) diff --git a/embedchain/loaders/github.py b/embedchain/loaders/github.py index e19ab2cc..fcb6af87 100644 --- a/embedchain/loaders/github.py +++ b/embedchain/loaders/github.py @@ -196,7 +196,6 @@ class GithubLoader(BaseLoader): logging.info(f"Total repos found: {repos_results.totalCount}") for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"): teams = repo_result.get_teams() - # import pdb; pdb.set_trace() for team in teams: team_discussions = team.get_discussions() for discussion in team_discussions: diff --git a/embedchain/utils.py b/embedchain/utils.py index 1f5309df..ba6efd21 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -1,3 +1,4 @@ +import itertools import json import logging import os @@ -6,6 +7,7 @@ import string from typing import Any from schema import Optional, Or, Schema +from tqdm import tqdm from embedchain.models.data_type import DataType @@ -422,3 +424,16 @@ def validate_config(config_data): ) return schema.validate(config_data) + + +def chunks(iterable, batch_size=100, desc="Processing chunks"): + """A helper function to break an iterable into chunks of size batch_size.""" + it = iter(iterable) + total_size = len(iterable) + + with tqdm(total=total_size, desc=desc, unit="batch") as pbar: + chunk = tuple(itertools.islice(it, batch_size)) + while chunk: + yield chunk + pbar.update(len(chunk)) + chunk = tuple(itertools.islice(it, batch_size)) diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index f32cf525..86c5ff38 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -133,6 +133,7 @@ class ChromaDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, Any]], ) -> Any: """ Add vectors to chroma database @@ -198,6 +199,7 @@ class ChromaDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ Query contents from vector database based on vector similarity @@ -225,6 +227,7 @@ class ChromaDB(BaseVectorDB): ], n_results=n_results, where=self._generate_where_clause(where), + **kwargs, ) else: result = self.collection.query( @@ -233,6 +236,7 @@ class ChromaDB(BaseVectorDB): ], n_results=n_results, where=self._generate_where_clause(where), + **kwargs, ) except InvalidDimensionException as e: raise InvalidDimensionException( diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index 5ae6fd7c..e3b25042 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -105,6 +105,7 @@ class ElasticsearchDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, any]], ) -> Any: """ add data in vector database @@ -142,6 +143,7 @@ class ElasticsearchDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector data base based on vector similarity diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 8ba3be1e..da51b600 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -1,6 +1,6 @@ import logging import time -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from tqdm import tqdm @@ -121,6 +121,7 @@ class OpenSearchDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, any]], ): """Add data in vector database. @@ -154,7 +155,7 @@ class OpenSearchDB(BaseVectorDB): ] # Perform bulk operation - bulk(self.client, batch_entries) + bulk(self.client, batch_entries, **kwargs) self.client.indices.refresh(index=self._get_index()) # Sleep to avoid rate limiting @@ -167,6 +168,7 @@ class OpenSearchDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector data base based on vector similarity @@ -209,6 +211,7 @@ class OpenSearchDB(BaseVectorDB): metadata_field="metadata", pre_filter=pre_filter, k=n_results, + **kwargs, ) contexts = [] diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index c3420c09..a794ccb8 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -10,6 +10,7 @@ except ImportError: from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.helpers.json_serializable import register_deserializable +from embedchain.utils import chunks from embedchain.vectordb.base import BaseVectorDB @@ -92,6 +93,7 @@ class PineconeDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, any]], ): """add data in vector database @@ -104,7 +106,6 @@ class PineconeDB(BaseVectorDB): """ docs = [] print("Adding documents to Pinecone...") - embeddings = self.embedder.embedding_fn(documents) for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): docs.append( @@ -115,8 +116,8 @@ class PineconeDB(BaseVectorDB): } ) - for i in range(0, len(docs), self.BATCH_SIZE): - self.client.upsert(docs[i : i + self.BATCH_SIZE]) + for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."): + self.client.upsert(chunk, **kwargs) def query( self, @@ -125,6 +126,7 @@ class PineconeDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector database based on vector similarity @@ -146,7 +148,7 @@ class PineconeDB(BaseVectorDB): query_vector = self.embedder.embedding_fn([input_query])[0] else: query_vector = input_query - data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) + data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs) contexts = [] for doc in data["matches"]: metadata = doc["metadata"] diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 3fe3888b..6656b5d7 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -1,7 +1,7 @@ import copy import os import uuid -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union try: from qdrant_client import QdrantClient @@ -127,6 +127,7 @@ class QdrantDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, any]], ): """add data in vector database :param embeddings: list of embeddings for the corresponding documents to be added @@ -158,6 +159,7 @@ class QdrantDB(BaseVectorDB): payloads=payloads[i : i + self.BATCH_SIZE], vectors=embeddings[i : i + self.BATCH_SIZE], ), + **kwargs, ) def query( @@ -167,6 +169,7 @@ class QdrantDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector database based on vector similarity @@ -208,6 +211,7 @@ class QdrantDB(BaseVectorDB): query_filter=models.Filter(must=qdrant_must_filters), query_vector=query_vector, limit=n_results, + **kwargs, ) contexts = [] diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 6ff329cb..ac3d9b57 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -1,6 +1,6 @@ import copy import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union try: import weaviate @@ -158,6 +158,7 @@ class WeaviateDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, any]], ): """add data in vector database :param embeddings: list of embeddings for the corresponding documents to be added @@ -192,7 +193,9 @@ class WeaviateDB(BaseVectorDB): class_name=self.index_name + "_metadata", vector=embedding, ) - batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") + batch.add_reference( + obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs + ) def query( self, @@ -201,6 +204,7 @@ class WeaviateDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector database based on vector similarity diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index 0608c12f..e4806817 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from embedchain.config import ZillizDBConfig from embedchain.helpers.json_serializable import register_deserializable @@ -113,6 +113,7 @@ class ZillizVectorDB(BaseVectorDB): metadatas: List[object], ids: List[str], skip_embedding: bool, + **kwargs: Optional[Dict[str, any]], ): """Add to database""" if not skip_embedding: @@ -120,7 +121,7 @@ class ZillizVectorDB(BaseVectorDB): for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings): data = {**metadata, "id": id, "text": doc, "embeddings": embedding} - self.client.insert(collection_name=self.config.collection_name, data=data) + self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs) self.collection.load() self.collection.flush() @@ -133,6 +134,7 @@ class ZillizVectorDB(BaseVectorDB): where: Dict[str, any], skip_embedding: bool, citations: bool = False, + **kwargs: Optional[Dict[str, Any]], ) -> Union[List[Tuple[str, str, str]], List[str]]: """ Query contents from vector data base based on vector similarity @@ -165,6 +167,7 @@ class ZillizVectorDB(BaseVectorDB): data=query_vector, limit=n_results, output_fields=output_fields, + **kwargs, ) else: @@ -176,6 +179,7 @@ class ZillizVectorDB(BaseVectorDB): data=[query_vector], limit=n_results, output_fields=output_fields, + **kwargs, ) contexts = [] diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py index bf0a485d..cc0b4b14 100644 --- a/tests/vectordb/test_pinecone.py +++ b/tests/vectordb/test_pinecone.py @@ -57,11 +57,11 @@ class TestPinecone: db.add(vectors, documents, metadatas, ids, True) expected_pinecone_upsert_args = [ - {"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]}, - {"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]}, + {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}}, + {"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}}, ] # Assert that the Pinecone client was called to upsert the documents - pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args) + pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args)) @patch("embedchain.vectordb.pinecone.pinecone") def test_query_documents(self, pinecone_mock):