[Improvements] Package improvements (#993)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-05 23:42:45 -08:00
committed by GitHub
parent 1d4e00ccef
commit 51b4966801
13 changed files with 96 additions and 40 deletions

View File

@@ -133,6 +133,7 @@ class ChromaDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, Any]],
) -> Any:
"""
Add vectors to chroma database
@@ -198,6 +199,7 @@ class ChromaDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
Query contents from vector database based on vector similarity
@@ -225,6 +227,7 @@ class ChromaDB(BaseVectorDB):
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
)
else:
result = self.collection.query(
@@ -233,6 +236,7 @@ class ChromaDB(BaseVectorDB):
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(

View File

@@ -105,6 +105,7 @@ class ElasticsearchDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
) -> Any:
"""
add data in vector database
@@ -142,6 +143,7 @@ class ElasticsearchDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector data base based on vector similarity

View File

@@ -1,6 +1,6 @@
import logging
import time
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from tqdm import tqdm
@@ -121,6 +121,7 @@ class OpenSearchDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""Add data in vector database.
@@ -154,7 +155,7 @@ class OpenSearchDB(BaseVectorDB):
]
# Perform bulk operation
bulk(self.client, batch_entries)
bulk(self.client, batch_entries, **kwargs)
self.client.indices.refresh(index=self._get_index())
# Sleep to avoid rate limiting
@@ -167,6 +168,7 @@ class OpenSearchDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector data base based on vector similarity
@@ -209,6 +211,7 @@ class OpenSearchDB(BaseVectorDB):
metadata_field="metadata",
pre_filter=pre_filter,
k=n_results,
**kwargs,
)
contexts = []

View File

@@ -10,6 +10,7 @@ except ImportError:
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils import chunks
from embedchain.vectordb.base import BaseVectorDB
@@ -92,6 +93,7 @@ class PineconeDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
@@ -104,7 +106,6 @@ class PineconeDB(BaseVectorDB):
"""
docs = []
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append(
@@ -115,8 +116,8 @@ class PineconeDB(BaseVectorDB):
}
)
for i in range(0, len(docs), self.BATCH_SIZE):
self.client.upsert(docs[i : i + self.BATCH_SIZE])
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."):
self.client.upsert(chunk, **kwargs)
def query(
self,
@@ -125,6 +126,7 @@ class PineconeDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity
@@ -146,7 +148,7 @@ class PineconeDB(BaseVectorDB):
query_vector = self.embedder.embedding_fn([input_query])[0]
else:
query_vector = input_query
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
contexts = []
for doc in data["matches"]:
metadata = doc["metadata"]

View File

@@ -1,7 +1,7 @@
import copy
import os
import uuid
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
try:
from qdrant_client import QdrantClient
@@ -127,6 +127,7 @@ class QdrantDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
@@ -158,6 +159,7 @@ class QdrantDB(BaseVectorDB):
payloads=payloads[i : i + self.BATCH_SIZE],
vectors=embeddings[i : i + self.BATCH_SIZE],
),
**kwargs,
)
def query(
@@ -167,6 +169,7 @@ class QdrantDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity
@@ -208,6 +211,7 @@ class QdrantDB(BaseVectorDB):
query_filter=models.Filter(must=qdrant_must_filters),
query_vector=query_vector,
limit=n_results,
**kwargs,
)
contexts = []

View File

@@ -1,6 +1,6 @@
import copy
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
try:
import weaviate
@@ -158,6 +158,7 @@ class WeaviateDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
@@ -192,7 +193,9 @@ class WeaviateDB(BaseVectorDB):
class_name=self.index_name + "_metadata",
vector=embedding,
)
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
batch.add_reference(
obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
)
def query(
self,
@@ -201,6 +204,7 @@ class WeaviateDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from embedchain.config import ZillizDBConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -113,6 +113,7 @@ class ZillizVectorDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""Add to database"""
if not skip_embedding:
@@ -120,7 +121,7 @@ class ZillizVectorDB(BaseVectorDB):
for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
self.client.insert(collection_name=self.config.collection_name, data=data)
self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs)
self.collection.load()
self.collection.flush()
@@ -133,6 +134,7 @@ class ZillizVectorDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
Query contents from vector data base based on vector similarity
@@ -165,6 +167,7 @@ class ZillizVectorDB(BaseVectorDB):
data=query_vector,
limit=n_results,
output_fields=output_fields,
**kwargs,
)
else:
@@ -176,6 +179,7 @@ class ZillizVectorDB(BaseVectorDB):
data=[query_vector],
limit=n_results,
output_fields=output_fields,
**kwargs,
)
contexts = []