[OpenSearch] Add chunks specific to an app_id if present (#765)
This commit is contained in:
@@ -10,7 +10,7 @@ class BaseChunker(JSONSerializable):
|
|||||||
self.text_splitter = text_splitter
|
self.text_splitter = text_splitter
|
||||||
self.data_type = None
|
self.data_type = None
|
||||||
|
|
||||||
def create_chunks(self, loader, src):
|
def create_chunks(self, loader, src, app_id=None):
|
||||||
"""
|
"""
|
||||||
Loads data and chunks it.
|
Loads data and chunks it.
|
||||||
|
|
||||||
@@ -18,13 +18,18 @@ class BaseChunker(JSONSerializable):
|
|||||||
the raw data.
|
the raw data.
|
||||||
:param src: The data to be handled by the loader. Can be a URL for
|
:param src: The data to be handled by the loader. Can be a URL for
|
||||||
remote sources or local content for local loaders.
|
remote sources or local content for local loaders.
|
||||||
|
:param app_id: App id used to generate the doc_id.
|
||||||
"""
|
"""
|
||||||
documents = []
|
documents = []
|
||||||
ids = []
|
chunk_ids = []
|
||||||
idMap = {}
|
idMap = {}
|
||||||
data_result = loader.load_data(src)
|
data_result = loader.load_data(src)
|
||||||
data_records = data_result["data"]
|
data_records = data_result["data"]
|
||||||
doc_id = data_result["doc_id"]
|
doc_id = data_result["doc_id"]
|
||||||
|
# Prefix app_id in the document id if app_id is not None to
|
||||||
|
# distinguish between different documents stored in the same
|
||||||
|
# elasticsearch or opensearch index
|
||||||
|
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
|
||||||
metadatas = []
|
metadatas = []
|
||||||
for data in data_records:
|
for data in data_records:
|
||||||
content = data["content"]
|
content = data["content"]
|
||||||
@@ -41,12 +46,12 @@ class BaseChunker(JSONSerializable):
|
|||||||
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
||||||
if idMap.get(chunk_id) is None:
|
if idMap.get(chunk_id) is None:
|
||||||
idMap[chunk_id] = True
|
idMap[chunk_id] = True
|
||||||
ids.append(chunk_id)
|
chunk_ids.append(chunk_id)
|
||||||
documents.append(chunk)
|
documents.append(chunk)
|
||||||
metadatas.append(meta_data)
|
metadatas.append(meta_data)
|
||||||
return {
|
return {
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
"ids": ids,
|
"ids": chunk_ids,
|
||||||
"metadatas": metadatas,
|
"metadatas": metadatas,
|
||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class ImagesChunker(BaseChunker):
|
|||||||
)
|
)
|
||||||
super().__init__(image_splitter)
|
super().__init__(image_splitter)
|
||||||
|
|
||||||
def create_chunks(self, loader, src):
|
def create_chunks(self, loader, src, app_id=None):
|
||||||
"""
|
"""
|
||||||
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
||||||
|
|
||||||
@@ -35,6 +35,7 @@ class ImagesChunker(BaseChunker):
|
|||||||
data_result = loader.load_data(src)
|
data_result = loader.load_data(src)
|
||||||
data_records = data_result["data"]
|
data_records = data_result["data"]
|
||||||
doc_id = data_result["doc_id"]
|
doc_id = data_result["doc_id"]
|
||||||
|
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
|
||||||
metadatas = []
|
metadatas = []
|
||||||
for data in data_records:
|
for data in data_records:
|
||||||
meta_data = data["meta_data"]
|
meta_data = data["meta_data"]
|
||||||
|
|||||||
@@ -268,14 +268,16 @@ class EmbedChain(JSONSerializable):
|
|||||||
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
|
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
|
||||||
# These types have a indirect source reference
|
# These types have a indirect source reference
|
||||||
# As long as the reference is the same, they can be updated.
|
# As long as the reference is the same, they can be updated.
|
||||||
existing_embeddings_data = self.db.get(
|
where = {"url": src}
|
||||||
where={
|
if self.config.id is not None:
|
||||||
"url": src,
|
where.update({"app_id": self.config.id})
|
||||||
},
|
|
||||||
|
existing_embeddings = self.db.get(
|
||||||
|
where=where,
|
||||||
limit=1,
|
limit=1,
|
||||||
)
|
)
|
||||||
if len(existing_embeddings_data.get("metadatas", [])) > 0:
|
if len(existing_embeddings.get("metadatas", [])) > 0:
|
||||||
return existing_embeddings_data["metadatas"][0]["doc_id"]
|
return existing_embeddings["metadatas"][0]["doc_id"]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
|
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
|
||||||
@@ -283,14 +285,16 @@ class EmbedChain(JSONSerializable):
|
|||||||
# Through custom logic, they can be attributed to a source and be updated.
|
# Through custom logic, they can be attributed to a source and be updated.
|
||||||
if chunker.data_type == DataType.QNA_PAIR:
|
if chunker.data_type == DataType.QNA_PAIR:
|
||||||
# QNA_PAIRs update the answer if the question already exists.
|
# QNA_PAIRs update the answer if the question already exists.
|
||||||
existing_embeddings_data = self.db.get(
|
where = {"question": src[0]}
|
||||||
where={
|
if self.config.id is not None:
|
||||||
"question": src[0],
|
where.update({"app_id": self.config.id})
|
||||||
},
|
|
||||||
|
existing_embeddings = self.db.get(
|
||||||
|
where=where,
|
||||||
limit=1,
|
limit=1,
|
||||||
)
|
)
|
||||||
if len(existing_embeddings_data.get("metadatas", [])) > 0:
|
if len(existing_embeddings.get("metadatas", [])) > 0:
|
||||||
return existing_embeddings_data["metadatas"][0]["doc_id"]
|
return existing_embeddings["metadatas"][0]["doc_id"]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
@@ -326,9 +330,10 @@ class EmbedChain(JSONSerializable):
|
|||||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||||
"""
|
"""
|
||||||
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
|
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
|
||||||
|
app_id = self.config.id if self.config is not None else None
|
||||||
|
|
||||||
# Create chunks
|
# Create chunks
|
||||||
embeddings_data = chunker.create_chunks(loader, src)
|
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
|
||||||
# spread chunking results
|
# spread chunking results
|
||||||
documents = embeddings_data["documents"]
|
documents = embeddings_data["documents"]
|
||||||
metadatas = embeddings_data["metadatas"]
|
metadatas = embeddings_data["metadatas"]
|
||||||
@@ -345,12 +350,11 @@ class EmbedChain(JSONSerializable):
|
|||||||
self.db.delete({"doc_id": existing_doc_id})
|
self.db.delete({"doc_id": existing_doc_id})
|
||||||
|
|
||||||
# get existing ids, and discard doc if any common id exist.
|
# get existing ids, and discard doc if any common id exist.
|
||||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
where = {"url": src}
|
||||||
# where={"url": src}
|
if self.config.id is not None:
|
||||||
db_result = self.db.get(
|
where.update({"metadata.app_id": self.config.id})
|
||||||
ids=ids,
|
|
||||||
where=where, # optional filter
|
db_result = self.db.get(ids=ids, where=where) # optional filter
|
||||||
)
|
|
||||||
existing_ids = set(db_result["ids"])
|
existing_ids = set(db_result["ids"])
|
||||||
|
|
||||||
if len(existing_ids):
|
if len(existing_ids):
|
||||||
|
|||||||
@@ -85,19 +85,29 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
:return: ids
|
:return: ids
|
||||||
:type: Set[str]
|
:type: Set[str]
|
||||||
"""
|
"""
|
||||||
|
query = {}
|
||||||
if ids:
|
if ids:
|
||||||
query = {"query": {"bool": {"must": [{"ids": {"values": ids}}]}}}
|
query["query"] = {"bool": {"must": [{"ids": {"values": ids}}]}}
|
||||||
else:
|
else:
|
||||||
query = {"query": {"bool": {"must": []}}}
|
query["query"] = {"bool": {"must": []}}
|
||||||
|
|
||||||
if "app_id" in where:
|
if "app_id" in where:
|
||||||
app_id = where["app_id"]
|
app_id = where["app_id"]
|
||||||
query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
|
query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
|
||||||
|
|
||||||
# OpenSearch syntax is different from Elasticsearch
|
# OpenSearch syntax is different from Elasticsearch
|
||||||
response = self.client.search(index=self._get_index(), body=query, _source=False, size=limit)
|
response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit)
|
||||||
docs = response["hits"]["hits"]
|
docs = response["hits"]["hits"]
|
||||||
ids = [doc["_id"] for doc in docs]
|
ids = [doc["_id"] for doc in docs]
|
||||||
return {"ids": set(ids)}
|
doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs]
|
||||||
|
|
||||||
|
# Result is modified for compatibility with other vector databases
|
||||||
|
# TODO: Add method in vector database to return result in a standard format
|
||||||
|
result = {"ids": ids, "metadatas": []}
|
||||||
|
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
result["metadatas"].append({"doc_id": doc_id})
|
||||||
|
return result
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool
|
self, embeddings: List[str], documents: List[str], metadatas: List[object], ids: List[str], skip_embedding: bool
|
||||||
@@ -204,6 +214,14 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
# delete index in Es
|
# delete index in Es
|
||||||
self.client.indices.delete(index=self._get_index())
|
self.client.indices.delete(index=self._get_index())
|
||||||
|
|
||||||
|
def delete(self, where):
|
||||||
|
"""Deletes a document from the OpenSearch index"""
|
||||||
|
if "doc_id" not in where:
|
||||||
|
raise ValueError("doc_id is required to delete a document")
|
||||||
|
|
||||||
|
query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}}
|
||||||
|
self.client.delete_by_query(index=self._get_index(), body=query)
|
||||||
|
|
||||||
def _get_index(self) -> str:
|
def _get_index(self) -> str:
|
||||||
"""Get the OpenSearch index for a collection
|
"""Get the OpenSearch index for a collection
|
||||||
|
|
||||||
|
|||||||
@@ -17,14 +17,15 @@ class TestImageChunker(unittest.TestCase):
|
|||||||
chunker.set_data_type(DataType.IMAGES)
|
chunker.set_data_type(DataType.IMAGES)
|
||||||
|
|
||||||
image_path = "./tmp/image.jpeg"
|
image_path = "./tmp/image.jpeg"
|
||||||
result = chunker.create_chunks(MockLoader(), image_path)
|
app_id = "app1"
|
||||||
|
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
|
||||||
|
|
||||||
expected_chunks = {
|
expected_chunks = {
|
||||||
"doc_id": "123",
|
"doc_id": f"{app_id}--123",
|
||||||
"documents": [image_path],
|
"documents": [image_path],
|
||||||
"embeddings": ["embedding"],
|
"embeddings": ["embedding"],
|
||||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||||
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
|
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
|
||||||
}
|
}
|
||||||
self.assertEqual(expected_chunks, result)
|
self.assertEqual(expected_chunks, result)
|
||||||
|
|
||||||
@@ -37,14 +38,15 @@ class TestImageChunker(unittest.TestCase):
|
|||||||
chunker.set_data_type(DataType.IMAGES)
|
chunker.set_data_type(DataType.IMAGES)
|
||||||
|
|
||||||
image_path = "./tmp/image.jpeg"
|
image_path = "./tmp/image.jpeg"
|
||||||
result = chunker.create_chunks(MockLoader(), image_path)
|
app_id = "app1"
|
||||||
|
result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
|
||||||
|
|
||||||
expected_chunks = {
|
expected_chunks = {
|
||||||
"doc_id": "123",
|
"doc_id": f"{app_id}--123",
|
||||||
"documents": [image_path],
|
"documents": [image_path],
|
||||||
"embeddings": ["embedding"],
|
"embeddings": ["embedding"],
|
||||||
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
"ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
|
||||||
"metadatas": [{"data_type": "images", "doc_id": "123", "url": "none"}],
|
"metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
|
||||||
}
|
}
|
||||||
self.assertEqual(expected_chunks, result)
|
self.assertEqual(expected_chunks, result)
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,35 @@
|
|||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from embedchain.chunkers.text import TextChunker
|
from embedchain.chunkers.text import TextChunker
|
||||||
from embedchain.config import ChunkerConfig
|
from embedchain.config import ChunkerConfig
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
|
|
||||||
class TestTextChunker(unittest.TestCase):
|
class TestTextChunker:
|
||||||
def test_chunks(self):
|
def test_chunks_without_app_id(self):
|
||||||
"""
|
"""
|
||||||
Test the chunks generated by TextChunker.
|
Test the chunks generated by TextChunker.
|
||||||
# TODO: Not a very precise test.
|
|
||||||
"""
|
"""
|
||||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
|
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text)
|
||||||
|
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
|
assert len(documents) > 5
|
||||||
|
|
||||||
self.assertGreaterEqual(len(documents), 5)
|
def test_chunks_with_app_id(self):
|
||||||
|
"""
|
||||||
# Additional test cases can be added to cover different scenarios
|
Test the chunks generated by TextChunker with app_id
|
||||||
|
"""
|
||||||
|
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
||||||
|
chunker = TextChunker(config=chunker_config)
|
||||||
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
|
chunker.set_data_type(DataType.TEXT)
|
||||||
|
result = chunker.create_chunks(MockLoader(), text)
|
||||||
|
documents = result["documents"]
|
||||||
|
assert len(documents) > 5
|
||||||
|
|
||||||
def test_big_chunksize(self):
|
def test_big_chunksize(self):
|
||||||
"""
|
"""
|
||||||
@@ -36,12 +40,9 @@ class TestTextChunker(unittest.TestCase):
|
|||||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
|
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text)
|
||||||
|
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
|
assert len(documents) == 1
|
||||||
self.assertEqual(len(documents), 1)
|
|
||||||
|
|
||||||
def test_small_chunksize(self):
|
def test_small_chunksize(self):
|
||||||
"""
|
"""
|
||||||
@@ -53,14 +54,9 @@ class TestTextChunker(unittest.TestCase):
|
|||||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
|
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
|
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text)
|
||||||
|
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
|
assert len(documents) == len(text)
|
||||||
print(documents)
|
|
||||||
|
|
||||||
self.assertEqual(len(documents), len(text))
|
|
||||||
|
|
||||||
def test_word_count(self):
|
def test_word_count(self):
|
||||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
||||||
@@ -69,7 +65,7 @@ class TestTextChunker(unittest.TestCase):
|
|||||||
|
|
||||||
document = ["ab cd", "ef gh"]
|
document = ["ab cd", "ef gh"]
|
||||||
result = chunker.get_word_count(document)
|
result = chunker.get_word_count(document)
|
||||||
self.assertEqual(result, 4)
|
assert result == 4
|
||||||
|
|
||||||
|
|
||||||
class MockLoader:
|
class MockLoader:
|
||||||
|
|||||||
Reference in New Issue
Block a user