[Bug fix] Fix embedding issue for opensearch and some other vector databases (#1163)
This commit is contained in:
@@ -27,7 +27,7 @@ class BaseChunker(JSONSerializable):
|
|||||||
chunk_ids = []
|
chunk_ids = []
|
||||||
id_map = {}
|
id_map = {}
|
||||||
min_chunk_size = config.min_chunk_size if config is not None else 1
|
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||||
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
|
logging.info(f"Skipping chunks smaller than {min_chunk_size} characters")
|
||||||
data_result = loader.load_data(src)
|
data_result = loader.load_data(src)
|
||||||
data_records = data_result["data"]
|
data_records = data_result["data"]
|
||||||
doc_id = data_result["doc_id"]
|
doc_id = data_result["doc_id"]
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
metadatas = embeddings_data["metadatas"]
|
metadatas = embeddings_data["metadatas"]
|
||||||
ids = embeddings_data["ids"]
|
ids = embeddings_data["ids"]
|
||||||
new_doc_id = embeddings_data["doc_id"]
|
new_doc_id = embeddings_data["doc_id"]
|
||||||
embeddings = embeddings_data.get("embeddings")
|
|
||||||
if existing_doc_id and existing_doc_id == new_doc_id:
|
if existing_doc_id and existing_doc_id == new_doc_id:
|
||||||
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
||||||
return [], [], [], 0
|
return [], [], [], 0
|
||||||
@@ -433,13 +433,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
# Count before, to calculate a delta in the end.
|
# Count before, to calculate a delta in the end.
|
||||||
chunks_before_addition = self.db.count()
|
chunks_before_addition = self.db.count()
|
||||||
|
|
||||||
self.db.add(
|
self.db.add(documents=documents, metadatas=metadatas, ids=ids, **kwargs)
|
||||||
embeddings=embeddings,
|
|
||||||
documents=documents,
|
|
||||||
metadatas=metadatas,
|
|
||||||
ids=ids,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
count_new_chunks = self.db.count() - chunks_before_addition
|
count_new_chunks = self.db.count() - chunks_before_addition
|
||||||
|
|
||||||
print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")
|
print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")
|
||||||
|
|||||||
@@ -129,17 +129,13 @@ class ChromaDB(BaseVectorDB):
|
|||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
|
||||||
documents: list[str],
|
documents: list[str],
|
||||||
metadatas: list[object],
|
metadatas: list[object],
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
**kwargs: Optional[dict[str, Any]],
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Add vectors to chroma database
|
Add vectors to chroma database
|
||||||
|
|
||||||
:param embeddings: list of embeddings to add
|
|
||||||
:type embeddings: list[list[str]]
|
|
||||||
:param documents: Documents
|
:param documents: Documents
|
||||||
:type documents: list[str]
|
:type documents: list[str]
|
||||||
:param metadatas: Metadatas
|
:param metadatas: Metadatas
|
||||||
|
|||||||
@@ -110,7 +110,6 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
|
||||||
documents: list[str],
|
documents: list[str],
|
||||||
metadatas: list[object],
|
metadatas: list[object],
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
@@ -118,8 +117,6 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
add data in vector database
|
add data in vector database
|
||||||
:param embeddings: list of embeddings to add
|
|
||||||
:type embeddings: list[list[str]]
|
|
||||||
:param documents: list of texts to add
|
:param documents: list of texts to add
|
||||||
:type documents: list[str]
|
:type documents: list[str]
|
||||||
:param metadatas: list of metadata associated with docs
|
:param metadatas: list of metadata associated with docs
|
||||||
|
|||||||
@@ -114,22 +114,10 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
result["metadatas"].append({"doc_id": doc_id})
|
result["metadatas"].append({"doc_id": doc_id})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def add(
|
def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
|
||||||
self,
|
"""Adds documents to the opensearch index"""
|
||||||
embeddings: list[list[str]],
|
|
||||||
documents: list[str],
|
|
||||||
metadatas: list[object],
|
|
||||||
ids: list[str],
|
|
||||||
**kwargs: Optional[dict[str, any]],
|
|
||||||
):
|
|
||||||
"""Add data in vector database.
|
|
||||||
|
|
||||||
Args:
|
embeddings = self.embedder.embedding_fn(documents)
|
||||||
embeddings (list[list[str]]): list of embeddings to add.
|
|
||||||
documents (list[str]): list of texts to add.
|
|
||||||
metadatas (list[object]): list of metadata associated with docs.
|
|
||||||
ids (list[str]): IDs of docs.
|
|
||||||
"""
|
|
||||||
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
|
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
|
||||||
batch_end = batch_start + self.BATCH_SIZE
|
batch_end = batch_start + self.BATCH_SIZE
|
||||||
batch_documents = documents[batch_start:batch_end]
|
batch_documents = documents[batch_start:batch_end]
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ class PineconeDB(BaseVectorDB):
|
|||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
|
||||||
documents: list[str],
|
documents: list[str],
|
||||||
metadatas: list[object],
|
metadatas: list[object],
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
|
|||||||
@@ -122,15 +122,12 @@ class QdrantDB(BaseVectorDB):
|
|||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
|
||||||
documents: list[str],
|
documents: list[str],
|
||||||
metadatas: list[object],
|
metadatas: list[object],
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
**kwargs: Optional[dict[str, any]],
|
**kwargs: Optional[dict[str, any]],
|
||||||
):
|
):
|
||||||
"""add data in vector database
|
"""add data in vector database
|
||||||
:param embeddings: list of embeddings for the corresponding documents to be added
|
|
||||||
:type documents: list[list[float]]
|
|
||||||
:param documents: list of texts to add
|
:param documents: list of texts to add
|
||||||
:type documents: list[str]
|
:type documents: list[str]
|
||||||
:param metadatas: list of metadata associated with docs
|
:param metadatas: list of metadata associated with docs
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
@@ -151,17 +151,8 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
|
|
||||||
return {"ids": existing_ids}
|
return {"ids": existing_ids}
|
||||||
|
|
||||||
def add(
|
def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
|
||||||
self,
|
|
||||||
embeddings: list[list[float]],
|
|
||||||
documents: list[str],
|
|
||||||
metadatas: list[object],
|
|
||||||
ids: list[str],
|
|
||||||
**kwargs: Optional[dict[str, any]],
|
|
||||||
):
|
|
||||||
"""add data in vector database
|
"""add data in vector database
|
||||||
:param embeddings: list of embeddings for the corresponding documents to be added
|
|
||||||
:type documents: list[list[float]]
|
|
||||||
:param documents: list of texts to add
|
:param documents: list of texts to add
|
||||||
:type documents: list[str]
|
:type documents: list[str]
|
||||||
:param metadatas: list of metadata associated with docs
|
:param metadatas: list of metadata associated with docs
|
||||||
@@ -191,12 +182,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False
|
||||||
input_query: list[str],
|
|
||||||
n_results: int,
|
|
||||||
where: dict[str, any],
|
|
||||||
citations: bool = False,
|
|
||||||
**kwargs: Optional[dict[str, Any]],
|
|
||||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
|
|||||||
@@ -108,7 +108,6 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
|
||||||
documents: list[str],
|
documents: list[str],
|
||||||
metadatas: list[object],
|
metadatas: list[object],
|
||||||
ids: list[str],
|
ids: list[str],
|
||||||
|
|||||||
@@ -28,14 +28,13 @@ class TestEsDB(unittest.TestCase):
|
|||||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||||
self.assertEqual(self.db.client, mock_client.return_value)
|
self.assertEqual(self.db.client, mock_client.return_value)
|
||||||
|
|
||||||
# Create some dummy data.
|
# Create some dummy data
|
||||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
|
||||||
documents = ["This is a document.", "This is another document."]
|
documents = ["This is a document.", "This is another document."]
|
||||||
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||||
ids = ["doc_1", "doc_2"]
|
ids = ["doc_1", "doc_2"]
|
||||||
|
|
||||||
# Add the data to the database.
|
# Add the data to the database.
|
||||||
self.db.add(embeddings, documents, metadatas, ids)
|
self.db.add(documents, metadatas, ids)
|
||||||
|
|
||||||
search_response = {
|
search_response = {
|
||||||
"hits": {
|
"hits": {
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ class TestPinecone:
|
|||||||
embedding_function = mock.Mock()
|
embedding_function = mock.Mock()
|
||||||
base_embedder = BaseEmbedder()
|
base_embedder = BaseEmbedder()
|
||||||
base_embedder.set_embedding_fn(embedding_function)
|
base_embedder.set_embedding_fn(embedding_function)
|
||||||
vectors = [[0, 0, 0], [1, 1, 1]]
|
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
|
||||||
embedding_function.return_value = vectors
|
|
||||||
# Create a PineconeDb instance
|
# Create a PineconeDb instance
|
||||||
db = PineconeDB()
|
db = PineconeDB()
|
||||||
app_config = AppConfig(collect_metrics=False)
|
app_config = AppConfig(collect_metrics=False)
|
||||||
@@ -54,7 +54,7 @@ class TestPinecone:
|
|||||||
documents = ["This is a document.", "This is another document."]
|
documents = ["This is a document.", "This is another document."]
|
||||||
metadatas = [{}, {}]
|
metadatas = [{}, {}]
|
||||||
ids = ["doc1", "doc2"]
|
ids = ["doc1", "doc2"]
|
||||||
db.add(vectors, documents, metadatas, ids)
|
db.add(documents, metadatas, ids)
|
||||||
|
|
||||||
expected_pinecone_upsert_args = [
|
expected_pinecone_upsert_args = [
|
||||||
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
||||||
|
|||||||
@@ -75,11 +75,10 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
app_config = AppConfig(collect_metrics=False)
|
app_config = AppConfig(collect_metrics=False)
|
||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
|
|
||||||
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
|
||||||
documents = ["This is a test document.", "This is another test document."]
|
documents = ["This is a test document.", "This is another test document."]
|
||||||
metadatas = [{}, {}]
|
metadatas = [{}, {}]
|
||||||
ids = ["123", "456"]
|
ids = ["123", "456"]
|
||||||
db.add(embeddings, documents, metadatas, ids)
|
db.add(documents, metadatas, ids)
|
||||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||||
collection_name="embedchain-store-1526",
|
collection_name="embedchain-store-1526",
|
||||||
points=Batch(
|
points=Batch(
|
||||||
@@ -96,7 +95,7 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
"metadata": {"text": "This is another test document."},
|
"metadata": {"text": "This is another test document."},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
vectors=embeddings,
|
vectors=[[1, 2, 3], [4, 5, 6]],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
weaviate_client_schema_mock.exists.return_value = False
|
weaviate_client_schema_mock.exists.return_value = False
|
||||||
# Set the embedder
|
# Set the embedder
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -40,7 +40,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
expected_class_obj = {
|
expected_class_obj = {
|
||||||
"classes": [
|
"classes": [
|
||||||
{
|
{
|
||||||
"class": "Embedchain_store_1526",
|
"class": "Embedchain_store_1536",
|
||||||
"vectorizer": "none",
|
"vectorizer": "none",
|
||||||
"properties": [
|
"properties": [
|
||||||
{
|
{
|
||||||
@@ -53,12 +53,12 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "metadata",
|
"name": "metadata",
|
||||||
"dataType": ["Embedchain_store_1526_metadata"],
|
"dataType": ["Embedchain_store_1536_metadata"],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"class": "Embedchain_store_1526_metadata",
|
"class": "Embedchain_store_1536_metadata",
|
||||||
"vectorizer": "none",
|
"vectorizer": "none",
|
||||||
"properties": [
|
"properties": [
|
||||||
{
|
{
|
||||||
@@ -88,7 +88,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
|
|
||||||
# Assert that the Weaviate client was initialized
|
# Assert that the Weaviate client was initialized
|
||||||
weaviate_mock.Client.assert_called_once()
|
weaviate_mock.Client.assert_called_once()
|
||||||
self.assertEqual(db.index_name, "Embedchain_store_1526")
|
self.assertEqual(db.index_name, "Embedchain_store_1536")
|
||||||
weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
|
weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
|
||||||
|
|
||||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||||
@@ -97,7 +97,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||||
|
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -117,7 +117,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
|
|
||||||
# Set the embedder
|
# Set the embedder
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -126,30 +126,21 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
db.BATCH_SIZE = 1
|
db.BATCH_SIZE = 1
|
||||||
|
|
||||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
documents = ["This is test document"]
|
||||||
documents = ["This is a test document.", "This is another test document."]
|
metadatas = [None]
|
||||||
metadatas = [None, None]
|
ids = ["id_1"]
|
||||||
ids = ["123", "456"]
|
db.add(documents, metadatas, ids)
|
||||||
db.add(embeddings, documents, metadatas, ids)
|
|
||||||
|
|
||||||
# Check if the document was added to the database.
|
# Check if the document was added to the database.
|
||||||
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
|
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
|
||||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
||||||
data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
|
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
|
||||||
)
|
|
||||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
|
||||||
data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
||||||
data_object={"identifier": ids[0], "text": documents[0]},
|
data_object={"text": documents[0]},
|
||||||
class_name="Embedchain_store_1526",
|
class_name="Embedchain_store_1536_metadata",
|
||||||
vector=embeddings[0],
|
vector=[1, 2, 3],
|
||||||
)
|
|
||||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
|
||||||
data_object={"identifier": ids[1], "text": documents[1]},
|
|
||||||
class_name="Embedchain_store_1526",
|
|
||||||
vector=embeddings[1],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||||
@@ -161,7 +152,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
|
|
||||||
# Set the embedder
|
# Set the embedder
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -172,7 +163,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
# Query for the document.
|
# Query for the document.
|
||||||
db.query(input_query=["This is a test document."], n_results=1, where={})
|
db.query(input_query=["This is a test document."], n_results=1, where={})
|
||||||
|
|
||||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||||
|
|
||||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||||
@@ -185,7 +176,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
|
|
||||||
# Set the embedder
|
# Set the embedder
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -196,9 +187,9 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
# Query for the document.
|
# Query for the document.
|
||||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||||
|
|
||||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
|
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||||
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
|
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
|
||||||
)
|
)
|
||||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||||
|
|
||||||
@@ -210,7 +201,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
|
|
||||||
# Set the embedder
|
# Set the embedder
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -222,7 +213,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
db.reset()
|
db.reset()
|
||||||
|
|
||||||
weaviate_client_batch_mock.delete_objects.assert_called_once_with(
|
weaviate_client_batch_mock.delete_objects.assert_called_once_with(
|
||||||
"Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
|
"Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||||
@@ -233,7 +224,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
|
|
||||||
# Set the embedder
|
# Set the embedder
|
||||||
embedder = BaseEmbedder()
|
embedder = BaseEmbedder()
|
||||||
embedder.set_vector_dimension(1526)
|
embedder.set_vector_dimension(1536)
|
||||||
embedder.set_embedding_fn(mock_embedding_fn)
|
embedder.set_embedding_fn(mock_embedding_fn)
|
||||||
|
|
||||||
# Create a Weaviate instance
|
# Create a Weaviate instance
|
||||||
@@ -244,4 +235,4 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
# Reset the database.
|
# Reset the database.
|
||||||
db.count()
|
db.count()
|
||||||
|
|
||||||
weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")
|
weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")
|
||||||
|
|||||||
Reference in New Issue
Block a user