diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 09fb8805..36e58064 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -10,7 +10,7 @@ class BaseChunker(JSONSerializable): self.text_splitter = text_splitter self.data_type = None - def create_chunks(self, loader, src): + def create_chunks(self, loader, src, app_id=None): """ Loads data and chunks it. @@ -18,13 +18,18 @@ class BaseChunker(JSONSerializable): the raw data. :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders. + :param app_id: App id used to generate the doc_id. """ documents = [] - ids = [] + chunk_ids = [] idMap = {} data_result = loader.load_data(src) data_records = data_result["data"] doc_id = data_result["doc_id"] + # Prefix app_id in the document id if app_id is not None to + # distinguish between different documents stored in the same + # elasticsearch or opensearch index + doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id metadatas = [] for data in data_records: content = data["content"] @@ -41,12 +46,12 @@ class BaseChunker(JSONSerializable): chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest() if idMap.get(chunk_id) is None: idMap[chunk_id] = True - ids.append(chunk_id) + chunk_ids.append(chunk_id) documents.append(chunk) metadatas.append(meta_data) return { "documents": documents, - "ids": ids, + "ids": chunk_ids, "metadatas": metadatas, "doc_id": doc_id, } diff --git a/embedchain/chunkers/images.py b/embedchain/chunkers/images.py index 22e4a7fd..853e027a 100644 --- a/embedchain/chunkers/images.py +++ b/embedchain/chunkers/images.py @@ -20,7 +20,7 @@ class ImagesChunker(BaseChunker): ) super().__init__(image_splitter) - def create_chunks(self, loader, src): + def create_chunks(self, loader, src, app_id=None): """ Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image @@ -35,6 +35,7 @@ class ImagesChunker(BaseChunker): data_result = loader.load_data(src) data_records = data_result["data"] doc_id = data_result["doc_id"] + doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id metadatas = [] for data in data_records: meta_data = data["meta_data"] diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index dc8fd578..83ae259b 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -268,14 +268,16 @@ class EmbedChain(JSONSerializable): elif chunker.data_type.value in [item.value for item in IndirectDataType]: # These types have a indirect source reference # As long as the reference is the same, they can be updated. - existing_embeddings_data = self.db.get( - where={ - "url": src, - }, + where = {"url": src} + if self.config.id is not None: + where.update({"app_id": self.config.id}) + + existing_embeddings = self.db.get( + where=where, limit=1, ) - if len(existing_embeddings_data.get("metadatas", [])) > 0: - return existing_embeddings_data["metadatas"][0]["doc_id"] + if len(existing_embeddings.get("metadatas", [])) > 0: + return existing_embeddings["metadatas"][0]["doc_id"] else: return None elif chunker.data_type.value in [item.value for item in SpecialDataType]: @@ -283,14 +285,16 @@ class EmbedChain(JSONSerializable): # Through custom logic, they can be attributed to a source and be updated. if chunker.data_type == DataType.QNA_PAIR: # QNA_PAIRs update the answer if the question already exists. - existing_embeddings_data = self.db.get( - where={ - "question": src[0], - }, + where = {"question": src[0]} + if self.config.id is not None: + where.update({"app_id": self.config.id}) + + existing_embeddings = self.db.get( + where=where, limit=1, ) - if len(existing_embeddings_data.get("metadatas", [])) > 0: - return existing_embeddings_data["metadatas"][0]["doc_id"] + if len(existing_embeddings.get("metadatas", [])) > 0: + return existing_embeddings["metadatas"][0]["doc_id"] else: return None else: @@ -326,9 +330,10 @@ class EmbedChain(JSONSerializable): :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks """ existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src) + app_id = self.config.id if self.config is not None else None # Create chunks - embeddings_data = chunker.create_chunks(loader, src) + embeddings_data = chunker.create_chunks(loader, src, app_id=app_id) # spread chunking results documents = embeddings_data["documents"] metadatas = embeddings_data["metadatas"] @@ -345,12 +350,11 @@ class EmbedChain(JSONSerializable): self.db.delete({"doc_id": existing_doc_id}) # get existing ids, and discard doc if any common id exist. - where = {"app_id": self.config.id} if self.config.id is not None else {} - # where={"url": src} - db_result = self.db.get( - ids=ids, - where=where, # optional filter - ) + where = {"url": src} + if self.config.id is not None: + where.update({"metadata.app_id": self.config.id}) + + db_result = self.db.get(ids=ids, where=where) # optional filter existing_ids = set(db_result["ids"]) if len(existing_ids): diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index cf07fcec..4121d53b 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -85,19 +85,29 @@ class OpenSearchDB(BaseVectorDB): :return: ids :type: Set[str] """ + query = {} if ids: - query = {"query": {"bool": {"must": [{"ids": {"values": ids}}]}}} + query["query"] = {"bool": {"must": [{"ids": {"values": ids}}]}} else: - query = {"query": {"bool": {"must": []}}} + query["query"] = {"bool": {"must": []}} + if "app_id" in where: app_id = where["app_id"] query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}}) # OpenSearch syntax is different from Elasticsearch - response = self.client.search(index=self._get_index(), body=query, _source=False, size=limit) + response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit) docs = response["hits"]["hits"] ids = [doc["_id"] for doc in docs] - return {"ids": set(ids)} + doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs] + + # Result is modified for compatibility with other vector databases + # TODO: Add method in vector database to return result in a standard format + result = {"ids": ids, "metadatas": []} + + for doc_id in doc_ids: + result["metadatas"].append({"doc_id": doc_id}) + return result def add( self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool @@ -204,6 +214,14 @@ class OpenSearchDB(BaseVectorDB): # delete index in Es self.client.indices.delete(index=self._get_index()) + def delete(self, where): + """Deletes a document from the OpenSearch index""" + if "doc_id" not in where: + raise ValueError("doc_id is required to delete a document") + + query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}} + self.client.delete_by_query(index=self._get_index(), body=query) + def _get_index(self) -> str: """Get the OpenSearch index for a collection diff --git a/tests/chunkers/test_image_chunker.py b/tests/chunkers/test_image_chunker.py index 299ff98d..eead2862 100644 --- a/tests/chunkers/test_image_chunker.py +++ b/tests/chunkers/test_image_chunker.py @@ -17,14 +17,15 @@ class TestImageChunker(unittest.TestCase): chunker.set_data_type(DataType.IMAGES) image_path = "./tmp/image.jpeg" - result = chunker.create_chunks(MockLoader(), image_path) + app_id = "app1" + result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id) expected_chunks = { - "doc_id": "123", + "doc_id": f"{app_id}--123", "documents": [image_path], "embeddings": ["embedding"], "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"], - "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}], + "metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}], } self.assertEqual(expected_chunks, result) @@ -37,14 +38,15 @@ class TestImageChunker(unittest.TestCase): chunker.set_data_type(DataType.IMAGES) image_path = "./tmp/image.jpeg" - result = chunker.create_chunks(MockLoader(), image_path) + app_id = "app1" + result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id) expected_chunks = { - "doc_id": "123", + "doc_id": f"{app_id}--123", "documents": [image_path], "embeddings": ["embedding"], "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"], - "metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}], + "metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}], } self.assertEqual(expected_chunks, result) diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index c6fabe98..9eb73133 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -1,31 +1,35 @@ # ruff: noqa: E501 -import unittest - from embedchain.chunkers.text import TextChunker from embedchain.config import ChunkerConfig from embedchain.models.data_type import DataType -class TestTextChunker(unittest.TestCase): - def test_chunks(self): +class TestTextChunker: + def test_chunks_without_app_id(self): """ Test the chunks generated by TextChunker. - # TODO: Not a very precise test. """ chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) chunker = TextChunker(config=chunker_config) text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." # Data type must be set manually in the test chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) - documents = result["documents"] + assert len(documents) > 5 - self.assertGreaterEqual(len(documents), 5) - - # Additional test cases can be added to cover different scenarios + def test_chunks_with_app_id(self): + """ + Test the chunks generated by TextChunker with app_id + """ + chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) + chunker = TextChunker(config=chunker_config) + text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." + chunker.set_data_type(DataType.TEXT) + result = chunker.create_chunks(MockLoader(), text) + documents = result["documents"] + assert len(documents) > 5 def test_big_chunksize(self): """ @@ -36,12 +40,9 @@ class TestTextChunker(unittest.TestCase): text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." # Data type must be set manually in the test chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) - documents = result["documents"] - - self.assertEqual(len(documents), 1) + assert len(documents) == 1 def test_small_chunksize(self): """ @@ -53,14 +54,9 @@ class TestTextChunker(unittest.TestCase): text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c""" # Data type must be set manually in the test chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) - documents = result["documents"] - - print(documents) - - self.assertEqual(len(documents), len(text)) + assert len(documents) == len(text) def test_word_count(self): chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) @@ -69,7 +65,7 @@ class TestTextChunker(unittest.TestCase): document = ["ab cd", "ef gh"] result = chunker.get_word_count(document) - self.assertEqual(result, 4) + assert result == 4 class MockLoader: