diff --git a/embedchain/config/InitConfig.py b/embedchain/config/InitConfig.py index 47fe7e64..26990879 100644 --- a/embedchain/config/InitConfig.py +++ b/embedchain/config/InitConfig.py @@ -1,32 +1,34 @@ import logging import os - from chromadb.utils import embedding_functions - from embedchain.config.BaseConfig import BaseConfig - class InitConfig(BaseConfig): """ Config to initialize an embedchain `App` instance. """ - - def __init__(self, log_level=None, ef=None, db=None, host=None, port=None): + def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. :param ef: Optional. Embedding function to use. :param db: Optional. (Vector) database to use for embeddings. + :param id: Optional. ID of the app. Document metadata will have this id. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. """ self._setup_logging(log_level) + if db is None: + from embedchain.vectordb.chroma_db import ChromaDB + self.db = ChromaDB(ef=self.ef) + else: + self.db = db + self.ef = ef - self.db = db - self.host = host self.port = port + self.id = id return def _set_embedding_function(self, ef): diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 0327675a..d7bd55f3 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -97,9 +97,11 @@ class EmbedChain: metadatas = embeddings_data["metadatas"] ids = embeddings_data["ids"] # get existing ids, and discard doc if any common id exist. + where={"app_id": self.config.id} if self.config.id is not None else {} + # where={"url": src} existing_docs = self.collection.get( ids=ids, - # where={"url": src} + where=where, # optional filter ) existing_ids = set(existing_docs["ids"]) @@ -113,6 +115,10 @@ class EmbedChain: ids = list(data_dict.keys()) documents, metadatas = zip(*data_dict.values()) + + # Add app id in metadatas so that they can be queried on later + if (self.config.id is not None): + metadatas = [{**m, "app_id": self.config.id} for m in metadatas] chunks_before_addition = self.count() @@ -144,11 +150,11 @@ class EmbedChain: :param config: The query configuration. :return: The content of the document that matched your query. """ + where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter result = self.collection.query( - query_texts=[ - input_query, - ], + query_texts=[input_query,], n_results=config.number_documents, + where=where, ) results_formatted = self._format_result(result) contents = [result[0].page_content for result in results_formatted] diff --git a/setup.py b/setup.py index 961a02ac..49a41982 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ setuptools.setup( "gpt4all", "sentence_transformers", "docx2txt", - "pydantic==1.10.8", + "pydantic==1.10.8" ], extras_require={"dev": ["black", "ruff", "isort", "pytest"]}, )