From 0de9491c61598ace9dcd5e05fa380f5b20e7112d Mon Sep 17 00:00:00 2001 From: Sandra Serrano <69241131+sandrasgg@users.noreply.github.com> Date: Tue, 9 Jan 2024 18:35:24 +0100 Subject: [PATCH] #1128 | Remove deprecated type hints from typing module (#1131) --- embedchain/app.py | 6 +-- embedchain/bots/poe.py | 6 +-- embedchain/cache.py | 4 +- embedchain/config/add_config.py | 3 +- embedchain/config/base_config.py | 6 +-- embedchain/config/cache_config.py | 8 ++-- embedchain/config/llm/base.py | 12 ++--- embedchain/config/vectordb/elasticsearch.py | 12 ++--- embedchain/config/vectordb/opensearch.py | 8 ++-- embedchain/config/vectordb/pinecone.py | 4 +- embedchain/config/vectordb/qdrant.py | 12 ++--- embedchain/config/vectordb/weaviate.py | 4 +- embedchain/embedchain.py | 50 ++++++++++----------- embedchain/embedder/base.py | 3 +- embedchain/helpers/callbacks.py | 4 +- embedchain/helpers/json_serializable.py | 6 +-- embedchain/llm/base.py | 21 ++++----- embedchain/llm/google.py | 3 +- embedchain/llm/gpt4all.py | 3 +- embedchain/llm/ollama.py | 3 +- embedchain/llm/openai.py | 4 +- embedchain/loaders/directory_loader.py | 4 +- embedchain/loaders/discourse.py | 4 +- embedchain/loaders/dropbox.py | 3 +- embedchain/loaders/github.py | 4 +- embedchain/loaders/gmail.py | 10 ++--- embedchain/loaders/json.py | 8 ++-- embedchain/loaders/mysql.py | 6 +-- embedchain/loaders/notion.py | 6 +-- embedchain/loaders/postgres.py | 6 +-- embedchain/loaders/slack.py | 6 +-- embedchain/memory/base.py | 6 +-- embedchain/memory/message.py | 6 +-- embedchain/memory/utils.py | 10 ++--- embedchain/vectordb/chroma.py | 44 +++++++++--------- embedchain/vectordb/elasticsearch.py | 40 ++++++++--------- embedchain/vectordb/opensearch.py | 44 +++++++++--------- embedchain/vectordb/pinecone.py | 38 ++++++++-------- embedchain/vectordb/qdrant.py | 40 ++++++++--------- embedchain/vectordb/weaviate.py | 40 ++++++++--------- embedchain/vectordb/zilliz.py | 32 ++++++------- 41 files changed, 272 insertions(+), 267 deletions(-) diff --git a/embedchain/app.py b/embedchain/app.py index 45f422c9..72f5bfd6 100644 --- a/embedchain/app.py +++ b/embedchain/app.py @@ -4,7 +4,7 @@ import logging import os import sqlite3 import uuid -from typing import Any, Dict, Optional +from typing import Any, Optional import requests import yaml @@ -364,7 +364,7 @@ class App(EmbedChain): def from_config( cls, config_path: Optional[str] = None, - config: Optional[Dict[str, Any]] = None, + config: Optional[dict[str, Any]] = None, auto_deploy: bool = False, yaml_path: Optional[str] = None, ): @@ -374,7 +374,7 @@ class App(EmbedChain): :param config_path: Path to the YAML or JSON configuration file. :type config_path: Optional[str] :param config: A dictionary containing the configuration. - :type config: Optional[Dict[str, Any]] + :type config: Optional[dict[str, Any]] :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False :type auto_deploy: bool, optional :param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead. diff --git a/embedchain/bots/poe.py b/embedchain/bots/poe.py index 762d090b..4753eed8 100644 --- a/embedchain/bots/poe.py +++ b/embedchain/bots/poe.py @@ -1,7 +1,7 @@ import argparse import logging import os -from typing import List, Optional +from typing import Optional from embedchain.helpers.json_serializable import register_deserializable @@ -53,7 +53,7 @@ class PoeBot(BaseBot, PoeBot): answer = self.handle_message(last_message, history) yield self.text_event(answer) - def handle_message(self, message, history: Optional[List[str]] = None): + def handle_message(self, message, history: Optional[list[str]] = None): if message.startswith("/add "): response = self.add_data(message) else: @@ -70,7 +70,7 @@ class PoeBot(BaseBot, PoeBot): # response = "Some error occurred while adding data." # return response - def ask_bot(self, message, history: List[str]): + def ask_bot(self, message, history: list[str]): try: self.app.llm.set_history(history=history) response = self.query(message) diff --git a/embedchain/cache.py b/embedchain/cache.py index ba1675ed..4dd4ccc4 100644 --- a/embedchain/cache.py +++ b/embedchain/cache.py @@ -1,6 +1,6 @@ import logging import os # noqa: F401 -from typing import Any, Dict +from typing import Any from gptcache import cache # noqa: F401 from gptcache.adapter.adapter import adapt # noqa: F401 @@ -15,7 +15,7 @@ from gptcache.similarity_evaluation.exact_match import \ ExactMatchEvaluation # noqa: F401 -def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]): +def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]): return data["input_query"] diff --git a/embedchain/config/add_config.py b/embedchain/config/add_config.py index 270b8a81..56686e8e 100644 --- a/embedchain/config/add_config.py +++ b/embedchain/config/add_config.py @@ -1,7 +1,8 @@ import builtins import logging +from collections.abc import Callable from importlib import import_module -from typing import Callable, Optional +from typing import Optional from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import register_deserializable diff --git a/embedchain/config/base_config.py b/embedchain/config/base_config.py index ff672f19..bf7869f4 100644 --- a/embedchain/config/base_config.py +++ b/embedchain/config/base_config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from embedchain.helpers.json_serializable import JSONSerializable @@ -12,10 +12,10 @@ class BaseConfig(JSONSerializable): """Initializes a configuration class for a class.""" pass - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Return config object as a dict :return: config object as dict - :rtype: Dict[str, Any] + :rtype: dict[str, Any] """ return vars(self) diff --git a/embedchain/config/cache_config.py b/embedchain/config/cache_config.py index 1d115be8..ef8bd1fb 100644 --- a/embedchain/config/cache_config.py +++ b/embedchain/config/cache_config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import register_deserializable @@ -30,7 +30,7 @@ class CacheSimilarityEvalConfig(BaseConfig): self.positive = positive @staticmethod - def from_config(config: Optional[Dict[str, Any]]): + def from_config(config: Optional[dict[str, Any]]): if config is None: return CacheSimilarityEvalConfig() else: @@ -65,7 +65,7 @@ class CacheInitConfig(BaseConfig): self.auto_flush = auto_flush @staticmethod - def from_config(config: Optional[Dict[str, Any]]): + def from_config(config: Optional[dict[str, Any]]): if config is None: return CacheInitConfig() else: @@ -86,7 +86,7 @@ class CacheConfig(BaseConfig): self.init_config = init_config @staticmethod - def from_config(config: Optional[Dict[str, Any]]): + def from_config(config: Optional[dict[str, Any]]): if config is None: return CacheConfig() else: diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 7fa84c94..83d2f7d1 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -1,7 +1,7 @@ import logging import re from string import Template -from typing import Any, Dict, List, Optional +from typing import Any, Optional from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import register_deserializable @@ -68,12 +68,12 @@ class BaseLlmConfig(BaseConfig): stream: bool = False, deployment_name: Optional[str] = None, system_prompt: Optional[str] = None, - where: Dict[str, Any] = None, + where: dict[str, Any] = None, query_type: Optional[str] = None, - callbacks: Optional[List] = None, + callbacks: Optional[list] = None, api_key: Optional[str] = None, endpoint: Optional[str] = None, - model_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[dict[str, Any]] = None, ): """ Initializes a configuration class instance for the LLM. @@ -106,7 +106,7 @@ class BaseLlmConfig(BaseConfig): :param system_prompt: System prompt string, defaults to None :type system_prompt: Optional[str], optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None - :type where: Dict[str, Any], optional + :type where: dict[str, Any], optional :param api_key: The api key of the custom endpoint, defaults to None :type api_key: Optional[str], optional :param endpoint: The api url of the custom endpoint, defaults to None @@ -114,7 +114,7 @@ class BaseLlmConfig(BaseConfig): :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None :type model_kwargs: Optional[Dict[str, Any]], optional :param callbacks: Langchain callback functions to use, defaults to None - :type callbacks: Optional[List], optional + :type callbacks: Optional[list], optional :param query_type: The type of query to use, defaults to None :type query_type: Optional[str], optional :raises ValueError: If the template is not valid as template should diff --git a/embedchain/config/vectordb/elasticsearch.py b/embedchain/config/vectordb/elasticsearch.py index 7ccf4226..700a7192 100644 --- a/embedchain/config/vectordb/elasticsearch.py +++ b/embedchain/config/vectordb/elasticsearch.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Union +from typing import Optional, Union from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable @@ -11,9 +11,9 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): self, collection_name: Optional[str] = None, dir: Optional[str] = None, - es_url: Union[str, List[str]] = None, + es_url: Union[str, list[str]] = None, cloud_id: Optional[str] = None, - **ES_EXTRA_PARAMS: Dict[str, any], + **ES_EXTRA_PARAMS: dict[str, any], ): """ Initializes a configuration class instance for an Elasticsearch client. @@ -23,13 +23,13 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): :param dir: Path to the database directory, where the database is stored, defaults to None :type dir: Optional[str], optional :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None - :type es_url: Union[str, List[str]], optional + :type es_url: Union[str, list[str]], optional :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. - :type ES_EXTRA_PARAMS: Dict[str, Any], optional + :type ES_EXTRA_PARAMS: dict[str, Any], optional """ if es_url and cloud_id: raise ValueError("Only one of `es_url` and `cloud_id` can be set.") - # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): + # self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]): self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID") if not self.ES_URL and not self.CLOUD_ID: diff --git a/embedchain/config/vectordb/opensearch.py b/embedchain/config/vectordb/opensearch.py index d8dc9a10..1e112772 100644 --- a/embedchain/config/vectordb/opensearch.py +++ b/embedchain/config/vectordb/opensearch.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable @@ -9,11 +9,11 @@ class OpenSearchDBConfig(BaseVectorDbConfig): def __init__( self, opensearch_url: str, - http_auth: Tuple[str, str], + http_auth: tuple[str, str], vector_dimension: int = 1536, collection_name: Optional[str] = None, dir: Optional[str] = None, - **extra_params: Dict[str, any], + **extra_params: dict[str, any], ): """ Initializes a configuration class instance for an OpenSearch client. @@ -23,7 +23,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig): :param opensearch_url: URL of the OpenSearch domain :type opensearch_url: str, Eg, "http://localhost:9200" :param http_auth: Tuple of username and password - :type http_auth: Tuple[str, str], Eg, ("username", "password") + :type http_auth: tuple[str, str], Eg, ("username", "password") :param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model) :type vector_dimension: int, optional :param dir: Path to the database directory, where the database is stored, defaults to None diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index e9165fdc..efb98c79 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable @@ -12,7 +12,7 @@ class PineconeDBConfig(BaseVectorDbConfig): dir: Optional[str] = None, vector_dimension: int = 1536, metric: Optional[str] = "cosine", - **extra_params: Dict[str, any], + **extra_params: dict[str, any], ): self.metric = metric self.vector_dimension = vector_dimension diff --git a/embedchain/config/vectordb/qdrant.py b/embedchain/config/vectordb/qdrant.py index 98e7c3f9..1268913e 100644 --- a/embedchain/config/vectordb/qdrant.py +++ b/embedchain/config/vectordb/qdrant.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable @@ -15,10 +15,10 @@ class QdrantDBConfig(BaseVectorDbConfig): self, collection_name: Optional[str] = None, dir: Optional[str] = None, - hnsw_config: Optional[Dict[str, any]] = None, - quantization_config: Optional[Dict[str, any]] = None, + hnsw_config: Optional[dict[str, any]] = None, + quantization_config: Optional[dict[str, any]] = None, on_disk: Optional[bool] = None, - **extra_params: Dict[str, any], + **extra_params: dict[str, any], ): """ Initializes a configuration class instance for a qdrant client. @@ -28,9 +28,9 @@ class QdrantDBConfig(BaseVectorDbConfig): :param dir: Path to the database directory, where the database is stored, defaults to None :type dir: Optional[str], optional :param hnsw_config: Params for HNSW index - :type hnsw_config: Optional[Dict[str, any]], defaults to None + :type hnsw_config: Optional[dict[str, any]], defaults to None :param quantization_config: Params for quantization, if None - quantization will be disabled - :type quantization_config: Optional[Dict[str, any]], defaults to None + :type quantization_config: Optional[dict[str, any]], defaults to None :param on_disk: If true - point`s payload will not be stored in memory. It will be read from the disk every time it is requested. This setting saves RAM by (slightly) increasing the response time. diff --git a/embedchain/config/vectordb/weaviate.py b/embedchain/config/vectordb/weaviate.py index 2db24134..3f5a353a 100644 --- a/embedchain/config/vectordb/weaviate.py +++ b/embedchain/config/vectordb/weaviate.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.helpers.json_serializable import register_deserializable @@ -10,7 +10,7 @@ class WeaviateDBConfig(BaseVectorDbConfig): self, collection_name: Optional[str] = None, dir: Optional[str] = None, - **extra_params: Dict[str, any], + **extra_params: dict[str, any], ): self.extra_params = extra_params super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 148d1f88..d99f038f 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -2,7 +2,7 @@ import hashlib import json import logging import sqlite3 -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from dotenv import load_dotenv from langchain.docstore.document import Document @@ -136,12 +136,12 @@ class EmbedChain(JSONSerializable): self, source: Any, data_type: Optional[DataType] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, config: Optional[AddConfig] = None, dry_run=False, loader: Optional[BaseLoader] = None, chunker: Optional[BaseChunker] = None, - **kwargs: Optional[Dict[str, Any]], + **kwargs: Optional[dict[str, Any]], ): """ Adds the data from the given URL to the vector db. @@ -154,7 +154,7 @@ class EmbedChain(JSONSerializable): defaults to None :type data_type: Optional[DataType], optional :param metadata: Metadata associated with the data source., defaults to None - :type metadata: Optional[Dict[str, Any]], optional + :type metadata: Optional[dict[str, Any]], optional :param config: The `AddConfig` instance to use as configuration options., defaults to None :type config: Optional[AddConfig], optional :raises ValueError: Invalid data type @@ -243,9 +243,9 @@ class EmbedChain(JSONSerializable): self, source: Any, data_type: Optional[DataType] = None, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, config: Optional[AddConfig] = None, - **kwargs: Optional[Dict[str, Any]], + **kwargs: Optional[dict[str, Any]], ): """ Adds the data from the given URL to the vector db. @@ -261,7 +261,7 @@ class EmbedChain(JSONSerializable): defaults to None :type data_type: Optional[DataType], optional :param metadata: Metadata associated with the data source., defaults to None - :type metadata: Optional[Dict[str, Any]], optional + :type metadata: Optional[dict[str, Any]], optional :param config: The `AddConfig` instance to use as configuration options., defaults to None :type config: Optional[AddConfig], optional :raises ValueError: Invalid data type @@ -342,11 +342,11 @@ class EmbedChain(JSONSerializable): loader: BaseLoader, chunker: BaseChunker, src: Any, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, source_hash: Optional[str] = None, add_config: Optional[AddConfig] = None, dry_run=False, - **kwargs: Optional[Dict[str, Any]], + **kwargs: Optional[dict[str, Any]], ): """ Loads the data from the given URL, chunks it, and adds it to database. @@ -359,7 +359,7 @@ class EmbedChain(JSONSerializable): :param source_hash: Hexadecimal hash of the source. :param dry_run: Optional. A dry run returns chunks and doesn't update DB. :type dry_run: bool, defaults to False - :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks + :return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks """ existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src) app_id = self.config.id if self.config is not None else None @@ -464,8 +464,8 @@ class EmbedChain(JSONSerializable): config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, str, str]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, str, str]], list[str]]: """ Queries the vector database based on the given input query. Gets relevant doc based on the query @@ -479,7 +479,7 @@ class EmbedChain(JSONSerializable): :param citations: A boolean to indicate if db should fetch citation source :type citations: bool :return: List of contents of the document that matched your query - :rtype: List[str] + :rtype: list[str] """ query_config = config or self.llm.config if where is not None: @@ -507,10 +507,10 @@ class EmbedChain(JSONSerializable): input_query: str, config: BaseLlmConfig = None, dry_run=False, - where: Optional[Dict] = None, + where: Optional[dict] = None, citations: bool = False, - **kwargs: Dict[str, Any], - ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]: + **kwargs: dict[str, Any], + ) -> Union[tuple[str, list[tuple[str, dict]]], str]: """ Queries the vector database based on the given input query. Gets relevant doc based on the query and then passes it to an @@ -525,13 +525,13 @@ class EmbedChain(JSONSerializable): the LLM. The purpose is to test the prompt, not the response., defaults to False :type dry_run: bool, optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None - :type where: Optional[Dict[str, str]], optional + :type where: Optional[dict[str, str]], optional :param kwargs: To read more params for the query function. Ex. we use citations boolean param to return context along with the answer - :type kwargs: Dict[str, Any] + :type kwargs: dict[str, Any] :return: The answer to the query, with citations if the citation flag is True or the dry run result - :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] + :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]] """ contexts = self._retrieve_from_database( input_query=input_query, config=config, where=where, citations=citations, **kwargs @@ -572,10 +572,10 @@ class EmbedChain(JSONSerializable): config: Optional[BaseLlmConfig] = None, dry_run=False, session_id: str = "default", - where: Optional[Dict[str, str]] = None, + where: Optional[dict[str, str]] = None, citations: bool = False, - **kwargs: Dict[str, Any], - ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]: + **kwargs: dict[str, Any], + ) -> Union[tuple[str, list[tuple[str, dict]]], str]: """ Queries the vector database on the given input query. Gets relevant doc based on the query and then passes it to an @@ -594,13 +594,13 @@ class EmbedChain(JSONSerializable): :param session_id: The session id to use for chat history, defaults to 'default'. :type session_id: Optional[str], optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None - :type where: Optional[Dict[str, str]], optional + :type where: Optional[dict[str, str]], optional :param kwargs: To read more params for the query function. Ex. we use citations boolean param to return context along with the answer - :type kwargs: Dict[str, Any] + :type kwargs: dict[str, Any] :return: The answer to the query, with citations if the citation flag is True or the dry run result - :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] + :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]] """ contexts = self._retrieve_from_database( input_query=input_query, config=config, where=where, citations=citations, **kwargs diff --git a/embedchain/embedder/base.py b/embedchain/embedder/base.py index 69f7a6c5..d0a0b082 100644 --- a/embedchain/embedder/base.py +++ b/embedchain/embedder/base.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from embedchain.config.embedder.base import BaseEmbedderConfig diff --git a/embedchain/helpers/callbacks.py b/embedchain/helpers/callbacks.py index 3c7ab356..994e0fdc 100644 --- a/embedchain/helpers/callbacks.py +++ b/embedchain/helpers/callbacks.py @@ -1,5 +1,5 @@ import queue -from typing import Any, Dict, List, Union +from typing import Any, Union from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.schema import LLMResult @@ -29,7 +29,7 @@ class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler): super().__init__() self.q = q - def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: + def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None: """Run when LLM starts running.""" with self.q.mutex: self.q.queue.clear() diff --git a/embedchain/helpers/json_serializable.py b/embedchain/helpers/json_serializable.py index 2692b107..5f3179af 100644 --- a/embedchain/helpers/json_serializable.py +++ b/embedchain/helpers/json_serializable.py @@ -1,7 +1,7 @@ import json import logging from string import Template -from typing import Any, Dict, Type, TypeVar, Union +from typing import Any, Type, TypeVar, Union T = TypeVar("T", bound="JSONSerializable") @@ -84,7 +84,7 @@ class JSONSerializable: return cls() @staticmethod - def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]: + def _auto_encoder(obj: Any) -> Union[dict[str, Any], None]: """ Automatically encode an object for JSON serialization. @@ -126,7 +126,7 @@ class JSONSerializable: raise TypeError(f"Object of type {type(obj)} is not JSON serializable") @classmethod - def _auto_decoder(cls, dct: Dict[str, Any]) -> Any: + def _auto_decoder(cls, dct: dict[str, Any]) -> Any: """ Automatically decode a dictionary to an object during JSON deserialization. diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index ba464da2..1e002916 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, Generator, List, Optional +from collections.abc import Generator +from typing import Any, Optional from langchain.schema import BaseMessage as LCBaseMessage @@ -55,7 +56,7 @@ class BaseLlm(JSONSerializable): app_id: str, question: str, answer: str, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, session_id: str = "default", ): chat_message = ChatMessage() @@ -64,7 +65,7 @@ class BaseLlm(JSONSerializable): self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id) self.update_history(app_id=app_id, session_id=session_id) - def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str: + def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str: """ Generates a prompt based on the given query and context, ready to be passed to an LLM @@ -72,7 +73,7 @@ class BaseLlm(JSONSerializable): :param input_query: The query to use. :type input_query: str :param contexts: List of similar documents to the query used as context. - :type contexts: List[str] + :type contexts: list[str] :return: The prompt :rtype: str """ @@ -170,7 +171,7 @@ class BaseLlm(JSONSerializable): yield chunk logging.info(f"Answer: {streamed_answer}") - def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False): + def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False): """ Queries the vector database based on the given input query. Gets relevant doc based on the query and then passes it to an @@ -179,7 +180,7 @@ class BaseLlm(JSONSerializable): :param input_query: The query to use. :type input_query: str :param contexts: Embeddings retrieved from the database to be used as context. - :type contexts: List[str] + :type contexts: list[str] :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call. To persistently use a config, declare it during app init., defaults to None :type config: Optional[BaseLlmConfig], optional @@ -223,7 +224,7 @@ class BaseLlm(JSONSerializable): self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) def chat( - self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None + self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None ): """ Queries the vector database on the given input query. @@ -235,7 +236,7 @@ class BaseLlm(JSONSerializable): :param input_query: The query to use. :type input_query: str :param contexts: Embeddings retrieved from the database to be used as context. - :type contexts: List[str] + :type contexts: list[str] :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call. To persistently use a config, declare it during app init., defaults to None :type config: Optional[BaseLlmConfig], optional @@ -281,7 +282,7 @@ class BaseLlm(JSONSerializable): self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) @staticmethod - def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]: + def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]: """ Construct a list of langchain messages @@ -290,7 +291,7 @@ class BaseLlm(JSONSerializable): :param system_prompt: System prompt, defaults to None :type system_prompt: Optional[str], optional :return: List of messages - :rtype: List[BaseMessage] + :rtype: list[BaseMessage] """ from langchain.schema import HumanMessage, SystemMessage diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py index 9c23f352..6f41e5e9 100644 --- a/embedchain/llm/google.py +++ b/embedchain/llm/google.py @@ -1,7 +1,8 @@ import importlib import logging import os -from typing import Any, Generator, Optional, Union +from collections.abc import Generator +from typing import Any, Optional, Union import google.generativeai as genai diff --git a/embedchain/llm/gpt4all.py b/embedchain/llm/gpt4all.py index fe4d6970..ce1e65bb 100644 --- a/embedchain/llm/gpt4all.py +++ b/embedchain/llm/gpt4all.py @@ -1,6 +1,7 @@ import os +from collections.abc import Iterable from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Optional, Union from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler diff --git a/embedchain/llm/ollama.py b/embedchain/llm/ollama.py index 9cd5c184..237a8797 100644 --- a/embedchain/llm/ollama.py +++ b/embedchain/llm/ollama.py @@ -1,4 +1,5 @@ -from typing import Iterable, Optional, Union +from collections.abc import Iterable +from typing import Optional, Union from langchain.callbacks.manager import CallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 78c2f136..639ec919 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Dict, Optional +from typing import Any, Optional from langchain.chat_models import ChatOpenAI from langchain.schema import AIMessage, HumanMessage, SystemMessage @@ -12,7 +12,7 @@ from embedchain.llm.base import BaseLlm @register_deserializable class OpenAILlm(BaseLlm): - def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None): + def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None): self.functions = functions super().__init__(config=config) diff --git a/embedchain/loaders/directory_loader.py b/embedchain/loaders/directory_loader.py index 51cd71db..915c0249 100644 --- a/embedchain/loaders/directory_loader.py +++ b/embedchain/loaders/directory_loader.py @@ -1,7 +1,7 @@ import hashlib import logging from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Optional from embedchain.config import AddConfig from embedchain.data_formatter.data_formatter import DataFormatter @@ -15,7 +15,7 @@ from embedchain.utils.misc import detect_datatype class DirectoryLoader(BaseLoader): """Load data from a directory.""" - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__(self, config: Optional[dict[str, Any]] = None): super().__init__() config = config or {} self.recursive = config.get("recursive", True) diff --git a/embedchain/loaders/discourse.py b/embedchain/loaders/discourse.py index 363bcb8a..22c158ab 100644 --- a/embedchain/loaders/discourse.py +++ b/embedchain/loaders/discourse.py @@ -1,7 +1,7 @@ import hashlib import logging import time -from typing import Any, Dict, Optional +from typing import Any, Optional import requests @@ -10,7 +10,7 @@ from embedchain.utils.misc import clean_string class DiscourseLoader(BaseLoader): - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__(self, config: Optional[dict[str, Any]] = None): super().__init__() if not config: raise ValueError( diff --git a/embedchain/loaders/dropbox.py b/embedchain/loaders/dropbox.py index c4b01f14..f31db368 100644 --- a/embedchain/loaders/dropbox.py +++ b/embedchain/loaders/dropbox.py @@ -1,6 +1,5 @@ import hashlib import os -from typing import List from dropbox.files import FileMetadata @@ -29,7 +28,7 @@ class DropboxLoader(BaseLoader): except exceptions.AuthError as ex: raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex - def _download_folder(self, path: str, local_root: str) -> List[FileMetadata]: + def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]: """Download a folder from Dropbox and save it preserving the directory structure.""" entries = self.dbx.files_list_folder(path).entries for entry in entries: diff --git a/embedchain/loaders/github.py b/embedchain/loaders/github.py index d3f0d6dd..967dbe77 100644 --- a/embedchain/loaders/github.py +++ b/embedchain/loaders/github.py @@ -4,7 +4,7 @@ import logging import os import re import shlex -from typing import Any, Dict, Optional +from typing import Any, Optional from tqdm import tqdm @@ -20,7 +20,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"]) class GithubLoader(BaseLoader): """Load data from GitHub search query.""" - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__(self, config: Optional[dict[str, Any]] = None): super().__init__() if not config: raise ValueError( diff --git a/embedchain/loaders/gmail.py b/embedchain/loaders/gmail.py index 3487a20b..07e10273 100644 --- a/embedchain/loaders/gmail.py +++ b/embedchain/loaders/gmail.py @@ -5,7 +5,7 @@ import os from email import message_from_bytes from email.utils import parsedate_to_datetime from textwrap import dedent -from typing import Dict, List, Optional +from typing import Optional from bs4 import BeautifulSoup @@ -57,7 +57,7 @@ class GmailReader: token.write(creds.to_json()) return creds - def load_emails(self) -> List[Dict]: + def load_emails(self) -> list[dict]: response = self.service.users().messages().list(userId="me", q=self.query).execute() messages = response.get("messages", []) @@ -67,7 +67,7 @@ class GmailReader: raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute() return base64.urlsafe_b64decode(raw_message["raw"]) - def _parse_email(self, raw_email) -> Dict: + def _parse_email(self, raw_email) -> dict: mime_msg = message_from_bytes(raw_email) return { "subject": self._get_header(mime_msg, "Subject"), @@ -124,7 +124,7 @@ class GmailLoader(BaseLoader): return {"doc_id": self._generate_doc_id(query, data), "data": data} @staticmethod - def _process_email(email: Dict) -> str: + def _process_email(email: dict) -> str: content = BeautifulSoup(email["body"], "html.parser").get_text() content = clean_string(content) return dedent( @@ -137,6 +137,6 @@ class GmailLoader(BaseLoader): ) @staticmethod - def _generate_doc_id(query: str, data: List[Dict]) -> str: + def _generate_doc_id(query: str, data: list[dict]) -> str: content_strings = [email["content"] for email in data] return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest() diff --git a/embedchain/loaders/json.py b/embedchain/loaders/json.py index b04fd1bd..13f4f7e6 100644 --- a/embedchain/loaders/json.py +++ b/embedchain/loaders/json.py @@ -2,7 +2,7 @@ import hashlib import json import os import re -from typing import Dict, List, Union +from typing import Union import requests @@ -16,14 +16,14 @@ class JSONReader: pass @staticmethod - def load_data(json_data: Union[Dict, str]) -> List[str]: + def load_data(json_data: Union[dict, str]) -> list[str]: """Load data from a JSON structure. Args: - json_data (Union[Dict, str]): The JSON data to load. + json_data (Union[dict, str]): The JSON data to load. Returns: - List[str]: A list of strings representing the leaf nodes of the JSON. + list[str]: A list of strings representing the leaf nodes of the JSON. """ if isinstance(json_data, str): json_data = json.loads(json_data) diff --git a/embedchain/loaders/mysql.py b/embedchain/loaders/mysql.py index f9e9d52b..7eee2893 100644 --- a/embedchain/loaders/mysql.py +++ b/embedchain/loaders/mysql.py @@ -1,13 +1,13 @@ import hashlib import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from embedchain.loaders.base_loader import BaseLoader from embedchain.utils.misc import clean_string class MySQLLoader(BaseLoader): - def __init__(self, config: Optional[Dict[str, Any]]): + def __init__(self, config: Optional[dict[str, Any]]): super().__init__() if not config: raise ValueError( @@ -20,7 +20,7 @@ class MySQLLoader(BaseLoader): self.cursor = None self._setup_loader(config=config) - def _setup_loader(self, config: Dict[str, Any]): + def _setup_loader(self, config: dict[str, Any]): try: import mysql.connector as sqlconnector except ImportError as e: diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index d51753b2..0ce8eb3f 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -1,7 +1,7 @@ import hashlib import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Optional import requests @@ -15,7 +15,7 @@ class NotionDocument: A simple Document class to hold the text and additional information of a page. """ - def __init__(self, text: str, extra_info: Dict[str, Any]): + def __init__(self, text: str, extra_info: dict[str, Any]): self.text = text self.extra_info = extra_info @@ -82,7 +82,7 @@ class NotionPageLoader: result_lines = "\n".join(result_lines_arr) return result_lines - def load_data(self, page_ids: List[str]) -> List[NotionDocument]: + def load_data(self, page_ids: list[str]) -> list[NotionDocument]: """Load data from the given list of page IDs.""" docs = [] for page_id in page_ids: diff --git a/embedchain/loaders/postgres.py b/embedchain/loaders/postgres.py index bd4e035f..d336248c 100644 --- a/embedchain/loaders/postgres.py +++ b/embedchain/loaders/postgres.py @@ -1,12 +1,12 @@ import hashlib import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from embedchain.loaders.base_loader import BaseLoader class PostgresLoader(BaseLoader): - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__(self, config: Optional[dict[str, Any]] = None): super().__init__() if not config: raise ValueError(f"Must provide the valid config. Received: {config}") @@ -15,7 +15,7 @@ class PostgresLoader(BaseLoader): self.cursor = None self._setup_loader(config=config) - def _setup_loader(self, config: Dict[str, Any]): + def _setup_loader(self, config: dict[str, Any]): try: import psycopg except ImportError as e: diff --git a/embedchain/loaders/slack.py b/embedchain/loaders/slack.py index 5e31a3ff..75f18738 100644 --- a/embedchain/loaders/slack.py +++ b/embedchain/loaders/slack.py @@ -2,7 +2,7 @@ import hashlib import logging import os import ssl -from typing import Any, Dict, Optional +from typing import Any, Optional import certifi @@ -13,7 +13,7 @@ SLACK_API_BASE_URL = "https://www.slack.com/api/" class SlackLoader(BaseLoader): - def __init__(self, config: Optional[Dict[str, Any]] = None): + def __init__(self, config: Optional[dict[str, Any]] = None): super().__init__() self.config = config if config else {} @@ -24,7 +24,7 @@ class SlackLoader(BaseLoader): self.client = None self._setup_loader(self.config) - def _setup_loader(self, config: Dict[str, Any]): + def _setup_loader(self, config: dict[str, Any]): try: from slack_sdk import WebClient except ImportError as e: diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py index 13fb3993..c453a351 100644 --- a/embedchain/memory/base.py +++ b/embedchain/memory/base.py @@ -2,7 +2,7 @@ import json import logging import sqlite3 import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Optional from embedchain.constants import SQLITE_PATH from embedchain.memory.message import ChatMessage @@ -67,7 +67,7 @@ class ChatHistory: self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id)) self.connection.commit() - def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]: + def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]: """ Get the most recent num_rounds rounds of conversations between human and AI, for a given app_id. @@ -114,7 +114,7 @@ class ChatHistory: return count @staticmethod - def _serialize_json(metadata: Dict[str, Any]): + def _serialize_json(metadata: dict[str, Any]): return json.dumps(metadata) @staticmethod diff --git a/embedchain/memory/message.py b/embedchain/memory/message.py index 1959aa54..cc8c3a94 100644 --- a/embedchain/memory/message.py +++ b/embedchain/memory/message.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from embedchain.helpers.json_serializable import JSONSerializable @@ -18,9 +18,9 @@ class BaseMessage(JSONSerializable): created_by: str # Any additional info. - metadata: Dict[str, Any] + metadata: dict[str, Any] - def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, content: str, created_by: str, metadata: Optional[dict[str, Any]] = None) -> None: super().__init__() self.content = content self.created_by = created_by diff --git a/embedchain/memory/utils.py b/embedchain/memory/utils.py index ec60704b..8abea4dc 100644 --- a/embedchain/memory/utils.py +++ b/embedchain/memory/utils.py @@ -1,16 +1,16 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional -def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: +def merge_metadata_dict(left: Optional[dict[str, Any]], right: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: """ Merge the metadatas of two BaseMessage types. Args: - left (Dict[str, Any]): metadata of human message - right (Dict[str, Any]): metadata of AI message + left (dict[str, Any]): metadata of human message + right (dict[str, Any]): metadata of AI message Returns: - Dict[str, Any]: combined metadata dict with dedup + dict[str, Any]: combined metadata dict with dedup to be saved in db. """ if not left and not right: diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index fad35167..827f12d3 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from chromadb import Collection, QueryResult from langchain.docstore.document import Document @@ -76,7 +76,7 @@ class ChromaDB(BaseVectorDB): return self.client @staticmethod - def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]: + def _generate_where_clause(where: dict[str, any]) -> dict[str, any]: # If only one filter is supplied, return it as is # (no need to wrap in $and based on chroma docs) if len(where.keys()) <= 1: @@ -105,18 +105,18 @@ class ChromaDB(BaseVectorDB): ) return self.collection - def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database :param ids: list of doc ids to check for existence - :type ids: List[str] + :type ids: list[str] :param where: Optional. to filter data - :type where: Dict[str, Any] + :type where: dict[str, Any] :param limit: Optional. maximum number of documents :type limit: Optional[int] :return: Existing documents. - :rtype: List[str] + :rtype: list[str] """ args = {} if ids: @@ -129,23 +129,23 @@ class ChromaDB(BaseVectorDB): def add( self, - embeddings: List[List[float]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, Any]], + embeddings: list[list[float]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, Any]], ) -> Any: """ Add vectors to chroma database :param embeddings: list of embeddings to add - :type embeddings: List[List[str]] + :type embeddings: list[list[str]] :param documents: Documents - :type documents: List[str] + :type documents: list[str] :param metadatas: Metadatas - :type metadatas: List[object] + :type metadatas: list[object] :param ids: ids - :type ids: List[str] + :type ids: list[str] """ size = len(documents) if len(documents) != size or len(metadatas) != size or len(ids) != size: @@ -182,27 +182,27 @@ class ChromaDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ Query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: to filter data - :type where: Dict[str, Any] + :type where: dict[str, Any] :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :raises InvalidDimensionException: Dimensions do not match. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ try: result = self.collection.query( diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index 2ba200f7..d47344d6 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union try: from elasticsearch import Elasticsearch @@ -84,14 +84,14 @@ class ElasticsearchDB(BaseVectorDB): def _get_or_create_collection(self, name): """Note: nothing to return here. Discuss later""" - def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database :param ids: _list of doc ids to check for existence - :type ids: List[str] + :type ids: list[str] :param where: to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :return: ids :rtype: Set[str] """ @@ -110,22 +110,22 @@ class ElasticsearchDB(BaseVectorDB): def add( self, - embeddings: List[List[float]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, any]], + embeddings: list[list[float]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], ) -> Any: """ add data in vector database :param embeddings: list of embeddings to add - :type embeddings: List[List[str]] + :type embeddings: list[list[str]] :param documents: list of texts to add - :type documents: List[str] + :type documents: list[str] :param metadatas: list of metadata associated with docs - :type metadatas: List[object] + :type metadatas: list[object] :param ids: ids of docs - :type ids: List[str] + :type ids: list[str] """ embeddings = self.embedder.embedding_fn(documents) @@ -154,27 +154,27 @@ class ElasticsearchDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :return: The context of the document that matched your query, url of the source, doc_id :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ input_query_vector = self.embedder.embedding_fn(input_query) query_vector = input_query_vector[0] diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 49da38dc..1626f772 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -1,6 +1,6 @@ import logging import time -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union from tqdm import tqdm @@ -78,17 +78,17 @@ class OpenSearchDB(BaseVectorDB): """Note: nothing to return here. Discuss later""" def get( - self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None - ) -> Set[str]: + self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None + ) -> set[str]: """ Get existing doc ids present in vector database :param ids: _list of doc ids to check for existence - :type ids: List[str] + :type ids: list[str] :param where: to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :return: ids - :type: Set[str] + :type: set[str] """ query = {} if ids: @@ -116,19 +116,19 @@ class OpenSearchDB(BaseVectorDB): def add( self, - embeddings: List[List[str]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, any]], + embeddings: list[list[str]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], ): """Add data in vector database. Args: - embeddings (List[List[str]]): List of embeddings to add. - documents (List[str]): List of texts to add. - metadatas (List[object]): List of metadata associated with docs. - ids (List[str]): IDs of docs. + embeddings (list[list[str]]): list of embeddings to add. + documents (list[str]): list of texts to add. + metadatas (list[object]): list of metadata associated with docs. + ids (list[str]): IDs of docs. """ for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"): batch_end = batch_start + self.BATCH_SIZE @@ -156,26 +156,26 @@ class OpenSearchDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ embeddings = OpenAIEmbeddings() docsearch = OpenSearchVectorSearch( diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 336967d6..16076dda 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union try: import pinecone @@ -67,14 +67,14 @@ class PineconeDB(BaseVectorDB): ) return pinecone.Index(self.index_name) - def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database :param ids: _list of doc ids to check for existence - :type ids: List[str] + :type ids: list[str] :param where: to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :return: ids :rtype: Set[str] """ @@ -88,20 +88,20 @@ class PineconeDB(BaseVectorDB): def add( self, - embeddings: List[List[float]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, any]], + embeddings: list[list[float]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], ): """add data in vector database :param documents: list of texts to add - :type documents: List[str] + :type documents: list[str] :param metadatas: list of metadata associated with docs - :type metadatas: List[object] + :type metadatas: list[object] :param ids: ids of docs - :type ids: List[str] + :type ids: list[str] """ docs = [] print("Adding documents to Pinecone...") @@ -120,25 +120,25 @@ class PineconeDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ query_vector = self.embedder.embedding_fn([input_query])[0] data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs) diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index be2d9523..b107568d 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -1,7 +1,7 @@ import copy import os import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union try: from qdrant_client import QdrantClient @@ -69,14 +69,14 @@ class QdrantDB(BaseVectorDB): def _get_or_create_collection(self): return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-") - def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database :param ids: _list of doc ids to check for existence - :type ids: List[str] + :type ids: list[str] :param where: to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :param limit: The number of entries to be fetched :type limit: Optional int, defaults to None :return: All the existing IDs @@ -122,21 +122,21 @@ class QdrantDB(BaseVectorDB): def add( self, - embeddings: List[List[float]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, any]], + embeddings: list[list[float]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], ): """add data in vector database :param embeddings: list of embeddings for the corresponding documents to be added - :type documents: List[List[float]] + :type documents: list[list[float]] :param documents: list of texts to add - :type documents: List[str] + :type documents: list[str] :param metadatas: list of metadata associated with docs - :type metadatas: List[object] + :type metadatas: list[object] :param ids: ids of docs - :type ids: List[str] + :type ids: list[str] """ embeddings = self.embedder.embedding_fn(documents) @@ -159,25 +159,25 @@ class QdrantDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ query_vector = self.embedder.embedding_fn([input_query])[0] keys = set(where.keys() if where is not None else set()) diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 1859be41..b3535bbe 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -1,6 +1,6 @@ import copy import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union try: import weaviate @@ -117,13 +117,13 @@ class WeaviateDB(BaseVectorDB): self.client.schema.create(class_obj) - def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database :param ids: _list of doc ids to check for existance - :type ids: List[str] + :type ids: list[str] :param where: to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :return: ids :rtype: Set[str] """ @@ -153,21 +153,21 @@ class WeaviateDB(BaseVectorDB): def add( self, - embeddings: List[List[float]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, any]], + embeddings: list[list[float]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], ): """add data in vector database :param embeddings: list of embeddings for the corresponding documents to be added - :type documents: List[List[float]] + :type documents: list[list[float]] :param documents: list of texts to add - :type documents: List[str] + :type documents: list[str] :param metadatas: list of metadata associated with docs - :type metadatas: List[object] + :type metadatas: list[object] :param ids: ids of docs - :type ids: List[str] + :type ids: list[str] """ embeddings = self.embedder.embedding_fn(documents) self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch @@ -192,25 +192,25 @@ class WeaviateDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data - :type where: Dict[str, any] + :type where: dict[str, any] :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ query_vector = self.embedder.embedding_fn([input_query])[0] keys = set(where.keys() if where is not None else set()) diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index a310a0ae..65e541db 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from embedchain.config import ZillizDBConfig from embedchain.helpers.json_serializable import register_deserializable @@ -88,14 +88,14 @@ class ZillizVectorDB(BaseVectorDB): self.collection.create_index("embeddings", index) return self.collection - def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database :param ids: list of doc ids to check for existence - :type ids: List[str] + :type ids: list[str] :param where: Optional. to filter data - :type where: Dict[str, Any] + :type where: dict[str, Any] :param limit: Optional. maximum number of documents :type limit: Optional[int] :return: Existing documents. @@ -115,11 +115,11 @@ class ZillizVectorDB(BaseVectorDB): def add( self, - embeddings: List[List[float]], - documents: List[str], - metadatas: List[object], - ids: List[str], - **kwargs: Optional[Dict[str, any]], + embeddings: list[list[float]], + documents: list[str], + metadatas: list[object], + ids: list[str], + **kwargs: Optional[dict[str, any]], ): """Add to database""" embeddings = self.embedder.embedding_fn(documents) @@ -134,17 +134,17 @@ class ZillizVectorDB(BaseVectorDB): def query( self, - input_query: List[str], + input_query: list[str], n_results: int, - where: Dict[str, any], + where: dict[str, any], citations: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> Union[List[Tuple[str, Dict]], List[str]]: + **kwargs: Optional[dict[str, Any]], + ) -> Union[list[tuple[str, dict]], list[str]]: """ Query contents from vector database based on vector similarity :param input_query: list of query string - :type input_query: List[str] + :type input_query: list[str] :param n_results: no of similar documents to fetch from database :type n_results: int :param where: to filter data @@ -154,7 +154,7 @@ class ZillizVectorDB(BaseVectorDB): :type citations: bool, default is False. :return: The content of the document that matched your query, along with url of the source and doc_id (if citations flag is true) - :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ if self.collection.is_empty: @@ -200,7 +200,7 @@ class ZillizVectorDB(BaseVectorDB): """ return self.collection.num_entities - def reset(self, collection_names: List[str] = None): + def reset(self, collection_names: list[str] = None): """ Resets the database. Deletes all embeddings irreversibly. """