[OpenSearch] Add chunks specific to an app_id if present (#765)
This commit is contained in:
@@ -10,7 +10,7 @@ class BaseChunker(JSONSerializable):
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src):
|
||||
def create_chunks(self, loader, src, app_id=None):
|
||||
"""
|
||||
Loads data and chunks it.
|
||||
|
||||
@@ -18,13 +18,18 @@ class BaseChunker(JSONSerializable):
|
||||
the raw data.
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param app_id: App id used to generate the doc_id.
|
||||
"""
|
||||
documents = []
|
||||
ids = []
|
||||
chunk_ids = []
|
||||
idMap = {}
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
# Prefix app_id in the document id if app_id is not None to
|
||||
# distinguish between different documents stored in the same
|
||||
# elasticsearch or opensearch index
|
||||
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
|
||||
metadatas = []
|
||||
for data in data_records:
|
||||
content = data["content"]
|
||||
@@ -41,12 +46,12 @@ class BaseChunker(JSONSerializable):
|
||||
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
||||
if idMap.get(chunk_id) is None:
|
||||
idMap[chunk_id] = True
|
||||
ids.append(chunk_id)
|
||||
chunk_ids.append(chunk_id)
|
||||
documents.append(chunk)
|
||||
metadatas.append(meta_data)
|
||||
return {
|
||||
"documents": documents,
|
||||
"ids": ids,
|
||||
"ids": chunk_ids,
|
||||
"metadatas": metadatas,
|
||||
"doc_id": doc_id,
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ class ImagesChunker(BaseChunker):
|
||||
)
|
||||
super().__init__(image_splitter)
|
||||
|
||||
def create_chunks(self, loader, src):
|
||||
def create_chunks(self, loader, src, app_id=None):
|
||||
"""
|
||||
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
||||
|
||||
@@ -35,6 +35,7 @@ class ImagesChunker(BaseChunker):
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
|
||||
metadatas = []
|
||||
for data in data_records:
|
||||
meta_data = data["meta_data"]
|
||||
|
||||
@@ -268,14 +268,16 @@ class EmbedChain(JSONSerializable):
|
||||
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
|
||||
# These types have a indirect source reference
|
||||
# As long as the reference is the same, they can be updated.
|
||||
existing_embeddings_data = self.db.get(
|
||||
where={
|
||||
"url": src,
|
||||
},
|
||||
where = {"url": src}
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
existing_embeddings = self.db.get(
|
||||
where=where,
|
||||
limit=1,
|
||||
)
|
||||
if len(existing_embeddings_data.get("metadatas", [])) > 0:
|
||||
return existing_embeddings_data["metadatas"][0]["doc_id"]
|
||||
if len(existing_embeddings.get("metadatas", [])) > 0:
|
||||
return existing_embeddings["metadatas"][0]["doc_id"]
|
||||
else:
|
||||
return None
|
||||
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
|
||||
@@ -283,14 +285,16 @@ class EmbedChain(JSONSerializable):
|
||||
# Through custom logic, they can be attributed to a source and be updated.
|
||||
if chunker.data_type == DataType.QNA_PAIR:
|
||||
# QNA_PAIRs update the answer if the question already exists.
|
||||
existing_embeddings_data = self.db.get(
|
||||
where={
|
||||
"question": src[0],
|
||||
},
|
||||
where = {"question": src[0]}
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
existing_embeddings = self.db.get(
|
||||
where=where,
|
||||
limit=1,
|
||||
)
|
||||
if len(existing_embeddings_data.get("metadatas", [])) > 0:
|
||||
return existing_embeddings_data["metadatas"][0]["doc_id"]
|
||||
if len(existing_embeddings.get("metadatas", [])) > 0:
|
||||
return existing_embeddings["metadatas"][0]["doc_id"]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
@@ -326,9 +330,10 @@ class EmbedChain(JSONSerializable):
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
"""
|
||||
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
|
||||
app_id = self.config.id if self.config is not None else None
|
||||
|
||||
# Create chunks
|
||||
embeddings_data = chunker.create_chunks(loader, src)
|
||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
@@ -345,12 +350,11 @@ class EmbedChain(JSONSerializable):
|
||||
self.db.delete({"doc_id": existing_doc_id})
|
||||
|
||||
# get existing ids, and discard doc if any common id exist.
|
||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||
# where={"url": src}
|
||||
db_result = self.db.get(
|
||||
ids=ids,
|
||||
where=where, # optional filter
|
||||
)
|
||||
where = {"url": src}
|
||||
if self.config.id is not None:
|
||||
where.update({"metadata.app_id": self.config.id})
|
||||
|
||||
db_result = self.db.get(ids=ids, where=where) # optional filter
|
||||
existing_ids = set(db_result["ids"])
|
||||
|
||||
if len(existing_ids):
|
||||
|
||||
@@ -85,19 +85,29 @@ class OpenSearchDB(BaseVectorDB):
|
||||
:return: ids
|
||||
:type: Set[str]
|
||||
"""
|
||||
query = {}
|
||||
if ids:
|
||||
query = {"query": {"bool": {"must": [{"ids": {"values": ids}}]}}}
|
||||
query["query"] = {"bool": {"must": [{"ids": {"values": ids}}]}}
|
||||
else:
|
||||
query = {"query": {"bool": {"must": []}}}
|
||||
query["query"] = {"bool": {"must": []}}
|
||||
|
||||
if "app_id" in where:
|
||||
app_id = where["app_id"]
|
||||
query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
|
||||
|
||||
# OpenSearch syntax is different from Elasticsearch
|
||||
response = self.client.search(index=self._get_index(), body=query, _source=False, size=limit)
|
||||
response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit)
|
||||
docs = response["hits"]["hits"]
|
||||
ids = [doc["_id"] for doc in docs]
|
||||
return {"ids": set(ids)}
|
||||
doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs]
|
||||
|
||||
# Result is modified for compatibility with other vector databases
|
||||
# TODO: Add method in vector database to return result in a standard format
|
||||
result = {"ids": ids, "metadatas": []}
|
||||
|
||||
for doc_id in doc_ids:
|
||||
result["metadatas"].append({"doc_id": doc_id})
|
||||
return result
|
||||
|
||||
def add(
|
||||
self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool
|
||||
@@ -204,6 +214,14 @@ class OpenSearchDB(BaseVectorDB):
|
||||
# delete index in Es
|
||||
self.client.indices.delete(index=self._get_index())
|
||||
|
||||
def delete(self, where):
|
||||
"""Deletes a document from the OpenSearch index"""
|
||||
if "doc_id" not in where:
|
||||
raise ValueError("doc_id is required to delete a document")
|
||||
|
||||
query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}}
|
||||
self.client.delete_by_query(index=self._get_index(), body=query)
|
||||
|
||||
def _get_index(self) -> str:
|
||||
"""Get the OpenSearch index for a collection
|
||||
|
||||
|
||||
@@ -17,14 +17,15 @@ class TestImageChunker(unittest.TestCase):
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
app_id = "app1"
|
||||
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
|
||||
|
||||
expected_chunks = {
|
||||
"doc_id": "123",
|
||||
"doc_id": f"{app_id}--123",
|
||||
"documents": [image_path],
|
||||
"embeddings": ["embedding"],
|
||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
|
||||
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
|
||||
}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
@@ -37,14 +38,15 @@ class TestImageChunker(unittest.TestCase):
|
||||
chunker.set_data_type(DataType.IMAGES)
|
||||
|
||||
image_path = "./tmp/image.jpeg"
|
||||
result = chunker.create_chunks(MockLoader(), image_path)
|
||||
app_id = "app1"
|
||||
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
|
||||
|
||||
expected_chunks = {
|
||||
"doc_id": "123",
|
||||
"doc_id": f"{app_id}--123",
|
||||
"documents": [image_path],
|
||||
"embeddings": ["embedding"],
|
||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
|
||||
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
|
||||
}
|
||||
self.assertEqual(expected_chunks, result)
|
||||
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import unittest
|
||||
|
||||
from embedchain.chunkers.text import TextChunker
|
||||
from embedchain.config import ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class TestTextChunker(unittest.TestCase):
|
||||
def test_chunks(self):
|
||||
class TestTextChunker:
|
||||
def test_chunks_without_app_id(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker.
|
||||
# TODO: Not a very precise test.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
|
||||
result = chunker.create_chunks(MockLoader(), text)
|
||||
|
||||
documents = result["documents"]
|
||||
assert len(documents) > 5
|
||||
|
||||
self.assertGreaterEqual(len(documents), 5)
|
||||
|
||||
# Additional test cases can be added to cover different scenarios
|
||||
def test_chunks_with_app_id(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker with app_id
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
result = chunker.create_chunks(MockLoader(), text)
|
||||
documents = result["documents"]
|
||||
assert len(documents) > 5
|
||||
|
||||
def test_big_chunksize(self):
|
||||
"""
|
||||
@@ -36,12 +40,9 @@ class TestTextChunker(unittest.TestCase):
|
||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
|
||||
result = chunker.create_chunks(MockLoader(), text)
|
||||
|
||||
documents = result["documents"]
|
||||
|
||||
self.assertEqual(len(documents), 1)
|
||||
assert len(documents) == 1
|
||||
|
||||
def test_small_chunksize(self):
|
||||
"""
|
||||
@@ -53,14 +54,9 @@ class TestTextChunker(unittest.TestCase):
|
||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
|
||||
result = chunker.create_chunks(MockLoader(), text)
|
||||
|
||||
documents = result["documents"]
|
||||
|
||||
print(documents)
|
||||
|
||||
self.assertEqual(len(documents), len(text))
|
||||
assert len(documents) == len(text)
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||
@@ -69,7 +65,7 @@ class TestTextChunker(unittest.TestCase):
|
||||
|
||||
document = ["ab cd", "ef gh"]
|
||||
result = chunker.get_word_count(document)
|
||||
self.assertEqual(result, 4)
|
||||
assert result == 4
|
||||
|
||||
|
||||
class MockLoader:
|
||||
|
||||
Reference in New Issue
Block a user