#1128 | Remove deprecated type hints from typing module (#1131)

This commit is contained in:
Sandra Serrano
2024-01-09 18:35:24 +01:00
committed by GitHub
parent c9df7a2020
commit 0de9491c61
41 changed files with 272 additions and 267 deletions

View File

@@ -4,7 +4,7 @@ import logging
import os
import sqlite3
import uuid
from typing import Any, Dict, Optional
from typing import Any, Optional
import requests
import yaml
@@ -364,7 +364,7 @@ class App(EmbedChain):
def from_config(
cls,
config_path: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
auto_deploy: bool = False,
yaml_path: Optional[str] = None,
):
@@ -374,7 +374,7 @@ class App(EmbedChain):
:param config_path: Path to the YAML or JSON configuration file.
:type config_path: Optional[str]
:param config: A dictionary containing the configuration.
:type config: Optional[Dict[str, Any]]
:type config: Optional[dict[str, Any]]
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.

View File

@@ -1,7 +1,7 @@
import argparse
import logging
import os
from typing import List, Optional
from typing import Optional
from embedchain.helpers.json_serializable import register_deserializable
@@ -53,7 +53,7 @@ class PoeBot(BaseBot, PoeBot):
answer = self.handle_message(last_message, history)
yield self.text_event(answer)
def handle_message(self, message, history: Optional[List[str]] = None):
def handle_message(self, message, history: Optional[list[str]] = None):
if message.startswith("/add "):
response = self.add_data(message)
else:
@@ -70,7 +70,7 @@ class PoeBot(BaseBot, PoeBot):
# response = "Some error occurred while adding data."
# return response
def ask_bot(self, message, history: List[str]):
def ask_bot(self, message, history: list[str]):
try:
self.app.llm.set_history(history=history)
response = self.query(message)

View File

@@ -1,6 +1,6 @@
import logging
import os # noqa: F401
from typing import Any, Dict
from typing import Any
from gptcache import cache # noqa: F401
from gptcache.adapter.adapter import adapt # noqa: F401
@@ -15,7 +15,7 @@ from gptcache.similarity_evaluation.exact_match import \
ExactMatchEvaluation # noqa: F401
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
return data["input_query"]

View File

@@ -1,7 +1,8 @@
import builtins
import logging
from collections.abc import Callable
from importlib import import_module
from typing import Callable, Optional
from typing import Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any
from embedchain.helpers.json_serializable import JSONSerializable
@@ -12,10 +12,10 @@ class BaseConfig(JSONSerializable):
"""Initializes a configuration class for a class."""
pass
def as_dict(self) -> Dict[str, Any]:
def as_dict(self) -> dict[str, Any]:
"""Return config object as a dict
:return: config object as dict
:rtype: Dict[str, Any]
:rtype: dict[str, Any]
"""
return vars(self)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -30,7 +30,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
self.positive = positive
@staticmethod
def from_config(config: Optional[Dict[str, Any]]):
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return CacheSimilarityEvalConfig()
else:
@@ -65,7 +65,7 @@ class CacheInitConfig(BaseConfig):
self.auto_flush = auto_flush
@staticmethod
def from_config(config: Optional[Dict[str, Any]]):
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return CacheInitConfig()
else:
@@ -86,7 +86,7 @@ class CacheConfig(BaseConfig):
self.init_config = init_config
@staticmethod
def from_config(config: Optional[Dict[str, Any]]):
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return CacheConfig()
else:

View File

@@ -1,7 +1,7 @@
import logging
import re
from string import Template
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -68,12 +68,12 @@ class BaseLlmConfig(BaseConfig):
stream: bool = False,
deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None,
where: Dict[str, Any] = None,
where: dict[str, Any] = None,
query_type: Optional[str] = None,
callbacks: Optional[List] = None,
callbacks: Optional[list] = None,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[dict[str, Any]] = None,
):
"""
Initializes a configuration class instance for the LLM.
@@ -106,7 +106,7 @@ class BaseLlmConfig(BaseConfig):
:param system_prompt: System prompt string, defaults to None
:type system_prompt: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Dict[str, Any], optional
:type where: dict[str, Any], optional
:param api_key: The api key of the custom endpoint, defaults to None
:type api_key: Optional[str], optional
:param endpoint: The api url of the custom endpoint, defaults to None
@@ -114,7 +114,7 @@ class BaseLlmConfig(BaseConfig):
:param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
:type model_kwargs: Optional[Dict[str, Any]], optional
:param callbacks: Langchain callback functions to use, defaults to None
:type callbacks: Optional[List], optional
:type callbacks: Optional[list], optional
:param query_type: The type of query to use, defaults to None
:type query_type: Optional[str], optional
:raises ValueError: If the template is not valid as template should

View File

@@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Union
from typing import Optional, Union
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -11,9 +11,9 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, List[str]] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
**ES_EXTRA_PARAMS: Dict[str, any],
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
@@ -23,13 +23,13 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, List[str]], optional
:type es_url: Union[str, list[str]], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: Dict[str, Any], optional
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -9,11 +9,11 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: Tuple[str, str],
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
**extra_params: Dict[str, any],
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
@@ -23,7 +23,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: Tuple[str, str], Eg, ("username", "password")
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -12,7 +12,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
dir: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
**extra_params: Dict[str, any],
**extra_params: dict[str, any],
):
self.metric = metric
self.vector_dimension = vector_dimension

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -15,10 +15,10 @@ class QdrantDBConfig(BaseVectorDbConfig):
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
hnsw_config: Optional[Dict[str, any]] = None,
quantization_config: Optional[Dict[str, any]] = None,
hnsw_config: Optional[dict[str, any]] = None,
quantization_config: Optional[dict[str, any]] = None,
on_disk: Optional[bool] = None,
**extra_params: Dict[str, any],
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for a qdrant client.
@@ -28,9 +28,9 @@ class QdrantDBConfig(BaseVectorDbConfig):
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param hnsw_config: Params for HNSW index
:type hnsw_config: Optional[Dict[str, any]], defaults to None
:type hnsw_config: Optional[dict[str, any]], defaults to None
:param quantization_config: Params for quantization, if None - quantization will be disabled
:type quantization_config: Optional[Dict[str, any]], defaults to None
:type quantization_config: Optional[dict[str, any]], defaults to None
:param on_disk: If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -10,7 +10,7 @@ class WeaviateDBConfig(BaseVectorDbConfig):
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
**extra_params: Dict[str, any],
**extra_params: dict[str, any],
):
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -2,7 +2,7 @@ import hashlib
import json
import logging
import sqlite3
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
@@ -136,12 +136,12 @@ class EmbedChain(JSONSerializable):
self,
source: Any,
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Optional[Dict[str, Any]],
**kwargs: Optional[dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
@@ -154,7 +154,7 @@ class EmbedChain(JSONSerializable):
defaults to None
:type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional
:type metadata: Optional[dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
@@ -243,9 +243,9 @@ class EmbedChain(JSONSerializable):
self,
source: Any,
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
config: Optional[AddConfig] = None,
**kwargs: Optional[Dict[str, Any]],
**kwargs: Optional[dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
@@ -261,7 +261,7 @@ class EmbedChain(JSONSerializable):
defaults to None
:type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional
:type metadata: Optional[dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
@@ -342,11 +342,11 @@ class EmbedChain(JSONSerializable):
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
source_hash: Optional[str] = None,
add_config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Optional[Dict[str, Any]],
**kwargs: Optional[dict[str, Any]],
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
@@ -359,7 +359,7 @@ class EmbedChain(JSONSerializable):
:param source_hash: Hexadecimal hash of the source.
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
:type dry_run: bool, defaults to False
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
:return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks
"""
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
app_id = self.config.id if self.config is not None else None
@@ -464,8 +464,8 @@ class EmbedChain(JSONSerializable):
config: Optional[BaseLlmConfig] = None,
where=None,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, str, str]], list[str]]:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
@@ -479,7 +479,7 @@ class EmbedChain(JSONSerializable):
:param citations: A boolean to indicate if db should fetch citation source
:type citations: bool
:return: List of contents of the document that matched your query
:rtype: List[str]
:rtype: list[str]
"""
query_config = config or self.llm.config
if where is not None:
@@ -507,10 +507,10 @@ class EmbedChain(JSONSerializable):
input_query: str,
config: BaseLlmConfig = None,
dry_run=False,
where: Optional[Dict] = None,
where: Optional[dict] = None,
citations: bool = False,
**kwargs: Dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
**kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str]:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -525,13 +525,13 @@ class EmbedChain(JSONSerializable):
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional
:type where: Optional[dict[str, str]], optional
:param kwargs: To read more params for the query function. Ex. we use citations boolean
param to return context along with the answer
:type kwargs: Dict[str, Any]
:type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
"""
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -572,10 +572,10 @@ class EmbedChain(JSONSerializable):
config: Optional[BaseLlmConfig] = None,
dry_run=False,
session_id: str = "default",
where: Optional[Dict[str, str]] = None,
where: Optional[dict[str, str]] = None,
citations: bool = False,
**kwargs: Dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
**kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str]:
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -594,13 +594,13 @@ class EmbedChain(JSONSerializable):
:param session_id: The session id to use for chat history, defaults to 'default'.
:type session_id: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional
:type where: Optional[dict[str, str]], optional
:param kwargs: To read more params for the query function. Ex. we use citations boolean
param to return context along with the answer
:type kwargs: Dict[str, Any]
:type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
"""
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig

View File

@@ -1,5 +1,5 @@
import queue
from typing import Any, Dict, List, Union
from typing import Any, Union
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult
@@ -29,7 +29,7 @@ class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
super().__init__()
self.q = q
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
"""Run when LLM starts running."""
with self.q.mutex:
self.q.queue.clear()

View File

@@ -1,7 +1,7 @@
import json
import logging
from string import Template
from typing import Any, Dict, Type, TypeVar, Union
from typing import Any, Type, TypeVar, Union
T = TypeVar("T", bound="JSONSerializable")
@@ -84,7 +84,7 @@ class JSONSerializable:
return cls()
@staticmethod
def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
def _auto_encoder(obj: Any) -> Union[dict[str, Any], None]:
"""
Automatically encode an object for JSON serialization.
@@ -126,7 +126,7 @@ class JSONSerializable:
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
@classmethod
def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
def _auto_decoder(cls, dct: dict[str, Any]) -> Any:
"""
Automatically decode a dictionary to an object during JSON deserialization.

View File

@@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, Generator, List, Optional
from collections.abc import Generator
from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
@@ -55,7 +56,7 @@ class BaseLlm(JSONSerializable):
app_id: str,
question: str,
answer: str,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
session_id: str = "default",
):
chat_message = ChatMessage()
@@ -64,7 +65,7 @@ class BaseLlm(JSONSerializable):
self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
self.update_history(app_id=app_id, session_id=session_id)
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
@@ -72,7 +73,7 @@ class BaseLlm(JSONSerializable):
:param input_query: The query to use.
:type input_query: str
:param contexts: List of similar documents to the query used as context.
:type contexts: List[str]
:type contexts: list[str]
:return: The prompt
:rtype: str
"""
@@ -170,7 +171,7 @@ class BaseLlm(JSONSerializable):
yield chunk
logging.info(f"Answer: {streamed_answer}")
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -179,7 +180,7 @@ class BaseLlm(JSONSerializable):
:param input_query: The query to use.
:type input_query: str
:param contexts: Embeddings retrieved from the database to be used as context.
:type contexts: List[str]
:type contexts: list[str]
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional
@@ -223,7 +224,7 @@ class BaseLlm(JSONSerializable):
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
def chat(
self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
):
"""
Queries the vector database on the given input query.
@@ -235,7 +236,7 @@ class BaseLlm(JSONSerializable):
:param input_query: The query to use.
:type input_query: str
:param contexts: Embeddings retrieved from the database to be used as context.
:type contexts: List[str]
:type contexts: list[str]
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional
@@ -281,7 +282,7 @@ class BaseLlm(JSONSerializable):
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
"""
Construct a list of langchain messages
@@ -290,7 +291,7 @@ class BaseLlm(JSONSerializable):
:param system_prompt: System prompt, defaults to None
:type system_prompt: Optional[str], optional
:return: List of messages
:rtype: List[BaseMessage]
:rtype: list[BaseMessage]
"""
from langchain.schema import HumanMessage, SystemMessage

View File

@@ -1,7 +1,8 @@
import importlib
import logging
import os
from typing import Any, Generator, Optional, Union
from collections.abc import Generator
from typing import Any, Optional, Union
import google.generativeai as genai

View File

@@ -1,6 +1,7 @@
import os
from collections.abc import Iterable
from pathlib import Path
from typing import Iterable, Optional, Union
from typing import Optional, Union
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

View File

@@ -1,4 +1,5 @@
from typing import Iterable, Optional, Union
from collections.abc import Iterable
from typing import Optional, Union
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Dict, Optional
from typing import Any, Optional
from langchain.chat_models import ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
@@ -12,7 +12,7 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class OpenAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
self.functions = functions
super().__init__(config=config)

View File

@@ -1,7 +1,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Optional
from embedchain.config import AddConfig
from embedchain.data_formatter.data_formatter import DataFormatter
@@ -15,7 +15,7 @@ from embedchain.utils.misc import detect_datatype
class DirectoryLoader(BaseLoader):
"""Load data from a directory."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
config = config or {}
self.recursive = config.get("recursive", True)

View File

@@ -1,7 +1,7 @@
import hashlib
import logging
import time
from typing import Any, Dict, Optional
from typing import Any, Optional
import requests
@@ -10,7 +10,7 @@ from embedchain.utils.misc import clean_string
class DiscourseLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(

View File

@@ -1,6 +1,5 @@
import hashlib
import os
from typing import List
from dropbox.files import FileMetadata
@@ -29,7 +28,7 @@ class DropboxLoader(BaseLoader):
except exceptions.AuthError as ex:
raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex
def _download_folder(self, path: str, local_root: str) -> List[FileMetadata]:
def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]:
"""Download a folder from Dropbox and save it preserving the directory structure."""
entries = self.dbx.files_list_folder(path).entries
for entry in entries:

View File

@@ -4,7 +4,7 @@ import logging
import os
import re
import shlex
from typing import Any, Dict, Optional
from typing import Any, Optional
from tqdm import tqdm
@@ -20,7 +20,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
class GithubLoader(BaseLoader):
"""Load data from GitHub search query."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(

View File

@@ -5,7 +5,7 @@ import os
from email import message_from_bytes
from email.utils import parsedate_to_datetime
from textwrap import dedent
from typing import Dict, List, Optional
from typing import Optional
from bs4 import BeautifulSoup
@@ -57,7 +57,7 @@ class GmailReader:
token.write(creds.to_json())
return creds
def load_emails(self) -> List[Dict]:
def load_emails(self) -> list[dict]:
response = self.service.users().messages().list(userId="me", q=self.query).execute()
messages = response.get("messages", [])
@@ -67,7 +67,7 @@ class GmailReader:
raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
return base64.urlsafe_b64decode(raw_message["raw"])
def _parse_email(self, raw_email) -> Dict:
def _parse_email(self, raw_email) -> dict:
mime_msg = message_from_bytes(raw_email)
return {
"subject": self._get_header(mime_msg, "Subject"),
@@ -124,7 +124,7 @@ class GmailLoader(BaseLoader):
return {"doc_id": self._generate_doc_id(query, data), "data": data}
@staticmethod
def _process_email(email: Dict) -> str:
def _process_email(email: dict) -> str:
content = BeautifulSoup(email["body"], "html.parser").get_text()
content = clean_string(content)
return dedent(
@@ -137,6 +137,6 @@ class GmailLoader(BaseLoader):
)
@staticmethod
def _generate_doc_id(query: str, data: List[Dict]) -> str:
def _generate_doc_id(query: str, data: list[dict]) -> str:
content_strings = [email["content"] for email in data]
return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()

View File

@@ -2,7 +2,7 @@ import hashlib
import json
import os
import re
from typing import Dict, List, Union
from typing import Union
import requests
@@ -16,14 +16,14 @@ class JSONReader:
pass
@staticmethod
def load_data(json_data: Union[Dict, str]) -> List[str]:
def load_data(json_data: Union[dict, str]) -> list[str]:
"""Load data from a JSON structure.
Args:
json_data (Union[Dict, str]): The JSON data to load.
json_data (Union[dict, str]): The JSON data to load.
Returns:
List[str]: A list of strings representing the leaf nodes of the JSON.
list[str]: A list of strings representing the leaf nodes of the JSON.
"""
if isinstance(json_data, str):
json_data = json.loads(json_data)

View File

@@ -1,13 +1,13 @@
import hashlib
import logging
from typing import Any, Dict, Optional
from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
class MySQLLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]]):
def __init__(self, config: Optional[dict[str, Any]]):
super().__init__()
if not config:
raise ValueError(
@@ -20,7 +20,7 @@ class MySQLLoader(BaseLoader):
self.cursor = None
self._setup_loader(config=config)
def _setup_loader(self, config: Dict[str, Any]):
def _setup_loader(self, config: dict[str, Any]):
try:
import mysql.connector as sqlconnector
except ImportError as e:

View File

@@ -1,7 +1,7 @@
import hashlib
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import requests
@@ -15,7 +15,7 @@ class NotionDocument:
A simple Document class to hold the text and additional information of a page.
"""
def __init__(self, text: str, extra_info: Dict[str, Any]):
def __init__(self, text: str, extra_info: dict[str, Any]):
self.text = text
self.extra_info = extra_info
@@ -82,7 +82,7 @@ class NotionPageLoader:
result_lines = "\n".join(result_lines_arr)
return result_lines
def load_data(self, page_ids: List[str]) -> List[NotionDocument]:
def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
"""Load data from the given list of page IDs."""
docs = []
for page_id in page_ids:

View File

@@ -1,12 +1,12 @@
import hashlib
import logging
from typing import Any, Dict, Optional
from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader
class PostgresLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(f"Must provide the valid config. Received: {config}")
@@ -15,7 +15,7 @@ class PostgresLoader(BaseLoader):
self.cursor = None
self._setup_loader(config=config)
def _setup_loader(self, config: Dict[str, Any]):
def _setup_loader(self, config: dict[str, Any]):
try:
import psycopg
except ImportError as e:

View File

@@ -2,7 +2,7 @@ import hashlib
import logging
import os
import ssl
from typing import Any, Dict, Optional
from typing import Any, Optional
import certifi
@@ -13,7 +13,7 @@ SLACK_API_BASE_URL = "https://www.slack.com/api/"
class SlackLoader(BaseLoader):
def __init__(self, config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
self.config = config if config else {}
@@ -24,7 +24,7 @@ class SlackLoader(BaseLoader):
self.client = None
self._setup_loader(self.config)
def _setup_loader(self, config: Dict[str, Any]):
def _setup_loader(self, config: dict[str, Any]):
try:
from slack_sdk import WebClient
except ImportError as e:

View File

@@ -2,7 +2,7 @@ import json
import logging
import sqlite3
import uuid
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from embedchain.constants import SQLITE_PATH
from embedchain.memory.message import ChatMessage
@@ -67,7 +67,7 @@ class ChatHistory:
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
self.connection.commit()
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
"""
Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id.
@@ -114,7 +114,7 @@ class ChatHistory:
return count
@staticmethod
def _serialize_json(metadata: Dict[str, Any]):
def _serialize_json(metadata: dict[str, Any]):
return json.dumps(metadata)
@staticmethod

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Optional
from typing import Any, Optional
from embedchain.helpers.json_serializable import JSONSerializable
@@ -18,9 +18,9 @@ class BaseMessage(JSONSerializable):
created_by: str
# Any additional info.
metadata: Dict[str, Any]
metadata: dict[str, Any]
def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, content: str, created_by: str, metadata: Optional[dict[str, Any]] = None) -> None:
super().__init__()
self.content = content
self.created_by = created_by

View File

@@ -1,16 +1,16 @@
from typing import Any, Dict, Optional
from typing import Any, Optional
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
def merge_metadata_dict(left: Optional[dict[str, Any]], right: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
"""
Merge the metadatas of two BaseMessage types.
Args:
left (Dict[str, Any]): metadata of human message
right (Dict[str, Any]): metadata of AI message
left (dict[str, Any]): metadata of human message
right (dict[str, Any]): metadata of AI message
Returns:
Dict[str, Any]: combined metadata dict with dedup
dict[str, Any]: combined metadata dict with dedup
to be saved in db.
"""
if not left and not right:

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
from chromadb import Collection, QueryResult
from langchain.docstore.document import Document
@@ -76,7 +76,7 @@ class ChromaDB(BaseVectorDB):
return self.client
@staticmethod
def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]:
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if len(where.keys()) <= 1:
@@ -105,18 +105,18 @@ class ChromaDB(BaseVectorDB):
)
return self.collection
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence
:type ids: List[str]
:type ids: list[str]
:param where: Optional. to filter data
:type where: Dict[str, Any]
:type where: dict[str, Any]
:param limit: Optional. maximum number of documents
:type limit: Optional[int]
:return: Existing documents.
:rtype: List[str]
:rtype: list[str]
"""
args = {}
if ids:
@@ -129,23 +129,23 @@ class ChromaDB(BaseVectorDB):
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, Any]],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, Any]],
) -> Any:
"""
Add vectors to chroma database
:param embeddings: list of embeddings to add
:type embeddings: List[List[str]]
:type embeddings: list[list[str]]
:param documents: Documents
:type documents: List[str]
:type documents: list[str]
:param metadatas: Metadatas
:type metadatas: List[object]
:type metadatas: list[object]
:param ids: ids
:type ids: List[str]
:type ids: list[str]
"""
size = len(documents)
if len(documents) != size or len(metadatas) != size or len(ids) != size:
@@ -182,27 +182,27 @@ class ChromaDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
Query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: to filter data
:type where: Dict[str, Any]
:type where: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
try:
result = self.collection.query(

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
try:
from elasticsearch import Elasticsearch
@@ -84,14 +84,14 @@ class ElasticsearchDB(BaseVectorDB):
def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later"""
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:type ids: list[str]
:param where: to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:return: ids
:rtype: Set[str]
"""
@@ -110,22 +110,22 @@ class ElasticsearchDB(BaseVectorDB):
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, any]],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, any]],
) -> Any:
"""
add data in vector database
:param embeddings: list of embeddings to add
:type embeddings: List[List[str]]
:type embeddings: list[list[str]]
:param documents: list of texts to add
:type documents: List[str]
:type documents: list[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:type metadatas: list[object]
:param ids: ids of docs
:type ids: List[str]
:type ids: list[str]
"""
embeddings = self.embedder.embedding_fn(documents)
@@ -154,27 +154,27 @@ class ElasticsearchDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:return: The context of the document that matched your query, url of the source, doc_id
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]

View File

@@ -1,6 +1,6 @@
import logging
import time
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Optional, Union
from tqdm import tqdm
@@ -78,17 +78,17 @@ class OpenSearchDB(BaseVectorDB):
"""Note: nothing to return here. Discuss later"""
def get(
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
) -> Set[str]:
self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None
) -> set[str]:
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:type ids: list[str]
:param where: to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:return: ids
:type: Set[str]
:type: set[str]
"""
query = {}
if ids:
@@ -116,19 +116,19 @@ class OpenSearchDB(BaseVectorDB):
def add(
self,
embeddings: List[List[str]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, any]],
embeddings: list[list[str]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, any]],
):
"""Add data in vector database.
Args:
embeddings (List[List[str]]): List of embeddings to add.
documents (List[str]): List of texts to add.
metadatas (List[object]): List of metadata associated with docs.
ids (List[str]): IDs of docs.
embeddings (list[list[str]]): list of embeddings to add.
documents (list[str]): list of texts to add.
metadatas (list[object]): list of metadata associated with docs.
ids (list[str]): IDs of docs.
"""
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
batch_end = batch_start + self.BATCH_SIZE
@@ -156,26 +156,26 @@ class OpenSearchDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
embeddings = OpenAIEmbeddings()
docsearch = OpenSearchVectorSearch(

View File

@@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union
try:
import pinecone
@@ -67,14 +67,14 @@ class PineconeDB(BaseVectorDB):
)
return pinecone.Index(self.index_name)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:type ids: list[str]
:param where: to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:return: ids
:rtype: Set[str]
"""
@@ -88,20 +88,20 @@ class PineconeDB(BaseVectorDB):
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, any]],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, any]],
):
"""add data in vector database
:param documents: list of texts to add
:type documents: List[str]
:type documents: list[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:type metadatas: list[object]
:param ids: ids of docs
:type ids: List[str]
:type ids: list[str]
"""
docs = []
print("Adding documents to Pinecone...")
@@ -120,25 +120,25 @@ class PineconeDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
query_vector = self.embedder.embedding_fn([input_query])[0]
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)

View File

@@ -1,7 +1,7 @@
import copy
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
try:
from qdrant_client import QdrantClient
@@ -69,14 +69,14 @@ class QdrantDB(BaseVectorDB):
def _get_or_create_collection(self):
return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:type ids: list[str]
:param where: to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:param limit: The number of entries to be fetched
:type limit: Optional int, defaults to None
:return: All the existing IDs
@@ -122,21 +122,21 @@ class QdrantDB(BaseVectorDB):
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, any]],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, any]],
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
:type documents: List[List[float]]
:type documents: list[list[float]]
:param documents: list of texts to add
:type documents: List[str]
:type documents: list[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:type metadatas: list[object]
:param ids: ids of docs
:type ids: List[str]
:type ids: list[str]
"""
embeddings = self.embedder.embedding_fn(documents)
@@ -159,25 +159,25 @@ class QdrantDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set())

View File

@@ -1,6 +1,6 @@
import copy
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
try:
import weaviate
@@ -117,13 +117,13 @@ class WeaviateDB(BaseVectorDB):
self.client.schema.create(class_obj)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existance
:type ids: List[str]
:type ids: list[str]
:param where: to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:return: ids
:rtype: Set[str]
"""
@@ -153,21 +153,21 @@ class WeaviateDB(BaseVectorDB):
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, any]],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, any]],
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
:type documents: List[List[float]]
:type documents: list[list[float]]
:param documents: list of texts to add
:type documents: List[str]
:type documents: list[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:type metadatas: list[object]
:param ids: ids of docs
:type ids: List[str]
:type ids: list[str]
"""
embeddings = self.embedder.embedding_fn(documents)
self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
@@ -192,25 +192,25 @@ class WeaviateDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set())

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
from embedchain.config import ZillizDBConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -88,14 +88,14 @@ class ZillizVectorDB(BaseVectorDB):
self.collection.create_index("embeddings", index)
return self.collection
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence
:type ids: List[str]
:type ids: list[str]
:param where: Optional. to filter data
:type where: Dict[str, Any]
:type where: dict[str, Any]
:param limit: Optional. maximum number of documents
:type limit: Optional[int]
:return: Existing documents.
@@ -115,11 +115,11 @@ class ZillizVectorDB(BaseVectorDB):
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
**kwargs: Optional[Dict[str, any]],
embeddings: list[list[float]],
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, any]],
):
"""Add to database"""
embeddings = self.embedder.embedding_fn(documents)
@@ -134,17 +134,17 @@ class ZillizVectorDB(BaseVectorDB):
def query(
self,
input_query: List[str],
input_query: list[str],
n_results: int,
where: Dict[str, any],
where: dict[str, any],
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
Query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: to filter data
@@ -154,7 +154,7 @@ class ZillizVectorDB(BaseVectorDB):
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
if self.collection.is_empty:
@@ -200,7 +200,7 @@ class ZillizVectorDB(BaseVectorDB):
"""
return self.collection.num_entities
def reset(self, collection_names: List[str] = None):
def reset(self, collection_names: list[str] = None):
"""
Resets the database. Deletes all embeddings irreversibly.
"""