[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@@ -438,7 +438,9 @@ class EmbedChain(JSONSerializable):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
def retrieve_from_database(
|
||||||
|
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
Gets relevant doc based on the query
|
Gets relevant doc based on the query
|
||||||
@@ -449,6 +451,8 @@ class EmbedChain(JSONSerializable):
|
|||||||
:type config: Optional[BaseLlmConfig], optional
|
:type config: Optional[BaseLlmConfig], optional
|
||||||
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
||||||
:type where: _type_, optional
|
:type where: _type_, optional
|
||||||
|
:param citations: A boolean to indicate if db should fetch citation source
|
||||||
|
:type citations: bool
|
||||||
:return: List of contents of the document that matched your query
|
:return: List of contents of the document that matched your query
|
||||||
:rtype: List[str]
|
:rtype: List[str]
|
||||||
"""
|
"""
|
||||||
@@ -478,14 +482,19 @@ class EmbedChain(JSONSerializable):
|
|||||||
n_results=query_config.number_documents,
|
n_results=query_config.number_documents,
|
||||||
where=where,
|
where=where,
|
||||||
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
||||||
|
citations=citations,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(contexts) > 0 and isinstance(contexts[0], tuple):
|
|
||||||
contexts = list(map(lambda x: x[0], contexts))
|
|
||||||
|
|
||||||
return contexts
|
return contexts
|
||||||
|
|
||||||
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
|
def query(
|
||||||
|
self,
|
||||||
|
input_query: str,
|
||||||
|
config: BaseLlmConfig = None,
|
||||||
|
dry_run=False,
|
||||||
|
where: Optional[Dict] = None,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
Gets relevant doc based on the query and then passes it to an
|
Gets relevant doc based on the query and then passes it to an
|
||||||
@@ -501,15 +510,31 @@ class EmbedChain(JSONSerializable):
|
|||||||
:type dry_run: bool, optional
|
:type dry_run: bool, optional
|
||||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||||
:type where: Optional[Dict[str, str]], optional
|
:type where: Optional[Dict[str, str]], optional
|
||||||
:return: The answer to the query or the dry run result
|
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||||
:rtype: str
|
param to return context along with the answer
|
||||||
|
:type kwargs: Dict[str, Any]
|
||||||
|
:return: The answer to the query, with citations if the citation flag is True
|
||||||
|
or the dry run result
|
||||||
|
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||||
"""
|
"""
|
||||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
citations = kwargs.get("citations", False)
|
||||||
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
||||||
|
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||||
|
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||||
|
else:
|
||||||
|
contexts_data_for_llm_query = contexts
|
||||||
|
|
||||||
|
answer = self.llm.query(
|
||||||
|
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||||
|
)
|
||||||
|
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
||||||
return answer
|
|
||||||
|
if citations:
|
||||||
|
return answer, contexts
|
||||||
|
else:
|
||||||
|
return answer
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
@@ -517,6 +542,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
config: Optional[BaseLlmConfig] = None,
|
config: Optional[BaseLlmConfig] = None,
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
where: Optional[Dict[str, str]] = None,
|
where: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Queries the vector database on the given input query.
|
Queries the vector database on the given input query.
|
||||||
@@ -535,15 +561,31 @@ class EmbedChain(JSONSerializable):
|
|||||||
:type dry_run: bool, optional
|
:type dry_run: bool, optional
|
||||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||||
:type where: Optional[Dict[str, str]], optional
|
:type where: Optional[Dict[str, str]], optional
|
||||||
:return: The answer to the query or the dry run result
|
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||||
:rtype: str
|
param to return context along with the answer
|
||||||
|
:type kwargs: Dict[str, Any]
|
||||||
|
:return: The answer to the query, with citations if the citation flag is True
|
||||||
|
or the dry run result
|
||||||
|
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||||
"""
|
"""
|
||||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
citations = kwargs.get("citations", False)
|
||||||
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
||||||
|
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||||
|
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||||
|
else:
|
||||||
|
contexts_data_for_llm_query = contexts
|
||||||
|
|
||||||
|
answer = self.llm.chat(
|
||||||
|
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||||
|
)
|
||||||
|
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
||||||
|
|
||||||
return answer
|
if citations:
|
||||||
|
return answer, contexts
|
||||||
|
else:
|
||||||
|
return answer
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -234,6 +234,7 @@ class Pipeline(EmbedChain):
|
|||||||
n_results=num_documents,
|
n_results=num_documents,
|
||||||
where=where,
|
where=where,
|
||||||
skip_embedding=False,
|
skip_embedding=False,
|
||||||
|
citations=True,
|
||||||
)
|
)
|
||||||
result = []
|
result = []
|
||||||
for c in context:
|
for c in context:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from chromadb import Collection, QueryResult
|
from chromadb import Collection, QueryResult
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@@ -192,8 +192,13 @@ class ChromaDB(BaseVectorDB):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector database based on vector similarity
|
Query contents from vector database based on vector similarity
|
||||||
|
|
||||||
@@ -205,9 +210,12 @@ class ChromaDB(BaseVectorDB):
|
|||||||
:type where: Dict[str, Any]
|
:type where: Dict[str, Any]
|
||||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
|
:type citations: bool, default is False.
|
||||||
:raises InvalidDimensionException: Dimensions do not match.
|
:raises InvalidDimensionException: Dimensions do not match.
|
||||||
:return: The content of the document that matched your query, url of the source, doc_id
|
:return: The content of the document that matched your query,
|
||||||
:rtype: List[Tuple[str,str,str]]
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
@@ -236,10 +244,13 @@ class ChromaDB(BaseVectorDB):
|
|||||||
contexts = []
|
contexts = []
|
||||||
for result in results_formatted:
|
for result in results_formatted:
|
||||||
context = result[0].page_content
|
context = result[0].page_content
|
||||||
metadata = result[0].metadata
|
if citations:
|
||||||
source = metadata["url"]
|
metadata = result[0].metadata
|
||||||
doc_id = metadata["doc_id"]
|
source = metadata["url"]
|
||||||
contexts.append((context, source, doc_id))
|
doc_id = metadata["doc_id"]
|
||||||
|
contexts.append((context, source, doc_id))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
return contexts
|
return contexts
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
@@ -136,8 +136,13 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
self.client.indices.refresh(index=self._get_index())
|
self.client.indices.refresh(index=self._get_index())
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector data base based on vector similarity
|
query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -150,8 +155,11 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: The context of the document that matched your query, url of the source, doc_id
|
:return: The context of the document that matched your query, url of the source, doc_id
|
||||||
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
:rtype: List[Tuple[str,str,str]]
|
:type citations: bool, default is False.
|
||||||
|
:return: The content of the document that matched your query,
|
||||||
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
if skip_embedding:
|
if skip_embedding:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
@@ -175,14 +183,17 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
_source = ["text", "metadata.url", "metadata.doc_id"]
|
_source = ["text", "metadata.url", "metadata.doc_id"]
|
||||||
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
|
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
|
||||||
docs = response["hits"]["hits"]
|
docs = response["hits"]["hits"]
|
||||||
contents = []
|
contexts = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
context = doc["_source"]["text"]
|
context = doc["_source"]["text"]
|
||||||
metadata = doc["_source"]["metadata"]
|
if citations:
|
||||||
source = metadata["url"]
|
metadata = doc["_source"]["metadata"]
|
||||||
doc_id = metadata["doc_id"]
|
source = metadata["url"]
|
||||||
contents.append(tuple((context, source, doc_id)))
|
doc_id = metadata["doc_id"]
|
||||||
return contents
|
contexts.append(tuple((context, source, doc_id)))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
|
return contexts
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch
|
from opensearchpy import OpenSearch
|
||||||
@@ -146,8 +146,13 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
self.client.indices.refresh(index=self._get_index())
|
self.client.indices.refresh(index=self._get_index())
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector data base based on vector similarity
|
query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -159,8 +164,11 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
:type where: Dict[str, any]
|
:type where: Dict[str, any]
|
||||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: The content of the document that matched your query, url of the source, doc_id
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
:rtype: List[Tuple[str,str,str]]
|
:type citations: bool, default is False.
|
||||||
|
:return: The content of the document that matched your query,
|
||||||
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
|
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
@@ -188,13 +196,16 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
k=n_results,
|
k=n_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
contents = []
|
contexts = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
context = doc.page_content
|
context = doc.page_content
|
||||||
source = doc.metadata["url"]
|
if citations:
|
||||||
doc_id = doc.metadata["doc_id"]
|
source = doc.metadata["url"]
|
||||||
contents.append(tuple((context, source, doc_id)))
|
doc_id = doc.metadata["doc_id"]
|
||||||
return contents
|
contexts.append(tuple((context, source, doc_id)))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
|
return contexts
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pinecone
|
import pinecone
|
||||||
@@ -119,8 +119,13 @@ class PineconeDB(BaseVectorDB):
|
|||||||
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -131,22 +136,28 @@ class PineconeDB(BaseVectorDB):
|
|||||||
:type where: Dict[str, any]
|
:type where: Dict[str, any]
|
||||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: The content of the document that matched your query, url of the source, doc_id
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
:rtype: List[Tuple[str,str,str]]
|
:type citations: bool, default is False.
|
||||||
|
:return: The content of the document that matched your query,
|
||||||
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
else:
|
else:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
||||||
contents = []
|
contexts = []
|
||||||
for doc in data["matches"]:
|
for doc in data["matches"]:
|
||||||
metadata = doc["metadata"]
|
metadata = doc["metadata"]
|
||||||
context = metadata["text"]
|
context = metadata["text"]
|
||||||
source = metadata["url"]
|
if citations:
|
||||||
doc_id = metadata["doc_id"]
|
source = metadata["url"]
|
||||||
contents.append(tuple((context, source, doc_id)))
|
doc_id = metadata["doc_id"]
|
||||||
return contents
|
contexts.append(tuple((context, source, doc_id)))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
|
return contexts
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
@@ -161,8 +161,13 @@ class QdrantDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -174,8 +179,11 @@ class QdrantDB(BaseVectorDB):
|
|||||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||||
generated or not
|
generated or not
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: The context of the document that matched your query, url of the source, doc_id
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
:rtype: List[Tuple[str,str,str]]
|
:type citations: bool, default is False.
|
||||||
|
:return: The content of the document that matched your query,
|
||||||
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
@@ -202,14 +210,17 @@ class QdrantDB(BaseVectorDB):
|
|||||||
limit=n_results,
|
limit=n_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = []
|
contexts = []
|
||||||
for result in results:
|
for result in results:
|
||||||
context = result.payload["text"]
|
context = result.payload["text"]
|
||||||
metadata = result.payload["metadata"]
|
if citations:
|
||||||
source = metadata["url"]
|
metadata = result.payload["metadata"]
|
||||||
doc_id = metadata["doc_id"]
|
source = metadata["url"]
|
||||||
response.append(tuple((context, source, doc_id)))
|
doc_id = metadata["doc_id"]
|
||||||
return response
|
contexts.append(tuple((context, source, doc_id)))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
|
return contexts
|
||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
response = self.client.get_collection(collection_name=self.collection_name)
|
response = self.client.get_collection(collection_name=self.collection_name)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import weaviate
|
import weaviate
|
||||||
@@ -58,10 +58,14 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||||
|
|
||||||
self.index_name = self._get_index_name()
|
self.index_name = self._get_index_name()
|
||||||
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
|
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
|
||||||
if not self.client.schema.exists(self.index_name):
|
if not self.client.schema.exists(self.index_name):
|
||||||
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
|
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
|
||||||
# The none vectorizer is crucial as we have our own custom embedding function
|
# The none vectorizer is crucial as we have our own custom embedding function
|
||||||
|
"""
|
||||||
|
TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
|
||||||
|
Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
|
||||||
|
"""
|
||||||
class_obj = {
|
class_obj = {
|
||||||
"classes": [
|
"classes": [
|
||||||
{
|
{
|
||||||
@@ -106,10 +110,6 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
"name": "app_id",
|
"name": "app_id",
|
||||||
"dataType": ["text"],
|
"dataType": ["text"],
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "text",
|
|
||||||
"dataType": ["text"],
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
@@ -195,8 +195,13 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
|
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: list of query string
|
||||||
@@ -208,15 +213,23 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||||
generated or not
|
generated or not
|
||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
:return: The context of the document that matched your query, url of the source, doc_id
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
:rtype: List[Tuple[str,str,str]]
|
:type citations: bool, default is False.
|
||||||
|
:return: The content of the document that matched your query,
|
||||||
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
else:
|
else:
|
||||||
query_vector = input_query
|
query_vector = input_query
|
||||||
|
|
||||||
keys = set(where.keys() if where is not None else set())
|
keys = set(where.keys() if where is not None else set())
|
||||||
data_fields = ["text"]
|
data_fields = ["text"]
|
||||||
|
|
||||||
|
if citations:
|
||||||
|
data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
|
||||||
|
|
||||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||||
weaviate_where_operands = []
|
weaviate_where_operands = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@@ -247,7 +260,18 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
.with_limit(n_results)
|
.with_limit(n_results)
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
contexts = results["data"]["Get"].get(self.index_name)
|
|
||||||
|
docs = results["data"]["Get"].get(self.index_name)
|
||||||
|
contexts = []
|
||||||
|
for doc in docs:
|
||||||
|
context = doc["text"]
|
||||||
|
if citations:
|
||||||
|
metadata = doc["metadata"][0]
|
||||||
|
source = metadata["url"]
|
||||||
|
doc_id = metadata["doc_id"]
|
||||||
|
contexts.append((context, source, doc_id))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
return contexts
|
return contexts
|
||||||
|
|
||||||
def set_collection_name(self, name: str):
|
def set_collection_name(self, name: str):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from embedchain.config import ZillizDBConfig
|
from embedchain.config import ZillizDBConfig
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
@@ -127,8 +127,13 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
self.client.flush(self.config.collection_name)
|
self.client.flush(self.config.collection_name)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
input_query: List[str],
|
||||||
|
n_results: int,
|
||||||
|
where: Dict[str, any],
|
||||||
|
skip_embedding: bool,
|
||||||
|
citations: bool = False,
|
||||||
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector data base based on vector similarity
|
Query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -139,8 +144,11 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
:param where: to filter data
|
:param where: to filter data
|
||||||
:type where: str
|
:type where: str
|
||||||
:raises InvalidDimensionException: Dimensions do not match.
|
:raises InvalidDimensionException: Dimensions do not match.
|
||||||
:return: The context of the document that matched your query, url of the source, doc_id
|
:param citations: we use citations boolean param to return context along with the answer.
|
||||||
:rtype: List[Tuple[str,str,str]]
|
:type citations: bool, default is False.
|
||||||
|
:return: The content of the document that matched your query,
|
||||||
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
|
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.collection.is_empty:
|
if self.collection.is_empty:
|
||||||
@@ -170,14 +178,17 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
output_fields=output_fields,
|
output_fields=output_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
doc_list = []
|
contexts = []
|
||||||
for query in query_result:
|
for query in query_result:
|
||||||
data = query[0]["entity"]
|
data = query[0]["entity"]
|
||||||
context = data["text"]
|
context = data["text"]
|
||||||
source = data["url"]
|
if citations:
|
||||||
doc_id = data["doc_id"]
|
source = data["url"]
|
||||||
doc_list.append(tuple((context, source, doc_id)))
|
doc_id = data["doc_id"]
|
||||||
return doc_list
|
contexts.append(tuple((context, source, doc_id)))
|
||||||
|
else:
|
||||||
|
contexts.append(context)
|
||||||
|
return contexts
|
||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.0.88"
|
version = "0.0.89"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
@@ -163,10 +163,12 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
|||||||
|
|
||||||
assert data == expected_value
|
assert data == expected_value
|
||||||
|
|
||||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
data_without_citations = app_with_settings.db.query(
|
||||||
expected_value = [("document", "url_1", "doc_id_1")]
|
input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
|
||||||
|
)
|
||||||
|
expected_value_without_citations = ["document"]
|
||||||
|
assert data_without_citations == expected_value_without_citations
|
||||||
|
|
||||||
assert data == expected_value
|
|
||||||
app_with_settings.db.reset()
|
app_with_settings.db.reset()
|
||||||
|
|
||||||
|
|
||||||
@@ -326,8 +328,16 @@ def test_chroma_db_collection_query(app_with_settings):
|
|||||||
|
|
||||||
assert app_with_settings.db.count() == 2
|
assert app_with_settings.db.count() == 2
|
||||||
|
|
||||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
|
data_without_citations = app_with_settings.db.query(
|
||||||
expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
|
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
|
||||||
|
)
|
||||||
|
expected_value_without_citations = ["document", "document2"]
|
||||||
|
assert data_without_citations == expected_value_without_citations
|
||||||
|
|
||||||
|
data_with_citations = app_with_settings.db.query(
|
||||||
|
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
|
||||||
|
)
|
||||||
|
expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
|
||||||
|
assert data_with_citations == expected_value_with_citations
|
||||||
|
|
||||||
assert data == expected_value
|
|
||||||
app_with_settings.db.reset()
|
app_with_settings.db.reset()
|
||||||
|
|||||||
@@ -60,12 +60,16 @@ class TestEsDB(unittest.TestCase):
|
|||||||
|
|
||||||
# Query the database for the documents that are most similar to the query "This is a document".
|
# Query the database for the documents that are most similar to the query "This is a document".
|
||||||
query = ["This is a document"]
|
query = ["This is a document"]
|
||||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||||
|
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||||
|
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||||
|
|
||||||
# Assert that the results are correct.
|
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
|
||||||
self.assertEqual(
|
expected_results_with_citations = [
|
||||||
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
("This is a document.", "url_1", "doc_id_1"),
|
||||||
)
|
("This is another document.", "url_2", "doc_id_2"),
|
||||||
|
]
|
||||||
|
self.assertEqual(results_with_citations, expected_results_with_citations)
|
||||||
|
|
||||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||||
def test_query_with_skip_embedding(self, mock_client):
|
def test_query_with_skip_embedding(self, mock_client):
|
||||||
@@ -111,9 +115,7 @@ class TestEsDB(unittest.TestCase):
|
|||||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||||
|
|
||||||
# Assert that the results are correct.
|
# Assert that the results are correct.
|
||||||
self.assertEqual(
|
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||||
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_init_without_url(self):
|
def test_init_without_url(self):
|
||||||
# Make sure it's not loaded from env
|
# Make sure it's not loaded from env
|
||||||
|
|||||||
@@ -75,10 +75,6 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
"name": "app_id",
|
"name": "app_id",
|
||||||
"dataType": ["text"],
|
"dataType": ["text"],
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "text",
|
|
||||||
"dataType": ["text"],
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class TestZillizDBCollection:
|
|||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||||
|
|
||||||
# Assert that MilvusClient.search was called with the correct parameters
|
# Assert that MilvusClient.search was called with the correct parameters
|
||||||
mock_search.assert_called_once_with(
|
mock_search.assert_called_with(
|
||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_text"],
|
data=["query_text"],
|
||||||
limit=1,
|
limit=1,
|
||||||
@@ -137,7 +137,20 @@ class TestZillizDBCollection:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the query result matches the expected result
|
# Assert that the query result matches the expected result
|
||||||
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
assert query_result == ["result_doc"]
|
||||||
|
|
||||||
|
query_result_with_citations = zilliz_db.query(
|
||||||
|
input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_search.assert_called_with(
|
||||||
|
collection_name=mock_config.collection_name,
|
||||||
|
data=["query_text"],
|
||||||
|
limit=1,
|
||||||
|
output_fields=["text", "url", "doc_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
|
||||||
|
|
||||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||||
@@ -168,7 +181,7 @@ class TestZillizDBCollection:
|
|||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||||
|
|
||||||
# Assert that MilvusClient.search was called with the correct parameters
|
# Assert that MilvusClient.search was called with the correct parameters
|
||||||
mock_search.assert_called_once_with(
|
mock_search.assert_called_with(
|
||||||
collection_name=mock_config.collection_name,
|
collection_name=mock_config.collection_name,
|
||||||
data=["query_vector"],
|
data=["query_vector"],
|
||||||
limit=1,
|
limit=1,
|
||||||
@@ -176,4 +189,17 @@ class TestZillizDBCollection:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the query result matches the expected result
|
# Assert that the query result matches the expected result
|
||||||
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
assert query_result == ["result_doc"]
|
||||||
|
|
||||||
|
query_result_with_citations = zilliz_db.query(
|
||||||
|
input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_search.assert_called_with(
|
||||||
|
collection_name=mock_config.collection_name,
|
||||||
|
data=["query_vector"],
|
||||||
|
limit=1,
|
||||||
|
output_fields=["text", "url", "doc_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
|
||||||
|
|||||||
Reference in New Issue
Block a user