#1128 | Remove deprecated type hints from typing module (#1131)

This commit is contained in:
Sandra Serrano
2024-01-09 18:35:24 +01:00
committed by GitHub
parent c9df7a2020
commit 0de9491c61
41 changed files with 272 additions and 267 deletions

View File

@@ -4,7 +4,7 @@ import logging
import os import os
import sqlite3 import sqlite3
import uuid import uuid
from typing import Any, Dict, Optional from typing import Any, Optional
import requests import requests
import yaml import yaml
@@ -364,7 +364,7 @@ class App(EmbedChain):
def from_config( def from_config(
cls, cls,
config_path: Optional[str] = None, config_path: Optional[str] = None,
config: Optional[Dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
auto_deploy: bool = False, auto_deploy: bool = False,
yaml_path: Optional[str] = None, yaml_path: Optional[str] = None,
): ):
@@ -374,7 +374,7 @@ class App(EmbedChain):
:param config_path: Path to the YAML or JSON configuration file. :param config_path: Path to the YAML or JSON configuration file.
:type config_path: Optional[str] :type config_path: Optional[str]
:param config: A dictionary containing the configuration. :param config: A dictionary containing the configuration.
:type config: Optional[Dict[str, Any]] :type config: Optional[dict[str, Any]]
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional :type auto_deploy: bool, optional
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead. :param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.

View File

@@ -1,7 +1,7 @@
import argparse import argparse
import logging import logging
import os import os
from typing import List, Optional from typing import Optional
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -53,7 +53,7 @@ class PoeBot(BaseBot, PoeBot):
answer = self.handle_message(last_message, history) answer = self.handle_message(last_message, history)
yield self.text_event(answer) yield self.text_event(answer)
def handle_message(self, message, history: Optional[List[str]] = None): def handle_message(self, message, history: Optional[list[str]] = None):
if message.startswith("/add "): if message.startswith("/add "):
response = self.add_data(message) response = self.add_data(message)
else: else:
@@ -70,7 +70,7 @@ class PoeBot(BaseBot, PoeBot):
# response = "Some error occurred while adding data." # response = "Some error occurred while adding data."
# return response # return response
def ask_bot(self, message, history: List[str]): def ask_bot(self, message, history: list[str]):
try: try:
self.app.llm.set_history(history=history) self.app.llm.set_history(history=history)
response = self.query(message) response = self.query(message)

View File

@@ -1,6 +1,6 @@
import logging import logging
import os # noqa: F401 import os # noqa: F401
from typing import Any, Dict from typing import Any
from gptcache import cache # noqa: F401 from gptcache import cache # noqa: F401
from gptcache.adapter.adapter import adapt # noqa: F401 from gptcache.adapter.adapter import adapt # noqa: F401
@@ -15,7 +15,7 @@ from gptcache.similarity_evaluation.exact_match import \
ExactMatchEvaluation # noqa: F401 ExactMatchEvaluation # noqa: F401
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]): def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
return data["input_query"] return data["input_query"]

View File

@@ -1,7 +1,8 @@
import builtins import builtins
import logging import logging
from collections.abc import Callable
from importlib import import_module from importlib import import_module
from typing import Callable, Optional from typing import Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict from typing import Any
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
@@ -12,10 +12,10 @@ class BaseConfig(JSONSerializable):
"""Initializes a configuration class for a class.""" """Initializes a configuration class for a class."""
pass pass
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> dict[str, Any]:
"""Return config object as a dict """Return config object as a dict
:return: config object as dict :return: config object as dict
:rtype: Dict[str, Any] :rtype: dict[str, Any]
""" """
return vars(self) return vars(self)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional from typing import Any, Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -30,7 +30,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
self.positive = positive self.positive = positive
@staticmethod @staticmethod
def from_config(config: Optional[Dict[str, Any]]): def from_config(config: Optional[dict[str, Any]]):
if config is None: if config is None:
return CacheSimilarityEvalConfig() return CacheSimilarityEvalConfig()
else: else:
@@ -65,7 +65,7 @@ class CacheInitConfig(BaseConfig):
self.auto_flush = auto_flush self.auto_flush = auto_flush
@staticmethod @staticmethod
def from_config(config: Optional[Dict[str, Any]]): def from_config(config: Optional[dict[str, Any]]):
if config is None: if config is None:
return CacheInitConfig() return CacheInitConfig()
else: else:
@@ -86,7 +86,7 @@ class CacheConfig(BaseConfig):
self.init_config = init_config self.init_config = init_config
@staticmethod @staticmethod
def from_config(config: Optional[Dict[str, Any]]): def from_config(config: Optional[dict[str, Any]]):
if config is None: if config is None:
return CacheConfig() return CacheConfig()
else: else:

View File

@@ -1,7 +1,7 @@
import logging import logging
import re import re
from string import Template from string import Template
from typing import Any, Dict, List, Optional from typing import Any, Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -68,12 +68,12 @@ class BaseLlmConfig(BaseConfig):
stream: bool = False, stream: bool = False,
deployment_name: Optional[str] = None, deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
where: Dict[str, Any] = None, where: dict[str, Any] = None,
query_type: Optional[str] = None, query_type: Optional[str] = None,
callbacks: Optional[List] = None, callbacks: Optional[list] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[dict[str, Any]] = None,
): ):
""" """
Initializes a configuration class instance for the LLM. Initializes a configuration class instance for the LLM.
@@ -106,7 +106,7 @@ class BaseLlmConfig(BaseConfig):
:param system_prompt: System prompt string, defaults to None :param system_prompt: System prompt string, defaults to None
:type system_prompt: Optional[str], optional :type system_prompt: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Dict[str, Any], optional :type where: dict[str, Any], optional
:param api_key: The api key of the custom endpoint, defaults to None :param api_key: The api key of the custom endpoint, defaults to None
:type api_key: Optional[str], optional :type api_key: Optional[str], optional
:param endpoint: The api url of the custom endpoint, defaults to None :param endpoint: The api url of the custom endpoint, defaults to None
@@ -114,7 +114,7 @@ class BaseLlmConfig(BaseConfig):
:param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
:type model_kwargs: Optional[Dict[str, Any]], optional :type model_kwargs: Optional[Dict[str, Any]], optional
:param callbacks: Langchain callback functions to use, defaults to None :param callbacks: Langchain callback functions to use, defaults to None
:type callbacks: Optional[List], optional :type callbacks: Optional[list], optional
:param query_type: The type of query to use, defaults to None :param query_type: The type of query to use, defaults to None
:type query_type: Optional[str], optional :type query_type: Optional[str], optional
:raises ValueError: If the template is not valid as template should :raises ValueError: If the template is not valid as template should

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Dict, List, Optional, Union from typing import Optional, Union
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -11,9 +11,9 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
es_url: Union[str, List[str]] = None, es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None, cloud_id: Optional[str] = None,
**ES_EXTRA_PARAMS: Dict[str, any], **ES_EXTRA_PARAMS: dict[str, any],
): ):
""" """
Initializes a configuration class instance for an Elasticsearch client. Initializes a configuration class instance for an Elasticsearch client.
@@ -23,13 +23,13 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
:param dir: Path to the database directory, where the database is stored, defaults to None :param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional :type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, List[str]], optional :type es_url: Union[str, list[str]], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: Dict[str, Any], optional :type ES_EXTRA_PARAMS: dict[str, Any], optional
""" """
if es_url and cloud_id: if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.") raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): # self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID") self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID: if not self.ES_URL and not self.CLOUD_ID:

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -9,11 +9,11 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
opensearch_url: str, opensearch_url: str,
http_auth: Tuple[str, str], http_auth: tuple[str, str],
vector_dimension: int = 1536, vector_dimension: int = 1536,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
**extra_params: Dict[str, any], **extra_params: dict[str, any],
): ):
""" """
Initializes a configuration class instance for an OpenSearch client. Initializes a configuration class instance for an OpenSearch client.
@@ -23,7 +23,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
:param opensearch_url: URL of the OpenSearch domain :param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200" :type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password :param http_auth: Tuple of username and password
:type http_auth: Tuple[str, str], Eg, ("username", "password") :type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model) :param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional :type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None :param dir: Path to the database directory, where the database is stored, defaults to None

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -12,7 +12,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
dir: Optional[str] = None, dir: Optional[str] = None,
vector_dimension: int = 1536, vector_dimension: int = 1536,
metric: Optional[str] = "cosine", metric: Optional[str] = "cosine",
**extra_params: Dict[str, any], **extra_params: dict[str, any],
): ):
self.metric = metric self.metric = metric
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -15,10 +15,10 @@ class QdrantDBConfig(BaseVectorDbConfig):
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
hnsw_config: Optional[Dict[str, any]] = None, hnsw_config: Optional[dict[str, any]] = None,
quantization_config: Optional[Dict[str, any]] = None, quantization_config: Optional[dict[str, any]] = None,
on_disk: Optional[bool] = None, on_disk: Optional[bool] = None,
**extra_params: Dict[str, any], **extra_params: dict[str, any],
): ):
""" """
Initializes a configuration class instance for a qdrant client. Initializes a configuration class instance for a qdrant client.
@@ -28,9 +28,9 @@ class QdrantDBConfig(BaseVectorDbConfig):
:param dir: Path to the database directory, where the database is stored, defaults to None :param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional :type dir: Optional[str], optional
:param hnsw_config: Params for HNSW index :param hnsw_config: Params for HNSW index
:type hnsw_config: Optional[Dict[str, any]], defaults to None :type hnsw_config: Optional[dict[str, any]], defaults to None
:param quantization_config: Params for quantization, if None - quantization will be disabled :param quantization_config: Params for quantization, if None - quantization will be disabled
:type quantization_config: Optional[Dict[str, any]], defaults to None :type quantization_config: Optional[dict[str, any]], defaults to None
:param on_disk: If true - point`s payload will not be stored in memory. :param on_disk: If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested. It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time. This setting saves RAM by (slightly) increasing the response time.

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -10,7 +10,7 @@ class WeaviateDBConfig(BaseVectorDbConfig):
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
**extra_params: Dict[str, any], **extra_params: dict[str, any],
): ):
self.extra_params = extra_params self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -2,7 +2,7 @@ import hashlib
import json import json
import logging import logging
import sqlite3 import sqlite3
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -136,12 +136,12 @@ class EmbedChain(JSONSerializable):
self, self,
source: Any, source: Any,
data_type: Optional[DataType] = None, data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
config: Optional[AddConfig] = None, config: Optional[AddConfig] = None,
dry_run=False, dry_run=False,
loader: Optional[BaseLoader] = None, loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None, chunker: Optional[BaseChunker] = None,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
): ):
""" """
Adds the data from the given URL to the vector db. Adds the data from the given URL to the vector db.
@@ -154,7 +154,7 @@ class EmbedChain(JSONSerializable):
defaults to None defaults to None
:type data_type: Optional[DataType], optional :type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None :param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional :type metadata: Optional[dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None :param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional :type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type :raises ValueError: Invalid data type
@@ -243,9 +243,9 @@ class EmbedChain(JSONSerializable):
self, self,
source: Any, source: Any,
data_type: Optional[DataType] = None, data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
config: Optional[AddConfig] = None, config: Optional[AddConfig] = None,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
): ):
""" """
Adds the data from the given URL to the vector db. Adds the data from the given URL to the vector db.
@@ -261,7 +261,7 @@ class EmbedChain(JSONSerializable):
defaults to None defaults to None
:type data_type: Optional[DataType], optional :type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None :param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional :type metadata: Optional[dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None :param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional :type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type :raises ValueError: Invalid data type
@@ -342,11 +342,11 @@ class EmbedChain(JSONSerializable):
loader: BaseLoader, loader: BaseLoader,
chunker: BaseChunker, chunker: BaseChunker,
src: Any, src: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
source_hash: Optional[str] = None, source_hash: Optional[str] = None,
add_config: Optional[AddConfig] = None, add_config: Optional[AddConfig] = None,
dry_run=False, dry_run=False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
): ):
""" """
Loads the data from the given URL, chunks it, and adds it to database. Loads the data from the given URL, chunks it, and adds it to database.
@@ -359,7 +359,7 @@ class EmbedChain(JSONSerializable):
:param source_hash: Hexadecimal hash of the source. :param source_hash: Hexadecimal hash of the source.
:param dry_run: Optional. A dry run returns chunks and doesn't update DB. :param dry_run: Optional. A dry run returns chunks and doesn't update DB.
:type dry_run: bool, defaults to False :type dry_run: bool, defaults to False
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks :return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks
""" """
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src) existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
app_id = self.config.id if self.config is not None else None app_id = self.config.id if self.config is not None else None
@@ -464,8 +464,8 @@ class EmbedChain(JSONSerializable):
config: Optional[BaseLlmConfig] = None, config: Optional[BaseLlmConfig] = None,
where=None, where=None,
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]: ) -> Union[list[tuple[str, str, str]], list[str]]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query Gets relevant doc based on the query
@@ -479,7 +479,7 @@ class EmbedChain(JSONSerializable):
:param citations: A boolean to indicate if db should fetch citation source :param citations: A boolean to indicate if db should fetch citation source
:type citations: bool :type citations: bool
:return: List of contents of the document that matched your query :return: List of contents of the document that matched your query
:rtype: List[str] :rtype: list[str]
""" """
query_config = config or self.llm.config query_config = config or self.llm.config
if where is not None: if where is not None:
@@ -507,10 +507,10 @@ class EmbedChain(JSONSerializable):
input_query: str, input_query: str,
config: BaseLlmConfig = None, config: BaseLlmConfig = None,
dry_run=False, dry_run=False,
where: Optional[Dict] = None, where: Optional[dict] = None,
citations: bool = False, citations: bool = False,
**kwargs: Dict[str, Any], **kwargs: dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]: ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -525,13 +525,13 @@ class EmbedChain(JSONSerializable):
the LLM. The purpose is to test the prompt, not the response., defaults to False the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional :type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional :type where: Optional[dict[str, str]], optional
:param kwargs: To read more params for the query function. Ex. we use citations boolean :param kwargs: To read more params for the query function. Ex. we use citations boolean
param to return context along with the answer param to return context along with the answer
:type kwargs: Dict[str, Any] :type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True :return: The answer to the query, with citations if the citation flag is True
or the dry run result or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
""" """
contexts = self._retrieve_from_database( contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -572,10 +572,10 @@ class EmbedChain(JSONSerializable):
config: Optional[BaseLlmConfig] = None, config: Optional[BaseLlmConfig] = None,
dry_run=False, dry_run=False,
session_id: str = "default", session_id: str = "default",
where: Optional[Dict[str, str]] = None, where: Optional[dict[str, str]] = None,
citations: bool = False, citations: bool = False,
**kwargs: Dict[str, Any], **kwargs: dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]: ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -594,13 +594,13 @@ class EmbedChain(JSONSerializable):
:param session_id: The session id to use for chat history, defaults to 'default'. :param session_id: The session id to use for chat history, defaults to 'default'.
:type session_id: Optional[str], optional :type session_id: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional :type where: Optional[dict[str, str]], optional
:param kwargs: To read more params for the query function. Ex. we use citations boolean :param kwargs: To read more params for the query function. Ex. we use citations boolean
param to return context along with the answer param to return context along with the answer
:type kwargs: Dict[str, Any] :type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True :return: The answer to the query, with citations if the citation flag is True
or the dry run result or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
""" """
contexts = self._retrieve_from_database( contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs input_query=input_query, config=config, where=where, citations=citations, **kwargs

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional from collections.abc import Callable
from typing import Any, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig from embedchain.config.embedder.base import BaseEmbedderConfig

View File

@@ -1,5 +1,5 @@
import queue import queue
from typing import Any, Dict, List, Union from typing import Any, Union
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult from langchain.schema import LLMResult
@@ -29,7 +29,7 @@ class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
super().__init__() super().__init__()
self.q = q self.q = q
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running."""
with self.q.mutex: with self.q.mutex:
self.q.queue.clear() self.q.queue.clear()

View File

@@ -1,7 +1,7 @@
import json import json
import logging import logging
from string import Template from string import Template
from typing import Any, Dict, Type, TypeVar, Union from typing import Any, Type, TypeVar, Union
T = TypeVar("T", bound="JSONSerializable") T = TypeVar("T", bound="JSONSerializable")
@@ -84,7 +84,7 @@ class JSONSerializable:
return cls() return cls()
@staticmethod @staticmethod
def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]: def _auto_encoder(obj: Any) -> Union[dict[str, Any], None]:
""" """
Automatically encode an object for JSON serialization. Automatically encode an object for JSON serialization.
@@ -126,7 +126,7 @@ class JSONSerializable:
raise TypeError(f"Object of type {type(obj)} is not JSON serializable") raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
@classmethod @classmethod
def _auto_decoder(cls, dct: Dict[str, Any]) -> Any: def _auto_decoder(cls, dct: dict[str, Any]) -> Any:
""" """
Automatically decode a dictionary to an object during JSON deserialization. Automatically decode a dictionary to an object during JSON deserialization.

View File

@@ -1,5 +1,6 @@
import logging import logging
from typing import Any, Dict, Generator, List, Optional from collections.abc import Generator
from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage from langchain.schema import BaseMessage as LCBaseMessage
@@ -55,7 +56,7 @@ class BaseLlm(JSONSerializable):
app_id: str, app_id: str,
question: str, question: str,
answer: str, answer: str,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
session_id: str = "default", session_id: str = "default",
): ):
chat_message = ChatMessage() chat_message = ChatMessage()
@@ -64,7 +65,7 @@ class BaseLlm(JSONSerializable):
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id) self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
self.update_history(app_id=app_id, session_id=session_id) self.update_history(app_id=app_id, session_id=session_id)
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str: def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
""" """
Generates a prompt based on the given query and context, ready to be Generates a prompt based on the given query and context, ready to be
passed to an LLM passed to an LLM
@@ -72,7 +73,7 @@ class BaseLlm(JSONSerializable):
:param input_query: The query to use. :param input_query: The query to use.
:type input_query: str :type input_query: str
:param contexts: List of similar documents to the query used as context. :param contexts: List of similar documents to the query used as context.
:type contexts: List[str] :type contexts: list[str]
:return: The prompt :return: The prompt
:rtype: str :rtype: str
""" """
@@ -170,7 +171,7 @@ class BaseLlm(JSONSerializable):
yield chunk yield chunk
logging.info(f"Answer: {streamed_answer}") logging.info(f"Answer: {streamed_answer}")
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False): def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -179,7 +180,7 @@ class BaseLlm(JSONSerializable):
:param input_query: The query to use. :param input_query: The query to use.
:type input_query: str :type input_query: str
:param contexts: Embeddings retrieved from the database to be used as context. :param contexts: Embeddings retrieved from the database to be used as context.
:type contexts: List[str] :type contexts: list[str]
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional :type config: Optional[BaseLlmConfig], optional
@@ -223,7 +224,7 @@ class BaseLlm(JSONSerializable):
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
def chat( def chat(
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
): ):
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
@@ -235,7 +236,7 @@ class BaseLlm(JSONSerializable):
:param input_query: The query to use. :param input_query: The query to use.
:type input_query: str :type input_query: str
:param contexts: Embeddings retrieved from the database to be used as context. :param contexts: Embeddings retrieved from the database to be used as context.
:type contexts: List[str] :type contexts: list[str]
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional :type config: Optional[BaseLlmConfig], optional
@@ -281,7 +282,7 @@ class BaseLlm(JSONSerializable):
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
@staticmethod @staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]: def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
""" """
Construct a list of langchain messages Construct a list of langchain messages
@@ -290,7 +291,7 @@ class BaseLlm(JSONSerializable):
:param system_prompt: System prompt, defaults to None :param system_prompt: System prompt, defaults to None
:type system_prompt: Optional[str], optional :type system_prompt: Optional[str], optional
:return: List of messages :return: List of messages
:rtype: List[BaseMessage] :rtype: list[BaseMessage]
""" """
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage

View File

@@ -1,7 +1,8 @@
import importlib import importlib
import logging import logging
import os import os
from typing import Any, Generator, Optional, Union from collections.abc import Generator
from typing import Any, Optional, Union
import google.generativeai as genai import google.generativeai as genai

View File

@@ -1,6 +1,7 @@
import os import os
from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Iterable, Optional, Union from typing import Optional, Union
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

View File

@@ -1,4 +1,5 @@
from typing import Iterable, Optional, Union from collections.abc import Iterable
from typing import Optional, Union
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler

View File

@@ -1,6 +1,6 @@
import json import json
import os import os
from typing import Any, Dict, Optional from typing import Any, Optional
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain.schema import AIMessage, HumanMessage, SystemMessage
@@ -12,7 +12,7 @@ from embedchain.llm.base import BaseLlm
@register_deserializable @register_deserializable
class OpenAILlm(BaseLlm): class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
self.functions = functions self.functions = functions
super().__init__(config=config) super().__init__(config=config)

View File

@@ -1,7 +1,7 @@
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Optional
from embedchain.config import AddConfig from embedchain.config import AddConfig
from embedchain.data_formatter.data_formatter import DataFormatter from embedchain.data_formatter.data_formatter import DataFormatter
@@ -15,7 +15,7 @@ from embedchain.utils.misc import detect_datatype
class DirectoryLoader(BaseLoader): class DirectoryLoader(BaseLoader):
"""Load data from a directory.""" """Load data from a directory."""
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__() super().__init__()
config = config or {} config = config or {}
self.recursive = config.get("recursive", True) self.recursive = config.get("recursive", True)

View File

@@ -1,7 +1,7 @@
import hashlib import hashlib
import logging import logging
import time import time
from typing import Any, Dict, Optional from typing import Any, Optional
import requests import requests
@@ -10,7 +10,7 @@ from embedchain.utils.misc import clean_string
class DiscourseLoader(BaseLoader): class DiscourseLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__() super().__init__()
if not config: if not config:
raise ValueError( raise ValueError(

View File

@@ -1,6 +1,5 @@
import hashlib import hashlib
import os import os
from typing import List
from dropbox.files import FileMetadata from dropbox.files import FileMetadata
@@ -29,7 +28,7 @@ class DropboxLoader(BaseLoader):
except exceptions.AuthError as ex: except exceptions.AuthError as ex:
raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex
def _download_folder(self, path: str, local_root: str) -> List[FileMetadata]: def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]:
"""Download a folder from Dropbox and save it preserving the directory structure.""" """Download a folder from Dropbox and save it preserving the directory structure."""
entries = self.dbx.files_list_folder(path).entries entries = self.dbx.files_list_folder(path).entries
for entry in entries: for entry in entries:

View File

@@ -4,7 +4,7 @@ import logging
import os import os
import re import re
import shlex import shlex
from typing import Any, Dict, Optional from typing import Any, Optional
from tqdm import tqdm from tqdm import tqdm
@@ -20,7 +20,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
class GithubLoader(BaseLoader): class GithubLoader(BaseLoader):
"""Load data from GitHub search query.""" """Load data from GitHub search query."""
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__() super().__init__()
if not config: if not config:
raise ValueError( raise ValueError(

View File

@@ -5,7 +5,7 @@ import os
from email import message_from_bytes from email import message_from_bytes
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from textwrap import dedent from textwrap import dedent
from typing import Dict, List, Optional from typing import Optional
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@@ -57,7 +57,7 @@ class GmailReader:
token.write(creds.to_json()) token.write(creds.to_json())
return creds return creds
def load_emails(self) -> List[Dict]: def load_emails(self) -> list[dict]:
response = self.service.users().messages().list(userId="me", q=self.query).execute() response = self.service.users().messages().list(userId="me", q=self.query).execute()
messages = response.get("messages", []) messages = response.get("messages", [])
@@ -67,7 +67,7 @@ class GmailReader:
raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute() raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
return base64.urlsafe_b64decode(raw_message["raw"]) return base64.urlsafe_b64decode(raw_message["raw"])
def _parse_email(self, raw_email) -> Dict: def _parse_email(self, raw_email) -> dict:
mime_msg = message_from_bytes(raw_email) mime_msg = message_from_bytes(raw_email)
return { return {
"subject": self._get_header(mime_msg, "Subject"), "subject": self._get_header(mime_msg, "Subject"),
@@ -124,7 +124,7 @@ class GmailLoader(BaseLoader):
return {"doc_id": self._generate_doc_id(query, data), "data": data} return {"doc_id": self._generate_doc_id(query, data), "data": data}
@staticmethod @staticmethod
def _process_email(email: Dict) -> str: def _process_email(email: dict) -> str:
content = BeautifulSoup(email["body"], "html.parser").get_text() content = BeautifulSoup(email["body"], "html.parser").get_text()
content = clean_string(content) content = clean_string(content)
return dedent( return dedent(
@@ -137,6 +137,6 @@ class GmailLoader(BaseLoader):
) )
@staticmethod @staticmethod
def _generate_doc_id(query: str, data: List[Dict]) -> str: def _generate_doc_id(query: str, data: list[dict]) -> str:
content_strings = [email["content"] for email in data] content_strings = [email["content"] for email in data]
return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest() return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()

View File

@@ -2,7 +2,7 @@ import hashlib
import json import json
import os import os
import re import re
from typing import Dict, List, Union from typing import Union
import requests import requests
@@ -16,14 +16,14 @@ class JSONReader:
pass pass
@staticmethod @staticmethod
def load_data(json_data: Union[Dict, str]) -> List[str]: def load_data(json_data: Union[dict, str]) -> list[str]:
"""Load data from a JSON structure. """Load data from a JSON structure.
Args: Args:
json_data (Union[Dict, str]): The JSON data to load. json_data (Union[dict, str]): The JSON data to load.
Returns: Returns:
List[str]: A list of strings representing the leaf nodes of the JSON. list[str]: A list of strings representing the leaf nodes of the JSON.
""" """
if isinstance(json_data, str): if isinstance(json_data, str):
json_data = json.loads(json_data) json_data = json.loads(json_data)

View File

@@ -1,13 +1,13 @@
import hashlib import hashlib
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string from embedchain.utils.misc import clean_string
class MySQLLoader(BaseLoader): class MySQLLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]]): def __init__(self, config: Optional[dict[str, Any]]):
super().__init__() super().__init__()
if not config: if not config:
raise ValueError( raise ValueError(
@@ -20,7 +20,7 @@ class MySQLLoader(BaseLoader):
self.cursor = None self.cursor = None
self._setup_loader(config=config) self._setup_loader(config=config)
def _setup_loader(self, config: Dict[str, Any]): def _setup_loader(self, config: dict[str, Any]):
try: try:
import mysql.connector as sqlconnector import mysql.connector as sqlconnector
except ImportError as e: except ImportError as e:

View File

@@ -1,7 +1,7 @@
import hashlib import hashlib
import logging import logging
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Optional
import requests import requests
@@ -15,7 +15,7 @@ class NotionDocument:
A simple Document class to hold the text and additional information of a page. A simple Document class to hold the text and additional information of a page.
""" """
def __init__(self, text: str, extra_info: Dict[str, Any]): def __init__(self, text: str, extra_info: dict[str, Any]):
self.text = text self.text = text
self.extra_info = extra_info self.extra_info = extra_info
@@ -82,7 +82,7 @@ class NotionPageLoader:
result_lines = "\n".join(result_lines_arr) result_lines = "\n".join(result_lines_arr)
return result_lines return result_lines
def load_data(self, page_ids: List[str]) -> List[NotionDocument]: def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
"""Load data from the given list of page IDs.""" """Load data from the given list of page IDs."""
docs = [] docs = []
for page_id in page_ids: for page_id in page_ids:

View File

@@ -1,12 +1,12 @@
import hashlib import hashlib
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
class PostgresLoader(BaseLoader): class PostgresLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__() super().__init__()
if not config: if not config:
raise ValueError(f"Must provide the valid config. Received: {config}") raise ValueError(f"Must provide the valid config. Received: {config}")
@@ -15,7 +15,7 @@ class PostgresLoader(BaseLoader):
self.cursor = None self.cursor = None
self._setup_loader(config=config) self._setup_loader(config=config)
def _setup_loader(self, config: Dict[str, Any]): def _setup_loader(self, config: dict[str, Any]):
try: try:
import psycopg import psycopg
except ImportError as e: except ImportError as e:

View File

@@ -2,7 +2,7 @@ import hashlib
import logging import logging
import os import os
import ssl import ssl
from typing import Any, Dict, Optional from typing import Any, Optional
import certifi import certifi
@@ -13,7 +13,7 @@ SLACK_API_BASE_URL = "https://www.slack.com/api/"
class SlackLoader(BaseLoader): class SlackLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__() super().__init__()
self.config = config if config else {} self.config = config if config else {}
@@ -24,7 +24,7 @@ class SlackLoader(BaseLoader):
self.client = None self.client = None
self._setup_loader(self.config) self._setup_loader(self.config)
def _setup_loader(self, config: Dict[str, Any]): def _setup_loader(self, config: dict[str, Any]):
try: try:
from slack_sdk import WebClient from slack_sdk import WebClient
except ImportError as e: except ImportError as e:

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
import sqlite3 import sqlite3
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Optional
from embedchain.constants import SQLITE_PATH from embedchain.constants import SQLITE_PATH
from embedchain.memory.message import ChatMessage from embedchain.memory.message import ChatMessage
@@ -67,7 +67,7 @@ class ChatHistory:
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id)) self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
self.connection.commit() self.connection.commit()
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]: def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
""" """
Get the most recent num_rounds rounds of conversations Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id. between human and AI, for a given app_id.
@@ -114,7 +114,7 @@ class ChatHistory:
return count return count
@staticmethod @staticmethod
def _serialize_json(metadata: Dict[str, Any]): def _serialize_json(metadata: dict[str, Any]):
return json.dumps(metadata) return json.dumps(metadata)
@staticmethod @staticmethod

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Optional
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
@@ -18,9 +18,9 @@ class BaseMessage(JSONSerializable):
created_by: str created_by: str
# Any additional info. # Any additional info.
metadata: Dict[str, Any] metadata: dict[str, Any]
def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None: def __init__(self, content: str, created_by: str, metadata: Optional[dict[str, Any]] = None) -> None:
super().__init__() super().__init__()
self.content = content self.content = content
self.created_by = created_by self.created_by = created_by

View File

@@ -1,16 +1,16 @@
from typing import Any, Dict, Optional from typing import Any, Optional
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: def merge_metadata_dict(left: Optional[dict[str, Any]], right: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
""" """
Merge the metadatas of two BaseMessage types. Merge the metadatas of two BaseMessage types.
Args: Args:
left (Dict[str, Any]): metadata of human message left (dict[str, Any]): metadata of human message
right (Dict[str, Any]): metadata of AI message right (dict[str, Any]): metadata of AI message
Returns: Returns:
Dict[str, Any]: combined metadata dict with dedup dict[str, Any]: combined metadata dict with dedup
to be saved in db. to be saved in db.
""" """
if not left and not right: if not left and not right:

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
from chromadb import Collection, QueryResult from chromadb import Collection, QueryResult
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -76,7 +76,7 @@ class ChromaDB(BaseVectorDB):
return self.client return self.client
@staticmethod @staticmethod
def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]: def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
# If only one filter is supplied, return it as is # If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs) # (no need to wrap in $and based on chroma docs)
if len(where.keys()) <= 1: if len(where.keys()) <= 1:
@@ -105,18 +105,18 @@ class ChromaDB(BaseVectorDB):
) )
return self.collection return self.collection
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence :param ids: list of doc ids to check for existence
:type ids: List[str] :type ids: list[str]
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, Any] :type where: dict[str, Any]
:param limit: Optional. maximum number of documents :param limit: Optional. maximum number of documents
:type limit: Optional[int] :type limit: Optional[int]
:return: Existing documents. :return: Existing documents.
:rtype: List[str] :rtype: list[str]
""" """
args = {} args = {}
if ids: if ids:
@@ -129,23 +129,23 @@ class ChromaDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[float]], embeddings: list[list[float]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Any: ) -> Any:
""" """
Add vectors to chroma database Add vectors to chroma database
:param embeddings: list of embeddings to add :param embeddings: list of embeddings to add
:type embeddings: List[List[str]] :type embeddings: list[list[str]]
:param documents: Documents :param documents: Documents
:type documents: List[str] :type documents: list[str]
:param metadatas: Metadatas :param metadatas: Metadatas
:type metadatas: List[object] :type metadatas: list[object]
:param ids: ids :param ids: ids
:type ids: List[str] :type ids: list[str]
""" """
size = len(documents) size = len(documents)
if len(documents) != size or len(metadatas) != size or len(ids) != size: if len(documents) != size or len(metadatas) != size or len(ids) != size:
@@ -182,27 +182,27 @@ class ChromaDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: to filter data :param where: to filter data
:type where: Dict[str, Any] :type where: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
try: try:
result = self.collection.query( result = self.collection.query(

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
try: try:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@@ -84,14 +84,14 @@ class ElasticsearchDB(BaseVectorDB):
def _get_or_create_collection(self, name): def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later""" """Note: nothing to return here. Discuss later"""
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence :param ids: _list of doc ids to check for existence
:type ids: List[str] :type ids: list[str]
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:return: ids :return: ids
:rtype: Set[str] :rtype: Set[str]
""" """
@@ -110,22 +110,22 @@ class ElasticsearchDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[float]], embeddings: list[list[float]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
) -> Any: ) -> Any:
""" """
add data in vector database add data in vector database
:param embeddings: list of embeddings to add :param embeddings: list of embeddings to add
:type embeddings: List[List[str]] :type embeddings: list[list[str]]
:param documents: list of texts to add :param documents: list of texts to add
:type documents: List[str] :type documents: list[str]
:param metadatas: list of metadata associated with docs :param metadatas: list of metadata associated with docs
:type metadatas: List[object] :type metadatas: list[object]
:param ids: ids of docs :param ids: ids of docs
:type ids: List[str] :type ids: list[str]
""" """
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
@@ -154,27 +154,27 @@ class ElasticsearchDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:return: The context of the document that matched your query, url of the source, doc_id :return: The context of the document that matched your query, url of the source, doc_id
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
input_query_vector = self.embedder.embedding_fn(input_query) input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0] query_vector = input_query_vector[0]

View File

@@ -1,6 +1,6 @@
import logging import logging
import time import time
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Optional, Union
from tqdm import tqdm from tqdm import tqdm
@@ -78,17 +78,17 @@ class OpenSearchDB(BaseVectorDB):
"""Note: nothing to return here. Discuss later""" """Note: nothing to return here. Discuss later"""
def get( def get(
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None
) -> Set[str]: ) -> set[str]:
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence :param ids: _list of doc ids to check for existence
:type ids: List[str] :type ids: list[str]
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:return: ids :return: ids
:type: Set[str] :type: set[str]
""" """
query = {} query = {}
if ids: if ids:
@@ -116,19 +116,19 @@ class OpenSearchDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[str]], embeddings: list[list[str]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
): ):
"""Add data in vector database. """Add data in vector database.
Args: Args:
embeddings (List[List[str]]): List of embeddings to add. embeddings (list[list[str]]): list of embeddings to add.
documents (List[str]): List of texts to add. documents (list[str]): list of texts to add.
metadatas (List[object]): List of metadata associated with docs. metadatas (list[object]): list of metadata associated with docs.
ids (List[str]): IDs of docs. ids (list[str]): IDs of docs.
""" """
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"): for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
batch_end = batch_start + self.BATCH_SIZE batch_end = batch_start + self.BATCH_SIZE
@@ -156,26 +156,26 @@ class OpenSearchDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
docsearch = OpenSearchVectorSearch( docsearch = OpenSearchVectorSearch(

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Dict, List, Optional, Tuple, Union from typing import Optional, Union
try: try:
import pinecone import pinecone
@@ -67,14 +67,14 @@ class PineconeDB(BaseVectorDB):
) )
return pinecone.Index(self.index_name) return pinecone.Index(self.index_name)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence :param ids: _list of doc ids to check for existence
:type ids: List[str] :type ids: list[str]
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:return: ids :return: ids
:rtype: Set[str] :rtype: Set[str]
""" """
@@ -88,20 +88,20 @@ class PineconeDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[float]], embeddings: list[list[float]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
): ):
"""add data in vector database """add data in vector database
:param documents: list of texts to add :param documents: list of texts to add
:type documents: List[str] :type documents: list[str]
:param metadatas: list of metadata associated with docs :param metadatas: list of metadata associated with docs
:type metadatas: List[object] :type metadatas: list[object]
:param ids: ids of docs :param ids: ids of docs
:type ids: List[str] :type ids: list[str]
""" """
docs = [] docs = []
print("Adding documents to Pinecone...") print("Adding documents to Pinecone...")
@@ -120,25 +120,25 @@ class PineconeDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs) data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)

View File

@@ -1,7 +1,7 @@
import copy import copy
import os import os
import uuid import uuid
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
try: try:
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
@@ -69,14 +69,14 @@ class QdrantDB(BaseVectorDB):
def _get_or_create_collection(self): def _get_or_create_collection(self):
return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-") return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence :param ids: _list of doc ids to check for existence
:type ids: List[str] :type ids: list[str]
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:param limit: The number of entries to be fetched :param limit: The number of entries to be fetched
:type limit: Optional int, defaults to None :type limit: Optional int, defaults to None
:return: All the existing IDs :return: All the existing IDs
@@ -122,21 +122,21 @@ class QdrantDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[float]], embeddings: list[list[float]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
): ):
"""add data in vector database """add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added :param embeddings: list of embeddings for the corresponding documents to be added
:type documents: List[List[float]] :type documents: list[list[float]]
:param documents: list of texts to add :param documents: list of texts to add
:type documents: List[str] :type documents: list[str]
:param metadatas: list of metadata associated with docs :param metadatas: list of metadata associated with docs
:type metadatas: List[object] :type metadatas: list[object]
:param ids: ids of docs :param ids: ids of docs
:type ids: List[str] :type ids: list[str]
""" """
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
@@ -159,25 +159,25 @@ class QdrantDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())

View File

@@ -1,6 +1,6 @@
import copy import copy
import os import os
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
try: try:
import weaviate import weaviate
@@ -117,13 +117,13 @@ class WeaviateDB(BaseVectorDB):
self.client.schema.create(class_obj) self.client.schema.create(class_obj)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existance :param ids: _list of doc ids to check for existance
:type ids: List[str] :type ids: list[str]
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:return: ids :return: ids
:rtype: Set[str] :rtype: Set[str]
""" """
@@ -153,21 +153,21 @@ class WeaviateDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[float]], embeddings: list[list[float]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
): ):
"""add data in vector database """add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added :param embeddings: list of embeddings for the corresponding documents to be added
:type documents: List[List[float]] :type documents: list[list[float]]
:param documents: list of texts to add :param documents: list of texts to add
:type documents: List[str] :type documents: list[str]
:param metadatas: list of metadata associated with docs :param metadatas: list of metadata associated with docs
:type metadatas: List[object] :type metadatas: list[object]
:param ids: ids of docs :param ids: ids of docs
:type ids: List[str] :type ids: list[str]
""" """
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
@@ -192,25 +192,25 @@ class WeaviateDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
from embedchain.config import ZillizDBConfig from embedchain.config import ZillizDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -88,14 +88,14 @@ class ZillizVectorDB(BaseVectorDB):
self.collection.create_index("embeddings", index) self.collection.create_index("embeddings", index)
return self.collection return self.collection
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence :param ids: list of doc ids to check for existence
:type ids: List[str] :type ids: list[str]
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, Any] :type where: dict[str, Any]
:param limit: Optional. maximum number of documents :param limit: Optional. maximum number of documents
:type limit: Optional[int] :type limit: Optional[int]
:return: Existing documents. :return: Existing documents.
@@ -115,11 +115,11 @@ class ZillizVectorDB(BaseVectorDB):
def add( def add(
self, self,
embeddings: List[List[float]], embeddings: list[list[float]],
documents: List[str], documents: list[str],
metadatas: List[object], metadatas: list[object],
ids: List[str], ids: list[str],
**kwargs: Optional[Dict[str, any]], **kwargs: Optional[dict[str, any]],
): ):
"""Add to database""" """Add to database"""
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
@@ -134,17 +134,17 @@ class ZillizVectorDB(BaseVectorDB):
def query( def query(
self, self,
input_query: List[str], input_query: list[str],
n_results: int, n_results: int,
where: Dict[str, any], where: dict[str, any],
citations: bool = False, citations: bool = False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: list[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: to filter data :param where: to filter data
@@ -154,7 +154,7 @@ class ZillizVectorDB(BaseVectorDB):
:type citations: bool, default is False. :type citations: bool, default is False.
:return: The content of the document that matched your query, :return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
if self.collection.is_empty: if self.collection.is_empty:
@@ -200,7 +200,7 @@ class ZillizVectorDB(BaseVectorDB):
""" """
return self.collection.num_entities return self.collection.num_entities
def reset(self, collection_names: List[str] = None): def reset(self, collection_names: list[str] = None):
""" """
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
""" """