[Feature] Update db.query to return source of context (#831)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -174,4 +174,5 @@ test-db
|
|||||||
|
|
||||||
notebooks/*.yaml
|
notebooks/*.yaml
|
||||||
.ipynb_checkpoints/
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
!configs/*.yaml
|
!configs/*.yaml
|
||||||
|
|||||||
@@ -500,13 +500,17 @@ class EmbedChain(JSONSerializable):
|
|||||||
|
|
||||||
db_query = ClipProcessor.get_text_features(query=input_query)
|
db_query = ClipProcessor.get_text_features(query=input_query)
|
||||||
|
|
||||||
contents = self.db.query(
|
contexts = self.db.query(
|
||||||
input_query=db_query,
|
input_query=db_query,
|
||||||
n_results=query_config.number_documents,
|
n_results=query_config.number_documents,
|
||||||
where=where,
|
where=where,
|
||||||
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
||||||
)
|
)
|
||||||
return contents
|
|
||||||
|
if len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||||
|
contexts = list(map(lambda x: x[0], contexts))
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
|
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -41,15 +41,15 @@ class LlmFactory:
|
|||||||
|
|
||||||
class EmbedderFactory:
|
class EmbedderFactory:
|
||||||
provider_to_class = {
|
provider_to_class = {
|
||||||
|
"azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
||||||
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
|
"gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder",
|
||||||
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
|
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
|
||||||
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
|
||||||
"azure_openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
|
||||||
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
|
||||||
|
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
||||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
|
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -72,16 +72,18 @@ class VectorDBFactory:
|
|||||||
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
|
||||||
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
|
||||||
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
|
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
|
||||||
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
|
|
||||||
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
|
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
|
||||||
|
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
|
||||||
|
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
|
||||||
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
|
||||||
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
|
||||||
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
|
||||||
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
|
|
||||||
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
|
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
|
||||||
|
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
|
||||||
|
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from chromadb import Collection, QueryResult
|
from chromadb import Collection, QueryResult
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@@ -191,7 +191,9 @@ class ChromaDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector database based on vector similarity
|
Query contents from vector database based on vector similarity
|
||||||
|
|
||||||
@@ -204,8 +206,8 @@ class ChromaDB(BaseVectorDB):
|
|||||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:raises InvalidDimensionException: Dimensions do not match.
|
:raises InvalidDimensionException: Dimensions do not match.
|
||||||
:return: The content of the document that matched your query.
|
:return: The content of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
@@ -231,8 +233,14 @@ class ChromaDB(BaseVectorDB):
|
|||||||
" embeddings, is used to retrieve an embedding from the database."
|
" embeddings, is used to retrieve an embedding from the database."
|
||||||
) from None
|
) from None
|
||||||
results_formatted = self._format_result(result)
|
results_formatted = self._format_result(result)
|
||||||
contents = [result[0].page_content for result in results_formatted]
|
contexts = []
|
||||||
return contents
|
for result in results_formatted:
|
||||||
|
context = result[0].page_content
|
||||||
|
metadata = result[0].metadata
|
||||||
|
source = metadata["url"]
|
||||||
|
doc_id = metadata["doc_id"]
|
||||||
|
contexts.append((context, source, doc_id))
|
||||||
|
return contexts
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
@@ -135,7 +135,9 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
bulk(self.client, docs)
|
bulk(self.client, docs)
|
||||||
self.client.indices.refresh(index=self._get_index())
|
self.client.indices.refresh(index=self._get_index())
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector data base based on vector similarity
|
query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -147,8 +149,9 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
:type where: Dict[str, any]
|
:type where: Dict[str, any]
|
||||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: Database contents that are the result of the query
|
:return: The context of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
|
||||||
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
@@ -156,6 +159,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||||
query_vector = input_query_vector[0]
|
query_vector = input_query_vector[0]
|
||||||
|
|
||||||
|
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
|
||||||
query = {
|
query = {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},
|
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},
|
||||||
@@ -167,11 +171,17 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
}
|
}
|
||||||
if "app_id" in where:
|
if "app_id" in where:
|
||||||
app_id = where["app_id"]
|
app_id = where["app_id"]
|
||||||
query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}]
|
query["script_score"]["query"] = {"match": {"metadata.app_id": app_id}}
|
||||||
_source = ["text"]
|
_source = ["text", "metadata.url", "metadata.doc_id"]
|
||||||
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
|
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
|
||||||
docs = response["hits"]["hits"]
|
docs = response["hits"]["hits"]
|
||||||
contents = [doc["_source"]["text"] for doc in docs]
|
contents = []
|
||||||
|
for doc in docs:
|
||||||
|
context = doc["_source"]["text"]
|
||||||
|
metadata = doc["_source"]["metadata"]
|
||||||
|
source = metadata["url"]
|
||||||
|
doc_id = metadata["doc_id"]
|
||||||
|
contents.append(tuple((context, source, doc_id)))
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch
|
from opensearchpy import OpenSearch
|
||||||
@@ -145,7 +145,9 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
bulk(self.client, docs)
|
bulk(self.client, docs)
|
||||||
self.client.indices.refresh(index=self._get_index())
|
self.client.indices.refresh(index=self._get_index())
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector data base based on vector similarity
|
query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -157,8 +159,8 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
:type where: Dict[str, any]
|
:type where: Dict[str, any]
|
||||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: Database contents that are the result of the query
|
:return: The content of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
|
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
@@ -185,7 +187,13 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
pre_filter=pre_filter,
|
pre_filter=pre_filter,
|
||||||
k=n_results,
|
k=n_results,
|
||||||
)
|
)
|
||||||
contents = [doc.page_content for doc in docs]
|
|
||||||
|
contents = []
|
||||||
|
for doc in docs:
|
||||||
|
context = doc.page_content
|
||||||
|
source = doc.metadata["url"]
|
||||||
|
doc_id = doc.metadata["doc_id"]
|
||||||
|
contents.append(tuple((context, source, doc_id)))
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pinecone
|
import pinecone
|
||||||
@@ -118,7 +118,9 @@ class PineconeDB(BaseVectorDB):
|
|||||||
for i in range(0, len(docs), self.BATCH_SIZE):
|
for i in range(0, len(docs), self.BATCH_SIZE):
|
||||||
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -129,16 +131,22 @@ class PineconeDB(BaseVectorDB):
|
|||||||
:type where: Dict[str, any]
|
:type where: Dict[str, any]
|
||||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: Database contents that are the result of the query
|
:return: The content of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
else:
|
else:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
||||||
embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
|
contents = []
|
||||||
return embeddings
|
for doc in data["matches"]:
|
||||||
|
metadata = doc["metadata"]
|
||||||
|
context = metadata["text"]
|
||||||
|
source = metadata["url"]
|
||||||
|
doc_id = metadata["doc_id"]
|
||||||
|
contents.append(tuple((context, source, doc_id)))
|
||||||
|
return contents
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
@@ -160,7 +160,9 @@ class QdrantDB(BaseVectorDB):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -172,8 +174,8 @@ class QdrantDB(BaseVectorDB):
|
|||||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||||
generated or not
|
generated or not
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: Database contents that are the result of the query
|
:return: The context of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
@@ -199,9 +201,14 @@ class QdrantDB(BaseVectorDB):
|
|||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=n_results,
|
limit=n_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = []
|
response = []
|
||||||
for result in results:
|
for result in results:
|
||||||
response.append(result.payload.get("text", ""))
|
context = result.payload["text"]
|
||||||
|
metadata = result.payload["metadata"]
|
||||||
|
source = metadata["url"]
|
||||||
|
doc_id = metadata["doc_id"]
|
||||||
|
response.append(tuple((context, source, doc_id)))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
@@ -211,3 +218,15 @@ class QdrantDB(BaseVectorDB):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.client.delete_collection(collection_name=self.collection_name)
|
self.client.delete_collection(collection_name=self.collection_name)
|
||||||
self._initialize()
|
self._initialize()
|
||||||
|
|
||||||
|
def set_collection_name(self, name: str):
|
||||||
|
"""
|
||||||
|
Set the name of the collection. A collection is an isolated space for vectors.
|
||||||
|
|
||||||
|
:param name: Name of the collection.
|
||||||
|
:type name: str
|
||||||
|
"""
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise TypeError("Collection name must be a string")
|
||||||
|
self.config.collection_name = name
|
||||||
|
self.collection_name = self._get_or_create_collection()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
@@ -194,7 +194,9 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
|
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -206,14 +208,15 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||||
generated or not
|
generated or not
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: Database contents that are the result of the query
|
:return: The context of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
else:
|
else:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
keys = set(where.keys() if where is not None else set())
|
keys = set(where.keys() if where is not None else set())
|
||||||
|
data_fields = ["text"]
|
||||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||||
weaviate_where_operands = []
|
weaviate_where_operands = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@@ -231,7 +234,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
||||||
|
|
||||||
results = (
|
results = (
|
||||||
self.client.query.get(self.index_name, ["text"])
|
self.client.query.get(self.index_name, data_fields)
|
||||||
.with_where(weaviate_where_clause)
|
.with_where(weaviate_where_clause)
|
||||||
.with_near_vector({"vector": query_vector})
|
.with_near_vector({"vector": query_vector})
|
||||||
.with_limit(n_results)
|
.with_limit(n_results)
|
||||||
@@ -239,16 +242,13 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results = (
|
results = (
|
||||||
self.client.query.get(self.index_name, ["text"])
|
self.client.query.get(self.index_name, data_fields)
|
||||||
.with_near_vector({"vector": query_vector})
|
.with_near_vector({"vector": query_vector})
|
||||||
.with_limit(n_results)
|
.with_limit(n_results)
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
matched_tokens = []
|
contexts = results["data"]["Get"].get(self.index_name)
|
||||||
for result in results["data"]["Get"].get(self.index_name):
|
return contexts
|
||||||
matched_tokens.append(result["text"])
|
|
||||||
|
|
||||||
return matched_tokens
|
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, List, Optional
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from embedchain.config import ZillizDBConfig
|
from embedchain.config import ZillizDBConfig
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
@@ -61,6 +62,7 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
:type name: str
|
:type name: str
|
||||||
"""
|
"""
|
||||||
if utility.has_collection(name):
|
if utility.has_collection(name):
|
||||||
|
logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
|
||||||
self.collection = Collection(name)
|
self.collection = Collection(name)
|
||||||
else:
|
else:
|
||||||
fields = [
|
fields = [
|
||||||
@@ -124,7 +126,9 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
self.collection.flush()
|
self.collection.flush()
|
||||||
self.client.flush(self.config.collection_name)
|
self.client.flush(self.config.collection_name)
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
def query(
|
||||||
|
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||||
|
) -> List[Tuple[str, str, str]]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector data base based on vector similarity
|
Query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -135,8 +139,8 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
:param where: to filter data
|
:param where: to filter data
|
||||||
:type where: str
|
:type where: str
|
||||||
:raises InvalidDimensionException: Dimensions do not match.
|
:raises InvalidDimensionException: Dimensions do not match.
|
||||||
:return: The content of the document that matched your query.
|
:return: The context of the document that matched your query, url of the source, doc_id
|
||||||
:rtype: List[str]
|
:rtype: List[Tuple[str,str,str]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.collection.is_empty:
|
if self.collection.is_empty:
|
||||||
@@ -145,13 +149,14 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
if not isinstance(where, str):
|
if not isinstance(where, str):
|
||||||
where = None
|
where = None
|
||||||
|
|
||||||
|
output_fields = ["text", "url", "doc_id"]
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
query_result = self.client.search(
|
query_result = self.client.search(
|
||||||
collection_name=self.config.collection_name,
|
collection_name=self.config.collection_name,
|
||||||
data=query_vector,
|
data=query_vector,
|
||||||
limit=n_results,
|
limit=n_results,
|
||||||
output_fields=["text"],
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -162,13 +167,16 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
collection_name=self.config.collection_name,
|
collection_name=self.config.collection_name,
|
||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
limit=n_results,
|
limit=n_results,
|
||||||
output_fields=["text"],
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
doc_list = []
|
doc_list = []
|
||||||
for query in query_result:
|
for query in query_result:
|
||||||
doc_list.append(query[0]["entity"]["text"])
|
data = query[0]["entity"]
|
||||||
|
context = data["text"]
|
||||||
|
source = data["url"]
|
||||||
|
doc_id = data["doc_id"]
|
||||||
|
doc_list.append(tuple((context, source, doc_id)))
|
||||||
return doc_list
|
return doc_list
|
||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
|||||||
app_with_settings.db.add(
|
app_with_settings.db.add(
|
||||||
embeddings=[[0, 0, 0]],
|
embeddings=[[0, 0, 0]],
|
||||||
documents=["document"],
|
documents=["document"],
|
||||||
metadatas=[{"value": "somevalue"}],
|
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||||
ids=["id"],
|
ids=["id"],
|
||||||
skip_embedding=True,
|
skip_embedding=True,
|
||||||
)
|
)
|
||||||
@@ -158,13 +158,13 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
|||||||
"documents": ["document"],
|
"documents": ["document"],
|
||||||
"embeddings": None,
|
"embeddings": None,
|
||||||
"ids": ["id"],
|
"ids": ["id"],
|
||||||
"metadatas": [{"value": "somevalue"}],
|
"metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||||
}
|
}
|
||||||
|
|
||||||
assert data == expected_value
|
assert data == expected_value
|
||||||
|
|
||||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
||||||
expected_value = ["document"]
|
expected_value = [("document", "url_1", "doc_id_1")]
|
||||||
|
|
||||||
assert data == expected_value
|
assert data == expected_value
|
||||||
app_with_settings.db.reset()
|
app_with_settings.db.reset()
|
||||||
@@ -299,3 +299,35 @@ def test_chroma_db_collection_reset():
|
|||||||
app2.db.reset()
|
app2.db.reset()
|
||||||
app3.db.reset()
|
app3.db.reset()
|
||||||
app4.db.reset()
|
app4.db.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_db_collection_query(app_with_settings):
|
||||||
|
app_with_settings.db.reset()
|
||||||
|
|
||||||
|
assert app_with_settings.db.count() == 0
|
||||||
|
|
||||||
|
app_with_settings.db.add(
|
||||||
|
embeddings=[[0, 0, 0]],
|
||||||
|
documents=["document"],
|
||||||
|
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||||
|
ids=["id"],
|
||||||
|
skip_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert app_with_settings.db.count() == 1
|
||||||
|
|
||||||
|
app_with_settings.db.add(
|
||||||
|
embeddings=[[0, 1, 0]],
|
||||||
|
documents=["document2"],
|
||||||
|
metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
|
||||||
|
ids=["id2"],
|
||||||
|
skip_embedding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert app_with_settings.db.count() == 2
|
||||||
|
|
||||||
|
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
|
||||||
|
expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
|
||||||
|
|
||||||
|
assert data == expected_value
|
||||||
|
app_with_settings.db.reset()
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestEsDB(unittest.TestCase):
|
|||||||
# Create some dummy data.
|
# Create some dummy data.
|
||||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||||
documents = ["This is a document.", "This is another document."]
|
documents = ["This is a document.", "This is another document."]
|
||||||
metadatas = [{}, {}]
|
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||||
ids = ["doc_1", "doc_2"]
|
ids = ["doc_1", "doc_2"]
|
||||||
|
|
||||||
# Add the data to the database.
|
# Add the data to the database.
|
||||||
@@ -40,8 +40,17 @@ class TestEsDB(unittest.TestCase):
|
|||||||
search_response = {
|
search_response = {
|
||||||
"hits": {
|
"hits": {
|
||||||
"hits": [
|
"hits": [
|
||||||
{"_source": {"text": "This is a document."}, "_score": 0.9},
|
{
|
||||||
{"_source": {"text": "This is another document."}, "_score": 0.8},
|
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||||
|
"_score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_source": {
|
||||||
|
"text": "This is another document.",
|
||||||
|
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||||
|
},
|
||||||
|
"_score": 0.8,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -54,7 +63,9 @@ class TestEsDB(unittest.TestCase):
|
|||||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||||
|
|
||||||
# Assert that the results are correct.
|
# Assert that the results are correct.
|
||||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
self.assertEqual(
|
||||||
|
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
||||||
|
)
|
||||||
|
|
||||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||||
def test_query_with_skip_embedding(self, mock_client):
|
def test_query_with_skip_embedding(self, mock_client):
|
||||||
@@ -68,7 +79,7 @@ class TestEsDB(unittest.TestCase):
|
|||||||
# Create some dummy data.
|
# Create some dummy data.
|
||||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||||
documents = ["This is a document.", "This is another document."]
|
documents = ["This is a document.", "This is another document."]
|
||||||
metadatas = [{}, {}]
|
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||||
ids = ["doc_1", "doc_2"]
|
ids = ["doc_1", "doc_2"]
|
||||||
|
|
||||||
# Add the data to the database.
|
# Add the data to the database.
|
||||||
@@ -77,8 +88,17 @@ class TestEsDB(unittest.TestCase):
|
|||||||
search_response = {
|
search_response = {
|
||||||
"hits": {
|
"hits": {
|
||||||
"hits": [
|
"hits": [
|
||||||
{"_source": {"text": "This is a document."}, "_score": 0.9},
|
{
|
||||||
{"_source": {"text": "This is another document."}, "_score": 0.8},
|
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||||
|
"_score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_source": {
|
||||||
|
"text": "This is another document.",
|
||||||
|
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||||
|
},
|
||||||
|
"_score": 0.8,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -91,7 +111,9 @@ class TestEsDB(unittest.TestCase):
|
|||||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||||
|
|
||||||
# Assert that the results are correct.
|
# Assert that the results are correct.
|
||||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
self.assertEqual(
|
||||||
|
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
||||||
|
)
|
||||||
|
|
||||||
def test_init_without_url(self):
|
def test_init_without_url(self):
|
||||||
# Make sure it's not loaded from env
|
# Make sure it's not loaded from env
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class TestZillizDBCollection:
|
|||||||
# Mock the MilvusClient search method
|
# Mock the MilvusClient search method
|
||||||
with patch.object(zilliz_db.client, "search") as mock_search:
|
with patch.object(zilliz_db.client, "search") as mock_search:
|
||||||
# Mock the search result
|
# Mock the search result
|
||||||
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
|
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
|
||||||
|
|
||||||
# Call the query method with skip_embedding=True
|
# Call the query method with skip_embedding=True
|
||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||||
@@ -133,11 +133,11 @@ class TestZillizDBCollection:
|
|||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_text"],
|
data=["query_text"],
|
||||||
limit=1,
|
limit=1,
|
||||||
output_fields=["text"],
|
output_fields=["text", "url", "doc_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the query result matches the expected result
|
# Assert that the query result matches the expected result
|
||||||
assert query_result == ["result_doc"]
|
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
||||||
|
|
||||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||||
@@ -162,7 +162,7 @@ class TestZillizDBCollection:
|
|||||||
mock_embedder.embedding_fn.return_value = ["query_vector"]
|
mock_embedder.embedding_fn.return_value = ["query_vector"]
|
||||||
|
|
||||||
# Mock the search result
|
# Mock the search result
|
||||||
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
|
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
|
||||||
|
|
||||||
# Call the query method with skip_embedding=False
|
# Call the query method with skip_embedding=False
|
||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||||
@@ -172,8 +172,8 @@ class TestZillizDBCollection:
|
|||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_vector"],
|
data=["query_vector"],
|
||||||
limit=1,
|
limit=1,
|
||||||
output_fields=["text"],
|
output_fields=["text", "url", "doc_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the query result matches the expected result
|
# Assert that the query result matches the expected result
|
||||||
assert query_result == ["result_doc"]
|
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
||||||
|
|||||||
Reference in New Issue
Block a user