[OpenSearch] Add chunks specific to an app_id if present (#765)

This commit is contained in:
Deshraj Yadav
2023-10-04 15:46:22 -07:00
committed by GitHub
parent 352e71461d
commit 64a34cac32
6 changed files with 81 additions and 55 deletions

View File

@@ -10,7 +10,7 @@ class BaseChunker(JSONSerializable):
self.text_splitter = text_splitter
self.data_type = None
def create_chunks(self, loader, src):
def create_chunks(self, loader, src, app_id=None):
"""
Loads data and chunks it.
@@ -18,13 +18,18 @@ class BaseChunker(JSONSerializable):
the raw data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param app_id: App id used to generate the doc_id.
"""
documents = []
ids = []
chunk_ids = []
idMap = {}
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
# Prefix app_id in the document id if app_id is not None to
# distinguish between different documents stored in the same
# elasticsearch or opensearch index
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
metadatas = []
for data in data_records:
content = data["content"]
@@ -41,12 +46,12 @@ class BaseChunker(JSONSerializable):
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
if idMap.get(chunk_id) is None:
idMap[chunk_id] = True
ids.append(chunk_id)
chunk_ids.append(chunk_id)
documents.append(chunk)
metadatas.append(meta_data)
return {
"documents": documents,
"ids": ids,
"ids": chunk_ids,
"metadatas": metadatas,
"doc_id": doc_id,
}

View File

@@ -20,7 +20,7 @@ class ImagesChunker(BaseChunker):
)
super().__init__(image_splitter)
def create_chunks(self, loader, src):
def create_chunks(self, loader, src, app_id=None):
"""
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
@@ -35,6 +35,7 @@ class ImagesChunker(BaseChunker):
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
metadatas = []
for data in data_records:
meta_data = data["meta_data"]

View File

@@ -268,14 +268,16 @@ class EmbedChain(JSONSerializable):
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
# These types have a indirect source reference
# As long as the reference is the same, they can be updated.
existing_embeddings_data = self.db.get(
where={
"url": src,
},
where = {"url": src}
if self.config.id is not None:
where.update({"app_id": self.config.id})
existing_embeddings = self.db.get(
where=where,
limit=1,
)
if len(existing_embeddings_data.get("metadatas", [])) > 0:
return existing_embeddings_data["metadatas"][0]["doc_id"]
if len(existing_embeddings.get("metadatas", [])) > 0:
return existing_embeddings["metadatas"][0]["doc_id"]
else:
return None
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
@@ -283,14 +285,16 @@ class EmbedChain(JSONSerializable):
# Through custom logic, they can be attributed to a source and be updated.
if chunker.data_type == DataType.QNA_PAIR:
# QNA_PAIRs update the answer if the question already exists.
existing_embeddings_data = self.db.get(
where={
"question": src[0],
},
where = {"question": src[0]}
if self.config.id is not None:
where.update({"app_id": self.config.id})
existing_embeddings = self.db.get(
where=where,
limit=1,
)
if len(existing_embeddings_data.get("metadatas", [])) > 0:
return existing_embeddings_data["metadatas"][0]["doc_id"]
if len(existing_embeddings.get("metadatas", [])) > 0:
return existing_embeddings["metadatas"][0]["doc_id"]
else:
return None
else:
@@ -326,9 +330,10 @@ class EmbedChain(JSONSerializable):
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
app_id = self.config.id if self.config is not None else None
# Create chunks
embeddings_data = chunker.create_chunks(loader, src)
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
@@ -345,12 +350,11 @@ class EmbedChain(JSONSerializable):
self.db.delete({"doc_id": existing_doc_id})
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
db_result = self.db.get(
ids=ids,
where=where, # optional filter
)
where = {"url": src}
if self.config.id is not None:
where.update({"metadata.app_id": self.config.id})
db_result = self.db.get(ids=ids, where=where) # optional filter
existing_ids = set(db_result["ids"])
if len(existing_ids):

View File

@@ -85,19 +85,29 @@ class OpenSearchDB(BaseVectorDB):
:return: ids
:type: Set[str]
"""
query = {}
if ids:
query = {"query": {"bool": {"must": [{"ids": {"values": ids}}]}}}
query["query"] = {"bool": {"must": [{"ids": {"values": ids}}]}}
else:
query = {"query": {"bool": {"must": []}}}
query["query"] = {"bool": {"must": []}}
if "app_id" in where:
app_id = where["app_id"]
query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
# OpenSearch syntax is different from Elasticsearch
response = self.client.search(index=self._get_index(), body=query, _source=False, size=limit)
response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit)
docs = response["hits"]["hits"]
ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)}
doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs]
# Result is modified for compatibility with other vector databases
# TODO: Add method in vector database to return result in a standard format
result = {"ids": ids, "metadatas": []}
for doc_id in doc_ids:
result["metadatas"].append({"doc_id": doc_id})
return result
def add(
self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool
@@ -204,6 +214,14 @@ class OpenSearchDB(BaseVectorDB):
# delete index in Es
self.client.indices.delete(index=self._get_index())
def delete(self, where):
"""Deletes a document from the OpenSearch index"""
if "doc_id" not in where:
raise ValueError("doc_id is required to delete a document")
query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}}
self.client.delete_by_query(index=self._get_index(), body=query)
def _get_index(self) -> str:
"""Get the OpenSearch index for a collection

View File

@@ -17,14 +17,15 @@ class TestImageChunker(unittest.TestCase):
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
app_id = "app1"
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
expected_chunks = {
"doc_id": "123",
"doc_id": f"{app_id}--123",
"documents": [image_path],
"embeddings": ["embedding"],
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
}
self.assertEqual(expected_chunks, result)
@@ -37,14 +38,15 @@ class TestImageChunker(unittest.TestCase):
chunker.set_data_type(DataType.IMAGES)
image_path = "./tmp/image.jpeg"
result = chunker.create_chunks(MockLoader(), image_path)
app_id = "app1"
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
expected_chunks = {
"doc_id": "123",
"doc_id": f"{app_id}--123",
"documents": [image_path],
"embeddings": ["embedding"],
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
}
self.assertEqual(expected_chunks, result)

View File

@@ -1,31 +1,35 @@
# ruff: noqa: E501
import unittest
from embedchain.chunkers.text import TextChunker
from embedchain.config import ChunkerConfig
from embedchain.models.data_type import DataType
class TestTextChunker(unittest.TestCase):
def test_chunks(self):
class TestTextChunker:
def test_chunks_without_app_id(self):
"""
Test the chunks generated by TextChunker.
# TODO: Not a very precise test.
"""
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)
documents = result["documents"]
assert len(documents) > 5
self.assertGreaterEqual(len(documents), 5)
# Additional test cases can be added to cover different scenarios
def test_chunks_with_app_id(self):
"""
Test the chunks generated by TextChunker with app_id
"""
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)
documents = result["documents"]
assert len(documents) > 5
def test_big_chunksize(self):
"""
@@ -36,12 +40,9 @@ class TestTextChunker(unittest.TestCase):
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)
documents = result["documents"]
self.assertEqual(len(documents), 1)
assert len(documents) == 1
def test_small_chunksize(self):
"""
@@ -53,14 +54,9 @@ class TestTextChunker(unittest.TestCase):
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text)
documents = result["documents"]
print(documents)
self.assertEqual(len(documents), len(text))
assert len(documents) == len(text)
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
@@ -69,7 +65,7 @@ class TestTextChunker(unittest.TestCase):
document = ["ab cd", "ef gh"]
result = chunker.get_word_count(document)
self.assertEqual(result, 4)
assert result == 4
class MockLoader: