Add support for image dataset (#571)
Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
@@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB):
|
||||
def get_advanced(self, where):
|
||||
return self.collection.get(where=where, limit=1)
|
||||
|
||||
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
|
||||
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
|
||||
ids: List[str], skip_embedding: bool) -> Any:
|
||||
"""
|
||||
Add vectors to chroma database
|
||||
|
||||
@@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB):
|
||||
:param ids: ids
|
||||
:type ids: List[str]
|
||||
"""
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
if skip_embedding:
|
||||
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
|
||||
else:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
|
||||
"""
|
||||
@@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB):
|
||||
)
|
||||
]
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
Query contents from vector data base based on vector similarity
|
||||
|
||||
@@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB):
|
||||
:rtype: List[str]
|
||||
"""
|
||||
try:
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
if skip_embedding:
|
||||
result = self.collection.query(
|
||||
query_embeddings=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
else:
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
)
|
||||
except InvalidDimensionException as e:
|
||||
raise InvalidDimensionException(
|
||||
e.message()
|
||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
||||
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
|
||||
) from None
|
||||
|
||||
results_formatted = self._format_result(result)
|
||||
contents = [result[0].page_content for result in results_formatted]
|
||||
return contents
|
||||
|
||||
Reference in New Issue
Block a user