Add support for image dataset (#571)

Co-authored-by: Rupesh Bansal <rupeshbansal@Shankars-MacBook-Air.local>
This commit is contained in:
Rupesh Bansal
2023-10-04 09:50:40 +05:30
committed by GitHub
parent 55e9a1cbd6
commit d0af018b8d
19 changed files with 498 additions and 31 deletions

View File

@@ -115,7 +115,8 @@ class ChromaDB(BaseVectorDB):
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
"""
Add vectors to chroma database
@@ -126,7 +127,10 @@ class ChromaDB(BaseVectorDB):
:param ids: ids
:type ids: List[str]
"""
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
if skip_embedding:
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
else:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
"""
@@ -146,7 +150,7 @@ class ChromaDB(BaseVectorDB):
)
]
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
Query contents from vector data base based on vector similarity
@@ -161,19 +165,27 @@ class ChromaDB(BaseVectorDB):
:rtype: List[str]
"""
try:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=where,
)
if skip_embedding:
result = self.collection.query(
query_embeddings=[
input_query,
],
n_results=n_results,
where=where,
)
else:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=where,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(
e.message()
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
+ ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
) from None
results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted]
return contents

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
try:
from elasticsearch import Elasticsearch
@@ -100,9 +100,10 @@ class ElasticsearchDB(BaseVectorDB):
ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)}
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
"""add data in vector database
def add(self, embeddings: List[List[float]], documents: List[str], metadatas: List[object],
ids: List[str], skip_embedding: bool) -> Any:
"""
add data in vector database
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
@@ -112,7 +113,9 @@ class ElasticsearchDB(BaseVectorDB):
"""
docs = []
embeddings = self.embedder.embedding_fn(documents)
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
@@ -124,7 +127,7 @@ class ElasticsearchDB(BaseVectorDB):
bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector data base based on vector similarity
@@ -137,8 +140,12 @@ class ElasticsearchDB(BaseVectorDB):
:return: Database contents that are the result of the query
:rtype: List[str]
"""
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
if skip_embedding:
query_vector = input_query
else:
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
query = {
"script_score": {
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},