[Improvements] Package improvements (#993)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-05 23:42:45 -08:00
committed by GitHub
parent 1d4e00ccef
commit 51b4966801
13 changed files with 96 additions and 40 deletions

View File

@@ -41,7 +41,6 @@ class BaseChunker(JSONSerializable):
url = meta_data["url"] url = meta_data["url"]
chunks = self.get_chunks(content) chunks = self.get_chunks(content)
for chunk in chunks: for chunk in chunks:
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest() chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id

View File

@@ -1,5 +1,5 @@
from importlib import import_module from importlib import import_module
from typing import Any, Dict from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig from embedchain.config import AddConfig
@@ -16,7 +16,13 @@ class DataFormatter(JSONSerializable):
.add or .add_local method call .add or .add_local method call
""" """
def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]): def __init__(
self,
data_type: DataType,
config: AddConfig,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
):
""" """
Initialize a dataformatter, set data type and chunker based on datatype. Initialize a dataformatter, set data type and chunker based on datatype.
@@ -25,15 +31,15 @@ class DataFormatter(JSONSerializable):
:param config: AddConfig instance with nested loader and chunker config attributes. :param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig :type config: AddConfig
""" """
self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs) self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs) self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
def _lazy_load(self, module_path: str): def _lazy_load(self, module_path: str):
module_path, class_name = module_path.rsplit(".", 1) module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path) module = import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader: def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
""" """
Returns the appropriate data loader for the given data type. Returns the appropriate data loader for the given data type.
@@ -68,8 +74,8 @@ class DataFormatter(JSONSerializable):
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader", DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
} }
if data_type == DataType.CUSTOM or ("loader" in kwargs): if data_type == DataType.CUSTOM or loader is not None:
loader_class: type = kwargs.get("loader", None) loader_class: type = loader
if loader_class: if loader_class:
return loader_class return loader_class
elif data_type in loaders: elif data_type in loaders:
@@ -82,7 +88,7 @@ class DataFormatter(JSONSerializable):
check `https://docs.embedchain.ai/data-sources/overview`." check `https://docs.embedchain.ai/data-sources/overview`."
) )
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker: def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker:
"""Returns the appropriate chunker for the given data type (updated for lazy loading).""" """Returns the appropriate chunker for the given data type (updated for lazy loading)."""
chunker_classes = { chunker_classes = {
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker", DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -108,12 +114,8 @@ class DataFormatter(JSONSerializable):
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker", DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
} }
if "chunker" in kwargs: if chunker is not None:
chunker_class = kwargs.get("chunker", None) return chunker
if chunker_class:
chunker = chunker_class(config)
chunker.set_data_type(data_type)
return chunker
elif data_type in chunker_classes: elif data_type in chunker_classes:
chunker_class = self._lazy_load(chunker_classes[data_type]) chunker_class = self._lazy_load(chunker_classes[data_type])
chunker = chunker_class(config) chunker = chunker_class(config)

View File

@@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None, config: Optional[AddConfig] = None,
dry_run=False, dry_run=False,
**kwargs: Dict[str, Any], loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Optional[Dict[str, Any]],
): ):
""" """
Adds the data from the given URL to the vector db. Adds the data from the given URL to the vector db.
@@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable):
self.user_asks.append([source, data_type.value, metadata]) self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, kwargs) data_formatter = DataFormatter(data_type, config, loader, chunker)
documents, metadatas, _ids, new_chunks = self._load_and_embed( documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
) )
if data_type in {DataType.DOCS_SITE}: if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True self.is_docs_site_instance = True
@@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable):
data_type: Optional[DataType] = None, data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None, config: Optional[AddConfig] = None,
**kwargs: Dict[str, Any], **kwargs: Optional[Dict[str, Any]],
): ):
""" """
Adds the data from the given URL to the vector db. Adds the data from the given URL to the vector db.
@@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable):
data_type=data_type, data_type=data_type,
metadata=metadata, metadata=metadata,
config=config, config=config,
kwargs=kwargs, **kwargs,
) )
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
@@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
source_hash: Optional[str] = None, source_hash: Optional[str] = None,
dry_run=False, dry_run=False,
**kwargs: Optional[Dict[str, Any]],
): ):
""" """
Loads the data from the given URL, chunks it, and adds it to database. Loads the data from the given URL, chunks it, and adds it to database.
@@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable):
metadatas=metadatas, metadatas=metadatas,
ids=ids, ids=ids,
skip_embedding=(chunker.data_type == DataType.IMAGES), skip_embedding=(chunker.data_type == DataType.IMAGES),
**kwargs,
) )
count_new_chunks = self.db.count() - chunks_before_addition count_new_chunks = self.db.count() - chunks_before_addition
@@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable):
] ]
def _retrieve_from_database( def _retrieve_from_database(
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
where=None,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
@@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable):
where=where, where=where,
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
citations=citations, citations=citations,
**kwargs,
) )
return contexts return contexts
@@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable):
or the dry run result or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
""" """
citations = kwargs.get("citations", False) if "citations" in kwargs:
citations = kwargs.pop("citations")
else:
citations = False
contexts = self._retrieve_from_database( contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations input_query=input_query, config=config, where=where, citations=citations, **kwargs
) )
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
@@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable):
or the dry run result or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
""" """
citations = kwargs.get("citations", False) if "citations" in kwargs:
citations = kwargs.pop("citations")
else:
citations = False
contexts = self._retrieve_from_database( contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations input_query=input_query, config=config, where=where, citations=citations, **kwargs
) )
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))

View File

@@ -196,7 +196,6 @@ class GithubLoader(BaseLoader):
logging.info(f"Total repos found: {repos_results.totalCount}") logging.info(f"Total repos found: {repos_results.totalCount}")
for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"): for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
teams = repo_result.get_teams() teams = repo_result.get_teams()
# import pdb; pdb.set_trace()
for team in teams: for team in teams:
team_discussions = team.get_discussions() team_discussions = team.get_discussions()
for discussion in team_discussions: for discussion in team_discussions:

View File

@@ -1,3 +1,4 @@
import itertools
import json import json
import logging import logging
import os import os
@@ -6,6 +7,7 @@ import string
from typing import Any from typing import Any
from schema import Optional, Or, Schema from schema import Optional, Or, Schema
from tqdm import tqdm
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
@@ -422,3 +424,16 @@ def validate_config(config_data):
) )
return schema.validate(config_data) return schema.validate(config_data)
def chunks(iterable, batch_size=100, desc="Processing chunks"):
"""A helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
total_size = len(iterable)
with tqdm(total=total_size, desc=desc, unit="batch") as pbar:
chunk = tuple(itertools.islice(it, batch_size))
while chunk:
yield chunk
pbar.update(len(chunk))
chunk = tuple(itertools.islice(it, batch_size))

View File

@@ -133,6 +133,7 @@ class ChromaDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, Any]],
) -> Any: ) -> Any:
""" """
Add vectors to chroma database Add vectors to chroma database
@@ -198,6 +199,7 @@ class ChromaDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
@@ -225,6 +227,7 @@ class ChromaDB(BaseVectorDB):
], ],
n_results=n_results, n_results=n_results,
where=self._generate_where_clause(where), where=self._generate_where_clause(where),
**kwargs,
) )
else: else:
result = self.collection.query( result = self.collection.query(
@@ -233,6 +236,7 @@ class ChromaDB(BaseVectorDB):
], ],
n_results=n_results, n_results=n_results,
where=self._generate_where_clause(where), where=self._generate_where_clause(where),
**kwargs,
) )
except InvalidDimensionException as e: except InvalidDimensionException as e:
raise InvalidDimensionException( raise InvalidDimensionException(

View File

@@ -105,6 +105,7 @@ class ElasticsearchDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
) -> Any: ) -> Any:
""" """
add data in vector database add data in vector database
@@ -142,6 +143,7 @@ class ElasticsearchDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity

View File

@@ -1,6 +1,6 @@
import logging import logging
import time import time
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
@@ -121,6 +121,7 @@ class OpenSearchDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
): ):
"""Add data in vector database. """Add data in vector database.
@@ -154,7 +155,7 @@ class OpenSearchDB(BaseVectorDB):
] ]
# Perform bulk operation # Perform bulk operation
bulk(self.client, batch_entries) bulk(self.client, batch_entries, **kwargs)
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
# Sleep to avoid rate limiting # Sleep to avoid rate limiting
@@ -167,6 +168,7 @@ class OpenSearchDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -209,6 +211,7 @@ class OpenSearchDB(BaseVectorDB):
metadata_field="metadata", metadata_field="metadata",
pre_filter=pre_filter, pre_filter=pre_filter,
k=n_results, k=n_results,
**kwargs,
) )
contexts = [] contexts = []

View File

@@ -10,6 +10,7 @@ except ImportError:
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils import chunks
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -92,6 +93,7 @@ class PineconeDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
): ):
"""add data in vector database """add data in vector database
@@ -104,7 +106,6 @@ class PineconeDB(BaseVectorDB):
""" """
docs = [] docs = []
print("Adding documents to Pinecone...") print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append( docs.append(
@@ -115,8 +116,8 @@ class PineconeDB(BaseVectorDB):
} }
) )
for i in range(0, len(docs), self.BATCH_SIZE): for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."):
self.client.upsert(docs[i : i + self.BATCH_SIZE]) self.client.upsert(chunk, **kwargs)
def query( def query(
self, self,
@@ -125,6 +126,7 @@ class PineconeDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
@@ -146,7 +148,7 @@ class PineconeDB(BaseVectorDB):
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
else: else:
query_vector = input_query query_vector = input_query
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
contexts = [] contexts = []
for doc in data["matches"]: for doc in data["matches"]:
metadata = doc["metadata"] metadata = doc["metadata"]

View File

@@ -1,7 +1,7 @@
import copy import copy
import os import os
import uuid import uuid
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
try: try:
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
@@ -127,6 +127,7 @@ class QdrantDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
): ):
"""add data in vector database """add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added :param embeddings: list of embeddings for the corresponding documents to be added
@@ -158,6 +159,7 @@ class QdrantDB(BaseVectorDB):
payloads=payloads[i : i + self.BATCH_SIZE], payloads=payloads[i : i + self.BATCH_SIZE],
vectors=embeddings[i : i + self.BATCH_SIZE], vectors=embeddings[i : i + self.BATCH_SIZE],
), ),
**kwargs,
) )
def query( def query(
@@ -167,6 +169,7 @@ class QdrantDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
@@ -208,6 +211,7 @@ class QdrantDB(BaseVectorDB):
query_filter=models.Filter(must=qdrant_must_filters), query_filter=models.Filter(must=qdrant_must_filters),
query_vector=query_vector, query_vector=query_vector,
limit=n_results, limit=n_results,
**kwargs,
) )
contexts = [] contexts = []

View File

@@ -1,6 +1,6 @@
import copy import copy
import os import os
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
try: try:
import weaviate import weaviate
@@ -158,6 +158,7 @@ class WeaviateDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
): ):
"""add data in vector database """add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added :param embeddings: list of embeddings for the corresponding documents to be added
@@ -192,7 +193,9 @@ class WeaviateDB(BaseVectorDB):
class_name=self.index_name + "_metadata", class_name=self.index_name + "_metadata",
vector=embedding, vector=embedding,
) )
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") batch.add_reference(
obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
)
def query( def query(
self, self,
@@ -201,6 +204,7 @@ class WeaviateDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from embedchain.config import ZillizDBConfig from embedchain.config import ZillizDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -113,6 +113,7 @@ class ZillizVectorDB(BaseVectorDB):
metadatas: List[object], metadatas: List[object],
ids: List[str], ids: List[str],
skip_embedding: bool, skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
): ):
"""Add to database""" """Add to database"""
if not skip_embedding: if not skip_embedding:
@@ -120,7 +121,7 @@ class ZillizVectorDB(BaseVectorDB):
for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings): for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
data = {**metadata, "id": id, "text": doc, "embeddings": embedding} data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
self.client.insert(collection_name=self.config.collection_name, data=data) self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs)
self.collection.load() self.collection.load()
self.collection.flush() self.collection.flush()
@@ -133,6 +134,7 @@ class ZillizVectorDB(BaseVectorDB):
where: Dict[str, any], where: Dict[str, any],
skip_embedding: bool, skip_embedding: bool,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
Query contents from vector data base based on vector similarity Query contents from vector data base based on vector similarity
@@ -165,6 +167,7 @@ class ZillizVectorDB(BaseVectorDB):
data=query_vector, data=query_vector,
limit=n_results, limit=n_results,
output_fields=output_fields, output_fields=output_fields,
**kwargs,
) )
else: else:
@@ -176,6 +179,7 @@ class ZillizVectorDB(BaseVectorDB):
data=[query_vector], data=[query_vector],
limit=n_results, limit=n_results,
output_fields=output_fields, output_fields=output_fields,
**kwargs,
) )
contexts = [] contexts = []

View File

@@ -57,11 +57,11 @@ class TestPinecone:
db.add(vectors, documents, metadatas, ids, True) db.add(vectors, documents, metadatas, ids, True)
expected_pinecone_upsert_args = [ expected_pinecone_upsert_args = [
{"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]}, {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
{"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]}, {"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
] ]
# Assert that the Pinecone client was called to upsert the documents # Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args) pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
@patch("embedchain.vectordb.pinecone.pinecone") @patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock): def test_query_documents(self, pinecone_mock):