docs: update docstrings (#565)

This commit is contained in:
cachho
2023-09-07 02:04:44 +02:00
committed by GitHub
parent 4754372fcd
commit 1ac8aef4de
25 changed files with 736 additions and 298 deletions

View File

@@ -12,12 +12,13 @@ from embedchain.vectordb.chroma_db import ChromaDB
@register_deserializable @register_deserializable
class App(EmbedChain): class App(EmbedChain):
""" """
The EmbedChain app. The EmbedChain app in it's simplest and most straightforward form.
Has two functions: add and query. An opinionated choice of LLM, vector database and embedding model.
adds(data_type, url): adds the data from the given URL to the vector db. Methods:
add(source, data_type): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM. query(query): finds answer to the given query using vector database and LLM.
dry_run(query): test your prompt without consuming tokens. chat(query): finds answer to the given query using vector database and LLM, with conversation history.
""" """
def __init__( def __init__(
@@ -28,8 +29,20 @@ class App(EmbedChain):
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
): ):
""" """
:param config: AppConfig instance to load as configuration. Optional. Initialize a new `CustomApp` instance. You only have a few choices to make.
:param system_prompt: System prompt string. Optional.
:param config: Config for the app instance.
This is the most basic configuration, that does not fall into the LLM, database or embedder category,
defaults to None
:type config: AppConfig, optional
:param llm_config: Allows you to configure the LLM, e.g. how many documents to return,
example: `from embedchain.config import LlmConfig`, defaults to None
:type llm_config: BaseLlmConfig, optional
:param chromadb_config: Allows you to configure the vector database,
example: `from embedchain.config import ChromaDbConfig`, defaults to None
:type chromadb_config: Optional[ChromaDbConfig], optional
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
:type system_prompt: Optional[str], optional
""" """
if config is None: if config is None:
config = AppConfig() config = AppConfig()

View File

@@ -11,26 +11,42 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
@register_deserializable @register_deserializable
class CustomApp(EmbedChain): class CustomApp(EmbedChain):
""" """
The custom EmbedChain app. Embedchain's custom app allows for most flexibility.
Has two functions: add and query.
adds(data_type, url): adds the data from the given URL to the vector db. You can craft your own mix of various LLMs, vector databases and embedding model/functions.
Methods:
add(source, data_type): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM. query(query): finds answer to the given query using vector database and LLM.
dry_run(query): test your prompt without consuming tokens. chat(query): finds answer to the given query using vector database and LLM, with conversation history.
""" """
def __init__( def __init__(
self, self,
config: CustomAppConfig = None, config: Optional[CustomAppConfig] = None,
llm: BaseLlm = None, llm: BaseLlm = None,
db: BaseVectorDB = None, db: BaseVectorDB = None,
embedder: BaseEmbedder = None, embedder: BaseEmbedder = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
): ):
""" """
:param config: Optional. `CustomAppConfig` instance to load as configuration. Initialize a new `CustomApp` instance. You have to choose a LLM, database and embedder.
:raises ValueError: Config must be provided for custom app
:param system_prompt: Optional. System prompt string. :param config: Config for the app instance. This is the most basic configuration,
that does not fall into the LLM, database or embedder category, defaults to None
:type config: Optional[CustomAppConfig], optional
:param llm: LLM Class instance. example: `from embedchain.llm.openai_llm import OpenAiLlm`, defaults to None
:type llm: BaseLlm
:param db: The database to use for storing and retrieving embeddings,
example: `from embedchain.vectordb.chroma_db import ChromaDb`, defaults to None
:type db: BaseVectorDB
:param embedder: The embedder (embedding model and function) use to calculate embeddings.
example: `from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder`, defaults to None
:type embedder: BaseEmbedder
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
:type system_prompt: Optional[str], optional
:raises ValueError: LLM, database or embedder has not been defined.
:raises TypeError: LLM, database or embedder is not a valid class instance.
""" """
# Config is not required, it has a default # Config is not required, it has a default
if config is None: if config is None:

View File

@@ -12,10 +12,11 @@ from embedchain.vectordb.chroma_db import ChromaDB
class Llama2App(CustomApp): class Llama2App(CustomApp):
""" """
The EmbedChain Llama2App class. The EmbedChain Llama2App class.
Has two functions: add and query.
adds(data_type, url): adds the data from the given URL to the vector db. Methods:
add(source, data_type): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM. query(query): finds answer to the given query using vector database and LLM.
chat(query): finds answer to the given query using vector database and LLM, with conversation history.
""" """
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None): def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):

View File

@@ -15,43 +15,64 @@ gpt4all_model = None
@register_deserializable @register_deserializable
class OpenSourceApp(EmbedChain): class OpenSourceApp(EmbedChain):
""" """
The OpenSource app. The embedchain Open Source App.
Same as App, but uses an open source embedding model and LLM. Comes preconfigured with the best open source LLM, embedding model, database.
Has two function: add and query. Methods:
add(source, data_type): adds the data from the given URL to the vector db.
adds(data_type, url): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM. query(query): finds answer to the given query using vector database and LLM.
chat(query): finds answer to the given query using vector database and LLM, with conversation history.
""" """
def __init__( def __init__(
self, self,
config: OpenSourceAppConfig = None, config: OpenSourceAppConfig = None,
llm_config: BaseLlmConfig = None,
chromadb_config: Optional[ChromaDbConfig] = None, chromadb_config: Optional[ChromaDbConfig] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
): ):
""" """
:param config: OpenSourceAppConfig instance to load as configuration. Optional. Initialize a new `CustomApp` instance.
`ef` defaults to open source. Since it's opinionated you don't have to choose a LLM, database and embedder.
:param system_prompt: System prompt string. Optional. However, you can configure those.
:param config: Config for the app instance. This is the most basic configuration,
that does not fall into the LLM, database or embedder category, defaults to None
:type config: OpenSourceAppConfig, optional
:param llm_config: Allows you to configure the LLM, e.g. how many documents to return.
example: `from embedchain.config import LlmConfig`, defaults to None
:type llm_config: BaseLlmConfig, optional
:param chromadb_config: Allows you to configure the open source database,
example: `from embedchain.config import ChromaDbConfig`, defaults to None
:type chromadb_config: Optional[ChromaDbConfig], optional
:param system_prompt: System prompt that will be provided to the LLM as such.
Please don't use for the time being, as it's not supported., defaults to None
:type system_prompt: Optional[str], optional
:raises TypeError: `OpenSourceAppConfig` or `LlmConfig` invalid.
""" """
logging.info("Loading open source embedding model. This may take some time...") # noqa:E501 logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
if not config: if not config:
config = OpenSourceAppConfig() config = OpenSourceAppConfig()
if not isinstance(config, OpenSourceAppConfig): if not isinstance(config, OpenSourceAppConfig):
raise ValueError( raise TypeError(
"OpenSourceApp needs a OpenSourceAppConfig passed to it. " "OpenSourceApp needs a OpenSourceAppConfig passed to it. "
"You can import it with `from embedchain.config import OpenSourceAppConfig`" "You can import it with `from embedchain.config import OpenSourceAppConfig`"
) )
if not config.model: if not llm_config:
raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?") llm_config = BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin")
elif not isinstance(llm_config, BaseLlmConfig):
raise TypeError(
"The LlmConfig passed to OpenSourceApp is invalid. "
"You can import it with `from embedchain.config import LlmConfig`"
)
elif not llm_config.model:
llm_config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
logging.info("Successfully loaded open source embedding model.") llm = GPT4ALLLlm(config=llm_config)
llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin"))
embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2")) embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2"))
logging.error("Successfully loaded open source embedding model.")
database = ChromaDB(config=chromadb_config) database = ChromaDB(config=chromadb_config)
super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt) super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt)

View File

@@ -19,7 +19,14 @@ class EmbedChainPersonApp:
:param config: BaseAppConfig instance to load as configuration. :param config: BaseAppConfig instance to load as configuration.
""" """
def __init__(self, person, config: BaseAppConfig = None): def __init__(self, person: str, config: BaseAppConfig = None):
"""Initialize a new person app
:param person: Name of the person that's imitated.
:type person: str
:param config: Configuration class instance, defaults to None
:type config: BaseAppConfig, optional
"""
self.person = person self.person = person
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501 self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
super().__init__(config) super().__init__(config)
@@ -30,9 +37,12 @@ class EmbedChainPersonApp:
if yes it adds the person prompt to it and return the updated config if yes it adds the person prompt to it and return the updated config
else it creates a config object with the default prompt added to the person prompt else it creates a config object with the default prompt added to the person prompt
:param default_prompt: it is the default prompt for query or chat methods :param default_prompt: it is the default prompt for query or chat methods
:param config: Optional. The `ChatConfig` instance to use as :type default_prompt: str
configuration options. :param config: _description_, defaults to None
:type config: BaseLlmConfig, optional
:return: The `ChatConfig` instance to use as configuration options.
:rtype: _type_
""" """
template = Template(self.person_prompt + " " + default_prompt) template = Template(self.person_prompt + " " + default_prompt)

View File

@@ -1,3 +1,5 @@
from typing import Any
from embedchain import CustomApp from embedchain import CustomApp
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
from embedchain.embedder.openai_embedder import OpenAiEmbedder from embedchain.embedder.openai_embedder import OpenAiEmbedder
@@ -12,13 +14,30 @@ class BaseBot(JSONSerializable):
def __init__(self): def __init__(self):
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder()) self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder())
def add(self, data, config: AddConfig = None): def add(self, data: Any, config: AddConfig = None):
"""Add data to the bot""" """
Add data to the bot (to the vector database).
Auto-dectects type only, so some data types might not be usable.
:param data: data to embed
:type data: Any
:param config: configuration class instance, defaults to None
:type config: AddConfig, optional
"""
config = config if config else AddConfig() config = config if config else AddConfig()
self.app.add(data, config=config) self.app.add(data, config=config)
def query(self, query, config: LlmConfig = None): def query(self, query: str, config: LlmConfig = None) -> str:
"""Query bot""" """
Query the bot
:param query: the user query
:type query: str
:param config: configuration class instance, defaults to None
:type config: LlmConfig, optional
:return: Answer
:rtype: str
"""
config = config config = config
return self.app.query(query, config=config) return self.app.query(query, config=config)

View File

@@ -42,5 +42,13 @@ class AddConfig(BaseConfig):
chunker: Optional[ChunkerConfig] = None, chunker: Optional[ChunkerConfig] = None,
loader: Optional[LoaderConfig] = None, loader: Optional[LoaderConfig] = None,
): ):
"""
Initializes a configuration class instance for the `add` method.
:param chunker: Chunker config, defaults to None
:type chunker: Optional[ChunkerConfig], optional
:param loader: Loader config, defaults to None
:type loader: Optional[LoaderConfig], optional
"""
self.loader = loader self.loader = loader
self.chunker = chunker self.chunker = chunker

View File

@@ -1,3 +1,5 @@
from typing import Any, Dict
from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.helper_classes.json_serializable import JSONSerializable
@@ -7,7 +9,13 @@ class BaseConfig(JSONSerializable):
""" """
def __init__(self): def __init__(self):
"""Initializes a configuration class for a class."""
pass pass
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
"""Return config object as a dict
:return: config object as dict
:rtype: Dict[str, Any]
"""
return vars(self) return vars(self)

View File

@@ -13,15 +13,23 @@ class AppConfig(BaseAppConfig):
def __init__( def __init__(
self, self,
log_level=None, log_level: str = "WARNING",
id=None, id: Optional[str] = None,
collect_metrics: Optional[bool] = None, collect_metrics: Optional[bool] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
""" """
:param log_level: Optional. (String) Debug level Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. Most of the configuration is done in the `App` class itself.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:type log_level: str, optional
:param id: ID of the app. Document metadata will have this id., defaults to None
:type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
""" """
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name) super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)

View File

@@ -13,23 +13,28 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
def __init__( def __init__(
self, self,
log_level=None, log_level: str = "WARNING",
db: Optional[BaseVectorDB] = None, db: Optional[BaseVectorDB] = None,
id=None, id: Optional[str] = None,
collect_metrics: bool = True, collect_metrics: bool = True,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
""" """
:param log_level: Optional. (String) Debug level Initializes a configuration class instance for an App.
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. Most of the configuration is done in the `App` class itself.
:param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db).
:param id: Optional. ID of the app. Document metadata will have this id. :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. :type log_level: str, optional
:param db_type: Optional. Initializes a default vector database of the given type. :param db: A database class. It is recommended to set this directly in the `App` class, not this config,
Using the `db` argument is preferred. defaults to None
:param es_config: Optional. elasticsearch database config to be used for connection :type db: Optional[BaseVectorDB], optional
:param collection_name: Optional. Default collection name. :param id: ID of the app. Document metadata will have this id., defaults to None
It's recommended to use app.set_collection_name() instead. :type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
""" """
self._setup_logging(log_level) self._setup_logging(log_level)
self.id = id self.id = id

View File

@@ -3,6 +3,7 @@ from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.vectordb.base_vector_db import BaseVectorDB
from .BaseAppConfig import BaseAppConfig from .BaseAppConfig import BaseAppConfig
@@ -17,24 +18,29 @@ class CustomAppConfig(BaseAppConfig):
def __init__( def __init__(
self, self,
log_level=None, log_level: str = "WARNING",
db=None, db: Optional[BaseVectorDB] = None,
id=None, id: Optional[str] = None,
collect_metrics: Optional[bool] = None, collect_metrics: Optional[bool] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
""" """
:param log_level: Optional. (String) Debug level Initializes a configuration class instance for an Custom App.
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. Most of the configuration is done in the `CustomApp` class itself.
:param db: Optional. (Vector) database to use for embeddings.
:param id: Optional. ID of the app. Document metadata will have this id.
:param provider: Optional. (Providers): LLM Provider to use.
:param open_source_app_config: Optional. Config instance needed for open source apps.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param collection_name: Optional. Default collection name.
It's recommended to use app.set_collection_name() instead.
"""
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:type log_level: str, optional
:param db: A database class. It is recommended to set this directly in the `CustomApp` class, not this config,
defaults to None
:type db: Optional[BaseVectorDB], optional
:param id: ID of the app. Document metadata will have this id., defaults to None
:type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
"""
super().__init__( super().__init__(
log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name
) )

View File

@@ -13,21 +13,27 @@ class OpenSourceAppConfig(BaseAppConfig):
def __init__( def __init__(
self, self,
log_level=None, log_level: str = "WARNING",
id=None, id: Optional[str] = None,
collect_metrics: Optional[bool] = None, collect_metrics: Optional[bool] = None,
model=None, model: str = "orca-mini-3b.ggmlv3.q4_0.bin",
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
""" """
:param log_level: Optional. (String) Debug level Initializes a configuration class instance for an Open Source App.
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id. :param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. :type log_level: str, optional
:param model: Optional. GPT4ALL uses the model to instantiate the class. :param id: ID of the app. Document metadata will have this id., defaults to None
So unlike `App`, it has to be provided before querying. :type id: Optional[str], optional
:param collection_name: Optional. Default collection name. :param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
It's recommended to use app.db.set_collection_name() instead. :type collect_metrics: Optional[bool], optional
:param model: GPT4ALL uses the model to instantiate the class.
Unlike `App`, it has to be provided before querying, defaults to "orca-mini-3b.ggmlv3.q4_0.bin"
:type model: str, optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
""" """
self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin" self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"

View File

@@ -6,5 +6,13 @@ from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable @register_deserializable
class BaseEmbedderConfig: class BaseEmbedderConfig:
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None): def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
"""
Initialize a new instance of an embedder config class.
:param model: model name of the llm embedding model (not applicable to all providers), defaults to None
:type model: Optional[str], optional
:param deployment_name: deployment name for llm embedding model, defaults to None
:type deployment_name: Optional[str], optional
"""
self.model = model self.model = model
self.deployment_name = deployment_name self.deployment_name = deployment_name

View File

@@ -1,6 +1,6 @@
import re import re
from string import Template from string import Template
from typing import Optional from typing import Any, Dict, Optional
from embedchain.config.BaseConfig import BaseConfig from embedchain.config.BaseConfig import BaseConfig
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
@@ -57,51 +57,59 @@ class BaseLlmConfig(BaseConfig):
def __init__( def __init__(
self, self,
number_documents=None, number_documents: int = 1,
template: Template = None, template: Optional[Template] = None,
model=None, model: Optional[str] = None,
temperature=None, temperature: float = 0,
max_tokens=None, max_tokens: int = 1000,
top_p=None, top_p: float = 1,
stream: bool = False, stream: bool = False,
deployment_name=None, deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
where=None, where: Dict[str, Any] = None,
): ):
""" """
Initializes the QueryConfig instance. Initializes a configuration class instance for the LLM.
:param number_documents: Number of documents to pull from the database as Takes the place of the former `QueryConfig` or `ChatConfig`.
context. Use `LlmConfig` as an alias to `BaseLlmConfig`.
:param template: Optional. The `Template` instance to use as a template for
prompt. :param number_documents: Number of documents to pull from the database as
:param model: Optional. Controls the OpenAI model used. context, defaults to 1
:param temperature: Optional. Controls the randomness of the model's output. :type number_documents: int, optional
Higher values (closer to 1) make output more random, lower values make it more :param template: The `Template` instance to use as a template for
deterministic. prompt, defaults to None
:param max_tokens: Optional. Controls how many tokens are generated. :type template: Optional[Template], optional
:param top_p: Optional. Controls the diversity of words. Higher values :param model: Controls the OpenAI model used, defaults to None
(closer to 1) make word selection more diverse, lower values make words less :type model: Optional[str], optional
diverse. :param temperature: Controls the randomness of the model's output.
:param stream: Optional. Control if response is streamed back to user Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
:param deployment_name: t.b.a. :type temperature: float, optional
:param system_prompt: Optional. System prompt string. :param max_tokens: Controls how many tokens are generated, defaults to 1000
:param where: Optional. A dictionary of key-value pairs to filter the database results. :type max_tokens: int, optional
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
defaults to 1
:type top_p: float, optional
:param stream: Control if response is streamed back to user, defaults to False
:type stream: bool, optional
:param deployment_name: t.b.a., defaults to None
:type deployment_name: Optional[str], optional
:param system_prompt: System prompt string, defaults to None
:type system_prompt: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Dict[str, Any], optional
:raises ValueError: If the template is not valid as template should :raises ValueError: If the template is not valid as template should
contain $context and $query (and optionally $history). contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean
""" """
if number_documents is None:
self.number_documents = 1
else:
self.number_documents = number_documents
if template is None: if template is None:
template = DEFAULT_PROMPT_TEMPLATE template = DEFAULT_PROMPT_TEMPLATE
self.temperature = temperature if temperature else 0 self.number_documents = number_documents
self.max_tokens = max_tokens if max_tokens else 1000 self.temperature = temperature
self.max_tokens = max_tokens
self.model = model self.model = model
self.top_p = top_p if top_p else 1 self.top_p = top_p
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.system_prompt = system_prompt self.system_prompt = system_prompt
@@ -115,20 +123,24 @@ class BaseLlmConfig(BaseConfig):
self.stream = stream self.stream = stream
self.where = where self.where = where
def validate_template(self, template: Template): def validate_template(self, template: Template) -> bool:
""" """
validate the template validate the template
:param template: the template to validate :param template: the template to validate
:return: Boolean, valid (true) or invalid (false) :type template: Template
:return: valid (true) or invalid (false)
:rtype: bool
""" """
return re.search(query_re, template.template) and re.search(context_re, template.template) return re.search(query_re, template.template) and re.search(context_re, template.template)
def _validate_template_history(self, template: Template): def _validate_template_history(self, template: Template) -> bool:
""" """
validate the history template for history validate the template with history
:param template: the template to validate :param template: the template to validate
:return: Boolean, valid (true) or invalid (false) :type template: Template
:return: valid (true) or invalid (false)
:rtype: bool
""" """
return re.search(history_re, template.template) return re.search(history_re, template.template)

View File

@@ -7,11 +7,23 @@ class BaseVectorDbConfig(BaseConfig):
def __init__( def __init__(
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: str = "db",
host: Optional[str] = None, host: Optional[str] = None,
port: Optional[str] = None, port: Optional[str] = None,
): ):
"""
Initializes a configuration class instance for the vector database.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to "db"
:type dir: str, optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
"""
self.collection_name = collection_name or "embedchain_store" self.collection_name = collection_name or "embedchain_store"
self.dir = dir or "db" self.dir = dir
self.host = host self.host = host
self.port = port self.port = port

View File

@@ -14,6 +14,20 @@ class ChromaDbConfig(BaseVectorDbConfig):
port: Optional[str] = None, port: Optional[str] = None,
chroma_settings: Optional[dict] = None, chroma_settings: Optional[dict] = None,
): ):
"""
Initializes a configuration class instance for ChromaDB.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param chroma_settings: Chroma settings dict, defaults to None
:type chroma_settings: Optional[dict], optional
"""
""" """
:param chroma_settings: Optional. Chroma settings for connection. :param chroma_settings: Optional. Chroma settings for connection.
""" """

View File

@@ -14,9 +14,16 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
**ES_EXTRA_PARAMS: Dict[str, any], **ES_EXTRA_PARAMS: Dict[str, any],
): ):
""" """
Config to initialize an elasticsearch client. Initializes a configuration class instance for an Elasticsearch client.
:param es_url. elasticsearch url or list of nodes url to be used for connection
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, List[str]], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: Dict[str, Any], optional
""" """
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url self.ES_URL = es_url

View File

@@ -1,3 +1,4 @@
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.notion import NotionChunker
@@ -8,7 +9,9 @@ from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config import AddConfig from embedchain.config import AddConfig
from embedchain.config.AddConfig import ChunkerConfig, LoaderConfig
from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.csv import CsvLoader from embedchain.loaders.csv import CsvLoader
from embedchain.loaders.docs_site_loader import DocsSiteLoader from embedchain.loaders.docs_site_loader import DocsSiteLoader
from embedchain.loaders.docx_file import DocxFileLoader from embedchain.loaders.docx_file import DocxFileLoader
@@ -29,16 +32,28 @@ class DataFormatter(JSONSerializable):
""" """
def __init__(self, data_type: DataType, config: AddConfig): def __init__(self, data_type: DataType, config: AddConfig):
self.loader = self._get_loader(data_type, config.loader) """
self.chunker = self._get_chunker(data_type, config.chunker) Initialize a dataformatter, set data type and chunker based on datatype.
def _get_loader(self, data_type: DataType, config): :param data_type: The type of the data to load and chunk.
:type data_type: DataType
:param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig
"""
self.loader = self._get_loader(data_type=data_type, config=config.loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
""" """
Returns the appropriate data loader for the given data type. Returns the appropriate data loader for the given data type.
:param data_type: The type of the data to load. :param data_type: The type of the data to load.
:return: The loader for the given data type. :type data_type: DataType
:param config: Config to initialize the loader with.
:type config: LoaderConfig
:raises ValueError: If an unsupported data type is provided. :raises ValueError: If an unsupported data type is provided.
:return: The loader for the given data type.
:rtype: BaseLoader
""" """
loaders = { loaders = {
DataType.YOUTUBE_VIDEO: YoutubeVideoLoader, DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
@@ -53,8 +68,8 @@ class DataFormatter(JSONSerializable):
} }
lazy_loaders = {DataType.NOTION} lazy_loaders = {DataType.NOTION}
if data_type in loaders: if data_type in loaders:
loader_class = loaders[data_type] loader_class: type = loaders[data_type]
loader = loader_class() loader: BaseLoader = loader_class()
return loader return loader
elif data_type in lazy_loaders: elif data_type in lazy_loaders:
if data_type == DataType.NOTION: if data_type == DataType.NOTION:
@@ -66,13 +81,16 @@ class DataFormatter(JSONSerializable):
else: else:
raise ValueError(f"Unsupported data type: {data_type}") raise ValueError(f"Unsupported data type: {data_type}")
def _get_chunker(self, data_type: DataType, config): def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
""" """Returns the appropriate chunker for the given data type.
Returns the appropriate chunker for the given data type.
:param data_type: The type of the data to chunk. :param data_type: The type of the data to chunk.
:return: The chunker for the given data type. :type data_type: DataType
:param config: Config to initialize the chunker with.
:type config: ChunkerConfig
:raises ValueError: If an unsupported data type is provided. :raises ValueError: If an unsupported data type is provided.
:return: The chunker for the given data type.
:rtype: BaseChunker
""" """
chunker_classes = { chunker_classes = {
DataType.YOUTUBE_VIDEO: YoutubeVideoChunker, DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
@@ -87,8 +105,8 @@ class DataFormatter(JSONSerializable):
DataType.CSV: TableChunker, DataType.CSV: TableChunker,
} }
if data_type in chunker_classes: if data_type in chunker_classes:
chunker_class = chunker_classes[data_type] chunker_class: type = chunker_classes[data_type]
chunker = chunker_class(config) chunker: BaseChunker = chunker_class(config)
chunker.set_data_type(data_type) chunker.set_data_type(data_type)
return chunker return chunker
else: else:

View File

@@ -6,11 +6,10 @@ import os
import threading import threading
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Any, Dict, List, Optional, Tuple
import requests import requests
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document
from tenacity import retry, stop_after_attempt, wait_fixed from tenacity import retry, stop_after_attempt, wait_fixed
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
@@ -46,8 +45,17 @@ class EmbedChain(JSONSerializable):
Initializes the EmbedChain instance, sets up a vector DB client and Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection. creates a collection.
:param config: BaseAppConfig instance to load as configuration. :param config: Configuration just for the app, not the db or llm or embedder.
:param system_prompt: Optional. System prompt string. :type config: BaseAppConfig
:param llm: Instance of the LLM you want to use.
:type llm: BaseLlm
:param db: Instance of the Database to use, defaults to None
:type db: BaseVectorDB, optional
:param embedder: instance of the embedder to use, defaults to None
:type embedder: BaseEmbedder, optional
:param system_prompt: System prompt to use in the llm query, defaults to None
:type system_prompt: Optional[str], optional
:raises ValueError: No database or embedder provided.
""" """
self.config = config self.config = config
@@ -88,10 +96,13 @@ class EmbedChain(JSONSerializable):
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",)) thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
thread_telemetry.start() thread_telemetry.start()
def _load_or_generate_user_id(self): def _load_or_generate_user_id(self) -> str:
""" """
Loads the user id from the config file if it exists, otherwise generates a new Loads the user id from the config file if it exists, otherwise generates a new
one and saves it to the config file. one and saves it to the config file.
:return: user id
:rtype: str
""" """
if not os.path.exists(CONFIG_DIR): if not os.path.exists(CONFIG_DIR):
os.makedirs(CONFIG_DIR) os.makedirs(CONFIG_DIR)
@@ -110,9 +121,9 @@ class EmbedChain(JSONSerializable):
def add( def add(
self, self,
source, source: Any,
data_type: Optional[DataType] = None, data_type: Optional[DataType] = None,
metadata: Optional[Dict] = None, metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None, config: Optional[AddConfig] = None,
): ):
""" """
@@ -121,12 +132,17 @@ class EmbedChain(JSONSerializable):
and then stores the embedding to vector database. and then stores the embedding to vector database.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type. :param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument. :type source: Any
The type of the data to add. :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
:param metadata: Optional. Metadata associated with the data source. defaults to None
:param config: Optional. The `AddConfig` instance to use as configuration :type data_type: Optional[DataType], optional
options. :param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:return: source_id, a md5-hash of the source, in hexadecimal representation. :return: source_id, a md5-hash of the source, in hexadecimal representation.
:rtype: str
""" """
if config is None: if config is None:
config = AddConfig() config = AddConfig()
@@ -177,39 +193,62 @@ class EmbedChain(JSONSerializable):
return source_id return source_id
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None): def add_local(
self,
source: Any,
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
):
""" """
Warning:
This method is deprecated and will be removed in future versions. Use `add` instead.
Adds the data from the given URL to the vector db. Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database. and then stores the embedding to vector database.
Warning:
This method is deprecated and will be removed in future versions. Use `add` instead.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type. :param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument. :type source: Any
The type of the data to add. :param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
:param metadata: Optional. Metadata associated with the data source. defaults to None
:param config: Optional. The `AddConfig` instance to use as configuration :type data_type: Optional[DataType], optional
options. :param metadata: Metadata associated with the data source., defaults to None
:return: md5-hash of the source, in hexadecimal representation. :type metadata: Optional[Dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:rtype: str
""" """
logging.warning( logging.warning(
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501 "The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
) )
return self.add(source=source, data_type=data_type, metadata=metadata, config=config) return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None): def load_and_embed(
""" self,
Loads the data from the given URL, chunks it, and adds it to database. loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
"""The loader to use to load the data.
:param loader: The loader to use to load the data. :param loader: The loader to use to load the data.
:type loader: BaseLoader
:param chunker: The chunker to use to chunk the data. :param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for :type chunker: BaseChunker
remote sources or local content for local loaders. :param src: The data to be handled by the loader.
:param metadata: Optional. Metadata associated with the data source. Can be a URL for remote sources or local content for local loaders.
:param source_id: Hexadecimal hash of the source. :type src: Any
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Dict[str, Any], optional
:param source_id: Hexadecimal hash of the source., defaults to None
:type source_id: str, optional
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
:rtype: Tuple[List[str], Dict[str, Any], List[str], int]
""" """
embeddings_data = chunker.create_chunks(loader, src) embeddings_data = chunker.create_chunks(loader, src)
@@ -264,25 +303,19 @@ class EmbedChain(JSONSerializable):
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results): def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query Gets relevant doc based on the query
:param input_query: The query to use. :param input_query: The query to use.
:param config: The query configuration. :type input_query: str
:param where: Optional. A dictionary of key-value pairs to filter the database results. :param config: The query configuration, defaults to None
:return: The content of the document that matched your query. :type config: Optional[BaseLlmConfig], optional
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
:type where: _type_, optional
:return: List of contents of the document that matched your query
:rtype: List[str]
""" """
query_config = config or self.llm.config query_config = config or self.llm.config
@@ -304,23 +337,24 @@ class EmbedChain(JSONSerializable):
return contents return contents
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None): def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer. LLM as context to get the answer.
:param input_query: The query to use. :param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options. :type input_query: str
This is used for one method call. To persistently use a config, declare it during app init. :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to To persistently use a config, declare it during app init., defaults to None
the LLM. The purpose is to test the prompt, not the response. :type config: Optional[BaseLlmConfig], optional
You can use it to test your prompt, including the context provided :param dry_run: A dry run does everything except send the resulting prompt to
by the vector database's doc retrieval. the LLM. The purpose is to test the prompt, not the response., defaults to False
The only thing the dry run does not consider is the cut-off due to :type dry_run: bool, optional
the `max_tokens` parameter. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:param where: Optional. A dictionary of key-value pairs to filter the database results. :type where: Optional[Dict[str, str]], optional
:return: The answer to the query. :return: The answer to the query or the dry run result
:rtype: str
""" """
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
@@ -331,24 +365,32 @@ class EmbedChain(JSONSerializable):
return answer return answer
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None): def chat(
self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
dry_run=False,
where: Optional[Dict[str, str]] = None,
) -> str:
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer. LLM as context to get the answer.
Maintains the whole conversation in memory. Maintains the whole conversation in memory.
:param input_query: The query to use. :param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options. :type input_query: str
This is used for one method call. To persistently use a config, declare it during app init. :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to To persistently use a config, declare it during app init., defaults to None
the LLM. The purpose is to test the prompt, not the response. :type config: Optional[BaseLlmConfig], optional
You can use it to test your prompt, including the context provided :param dry_run: A dry run does everything except send the resulting prompt to
by the vector database's doc retrieval. the LLM. The purpose is to test the prompt, not the response., defaults to False
The only thing the dry run does not consider is the cut-off due to :type dry_run: bool, optional
the `max_tokens` parameter. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:param where: Optional. A dictionary of key-value pairs to filter the database results. :type where: Optional[Dict[str, str]], optional
:return: The answer to the query. :return: The answer to the query or the dry run result
:rtype: str
""" """
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
@@ -359,15 +401,18 @@ class EmbedChain(JSONSerializable):
return answer return answer
def set_collection(self, collection_name): def set_collection_name(self, name: str):
""" """
Set the collection to use. Set the name of the collection. A collection is an isolated space for vectors.
:param collection_name: The name of the collection to use. Using `app.db.set_collection_name` method is preferred to this.
:param name: Name of the collection.
:type name: str
""" """
self.db.set_collection_name(collection_name) self.db.set_collection_name(name)
# Create the collection if it does not exist # Create the collection if it does not exist
self.db._get_or_create_collection(collection_name) self.db._get_or_create_collection(name)
# TODO: Check whether it is necessary to assign to the `self.collection` attribute, # TODO: Check whether it is necessary to assign to the `self.collection` attribute,
# since the main purpose is the creation. # since the main purpose is the creation.
@@ -378,8 +423,9 @@ class EmbedChain(JSONSerializable):
DEPRECATED IN FAVOR OF `db.count()` DEPRECATED IN FAVOR OF `db.count()`
:return: The number of embeddings. :return: The number of embeddings.
:rtype: int
""" """
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.") logging.warning("DEPRECATION WARNING: Please use `app.db.count()` instead of `app.count()`.")
return self.db.count() return self.db.count()
def reset(self): def reset(self):
@@ -393,11 +439,14 @@ class EmbedChain(JSONSerializable):
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",)) thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
thread_telemetry.start() thread_telemetry.start()
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.") logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
self.db.reset() self.db.reset()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) @retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None): def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
"""
Send telemetry event to the embedchain server. This is anonymous. It can be toggled off in `AppConfig`.
"""
if not self.config.collect_metrics: if not self.config.collect_metrics:
return return

View File

@@ -19,7 +19,13 @@ class BaseEmbedder:
To manually overwrite you can use this classes `set_...` methods. To manually overwrite you can use this classes `set_...` methods.
""" """
def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
"""
Intialize the embedder class.
:param config: embedder configuration option class, defaults to None
:type config: Optional[BaseEmbedderConfig], optional
"""
if config is None: if config is None:
self.config = BaseEmbedderConfig() self.config = BaseEmbedderConfig()
else: else:
@@ -27,17 +33,35 @@ class BaseEmbedder:
self.vector_dimension: int self.vector_dimension: int
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]): def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
"""
Set or overwrite the embedding function to be used by the database to store and retrieve documents.
:param embedding_fn: Function to be used to generate embeddings.
:type embedding_fn: Callable[[list[str]], list[str]]
:raises ValueError: Embedding function is not callable.
"""
if not hasattr(embedding_fn, "__call__"): if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function") raise ValueError("Embedding function is not a function")
self.embedding_fn = embedding_fn self.embedding_fn = embedding_fn
def set_vector_dimension(self, vector_dimension: int): def set_vector_dimension(self, vector_dimension: int):
"""
Set or overwrite the vector dimension size
:param vector_dimension: vector dimension size
:type vector_dimension: int
"""
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
@staticmethod @staticmethod
def _langchain_default_concept(embeddings: Any): def _langchain_default_concept(embeddings: Any):
""" """
Langchains default function layout for embeddings. Langchains default function layout for embeddings.
:param embeddings: Langchain embeddings
:type embeddings: Any
:return: embedding function
:rtype: Callable
""" """
def embed_function(texts: Documents) -> Embeddings: def embed_function(texts: Documents) -> Embeddings:

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import List, Optional from typing import Any, Dict, Generator, List, Optional
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage from langchain.schema import BaseMessage
@@ -13,6 +13,11 @@ from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseLlm(JSONSerializable): class BaseLlm(JSONSerializable):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""Initialize a base LLM class
:param config: LLM configuration option class, defaults to None
:type config: Optional[BaseLlmConfig], optional
"""
if config is None: if config is None:
self.config = BaseLlmConfig() self.config = BaseLlmConfig()
else: else:
@@ -21,7 +26,7 @@ class BaseLlm(JSONSerializable):
self.memory = ConversationBufferMemory() self.memory = ConversationBufferMemory()
self.is_docs_site_instance = False self.is_docs_site_instance = False
self.online = False self.online = False
self.history: any = None self.history: Any = None
def get_llm_model_answer(self): def get_llm_model_answer(self):
""" """
@@ -29,24 +34,33 @@ class BaseLlm(JSONSerializable):
""" """
raise NotImplementedError raise NotImplementedError
def set_history(self, history: any): def set_history(self, history: Any):
"""
Provide your own history.
Especially interesting for the query method, which does not internally manage conversation history.
:param history: History to set
:type history: Any
"""
self.history = history self.history = history
def update_history(self): def update_history(self):
"""Update class history attribute with history in memory (for chat method)"""
chat_history = self.memory.load_memory_variables({})["history"] chat_history = self.memory.load_memory_variables({})["history"]
if chat_history: if chat_history:
self.set_history(chat_history) self.set_history(chat_history)
def generate_prompt(self, input_query, contexts, **kwargs): def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
""" """
Generates a prompt based on the given query and context, ready to be Generates a prompt based on the given query and context, ready to be
passed to an LLM passed to an LLM
:param input_query: The query to use. :param input_query: The query to use.
:type input_query: str
:param contexts: List of similar documents to the query used as context. :param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as :type contexts: List[str]
configuration options.
:return: The prompt :return: The prompt
:rtype: str
""" """
context_string = (" | ").join(contexts) context_string = (" | ").join(contexts)
web_search_result = kwargs.get("web_search_result", "") web_search_result = kwargs.get("web_search_result", "")
@@ -73,36 +87,67 @@ class BaseLlm(JSONSerializable):
) )
return prompt return prompt
def _append_search_and_context(self, context, web_search_result): def _append_search_and_context(self, context: str, web_search_result: str) -> str:
"""Append web search context to existing context
:param context: Existing context
:type context: str
:param web_search_result: Web search result
:type web_search_result: str
:return: Concatenated web search result
:rtype: str
"""
return f"{context}\nWeb Search Result: {web_search_result}" return f"{context}\nWeb Search Result: {web_search_result}"
def get_answer_from_llm(self, prompt): def get_answer_from_llm(self, prompt: str):
""" """
Gets an answer based on the given query and context by passing it Gets an answer based on the given query and context by passing it
to an LLM. to an LLM.
:param query: The query to use. :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
:param context: Similar documents to the query used as context. :type prompt: str
:return: The answer. :return: The answer.
:rtype: _type_
""" """
return self.get_llm_model_answer(prompt) return self.get_llm_model_answer(prompt)
def access_search_and_get_results(self, input_query): def access_search_and_get_results(self, input_query: str):
"""
Search the internet for additional context
:param input_query: search query
:type input_query: str
:return: Search results
:rtype: Unknown
"""
from langchain.tools import DuckDuckGoSearchRun from langchain.tools import DuckDuckGoSearchRun
search = DuckDuckGoSearchRun() search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}") logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query) return search.run(input_query)
def _stream_query_response(self, answer): def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
:type answer: Any
:yield: Answer chunk from llm
:rtype: Generator[Any, Any, None]
"""
streamed_answer = "" streamed_answer = ""
for chunk in answer: for chunk in answer:
streamed_answer = streamed_answer + chunk streamed_answer = streamed_answer + chunk
yield chunk yield chunk
logging.info(f"Answer: {streamed_answer}") logging.info(f"Answer: {streamed_answer}")
def _stream_chat_response(self, answer): def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
:type answer: Any
:yield: Answer chunk from llm
:rtype: Generator[Any, Any, None]
"""
streamed_answer = "" streamed_answer = ""
for chunk in answer: for chunk in answer:
streamed_answer = streamed_answer + chunk streamed_answer = streamed_answer + chunk
@@ -110,23 +155,24 @@ class BaseLlm(JSONSerializable):
self.memory.chat_memory.add_ai_message(streamed_answer) self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}") logging.info(f"Answer: {streamed_answer}")
def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None): def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer. LLM as context to get the answer.
:param input_query: The query to use. :param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options. :type input_query: str
This is used for one method call. To persistently use a config, declare it during app init. :param contexts: Embeddings retrieved from the database to be used as context.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to :type contexts: List[str]
the LLM. The purpose is to test the prompt, not the response. :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
You can use it to test your prompt, including the context provided To persistently use a config, declare it during app init., defaults to None
by the vector database's doc retrieval. :type config: Optional[BaseLlmConfig], optional
The only thing the dry run does not consider is the cut-off due to :param dry_run: A dry run does everything except send the resulting prompt to
the `max_tokens` parameter. the LLM. The purpose is to test the prompt, not the response., defaults to False
:param where: Optional. A dictionary of key-value pairs to filter the database results. :type dry_run: bool, optional
:return: The answer to the query. :return: The answer to the query or the dry run result
:rtype: str
""" """
query_config = config or self.config query_config = config or self.config
@@ -150,24 +196,26 @@ class BaseLlm(JSONSerializable):
else: else:
return self._stream_query_response(answer) return self._stream_query_response(answer)
def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None): def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer. LLM as context to get the answer.
Maintains the whole conversation in memory. Maintains the whole conversation in memory.
:param input_query: The query to use. :param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options. :type input_query: str
This is used for one method call. To persistently use a config, declare it during app init. :param contexts: Embeddings retrieved from the database to be used as context.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to :type contexts: List[str]
the LLM. The purpose is to test the prompt, not the response. :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
You can use it to test your prompt, including the context provided To persistently use a config, declare it during app init., defaults to None
by the vector database's doc retrieval. :type config: Optional[BaseLlmConfig], optional
The only thing the dry run does not consider is the cut-off due to :param dry_run: A dry run does everything except send the resulting prompt to
the `max_tokens` parameter. the LLM. The purpose is to test the prompt, not the response., defaults to False
:param where: Optional. A dictionary of key-value pairs to filter the database results. :type dry_run: bool, optional
:return: The answer to the query. :return: The answer to the query or the dry run result
:rtype: str
""" """
query_config = config or self.config query_config = config or self.config
@@ -205,6 +253,16 @@ class BaseLlm(JSONSerializable):
@staticmethod @staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
"""
Construct a list of langchain messages
:param prompt: User prompt
:type prompt: str
:param system_prompt: System prompt, defaults to None
:type system_prompt: Optional[str], optional
:return: List of messages
:rtype: List[BaseMessage]
"""
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
messages = [] messages = []

View File

@@ -7,6 +7,11 @@ class BaseVectorDB(JSONSerializable):
"""Base class for vector database.""" """Base class for vector database."""
def __init__(self, config: BaseVectorDbConfig): def __init__(self, config: BaseVectorDbConfig):
"""Initialize the database. Save the config and client as an attribute.
:param config: Database configuration class instance.
:type config: BaseVectorDbConfig
"""
self.client = self._get_or_create_db() self.client = self._get_or_create_db()
self.config: BaseVectorDbConfig = config self.config: BaseVectorDbConfig = config
@@ -23,25 +28,50 @@ class BaseVectorDB(JSONSerializable):
raise NotImplementedError raise NotImplementedError
def _get_or_create_collection(self): def _get_or_create_collection(self):
"""Get or create a named collection."""
raise NotImplementedError raise NotImplementedError
def _set_embedder(self, embedder: BaseEmbedder): def _set_embedder(self, embedder: BaseEmbedder):
"""
The database needs to access the embedder sometimes, with this method you can persistently set it.
:param embedder: Embedder to be set as the embedder for this database.
:type embedder: BaseEmbedder
"""
self.embedder = embedder self.embedder = embedder
def get(self): def get(self):
"""Get database embeddings by id."""
raise NotImplementedError raise NotImplementedError
def add(self): def add(self):
"""Add to database"""
raise NotImplementedError raise NotImplementedError
def query(self): def query(self):
"""Query contents from vector data base based on vector similarity"""
raise NotImplementedError raise NotImplementedError
def count(self): def count(self) -> int:
"""
Count number of documents/chunks embedded in the database.
:return: number of documents
:rtype: int
"""
raise NotImplementedError raise NotImplementedError
def reset(self): def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
raise NotImplementedError raise NotImplementedError
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
from chromadb import Collection, QueryResult
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.config import ChromaDbConfig from embedchain.config import ChromaDbConfig
@@ -25,6 +26,11 @@ class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB.""" """Vector database using ChromaDB."""
def __init__(self, config: Optional[ChromaDbConfig] = None): def __init__(self, config: Optional[ChromaDbConfig] = None):
"""Initialize a new ChromaDB instance
:param config: Configuration options for Chroma, defaults to None
:type config: Optional[ChromaDbConfig], optional
"""
if config: if config:
self.config = config self.config = config
else: else:
@@ -60,11 +66,19 @@ class ChromaDB(BaseVectorDB):
self._get_or_create_collection(self.config.collection_name) self._get_or_create_collection(self.config.collection_name)
def _get_or_create_db(self): def _get_or_create_db(self):
"""Get or create the database.""" """Called during initialization"""
return self.client return self.client
def _get_or_create_collection(self, name): def _get_or_create_collection(self, name: str) -> Collection:
"""Get or create the collection.""" """
Get or create a named collection.
:param name: Name of the collection
:type name: str
:raises ValueError: No embedder configured.
:return: Created collection
:rtype: Collection
"""
if not hasattr(self, "embedder") or not self.embedder: if not hasattr(self, "embedder") or not self.embedder:
raise ValueError("Cannot create a Chroma database collection without an embedder.") raise ValueError("Cannot create a Chroma database collection without an embedder.")
self.collection = self.client.get_or_create_collection( self.collection = self.client.get_or_create_collection(
@@ -76,8 +90,13 @@ class ChromaDB(BaseVectorDB):
def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: list of doc ids to check for existance
:param ids: list of doc ids to check for existence
:type ids: List[str]
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any]
:return: Existing documents.
:rtype: List[str]
""" """
existing_docs = self.collection.get( existing_docs = self.collection.get(
ids=ids, ids=ids,
@@ -86,16 +105,28 @@ class ChromaDB(BaseVectorDB):
return set(existing_docs["ids"]) return set(existing_docs["ids"])
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
""" """
add data in vector database Add vectors to chroma database
:param documents: list of texts to add
:param metadatas: list of metadata associated with docs :param documents: Documents
:param ids: ids of docs :type documents: List[str]
:param metadatas: Metadatas
:type metadatas: List[object]
:param ids: ids
:type ids: List[str]
""" """
self.collection.add(documents=documents, metadatas=metadatas, ids=ids) self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
def _format_result(self, results): def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
"""
Format Chroma results
:param results: ChromaDB query results to format.
:type results: QueryResult
:return: Formatted results
:rtype: list[tuple[Document, float]]
"""
return [ return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2]) (Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip( for result in zip(
@@ -107,11 +138,17 @@ class ChromaDB(BaseVectorDB):
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
""" """
query contents from vector data base based on vector similarity Query contents from vector data base based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:param where: Optional. to filter data :type n_results: int
:param where: to filter data
:type where: Dict[str, any]
:raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query. :return: The content of the document that matched your query.
:rtype: List[str]
""" """
try: try:
result = self.collection.query( result = self.collection.query(
@@ -132,21 +169,27 @@ class ChromaDB(BaseVectorDB):
return contents return contents
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
self.config.collection_name = name self.config.collection_name = name
self._get_or_create_collection(self.config.collection_name) self._get_or_create_collection(self.config.collection_name)
def count(self) -> int: def count(self) -> int:
""" """
Count the number of embeddings. Count number of documents/chunks embedded in the database.
:return: The number of embeddings. :return: number of documents
:rtype: int
""" """
return self.collection.count() return self.collection.count()
def reset(self): def reset(self):
""" """
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
""" """
# Delete all data from the database # Delete all data from the database
try: try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import Dict, List, Optional, Set
try: try:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@@ -15,16 +15,23 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
@register_deserializable @register_deserializable
class ElasticsearchDB(BaseVectorDB): class ElasticsearchDB(BaseVectorDB):
"""
Elasticsearch as vector database
"""
def __init__( def __init__(
self, self,
config: ElasticsearchDBConfig = None, config: Optional[ElasticsearchDBConfig] = None,
es_config: ElasticsearchDBConfig = None, # Backwards compatibility es_config: Optional[ElasticsearchDBConfig] = None, # Backwards compatibility
): ):
""" """Elasticsearch as vector database.
Elasticsearch as vector database
:param es_config. elasticsearch database config to be used for connection :param config: Elasticsearch database config, defaults to None
:param embedding_fn: Function to generate embedding vectors. :type config: ElasticsearchDBConfig, optional
:param vector_dim: Vector dimension generated by embedding fn :param es_config: `es_config` is supported as an alias for `config` (for backwards compatibility),
defaults to None
:type es_config: ElasticsearchDBConfig, optional
:raises ValueError: No config provided
""" """
if config is None and es_config is None: if config is None and es_config is None:
raise ValueError("ElasticsearchDBConfig is required") raise ValueError("ElasticsearchDBConfig is required")
@@ -53,16 +60,22 @@ class ElasticsearchDB(BaseVectorDB):
self.client.indices.create(index=es_index, body=index_settings) self.client.indices.create(index=es_index, body=index_settings)
def _get_or_create_db(self): def _get_or_create_db(self):
"""Called during initialization"""
return self.client return self.client
def _get_or_create_collection(self, name): def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later""" """Note: nothing to return here. Discuss later"""
def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: def get(self, ids: List[str], where: Dict[str, any]) -> Set[str]:
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: list of doc ids to check for existance
:param where: Optional. to filter data :param ids: _list of doc ids to check for existance
:type ids: List[str]
:param where: to filter data
:type where: Dict[str, any]
:return: ids
:rtype: Set[str]
""" """
query = {"bool": {"must": [{"ids": {"values": ids}}]}} query = {"bool": {"must": [{"ids": {"values": ids}}]}}
if "app_id" in where: if "app_id" in where:
@@ -73,13 +86,17 @@ class ElasticsearchDB(BaseVectorDB):
ids = [doc["_id"] for doc in docs] ids = [doc["_id"] for doc in docs]
return set(ids) return set(ids)
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
""" """add data in vector database
add data in vector database
:param documents: list of texts to add :param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs :param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:param ids: ids of docs :param ids: ids of docs
:type ids: List[str]
""" """
docs = [] docs = []
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
@@ -92,14 +109,19 @@ class ElasticsearchDB(BaseVectorDB):
) )
bulk(self.client, docs) bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
return
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any]
:return: Database contents that are the result of the query
:rtype: List[str]
""" """
input_query_vector = self.embedder.embedding_fn(input_query) input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0] query_vector = input_query_vector[0]
@@ -122,21 +144,41 @@ class ElasticsearchDB(BaseVectorDB):
return contents return contents
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
self.config.collection_name = name self.config.collection_name = name
def count(self) -> int: def count(self) -> int:
"""
Count number of documents/chunks embedded in the database.
:return: number of documents
:rtype: int
"""
query = {"match_all": {}} query = {"match_all": {}}
response = self.client.count(index=self._get_index(), query=query) response = self.client.count(index=self._get_index(), query=query)
doc_count = response["count"] doc_count = response["count"]
return doc_count return doc_count
def reset(self): def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the database # Delete all data from the database
if self.client.indices.exists(index=self._get_index()): if self.client.indices.exists(index=self._get_index()):
# delete index in Es # delete index in Es
self.client.indices.delete(index=self._get_index()) self.client.indices.delete(index=self._get_index())
def _get_index(self): def _get_index(self) -> str:
"""Get the Elasticsearch index for a collection
:return: Elasticsearch index
:rtype: str
"""
# NOTE: The method is preferred to an attribute, because if collection name changes, # NOTE: The method is preferred to an attribute, because if collection name changes,
# it's always up-to-date. # it's always up-to-date.
return f"{self.config.collection_name}_{self.embedder.vector_dimension}" return f"{self.config.collection_name}_{self.embedder.vector_dimension}"

View File

@@ -121,9 +121,9 @@ class TestChromaDbDuplicateHandling:
self.app_with_settings.reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection("test_collection_2") app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not assert "Add of existing embedding ID: 0" not in caplog.text # not
@@ -149,16 +149,16 @@ class TestChromaDbCollection(unittest.TestCase):
""" """
config = AppConfig(collect_metrics=False) config = AppConfig(collect_metrics=False)
app = App(config=config) app = App(config=config)
app.set_collection(collection_name="test_collection") app.set_collection_name(name="test_collection")
self.assertEqual(app.db.collection.name, "test_collection") self.assertEqual(app.db.collection.name, "test_collection")
def test_set_collection(self): def test_set_collection_name(self):
""" """
Test if the `App` collection is correctly switched using the `set_collection` method. Test if the `App` collection is correctly switched using the `set_collection_name` method.
""" """
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection") app.set_collection_name("test_collection")
self.assertEqual(app.db.collection.name, "test_collection") self.assertEqual(app.db.collection.name, "test_collection")
@@ -170,7 +170,7 @@ class TestChromaDbCollection(unittest.TestCase):
self.app_with_settings.reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection_name("test_collection_1")
# Collection should be empty when created # Collection should be empty when created
self.assertEqual(app.count(), 0) self.assertEqual(app.count(), 0)
@@ -178,13 +178,13 @@ class TestChromaDbCollection(unittest.TestCase):
# After adding, should contain one item # After adding, should contain one item
self.assertEqual(app.count(), 1) self.assertEqual(app.count(), 1)
app.set_collection("test_collection_2") app.set_collection_name("test_collection_2")
# New collection is empty # New collection is empty
self.assertEqual(app.count(), 0) self.assertEqual(app.count(), 0)
# Adding to new collection should not effect existing collection # Adding to new collection should not effect existing collection
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection("test_collection_1") app.set_collection_name("test_collection_1")
# Should still be 1, not 2. # Should still be 1, not 2.
self.assertEqual(app.count(), 1) self.assertEqual(app.count(), 1)
@@ -196,12 +196,12 @@ class TestChromaDbCollection(unittest.TestCase):
self.app_with_settings.reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app del app
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection_name("test_collection_1")
self.assertEqual(app.count(), 1) self.assertEqual(app.count(), 1)
def test_parallel_collections(self): def test_parallel_collections(self):
@@ -227,9 +227,9 @@ class TestChromaDbCollection(unittest.TestCase):
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
# Swap names and test # Swap names and test
app1.set_collection("test_collection_2") app1.set_collection_name("test_collection_2")
self.assertEqual(app1.count(), 1) self.assertEqual(app1.count(), 1)
app2.set_collection("test_collection_1") app2.set_collection_name("test_collection_1")
self.assertEqual(app2.count(), 3) self.assertEqual(app2.count(), 3)
def test_ids_share_collections(self): def test_ids_share_collections(self):
@@ -241,9 +241,9 @@ class TestChromaDbCollection(unittest.TestCase):
# Create two apps # Create two apps
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app1.set_collection("one_collection") app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection("one_collection") app2.set_collection_name("one_collection")
# Add data # Add data
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
@@ -263,13 +263,13 @@ class TestChromaDbCollection(unittest.TestCase):
# Create four apps. # Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config) app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config)
app1.set_collection("one_collection") app1.set_collection_name("one_collection")
app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False))
app2.set_collection("one_collection") app2.set_collection_name("one_collection")
app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False))
app3.set_collection("three_collection") app3.set_collection_name("three_collection")
app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False)) app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False))
app4.set_collection("four_collection") app4.set_collection_name("four_collection")
# Each one of them get data # Each one of them get data
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"]) app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])