@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -364,7 +364,7 @@ class App(EmbedChain):
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
@@ -374,7 +374,7 @@ class App(EmbedChain):
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[Dict[str, Any]]
|
||||
:type config: Optional[dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -53,7 +53,7 @@ class PoeBot(BaseBot, PoeBot):
|
||||
answer = self.handle_message(last_message, history)
|
||||
yield self.text_event(answer)
|
||||
|
||||
def handle_message(self, message, history: Optional[List[str]] = None):
|
||||
def handle_message(self, message, history: Optional[list[str]] = None):
|
||||
if message.startswith("/add "):
|
||||
response = self.add_data(message)
|
||||
else:
|
||||
@@ -70,7 +70,7 @@ class PoeBot(BaseBot, PoeBot):
|
||||
# response = "Some error occurred while adding data."
|
||||
# return response
|
||||
|
||||
def ask_bot(self, message, history: List[str]):
|
||||
def ask_bot(self, message, history: list[str]):
|
||||
try:
|
||||
self.app.llm.set_history(history=history)
|
||||
response = self.query(message)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os # noqa: F401
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from gptcache import cache # noqa: F401
|
||||
from gptcache.adapter.adapter import adapt # noqa: F401
|
||||
@@ -15,7 +15,7 @@ from gptcache.similarity_evaluation.exact_match import \
|
||||
ExactMatchEvaluation # noqa: F401
|
||||
|
||||
|
||||
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
|
||||
def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
|
||||
return data["input_query"]
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import builtins
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from importlib import import_module
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
@@ -12,10 +12,10 @@ class BaseConfig(JSONSerializable):
|
||||
"""Initializes a configuration class for a class."""
|
||||
pass
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return config object as a dict
|
||||
|
||||
:return: config object as dict
|
||||
:rtype: Dict[str, Any]
|
||||
:rtype: dict[str, Any]
|
||||
"""
|
||||
return vars(self)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -30,7 +30,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
|
||||
self.positive = positive
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[Dict[str, Any]]):
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheSimilarityEvalConfig()
|
||||
else:
|
||||
@@ -65,7 +65,7 @@ class CacheInitConfig(BaseConfig):
|
||||
self.auto_flush = auto_flush
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[Dict[str, Any]]):
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheInitConfig()
|
||||
else:
|
||||
@@ -86,7 +86,7 @@ class CacheConfig(BaseConfig):
|
||||
self.init_config = init_config
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[Dict[str, Any]]):
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheConfig()
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -68,12 +68,12 @@ class BaseLlmConfig(BaseConfig):
|
||||
stream: bool = False,
|
||||
deployment_name: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where: Dict[str, Any] = None,
|
||||
where: dict[str, Any] = None,
|
||||
query_type: Optional[str] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
callbacks: Optional[list] = None,
|
||||
api_key: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -106,7 +106,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
:param system_prompt: System prompt string, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Dict[str, Any], optional
|
||||
:type where: dict[str, Any], optional
|
||||
:param api_key: The api key of the custom endpoint, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param endpoint: The api url of the custom endpoint, defaults to None
|
||||
@@ -114,7 +114,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
:param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
|
||||
:type model_kwargs: Optional[Dict[str, Any]], optional
|
||||
:param callbacks: Langchain callback functions to use, defaults to None
|
||||
:type callbacks: Optional[List], optional
|
||||
:type callbacks: Optional[list], optional
|
||||
:param query_type: The type of query to use, defaults to None
|
||||
:type query_type: Optional[str], optional
|
||||
:raises ValueError: If the template is not valid as template should
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -11,9 +11,9 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, List[str]] = None,
|
||||
es_url: Union[str, list[str]] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
**ES_EXTRA_PARAMS: Dict[str, any],
|
||||
**ES_EXTRA_PARAMS: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an Elasticsearch client.
|
||||
@@ -23,13 +23,13 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||
:type es_url: Union[str, List[str]], optional
|
||||
:type es_url: Union[str, list[str]], optional
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
:type ES_EXTRA_PARAMS: Dict[str, Any], optional
|
||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||
"""
|
||||
if es_url and cloud_id:
|
||||
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
||||
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
|
||||
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
||||
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
||||
if not self.ES_URL and not self.CLOUD_ID:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -9,11 +9,11 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
opensearch_url: str,
|
||||
http_auth: Tuple[str, str],
|
||||
http_auth: tuple[str, str],
|
||||
vector_dimension: int = 1536,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
**extra_params: Dict[str, any],
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an OpenSearch client.
|
||||
@@ -23,7 +23,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
:param opensearch_url: URL of the OpenSearch domain
|
||||
:type opensearch_url: str, Eg, "http://localhost:9200"
|
||||
:param http_auth: Tuple of username and password
|
||||
:type http_auth: Tuple[str, str], Eg, ("username", "password")
|
||||
:type http_auth: tuple[str, str], Eg, ("username", "password")
|
||||
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
|
||||
:type vector_dimension: int, optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -12,7 +12,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
||||
dir: Optional[str] = None,
|
||||
vector_dimension: int = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
**extra_params: Dict[str, any],
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.metric = metric
|
||||
self.vector_dimension = vector_dimension
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -15,10 +15,10 @@ class QdrantDBConfig(BaseVectorDbConfig):
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
hnsw_config: Optional[Dict[str, any]] = None,
|
||||
quantization_config: Optional[Dict[str, any]] = None,
|
||||
hnsw_config: Optional[dict[str, any]] = None,
|
||||
quantization_config: Optional[dict[str, any]] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
**extra_params: Dict[str, any],
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for a qdrant client.
|
||||
@@ -28,9 +28,9 @@ class QdrantDBConfig(BaseVectorDbConfig):
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param hnsw_config: Params for HNSW index
|
||||
:type hnsw_config: Optional[Dict[str, any]], defaults to None
|
||||
:type hnsw_config: Optional[dict[str, any]], defaults to None
|
||||
:param quantization_config: Params for quantization, if None - quantization will be disabled
|
||||
:type quantization_config: Optional[Dict[str, any]], defaults to None
|
||||
:type quantization_config: Optional[dict[str, any]], defaults to None
|
||||
:param on_disk: If true - point`s payload will not be stored in memory.
|
||||
It will be read from the disk every time it is requested.
|
||||
This setting saves RAM by (slightly) increasing the response time.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -10,7 +10,7 @@ class WeaviateDBConfig(BaseVectorDbConfig):
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
**extra_params: Dict[str, any],
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -2,7 +2,7 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
@@ -136,12 +136,12 @@ class EmbedChain(JSONSerializable):
|
||||
self,
|
||||
source: Any,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
dry_run=False,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -154,7 +154,7 @@ class EmbedChain(JSONSerializable):
|
||||
defaults to None
|
||||
:type data_type: Optional[DataType], optional
|
||||
:param metadata: Metadata associated with the data source., defaults to None
|
||||
:type metadata: Optional[Dict[str, Any]], optional
|
||||
:type metadata: Optional[dict[str, Any]], optional
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
@@ -243,9 +243,9 @@ class EmbedChain(JSONSerializable):
|
||||
self,
|
||||
source: Any,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -261,7 +261,7 @@ class EmbedChain(JSONSerializable):
|
||||
defaults to None
|
||||
:type data_type: Optional[DataType], optional
|
||||
:param metadata: Metadata associated with the data source., defaults to None
|
||||
:type metadata: Optional[Dict[str, Any]], optional
|
||||
:type metadata: Optional[dict[str, Any]], optional
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
@@ -342,11 +342,11 @@ class EmbedChain(JSONSerializable):
|
||||
loader: BaseLoader,
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
source_hash: Optional[str] = None,
|
||||
add_config: Optional[AddConfig] = None,
|
||||
dry_run=False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
@@ -359,7 +359,7 @@ class EmbedChain(JSONSerializable):
|
||||
:param source_hash: Hexadecimal hash of the source.
|
||||
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
|
||||
:type dry_run: bool, defaults to False
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
:return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks
|
||||
"""
|
||||
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
|
||||
app_id = self.config.id if self.config is not None else None
|
||||
@@ -464,8 +464,8 @@ class EmbedChain(JSONSerializable):
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
where=None,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, str, str]], list[str]]:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
@@ -479,7 +479,7 @@ class EmbedChain(JSONSerializable):
|
||||
:param citations: A boolean to indicate if db should fetch citation source
|
||||
:type citations: bool
|
||||
:return: List of contents of the document that matched your query
|
||||
:rtype: List[str]
|
||||
:rtype: list[str]
|
||||
"""
|
||||
query_config = config or self.llm.config
|
||||
if where is not None:
|
||||
@@ -507,10 +507,10 @@ class EmbedChain(JSONSerializable):
|
||||
input_query: str,
|
||||
config: BaseLlmConfig = None,
|
||||
dry_run=False,
|
||||
where: Optional[Dict] = None,
|
||||
where: Optional[dict] = None,
|
||||
citations: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Union[tuple[str, list[tuple[str, dict]]], str]:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -525,13 +525,13 @@ class EmbedChain(JSONSerializable):
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Optional[Dict[str, str]], optional
|
||||
:type where: Optional[dict[str, str]], optional
|
||||
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||
param to return context along with the answer
|
||||
:type kwargs: Dict[str, Any]
|
||||
:type kwargs: dict[str, Any]
|
||||
:return: The answer to the query, with citations if the citation flag is True
|
||||
or the dry run result
|
||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
|
||||
"""
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
||||
@@ -572,10 +572,10 @@ class EmbedChain(JSONSerializable):
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
dry_run=False,
|
||||
session_id: str = "default",
|
||||
where: Optional[Dict[str, str]] = None,
|
||||
where: Optional[dict[str, str]] = None,
|
||||
citations: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Union[tuple[str, list[tuple[str, dict]]], str]:
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -594,13 +594,13 @@ class EmbedChain(JSONSerializable):
|
||||
:param session_id: The session id to use for chat history, defaults to 'default'.
|
||||
:type session_id: Optional[str], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Optional[Dict[str, str]], optional
|
||||
:type where: Optional[dict[str, str]], optional
|
||||
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||
param to return context along with the answer
|
||||
:type kwargs: Dict[str, Any]
|
||||
:type kwargs: dict[str, Any]
|
||||
:return: The answer to the query, with citations if the citation flag is True
|
||||
or the dry run result
|
||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
|
||||
"""
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import queue
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
@@ -29,7 +29,7 @@ class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
|
||||
super().__init__()
|
||||
self.q = q
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self.q.mutex:
|
||||
self.q.queue.clear()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from string import Template
|
||||
from typing import Any, Dict, Type, TypeVar, Union
|
||||
from typing import Any, Type, TypeVar, Union
|
||||
|
||||
T = TypeVar("T", bound="JSONSerializable")
|
||||
|
||||
@@ -84,7 +84,7 @@ class JSONSerializable:
|
||||
return cls()
|
||||
|
||||
@staticmethod
|
||||
def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
|
||||
def _auto_encoder(obj: Any) -> Union[dict[str, Any], None]:
|
||||
"""
|
||||
Automatically encode an object for JSON serialization.
|
||||
|
||||
@@ -126,7 +126,7 @@ class JSONSerializable:
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
|
||||
@classmethod
|
||||
def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
|
||||
def _auto_decoder(cls, dct: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Automatically decode a dictionary to an object during JSON deserialization.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
@@ -55,7 +56,7 @@ class BaseLlm(JSONSerializable):
|
||||
app_id: str,
|
||||
question: str,
|
||||
answer: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
session_id: str = "default",
|
||||
):
|
||||
chat_message = ChatMessage()
|
||||
@@ -64,7 +65,7 @@ class BaseLlm(JSONSerializable):
|
||||
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
|
||||
self.update_history(app_id=app_id, session_id=session_id)
|
||||
|
||||
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
||||
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
@@ -72,7 +73,7 @@ class BaseLlm(JSONSerializable):
|
||||
:param input_query: The query to use.
|
||||
:type input_query: str
|
||||
:param contexts: List of similar documents to the query used as context.
|
||||
:type contexts: List[str]
|
||||
:type contexts: list[str]
|
||||
:return: The prompt
|
||||
:rtype: str
|
||||
"""
|
||||
@@ -170,7 +171,7 @@ class BaseLlm(JSONSerializable):
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -179,7 +180,7 @@ class BaseLlm(JSONSerializable):
|
||||
:param input_query: The query to use.
|
||||
:type input_query: str
|
||||
:param contexts: Embeddings retrieved from the database to be used as context.
|
||||
:type contexts: List[str]
|
||||
:type contexts: list[str]
|
||||
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
|
||||
To persistently use a config, declare it during app init., defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
@@ -223,7 +224,7 @@ class BaseLlm(JSONSerializable):
|
||||
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
||||
|
||||
def chat(
|
||||
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
|
||||
self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
|
||||
):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
@@ -235,7 +236,7 @@ class BaseLlm(JSONSerializable):
|
||||
:param input_query: The query to use.
|
||||
:type input_query: str
|
||||
:param contexts: Embeddings retrieved from the database to be used as context.
|
||||
:type contexts: List[str]
|
||||
:type contexts: list[str]
|
||||
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
|
||||
To persistently use a config, declare it during app init., defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
@@ -281,7 +282,7 @@ class BaseLlm(JSONSerializable):
|
||||
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
|
||||
"""
|
||||
Construct a list of langchain messages
|
||||
|
||||
@@ -290,7 +291,7 @@ class BaseLlm(JSONSerializable):
|
||||
:param system_prompt: System prompt, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:return: List of messages
|
||||
:rtype: List[BaseMessage]
|
||||
:rtype: list[BaseMessage]
|
||||
"""
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Generator, Optional, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Iterable, Optional, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
@@ -12,7 +12,7 @@ from embedchain.llm.base import BaseLlm
|
||||
|
||||
@register_deserializable
|
||||
class OpenAILlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
|
||||
self.functions = functions
|
||||
super().__init__(config=config)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.data_formatter.data_formatter import DataFormatter
|
||||
@@ -15,7 +15,7 @@ from embedchain.utils.misc import detect_datatype
|
||||
class DirectoryLoader(BaseLoader):
|
||||
"""Load data from a directory."""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
config = config or {}
|
||||
self.recursive = config.get("recursive", True)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -10,7 +10,7 @@ from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
class DiscourseLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import hashlib
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from dropbox.files import FileMetadata
|
||||
|
||||
@@ -29,7 +28,7 @@ class DropboxLoader(BaseLoader):
|
||||
except exceptions.AuthError as ex:
|
||||
raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex
|
||||
|
||||
def _download_folder(self, path: str, local_root: str) -> List[FileMetadata]:
|
||||
def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]:
|
||||
"""Download a folder from Dropbox and save it preserving the directory structure."""
|
||||
entries = self.dbx.files_list_folder(path).entries
|
||||
for entry in entries:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -20,7 +20,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Load data from GitHub search query."""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from email import message_from_bytes
|
||||
from email.utils import parsedate_to_datetime
|
||||
from textwrap import dedent
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
@@ -57,7 +57,7 @@ class GmailReader:
|
||||
token.write(creds.to_json())
|
||||
return creds
|
||||
|
||||
def load_emails(self) -> List[Dict]:
|
||||
def load_emails(self) -> list[dict]:
|
||||
response = self.service.users().messages().list(userId="me", q=self.query).execute()
|
||||
messages = response.get("messages", [])
|
||||
|
||||
@@ -67,7 +67,7 @@ class GmailReader:
|
||||
raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
|
||||
return base64.urlsafe_b64decode(raw_message["raw"])
|
||||
|
||||
def _parse_email(self, raw_email) -> Dict:
|
||||
def _parse_email(self, raw_email) -> dict:
|
||||
mime_msg = message_from_bytes(raw_email)
|
||||
return {
|
||||
"subject": self._get_header(mime_msg, "Subject"),
|
||||
@@ -124,7 +124,7 @@ class GmailLoader(BaseLoader):
|
||||
return {"doc_id": self._generate_doc_id(query, data), "data": data}
|
||||
|
||||
@staticmethod
|
||||
def _process_email(email: Dict) -> str:
|
||||
def _process_email(email: dict) -> str:
|
||||
content = BeautifulSoup(email["body"], "html.parser").get_text()
|
||||
content = clean_string(content)
|
||||
return dedent(
|
||||
@@ -137,6 +137,6 @@ class GmailLoader(BaseLoader):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_doc_id(query: str, data: List[Dict]) -> str:
|
||||
def _generate_doc_id(query: str, data: list[dict]) -> str:
|
||||
content_strings = [email["content"] for email in data]
|
||||
return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()
|
||||
|
||||
@@ -2,7 +2,7 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -16,14 +16,14 @@ class JSONReader:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_data(json_data: Union[Dict, str]) -> List[str]:
|
||||
def load_data(json_data: Union[dict, str]) -> list[str]:
|
||||
"""Load data from a JSON structure.
|
||||
|
||||
Args:
|
||||
json_data (Union[Dict, str]): The JSON data to load.
|
||||
json_data (Union[dict, str]): The JSON data to load.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of strings representing the leaf nodes of the JSON.
|
||||
list[str]: A list of strings representing the leaf nodes of the JSON.
|
||||
"""
|
||||
if isinstance(json_data, str):
|
||||
json_data = json.loads(json_data)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]]):
|
||||
def __init__(self, config: Optional[dict[str, Any]]):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
@@ -20,7 +20,7 @@ class MySQLLoader(BaseLoader):
|
||||
self.cursor = None
|
||||
self._setup_loader(config=config)
|
||||
|
||||
def _setup_loader(self, config: Dict[str, Any]):
|
||||
def _setup_loader(self, config: dict[str, Any]):
|
||||
try:
|
||||
import mysql.connector as sqlconnector
|
||||
except ImportError as e:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -15,7 +15,7 @@ class NotionDocument:
|
||||
A simple Document class to hold the text and additional information of a page.
|
||||
"""
|
||||
|
||||
def __init__(self, text: str, extra_info: Dict[str, Any]):
|
||||
def __init__(self, text: str, extra_info: dict[str, Any]):
|
||||
self.text = text
|
||||
self.extra_info = extra_info
|
||||
|
||||
@@ -82,7 +82,7 @@ class NotionPageLoader:
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def load_data(self, page_ids: List[str]) -> List[NotionDocument]:
|
||||
def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
|
||||
"""Load data from the given list of page IDs."""
|
||||
docs = []
|
||||
for page_id in page_ids:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(f"Must provide the valid config. Received: {config}")
|
||||
@@ -15,7 +15,7 @@ class PostgresLoader(BaseLoader):
|
||||
self.cursor = None
|
||||
self._setup_loader(config=config)
|
||||
|
||||
def _setup_loader(self, config: Dict[str, Any]):
|
||||
def _setup_loader(self, config: dict[str, Any]):
|
||||
try:
|
||||
import psycopg
|
||||
except ImportError as e:
|
||||
|
||||
@@ -2,7 +2,7 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import certifi
|
||||
|
||||
@@ -13,7 +13,7 @@ SLACK_API_BASE_URL = "https://www.slack.com/api/"
|
||||
|
||||
|
||||
class SlackLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config if config else {}
|
||||
@@ -24,7 +24,7 @@ class SlackLoader(BaseLoader):
|
||||
self.client = None
|
||||
self._setup_loader(self.config)
|
||||
|
||||
def _setup_loader(self, config: Dict[str, Any]):
|
||||
def _setup_loader(self, config: dict[str, Any]):
|
||||
try:
|
||||
from slack_sdk import WebClient
|
||||
except ImportError as e:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.memory.message import ChatMessage
|
||||
@@ -67,7 +67,7 @@ class ChatHistory:
|
||||
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
|
||||
self.connection.commit()
|
||||
|
||||
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
|
||||
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
|
||||
"""
|
||||
Get the most recent num_rounds rounds of conversations
|
||||
between human and AI, for a given app_id.
|
||||
@@ -114,7 +114,7 @@ class ChatHistory:
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def _serialize_json(metadata: Dict[str, Any]):
|
||||
def _serialize_json(metadata: dict[str, Any]):
|
||||
return json.dumps(metadata)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
@@ -18,9 +18,9 @@ class BaseMessage(JSONSerializable):
|
||||
created_by: str
|
||||
|
||||
# Any additional info.
|
||||
metadata: Dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, content: str, created_by: str, metadata: Optional[dict[str, Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self.content = content
|
||||
self.created_by = created_by
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
def merge_metadata_dict(left: Optional[dict[str, Any]], right: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Merge the metadatas of two BaseMessage types.
|
||||
|
||||
Args:
|
||||
left (Dict[str, Any]): metadata of human message
|
||||
right (Dict[str, Any]): metadata of AI message
|
||||
left (dict[str, Any]): metadata of human message
|
||||
right (dict[str, Any]): metadata of AI message
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: combined metadata dict with dedup
|
||||
dict[str, Any]: combined metadata dict with dedup
|
||||
to be saved in db.
|
||||
"""
|
||||
if not left and not right:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from chromadb import Collection, QueryResult
|
||||
from langchain.docstore.document import Document
|
||||
@@ -76,7 +76,7 @@ class ChromaDB(BaseVectorDB):
|
||||
return self.client
|
||||
|
||||
@staticmethod
|
||||
def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]:
|
||||
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
|
||||
# If only one filter is supplied, return it as is
|
||||
# (no need to wrap in $and based on chroma docs)
|
||||
if len(where.keys()) <= 1:
|
||||
@@ -105,18 +105,18 @@ class ChromaDB(BaseVectorDB):
|
||||
)
|
||||
return self.collection
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, Any]
|
||||
:type where: dict[str, Any]
|
||||
:param limit: Optional. maximum number of documents
|
||||
:type limit: Optional[int]
|
||||
:return: Existing documents.
|
||||
:rtype: List[str]
|
||||
:rtype: list[str]
|
||||
"""
|
||||
args = {}
|
||||
if ids:
|
||||
@@ -129,23 +129,23 @@ class ChromaDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
Add vectors to chroma database
|
||||
|
||||
:param embeddings: list of embeddings to add
|
||||
:type embeddings: List[List[str]]
|
||||
:type embeddings: list[list[str]]
|
||||
:param documents: Documents
|
||||
:type documents: List[str]
|
||||
:type documents: list[str]
|
||||
:param metadatas: Metadatas
|
||||
:type metadatas: List[object]
|
||||
:type metadatas: list[object]
|
||||
:param ids: ids
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
"""
|
||||
size = len(documents)
|
||||
if len(documents) != size or len(metadatas) != size or len(ids) != size:
|
||||
@@ -182,27 +182,27 @@ class ChromaDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
Query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, Any]
|
||||
:type where: dict[str, Any]
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:raises InvalidDimensionException: Dimensions do not match.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
try:
|
||||
result = self.collection.query(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -84,14 +84,14 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
def _get_or_create_collection(self, name):
|
||||
"""Note: nothing to return here. Discuss later"""
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: _list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:return: ids
|
||||
:rtype: Set[str]
|
||||
"""
|
||||
@@ -110,22 +110,22 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
) -> Any:
|
||||
"""
|
||||
add data in vector database
|
||||
:param embeddings: list of embeddings to add
|
||||
:type embeddings: List[List[str]]
|
||||
:type embeddings: list[list[str]]
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:type documents: list[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
:type metadatas: List[object]
|
||||
:type metadatas: list[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
"""
|
||||
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
@@ -154,27 +154,27 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:return: The context of the document that matched your query, url of the source, doc_id
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||
query_vector = input_query_vector[0]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -78,17 +78,17 @@ class OpenSearchDB(BaseVectorDB):
|
||||
"""Note: nothing to return here. Discuss later"""
|
||||
|
||||
def get(
|
||||
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
|
||||
) -> Set[str]:
|
||||
self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: _list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:return: ids
|
||||
:type: Set[str]
|
||||
:type: set[str]
|
||||
"""
|
||||
query = {}
|
||||
if ids:
|
||||
@@ -116,19 +116,19 @@ class OpenSearchDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[str]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
embeddings: list[list[str]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
):
|
||||
"""Add data in vector database.
|
||||
|
||||
Args:
|
||||
embeddings (List[List[str]]): List of embeddings to add.
|
||||
documents (List[str]): List of texts to add.
|
||||
metadatas (List[object]): List of metadata associated with docs.
|
||||
ids (List[str]): IDs of docs.
|
||||
embeddings (list[list[str]]): list of embeddings to add.
|
||||
documents (list[str]): list of texts to add.
|
||||
metadatas (list[object]): list of metadata associated with docs.
|
||||
ids (list[str]): IDs of docs.
|
||||
"""
|
||||
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
|
||||
batch_end = batch_start + self.BATCH_SIZE
|
||||
@@ -156,26 +156,26 @@ class OpenSearchDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
embeddings = OpenAIEmbeddings()
|
||||
docsearch = OpenSearchVectorSearch(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
try:
|
||||
import pinecone
|
||||
@@ -67,14 +67,14 @@ class PineconeDB(BaseVectorDB):
|
||||
)
|
||||
return pinecone.Index(self.index_name)
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: _list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:return: ids
|
||||
:rtype: Set[str]
|
||||
"""
|
||||
@@ -88,20 +88,20 @@ class PineconeDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
):
|
||||
"""add data in vector database
|
||||
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:type documents: list[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
:type metadatas: List[object]
|
||||
:type metadatas: list[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
"""
|
||||
docs = []
|
||||
print("Adding documents to Pinecone...")
|
||||
@@ -120,25 +120,25 @@ class PineconeDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
@@ -69,14 +69,14 @@ class QdrantDB(BaseVectorDB):
|
||||
def _get_or_create_collection(self):
|
||||
return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: _list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:param limit: The number of entries to be fetched
|
||||
:type limit: Optional int, defaults to None
|
||||
:return: All the existing IDs
|
||||
@@ -122,21 +122,21 @@ class QdrantDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
):
|
||||
"""add data in vector database
|
||||
:param embeddings: list of embeddings for the corresponding documents to be added
|
||||
:type documents: List[List[float]]
|
||||
:type documents: list[list[float]]
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:type documents: list[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
:type metadatas: List[object]
|
||||
:type metadatas: list[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
"""
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
|
||||
@@ -159,25 +159,25 @@ class QdrantDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
@@ -117,13 +117,13 @@ class WeaviateDB(BaseVectorDB):
|
||||
|
||||
self.client.schema.create(class_obj)
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
:param ids: _list of doc ids to check for existance
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:return: ids
|
||||
:rtype: Set[str]
|
||||
"""
|
||||
@@ -153,21 +153,21 @@ class WeaviateDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
):
|
||||
"""add data in vector database
|
||||
:param embeddings: list of embeddings for the corresponding documents to be added
|
||||
:type documents: List[List[float]]
|
||||
:type documents: list[list[float]]
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:type documents: list[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
:type metadatas: List[object]
|
||||
:type metadatas: list[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
"""
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
|
||||
@@ -192,25 +192,25 @@ class WeaviateDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:type where: dict[str, any]
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from embedchain.config import ZillizDBConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -88,14 +88,14 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
self.collection.create_index("embeddings", index)
|
||||
return self.collection
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:type ids: list[str]
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, Any]
|
||||
:type where: dict[str, Any]
|
||||
:param limit: Optional. maximum number of documents
|
||||
:type limit: Optional[int]
|
||||
:return: Existing documents.
|
||||
@@ -115,11 +115,11 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str],
|
||||
metadatas: list[object],
|
||||
ids: list[str],
|
||||
**kwargs: Optional[dict[str, any]],
|
||||
):
|
||||
"""Add to database"""
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
@@ -134,17 +134,17 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: List[str],
|
||||
input_query: list[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
where: dict[str, any],
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||
"""
|
||||
Query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:type input_query: list[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: to filter data
|
||||
@@ -154,7 +154,7 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||
"""
|
||||
|
||||
if self.collection.is_empty:
|
||||
@@ -200,7 +200,7 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
"""
|
||||
return self.collection.num_entities
|
||||
|
||||
def reset(self, collection_names: List[str] = None):
|
||||
def reset(self, collection_names: list[str] = None):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user