From 344e7470f6d5ea59230fa162f7877dfd2e296a01 Mon Sep 17 00:00:00 2001 From: cachho Date: Tue, 5 Sep 2023 10:12:58 +0200 Subject: [PATCH] refactor: classes and configs (#528) --- docs/advanced/app_types.mdx | 24 +- docs/advanced/configuration.mdx | 53 +++-- docs/advanced/query_configuration.mdx | 2 +- embedchain/__init__.py | 2 + embedchain/apps/App.py | 54 ++--- embedchain/apps/CustomApp.py | 177 ++++---------- embedchain/apps/Llama2App.py | 34 ++- embedchain/apps/OpenSourceApp.py | 58 ++--- embedchain/apps/PersonApp.py | 17 +- embedchain/bots/base.py | 17 +- embedchain/bots/discord.py | 3 + embedchain/bots/poe.py | 6 +- embedchain/config/ChatConfig.py | 92 -------- embedchain/config/__init__.py | 22 +- embedchain/config/apps/AppConfig.py | 43 +--- embedchain/config/apps/BaseAppConfig.py | 77 ++---- embedchain/config/apps/CustomAppConfig.py | 114 +-------- embedchain/config/apps/OpenSourceAppConfig.py | 36 +-- .../config/embedder/BaseEmbedderConfig.py | 10 + embedchain/config/embedder/__init__.py | 0 embedchain/config/llm/__init__.py | 0 .../base_llm_config.py} | 40 ++-- .../config/vectordbs/BaseVectorDbConfig.py | 17 ++ embedchain/config/vectordbs/ChromaDbConfig.py | 21 ++ .../config/vectordbs/ElasticsearchDBConfig.py | 28 ++- embedchain/embedchain.py | 219 ++++++------------ embedchain/embedder/__init__.py | 0 embedchain/embedder/base_embedder.py | 45 ++++ embedchain/embedder/gpt4all_embedder.py | 21 ++ embedchain/embedder/huggingface_embedder.py | 19 ++ embedchain/embedder/openai_embedder.py | 40 ++++ embedchain/embedder/vertexai_embedder.py | 19 ++ embedchain/llm/__init__.py | 0 embedchain/llm/antrophic_llm.py | 29 +++ embedchain/llm/azure_openai_llm.py | 39 ++++ embedchain/llm/base_llm.py | 214 +++++++++++++++++ embedchain/llm/gpt4all_llm.py | 47 ++++ embedchain/llm/llama2_llm.py | 27 +++ embedchain/llm/openai_llm.py | 43 ++++ embedchain/llm/vertex_ai_llm.py | 29 +++ embedchain/vectordb/base_vector_db.py | 19 +- embedchain/vectordb/chroma_db.py | 92 +++++--- embedchain/vectordb/elasticsearch_db.py | 58 ++--- tests/embedchain/test_embedchain.py | 16 +- .../helper_classes/test_json_serializable.py | 3 +- tests/{embedchain => llm}/test_chat.py | 64 +++-- .../test_generate_prompt.py | 18 +- tests/{embedchain => llm}/test_query.py | 68 +++--- tests/vectordb/test_chroma_db.py | 124 +++++----- tests/vectordb/test_elasticsearch_db.py | 18 +- 50 files changed, 1221 insertions(+), 997 deletions(-) delete mode 100644 embedchain/config/ChatConfig.py create mode 100644 embedchain/config/embedder/BaseEmbedderConfig.py create mode 100644 embedchain/config/embedder/__init__.py create mode 100644 embedchain/config/llm/__init__.py rename embedchain/config/{QueryConfig.py => llm/base_llm_config.py} (78%) create mode 100644 embedchain/config/vectordbs/BaseVectorDbConfig.py create mode 100644 embedchain/config/vectordbs/ChromaDbConfig.py create mode 100644 embedchain/embedder/__init__.py create mode 100644 embedchain/embedder/base_embedder.py create mode 100644 embedchain/embedder/gpt4all_embedder.py create mode 100644 embedchain/embedder/huggingface_embedder.py create mode 100644 embedchain/embedder/openai_embedder.py create mode 100644 embedchain/embedder/vertexai_embedder.py create mode 100644 embedchain/llm/__init__.py create mode 100644 embedchain/llm/antrophic_llm.py create mode 100644 embedchain/llm/azure_openai_llm.py create mode 100644 embedchain/llm/base_llm.py create mode 100644 embedchain/llm/gpt4all_llm.py create mode 100644 embedchain/llm/llama2_llm.py create mode 100644 embedchain/llm/openai_llm.py create mode 100644 embedchain/llm/vertex_ai_llm.py rename tests/{embedchain => llm}/test_chat.py (62%) rename tests/{embedchain => llm}/test_generate_prompt.py (79%) rename tests/{embedchain => llm}/test_query.py (69%) diff --git a/docs/advanced/app_types.mdx b/docs/advanced/app_types.mdx index 0bc34e47..147d769f 100644 --- a/docs/advanced/app_types.mdx +++ b/docs/advanced/app_types.mdx @@ -69,16 +69,27 @@ app = OpenSourceApp() ```python from embedchain import CustomApp -from embedchain.config import CustomAppConfig -from embedchain.models import Providers, EmbeddingFunctions +from embedchain.config import (CustomAppConfig, ElasticsearchDBConfig, + EmbedderConfig, LlmConfig) +from embedchain.embedder.vertexai_embedder import VertexAiEmbedder +from embedchain.llm.vertex_ai_llm import VertexAiLlm +from embedchain.models import EmbeddingFunctions, Providers +from embedchain.vectordb.elasticsearch_db import Elasticsearch -config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI) -app = CustomApp(config) +# short +app = CustomApp(llm=VertexAiLlm(), db=Elasticsearch(), embedder=VertexAiEmbedder()) +# with configs +app = CustomApp( + config=CustomAppConfig(log_level="INFO"), + llm=VertexAiLlm(config=LlmConfig(number_documents=5)), + db=Elasticsearch(config=ElasticsearchDBConfig(es_url="...")), + embedder=VertexAiEmbedder(config=EmbedderConfig()), +) ``` - `CustomApp` is not opinionated. -- Configuration required. It's for advanced users who want to mix and match different embedding models and LLMs. Configuration required. -- while it's doing that, it's still providing abstractions through `Providers`. +- Configuration required. It's for advanced users who want to mix and match different embedding models and LLMs. +- while it's doing that, it's still providing abstractions by allowing you to import Classes from `embedchain.llm`, `embedchain.vectordb`, and `embedchain.embedder`. - paid and free/open source providers included. - Once you have imported and instantiated the app, every functionality from here onwards is the same for either type of app. 📚 - Following providers are available for an LLM @@ -87,6 +98,7 @@ app = CustomApp(config) - VERTEX_AI - GPT4ALL - AZURE_OPENAI + - LLAMA2 - Following embedding functions are available for an embedding function - OPENAI - HUGGING_FACE diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index 1de54eed..dc6301c8 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -4,6 +4,16 @@ title: '⚙️ Custom configurations' Embedchain is made to work out of the box. However, for advanced users we're also offering configuration options. All of these configuration options are optional and have sane defaults. +## Concept +The main `App` class is available in the following varieties: `CustomApp`, `OpenSourceApp` and `Llama2App` and `App`. The first is fully configurable, the others are opinionated in some aspects. + +The `App` class has three subclasses: `llm`, `db` and `embedder`. These are the core ingredients that make up an EmbedChain app. +App plus each one of the subclasses have a `config` attribute. +You can pass a `Config` instance as an argument during initialization to persistently configure a class. +These configs can be imported from `embedchain.config` + +There are `set` methods for some things that should not (only) be set at start-up, like `app.db.set_collection_name`. + ## Examples ### General @@ -11,31 +21,31 @@ Embedchain is made to work out of the box. However, for advanced users we're als Here's the readme example with configuration options. ```python -import os from embedchain import App -from embedchain.config import AppConfig, AddConfig, QueryConfig, ChunkerConfig -from chromadb.utils import embedding_functions +from embedchain.config import AppConfig, AddConfig, LlmConfig, ChunkerConfig # Example: set the log level for debugging config = AppConfig(log_level="DEBUG") naval_chat_bot = App(config) # Example: specify a custom collection name -config = AppConfig(collection_name="naval_chat_bot") -naval_chat_bot = App(config) +naval_chat_bot.db.set_collection_name("naval_chat_bot") # Example: define your own chunker config for `youtube_video` chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len) -naval_chat_bot.add("https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config)) +# Example: Add your chunker config to an AddConfig to actually use it +add_config = AddConfig(chunker=chunker_config) +naval_chat_bot.add("https://www.youtube.com/watch?v=3qHkcs3kG44", config=add_config) +# Example: Reset to default add_config = AddConfig() naval_chat_bot.add("https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", config=add_config) naval_chat_bot.add("https://nav.al/feedback", config=add_config) naval_chat_bot.add("https://nav.al/agi", config=add_config) - naval_chat_bot.add(("Who is Naval Ravikant?", "Naval Ravikant is an Indian-American entrepreneur and investor."), config=add_config) -query_config = QueryConfig() +# Change the number of documents. +query_config = LlmConfig(number_documents=5) print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?", config=query_config)) ``` @@ -44,11 +54,13 @@ print(naval_chat_bot.query("What unique capacity does Naval argue humans possess Here's the example of using custom prompt template with `.query` ```python -from embedchain.config import QueryConfig -from embedchain.embedchain import App from string import Template + import wikipedia +from embedchain import App +from embedchain.config import LlmConfig + einstein_chat_bot = App() # Embed Wikipedia page @@ -56,7 +68,8 @@ page = wikipedia.page("Albert Einstein") einstein_chat_bot.add(page.content) # Example: use your own custom template with `$context` and `$query` -einstein_chat_template = Template(""" +einstein_chat_template = Template( + """ You are Albert Einstein, a German-born theoretical physicist, widely ranked among the greatest and most influential scientists of all time. @@ -67,17 +80,19 @@ einstein_chat_template = Template(""" Keep the response brief. If you don't know the answer, just say that you don't know, don't try to make up an answer. Human: $query - Albert Einstein:""") -query_config = QueryConfig(template=einstein_chat_template, system_prompt="You are Albert Einstein.") + Albert Einstein:""" +) +# Example: Use the template, also add a system prompt. +llm_config = LlmConfig(template=einstein_chat_template, system_prompt="You are Albert Einstein.") queries = [ - "Where did you complete your studies?", - "Why did you win nobel prize?", - "Why did you divorce your first wife?", + "Where did you complete your studies?", + "Why did you win nobel prize?", + "Why did you divorce your first wife?", ] for query in queries: - response = einstein_chat_bot.query(query, config=query_config) - print("Query: ", query) - print("Response: ", response) + response = einstein_chat_bot.query(query, config=llm_config) + print("Query: ", query) + print("Response: ", response) # Output # Query: Where did you complete your studies? diff --git a/docs/advanced/query_configuration.mdx b/docs/advanced/query_configuration.mdx index ade45592..b71eba11 100644 --- a/docs/advanced/query_configuration.mdx +++ b/docs/advanced/query_configuration.mdx @@ -53,7 +53,7 @@ Default values of chunker config parameters for different `data_type`: _coming soon_ -## QueryConfig +## LlmConfig |option|description|type|default| |---|---|---|---| diff --git a/embedchain/__init__.py b/embedchain/__init__.py index 7fd6ce44..2ae5c9f6 100644 --- a/embedchain/__init__.py +++ b/embedchain/__init__.py @@ -8,3 +8,5 @@ from embedchain.apps.Llama2App import Llama2App # noqa: F401 from embedchain.apps.OpenSourceApp import OpenSourceApp # noqa: F401 from embedchain.apps.PersonApp import (PersonApp, # noqa: F401 PersonOpenSourceApp) +from embedchain.vectordb.chroma_db import ChromaDB # noqa: F401 +from embedchain.vectordb.elasticsearch_db import ElasticsearchDB # noqa: F401 diff --git a/embedchain/apps/App.py b/embedchain/apps/App.py index 693fd804..0ee8a676 100644 --- a/embedchain/apps/App.py +++ b/embedchain/apps/App.py @@ -1,10 +1,12 @@ from typing import Optional -import openai - -from embedchain.config import AppConfig, ChatConfig +from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig, + ChromaDbConfig) from embedchain.embedchain import EmbedChain +from embedchain.embedder.openai_embedder import OpenAiEmbedder from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.openai_llm import OpenAiLlm +from embedchain.vectordb.chroma_db import ChromaDB @register_deserializable @@ -18,7 +20,13 @@ class App(EmbedChain): dry_run(query): test your prompt without consuming tokens. """ - def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None): + def __init__( + self, + config: AppConfig = None, + llm_config: BaseLlmConfig = None, + chromadb_config: Optional[ChromaDbConfig] = None, + system_prompt: Optional[str] = None, + ): """ :param config: AppConfig instance to load as configuration. Optional. :param system_prompt: System prompt string. Optional. @@ -26,38 +34,8 @@ class App(EmbedChain): if config is None: config = AppConfig() - super().__init__(config, system_prompt) + llm = OpenAiLlm(config=llm_config) + embedder = OpenAiEmbedder(config=BaseEmbedderConfig(model="text-embedding-ada-002")) + database = ChromaDB(config=chromadb_config) - def get_llm_model_answer(self, prompt, config: ChatConfig): - messages = [] - system_prompt = ( - self.system_prompt - if self.system_prompt is not None - else config.system_prompt - if config.system_prompt is not None - else None - ) - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - response = openai.ChatCompletion.create( - model=config.model or "gpt-3.5-turbo-0613", - messages=messages, - temperature=config.temperature, - max_tokens=config.max_tokens, - top_p=config.top_p, - stream=config.stream, - ) - - if config.stream: - return self._stream_llm_model_response(response) - else: - return response["choices"][0]["message"]["content"] - - def _stream_llm_model_response(self, response): - """ - This is a generator for streaming response from the OpenAI completions API - """ - for line in response: - chunk = line["choices"][0].get("delta", {}).get("content", "") - yield chunk + super().__init__(config, llm, db=database, embedder=embedder, system_prompt=system_prompt) diff --git a/embedchain/apps/CustomApp.py b/embedchain/apps/CustomApp.py index ad20d1eb..cf01c8ac 100644 --- a/embedchain/apps/CustomApp.py +++ b/embedchain/apps/CustomApp.py @@ -1,12 +1,11 @@ -import logging -from typing import List, Optional +from typing import Optional -from langchain.schema import BaseMessage - -from embedchain.config import ChatConfig, CustomAppConfig +from embedchain.config import CustomAppConfig from embedchain.embedchain import EmbedChain +from embedchain.embedder.base_embedder import BaseEmbedder from embedchain.helper_classes.json_serializable import register_deserializable -from embedchain.models import Providers +from embedchain.llm.base_llm import BaseLlm +from embedchain.vectordb.base_vector_db import BaseVectorDB @register_deserializable @@ -20,143 +19,49 @@ class CustomApp(EmbedChain): dry_run(query): test your prompt without consuming tokens. """ - def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None): + def __init__( + self, + config: CustomAppConfig = None, + llm: BaseLlm = None, + db: BaseVectorDB = None, + embedder: BaseEmbedder = None, + system_prompt: Optional[str] = None, + ): """ :param config: Optional. `CustomAppConfig` instance to load as configuration. :raises ValueError: Config must be provided for custom app :param system_prompt: Optional. System prompt string. """ + # Config is not required, it has a default if config is None: - raise ValueError("Config must be provided for custom app") + config = CustomAppConfig() - self.provider = config.provider + if llm is None: + raise ValueError("LLM must be provided for custom app. Please import from `embedchain.llm`.") + if db is None: + raise ValueError("Database must be provided for custom app. Please import from `embedchain.vectordb`.") + if embedder is None: + raise ValueError("Embedder must be provided for custom app. Please import from `embedchain.embedder`.") - if config.provider == Providers.GPT4ALL: - from embedchain import OpenSourceApp - - # Because these models run locally, they should have an instance running when the custom app is created - self.open_source_app = OpenSourceApp(config=config.open_source_app_config) - - super().__init__(config, system_prompt) - - def set_llm_model(self, provider: Providers): - self.provider = provider - if provider == Providers.GPT4ALL: - raise ValueError( - "GPT4ALL needs to be instantiated with the model known, please create a new app instance instead" + if not isinstance(config, CustomAppConfig): + raise TypeError( + "Config is not a `CustomAppConfig` instance. " + "Please make sure the type is right and that you are passing an instance." + ) + if not isinstance(llm, BaseLlm): + raise TypeError( + "LLM is not a `BaseLlm` instance. " + "Please make sure the type is right and that you are passing an instance." + ) + if not isinstance(db, BaseVectorDB): + raise TypeError( + "Database is not a `BaseVectorDB` instance. " + "Please make sure the type is right and that you are passing an instance." + ) + if not isinstance(embedder, BaseEmbedder): + raise TypeError( + "Embedder is not a `BaseEmbedder` instance. " + "Please make sure the type is right and that you are passing an instance." ) - def get_llm_model_answer(self, prompt, config: ChatConfig): - # TODO: Quitting the streaming response here for now. - # Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 - if config.stream: - raise NotImplementedError( - "Streaming responses have not been implemented for this model yet. Please disable." - ) - - if config.system_prompt is None and self.system_prompt is not None: - config.system_prompt = self.system_prompt - - try: - if self.provider == Providers.OPENAI: - return CustomApp._get_openai_answer(prompt, config) - - if self.provider == Providers.ANTHROPHIC: - return CustomApp._get_athrophic_answer(prompt, config) - - if self.provider == Providers.VERTEX_AI: - return CustomApp._get_vertex_answer(prompt, config) - - if self.provider == Providers.GPT4ALL: - return self.open_source_app._get_gpt4all_answer(prompt, config) - - if self.provider == Providers.AZURE_OPENAI: - return CustomApp._get_azure_openai_answer(prompt, config) - - except ImportError as e: - raise ModuleNotFoundError(e.msg) from None - - @staticmethod - def _get_openai_answer(prompt: str, config: ChatConfig) -> str: - from langchain.chat_models import ChatOpenAI - - chat = ChatOpenAI( - temperature=config.temperature, - model=config.model or "gpt-3.5-turbo", - max_tokens=config.max_tokens, - streaming=config.stream, - ) - - if config.top_p and config.top_p != 1: - logging.warning("Config option `top_p` is not supported by this model.") - - messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) - - return chat(messages).content - - @staticmethod - def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str: - from langchain.chat_models import ChatAnthropic - - chat = ChatAnthropic(temperature=config.temperature, model=config.model) - - if config.max_tokens and config.max_tokens != 1000: - logging.warning("Config option `max_tokens` is not supported by this model.") - - messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) - - return chat(messages).content - - @staticmethod - def _get_vertex_answer(prompt: str, config: ChatConfig) -> str: - from langchain.chat_models import ChatVertexAI - - chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens) - - if config.top_p and config.top_p != 1: - logging.warning("Config option `top_p` is not supported by this model.") - - messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) - - return chat(messages).content - - @staticmethod - def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str: - from langchain.chat_models import AzureChatOpenAI - - if not config.deployment_name: - raise ValueError("Deployment name must be provided for Azure OpenAI") - - chat = AzureChatOpenAI( - deployment_name=config.deployment_name, - openai_api_version="2023-05-15", - model_name=config.model or "gpt-3.5-turbo", - temperature=config.temperature, - max_tokens=config.max_tokens, - streaming=config.stream, - ) - - if config.top_p and config.top_p != 1: - logging.warning("Config option `top_p` is not supported by this model.") - - messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) - - return chat(messages).content - - @staticmethod - def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: - from langchain.schema import HumanMessage, SystemMessage - - messages = [] - if system_prompt: - messages.append(SystemMessage(content=system_prompt)) - messages.append(HumanMessage(content=prompt)) - return messages - - def _stream_llm_model_response(self, response): - """ - This is a generator for streaming response from the OpenAI completions API - """ - for line in response: - chunk = line["choices"][0].get("delta", {}).get("content", "") - yield chunk + super().__init__(config=config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt) diff --git a/embedchain/apps/Llama2App.py b/embedchain/apps/Llama2App.py index 4ef39220..ce8b095f 100644 --- a/embedchain/apps/Llama2App.py +++ b/embedchain/apps/Llama2App.py @@ -1,13 +1,15 @@ -import os from typing import Optional -from langchain.llms import Replicate - -from embedchain.config import AppConfig, ChatConfig -from embedchain.embedchain import EmbedChain +from embedchain.apps.CustomApp import CustomApp +from embedchain.config import CustomAppConfig +from embedchain.embedder.openai_embedder import OpenAiEmbedder +from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.llama2_llm import Llama2Llm +from embedchain.vectordb.chroma_db import ChromaDB -class Llama2App(EmbedChain): +@register_deserializable +class Llama2App(CustomApp): """ The EmbedChain Llama2App class. Has two functions: add and query. @@ -16,25 +18,15 @@ class Llama2App(EmbedChain): query(query): finds answer to the given query using vector database and LLM. """ - def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None): + def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None): """ - :param config: AppConfig instance to load as configuration. Optional. + :param config: CustomAppConfig instance to load as configuration. Optional. :param system_prompt: System prompt string. Optional. """ - if "REPLICATE_API_TOKEN" not in os.environ: - raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") if config is None: - config = AppConfig() + config = CustomAppConfig() - super().__init__(config, system_prompt) - - def get_llm_model_answer(self, prompt, config: ChatConfig = None): - # TODO: Move the model and other inputs into config - if self.system_prompt or config.system_prompt: - raise ValueError("Llama2App does not support `system_prompt`") - llm = Replicate( - model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", - input={"temperature": 0.75, "max_length": 500, "top_p": 1}, + super().__init__( + config=config, llm=Llama2Llm(), db=ChromaDB(), embedder=OpenAiEmbedder(), system_prompt=system_prompt ) - return llm(prompt) diff --git a/embedchain/apps/OpenSourceApp.py b/embedchain/apps/OpenSourceApp.py index 95bdc1e2..f701cd5b 100644 --- a/embedchain/apps/OpenSourceApp.py +++ b/embedchain/apps/OpenSourceApp.py @@ -1,9 +1,13 @@ import logging -from typing import Iterable, Optional, Union +from typing import Optional -from embedchain.config import ChatConfig, OpenSourceAppConfig +from embedchain.config import (BaseEmbedderConfig, BaseLlmConfig, + ChromaDbConfig, OpenSourceAppConfig) from embedchain.embedchain import EmbedChain +from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.gpt4all_llm import GPT4ALLLlm +from embedchain.vectordb.chroma_db import ChromaDB gpt4all_model = None @@ -20,7 +24,12 @@ class OpenSourceApp(EmbedChain): query(query): finds answer to the given query using vector database and LLM. """ - def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None): + def __init__( + self, + config: OpenSourceAppConfig = None, + chromadb_config: Optional[ChromaDbConfig] = None, + system_prompt: Optional[str] = None, + ): """ :param config: OpenSourceAppConfig instance to load as configuration. Optional. `ef` defaults to open source. @@ -30,42 +39,19 @@ class OpenSourceApp(EmbedChain): if not config: config = OpenSourceAppConfig() + if not isinstance(config, OpenSourceAppConfig): + raise ValueError( + "OpenSourceApp needs a OpenSourceAppConfig passed to it. " + "You can import it with `from embedchain.config import OpenSourceAppConfig`" + ) + if not config.model: raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?") - self.instance = OpenSourceApp._get_instance(config.model) - logging.info("Successfully loaded open source embedding model.") - super().__init__(config, system_prompt) - def get_llm_model_answer(self, prompt, config: ChatConfig): - return self._get_gpt4all_answer(prompt=prompt, config=config) + llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin")) + embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2")) + database = ChromaDB(config=chromadb_config) - @staticmethod - def _get_instance(model): - try: - from gpt4all import GPT4All - except ModuleNotFoundError: - raise ModuleNotFoundError( - "The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501 - ) from None - - return GPT4All(model) - - def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]: - if config.model and config.model != self.config.model: - raise RuntimeError( - "OpenSourceApp does not support switching models at runtime. Please create a new app instance." - ) - - if self.system_prompt or config.system_prompt: - raise ValueError("OpenSourceApp does not support `system_prompt`") - - response = self.instance.generate( - prompt=prompt, - streaming=config.stream, - top_p=config.top_p, - max_tokens=config.max_tokens, - temp=config.temperature, - ) - return response + super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt) diff --git a/embedchain/apps/PersonApp.py b/embedchain/apps/PersonApp.py index a804fa6b..971c7e35 100644 --- a/embedchain/apps/PersonApp.py +++ b/embedchain/apps/PersonApp.py @@ -2,9 +2,10 @@ from string import Template from embedchain.apps.App import App from embedchain.apps.OpenSourceApp import OpenSourceApp -from embedchain.config import ChatConfig, QueryConfig +from embedchain.config import BaseLlmConfig from embedchain.config.apps.BaseAppConfig import BaseAppConfig -from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY +from embedchain.config.llm.base_llm_config import (DEFAULT_PROMPT, + DEFAULT_PROMPT_WITH_HISTORY) from embedchain.helper_classes.json_serializable import register_deserializable @@ -23,7 +24,7 @@ class EmbedChainPersonApp: self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501 super().__init__(config) - def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None): + def add_person_template_to_config(self, default_prompt: str, config: BaseLlmConfig = None): """ This method checks if the config object contains a prompt template if yes it adds the person prompt to it and return the updated config @@ -44,7 +45,7 @@ class EmbedChainPersonApp: config.template = template else: # if no config is present at all, initialize the config with person prompt and default template - config = QueryConfig( + config = BaseLlmConfig( template=template, ) @@ -58,11 +59,11 @@ class PersonApp(EmbedChainPersonApp, App): Extends functionality from EmbedChainPersonApp and App """ - def query(self, input_query, config: QueryConfig = None, dry_run=False): + def query(self, input_query, config: BaseLlmConfig = None, dry_run=False): config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None) return super().query(input_query, config, dry_run, where=None) - def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None): + def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None): config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config) return super().chat(input_query, config, dry_run, where) @@ -74,10 +75,10 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp): Extends functionality from EmbedChainPersonApp and OpenSourceApp """ - def query(self, input_query, config: QueryConfig = None, dry_run=False): + def query(self, input_query, config: BaseLlmConfig = None, dry_run=False): config = self.add_person_template_to_config(DEFAULT_PROMPT, config) return super().query(input_query, config, dry_run) - def chat(self, input_query, config: ChatConfig = None, dry_run=False): + def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False): config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config) return super().chat(input_query, config, dry_run) diff --git a/embedchain/bots/base.py b/embedchain/bots/base.py index 48cfb71c..0b47a2e8 100644 --- a/embedchain/bots/base.py +++ b/embedchain/bots/base.py @@ -1,26 +1,25 @@ from embedchain import CustomApp -from embedchain.config import AddConfig, CustomAppConfig, QueryConfig +from embedchain.config import AddConfig, CustomAppConfig, LlmConfig +from embedchain.embedder.openai_embedder import OpenAiEmbedder from embedchain.helper_classes.json_serializable import ( JSONSerializable, register_deserializable) -from embedchain.models import EmbeddingFunctions, Providers +from embedchain.llm.openai_llm import OpenAiLlm +from embedchain.vectordb.chroma_db import ChromaDB @register_deserializable class BaseBot(JSONSerializable): - def __init__(self, app_config=None): - if app_config is None: - app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI) - self.app_config = app_config - self.app = CustomApp(config=self.app_config) + def __init__(self): + self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder()) def add(self, data, config: AddConfig = None): """Add data to the bot""" config = config if config else AddConfig() self.app.add(data, config=config) - def query(self, query, config: QueryConfig = None): + def query(self, query, config: LlmConfig = None): """Query bot""" - config = config if config else QueryConfig() + config = config return self.app.query(query, config=config) def start(self): diff --git a/embedchain/bots/discord.py b/embedchain/bots/discord.py index 2a37cb28..fd6716ec 100644 --- a/embedchain/bots/discord.py +++ b/embedchain/bots/discord.py @@ -6,6 +6,8 @@ import discord from discord import app_commands from discord.ext import commands +from embedchain.helper_classes.json_serializable import register_deserializable + from .base import BaseBot intents = discord.Intents.default() @@ -17,6 +19,7 @@ tree = app_commands.CommandTree(client) # https://discord.com/api/oauth2/authorize?client_id={DISCORD_CLIENT_ID}&permissions=2048&scope=bot +@register_deserializable class DiscordBot(BaseBot): def __init__(self, *args, **kwargs): BaseBot.__init__(self, *args, **kwargs) diff --git a/embedchain/bots/poe.py b/embedchain/bots/poe.py index 99bb5321..17938217 100644 --- a/embedchain/bots/poe.py +++ b/embedchain/bots/poe.py @@ -5,7 +5,6 @@ from typing import List, Optional from fastapi_poe import PoeBot, run -from embedchain.config import QueryConfig from embedchain.helper_classes.json_serializable import register_deserializable from .base import BaseBot @@ -46,7 +45,6 @@ class PoeBot(BaseBot, PoeBot): ) except Exception as e: logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}") - logging.warning(history) answer = self.handle_message(last_message, history) yield self.text_event(answer) @@ -69,8 +67,8 @@ class PoeBot(BaseBot, PoeBot): def ask_bot(self, message, history: List[str]): try: - config = QueryConfig(history=history) - response = self.query(message, config) + self.app.llm.set_history(history=history) + response = self.query(message) except Exception: logging.exception(f"Failed to query {message}.") response = "An error occurred. Please try again!" diff --git a/embedchain/config/ChatConfig.py b/embedchain/config/ChatConfig.py deleted file mode 100644 index c6bbcc48..00000000 --- a/embedchain/config/ChatConfig.py +++ /dev/null @@ -1,92 +0,0 @@ -from string import Template -from typing import Optional - -from embedchain.config.QueryConfig import QueryConfig -from embedchain.helper_classes.json_serializable import register_deserializable - -DEFAULT_PROMPT = """ - You are a chatbot having a conversation with a human. You are given chat - history and context. - You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know". - - $context - - History: $history - - Query: $query - - Helpful Answer: -""" # noqa:E501 - -DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT) - - -@register_deserializable -class ChatConfig(QueryConfig): - """ - Config for the `chat` method, inherits from `QueryConfig`. - """ - - def __init__( - self, - number_documents=None, - template: Template = None, - model=None, - temperature=None, - max_tokens=None, - top_p=None, - stream: bool = False, - deployment_name=None, - system_prompt: Optional[str] = None, - where=None, - ): - """ - Initializes the ChatConfig instance. - - :param number_documents: Number of documents to pull from the database as - context. - :param template: Optional. The `Template` instance to use as a template for - prompt. - :param model: Optional. Controls the OpenAI model used. - :param temperature: Optional. Controls the randomness of the model's output. - Higher values (closer to 1) make output more random,lower values make it more - deterministic. - :param max_tokens: Optional. Controls how many tokens are generated. - :param top_p: Optional. Controls the diversity of words.Higher values - (closer to 1) make word selection more diverse, lower values make words less - diverse. - :param stream: Optional. Control if response is streamed back to the user - :param deployment_name: t.b.a. - :param system_prompt: Optional. System prompt string. - :param where: Optional. A dictionary of key-value pairs to filter the database results. - :raises ValueError: If the template is not valid as template should contain - $context and $query and $history - """ - if template is None: - template = DEFAULT_PROMPT_TEMPLATE - - # History is set as 0 to ensure that there is always a history, that way, - # there don't have to be two templates. Having two templates would make it - # complicated because the history is not user controlled. - super().__init__( - number_documents=number_documents, - template=template, - model=model, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - history=[0], - stream=stream, - deployment_name=deployment_name, - system_prompt=system_prompt, - where=where, - ) - - def set_history(self, history): - """ - Chat history is not user provided and not set at initialization time - - :param history: (string) history to set - """ - self.history = history - return diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index fc8bc450..9a1bd5e4 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -1,9 +1,13 @@ -from .AddConfig import AddConfig, ChunkerConfig # noqa: F401 -from .apps.AppConfig import AppConfig # noqa: F401 -from .apps.CustomAppConfig import CustomAppConfig # noqa: F401 -from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401 -from .BaseConfig import BaseConfig # noqa: F401 -from .ChatConfig import ChatConfig # noqa: F401 -from .QueryConfig import QueryConfig # noqa: F401 -from .vectordbs.ElasticsearchDBConfig import \ - ElasticsearchDBConfig # noqa: F401 +# flake8: noqa: F401 + +from .AddConfig import AddConfig, ChunkerConfig +from .apps.AppConfig import AppConfig +from .apps.CustomAppConfig import CustomAppConfig +from .apps.OpenSourceAppConfig import OpenSourceAppConfig +from .BaseConfig import BaseConfig +from .embedder.BaseEmbedderConfig import BaseEmbedderConfig +from .embedder.BaseEmbedderConfig import BaseEmbedderConfig as EmbedderConfig +from .llm.base_llm_config import BaseLlmConfig +from .llm.base_llm_config import BaseLlmConfig as LlmConfig +from .vectordbs.ChromaDbConfig import ChromaDbConfig +from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index 5e74957c..0317551c 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -1,14 +1,5 @@ -import os from typing import Optional -try: - from chromadb.utils import embedding_functions -except RuntimeError: - from embedchain.utils import use_pysqlite3 - - use_pysqlite3() - from chromadb.utils import embedding_functions - from embedchain.helper_classes.json_serializable import register_deserializable from .BaseAppConfig import BaseAppConfig @@ -23,44 +14,14 @@ class AppConfig(BaseAppConfig): def __init__( self, log_level=None, - host=None, - port=None, id=None, - collection_name=None, collect_metrics: Optional[bool] = None, + collection_name: Optional[str] = None, ): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. - :param host: Optional. Hostname for the database server. - :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. - :param collection_name: Optional. Collection name for the database. :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. """ - super().__init__( - log_level=log_level, - embedding_fn=AppConfig.default_embedding_function(), - host=host, - port=port, - id=id, - collection_name=collection_name, - collect_metrics=collect_metrics, - ) - - @staticmethod - def default_embedding_function(): - """ - Sets embedding function to default (`text-embedding-ada-002`). - - :raises ValueError: If the template is not valid as template should contain - $context and $query - :returns: The default embedding function for the app class. - """ - if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None: - raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501 - return embedding_functions.OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), - organization_id=os.getenv("OPENAI_ORGANIZATION"), - model_name="text-embedding-ada-002", - ) + super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name) diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index 890045e1..c781b49a 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -1,9 +1,9 @@ import logging +from typing import Optional from embedchain.config.BaseConfig import BaseConfig -from embedchain.config.vectordbs import ElasticsearchDBConfig from embedchain.helper_classes.json_serializable import JSONSerializable -from embedchain.models import VectorDatabases, VectorDimensions +from embedchain.vectordb.base_vector_db import BaseVectorDB class BaseAppConfig(BaseConfig, JSONSerializable): @@ -14,81 +14,38 @@ class BaseAppConfig(BaseConfig, JSONSerializable): def __init__( self, log_level=None, - embedding_fn=None, - db=None, - host=None, - port=None, + db: Optional[BaseVectorDB] = None, id=None, - collection_name=None, collect_metrics: bool = True, - db_type: VectorDatabases = None, - vector_dim: VectorDimensions = None, - es_config: ElasticsearchDBConfig = None, - chroma_settings: dict = {}, + collection_name: Optional[str] = None, ): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. - :param embedding_fn: Embedding function to use. - :param db: Optional. (Vector) database instance to use for embeddings. - :param host: Optional. Hostname for the database server. - :param port: Optional. Port for the database server. + :param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db). :param id: Optional. ID of the app. Document metadata will have this id. - :param collection_name: Optional. Collection name for the database. :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. - :param db_type: Optional. type of Vector database to use - :param vector_dim: Vector dimension generated by embedding fn + :param db_type: Optional. Initializes a default vector database of the given type. + Using the `db` argument is preferred. :param es_config: Optional. elasticsearch database config to be used for connection - :param chroma_settings: Optional. Chroma settings for connection. + :param collection_name: Optional. Default collection name. + It's recommended to use app.set_collection_name() instead. """ self._setup_logging(log_level) - self.collection_name = collection_name if collection_name else "embedchain_store" - self.db = BaseAppConfig.get_db( - db=db, - embedding_fn=embedding_fn, - host=host, - port=port, - db_type=db_type, - vector_dim=vector_dim, - collection_name=self.collection_name, - es_config=es_config, - chroma_settings=chroma_settings, - ) self.id = id self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False - return + self.collection_name = collection_name - @staticmethod - def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings): - """ - Get db based on db_type, db with default database (`ChromaDb`) - :param Optional. (Vector) database to use for embeddings. - :param embedding_fn: Embedding function to use in database. - :param host: Optional. Hostname for the database server. - :param port: Optional. Port for the database server. - :param db_type: Optional. db type to use. Supported values (`es`, `chroma`) - :param vector_dim: Vector dimension generated by embedding fn - :param collection_name: Optional. Collection name for the database. - :param es_config: Optional. elasticsearch database config to be used for connection - :raises ValueError: BaseAppConfig knows no default embedding function. - :returns: database instance - """ if db: - return db - - if embedding_fn is None: - raise ValueError("ChromaDb cannot be instantiated without an embedding function") - - if db_type == VectorDatabases.ELASTICSEARCH: - from embedchain.vectordb.elasticsearch_db import ElasticsearchDB - - return ElasticsearchDB( - embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config + self._db = db + logging.warning( + "DEPRECATION WARNING: Please supply the database as the second parameter during app init. " + "Such as `app(config=config, db=db)`." ) - from embedchain.vectordb.chroma_db import ChromaDB - - return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings) + if collection_name: + logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.") + return def _setup_logging(self, debug_level): level = logging.WARNING # Default level diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index d6dfe486..72c598da 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -1,12 +1,8 @@ -from typing import Any, Optional +from typing import Optional -from chromadb.api.types import Documents, Embeddings from dotenv import load_dotenv -from embedchain.config.vectordbs import ElasticsearchDBConfig from embedchain.helper_classes.json_serializable import register_deserializable -from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases, - VectorDimensions) from .BaseAppConfig import BaseAppConfig @@ -22,123 +18,23 @@ class CustomAppConfig(BaseAppConfig): def __init__( self, log_level=None, - embedding_fn: EmbeddingFunctions = None, - embedding_fn_model=None, db=None, - host=None, - port=None, id=None, - collection_name=None, - provider: Providers = None, - open_source_app_config=None, - deployment_name=None, collect_metrics: Optional[bool] = None, - db_type: VectorDatabases = None, - es_config: ElasticsearchDBConfig = None, - chroma_settings: dict = {}, + collection_name: Optional[str] = None, ): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. - :param embedding_fn: Optional. Embedding function to use. - :param embedding_fn_model: Optional. Model name to use for embedding function. :param db: Optional. (Vector) database to use for embeddings. - :param host: Optional. Hostname for the database server. - :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. - :param collection_name: Optional. Collection name for the database. :param provider: Optional. (Providers): LLM Provider to use. :param open_source_app_config: Optional. Config instance needed for open source apps. :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. - :param db_type: Optional. type of Vector database to use. - :param es_config: Optional. elasticsearch database config to be used for connection - :param chroma_settings: Optional. Chroma settings for connection. + :param collection_name: Optional. Default collection name. + It's recommended to use app.set_collection_name() instead. """ - if provider: - self.provider = provider - else: - raise ValueError("CustomApp must have a provider assigned.") - - self.open_source_app_config = open_source_app_config super().__init__( - log_level=log_level, - embedding_fn=CustomAppConfig.embedding_function( - embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name - ), - db=db, - host=host, - port=port, - id=id, - collection_name=collection_name, - collect_metrics=collect_metrics, - db_type=db_type, - vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn), - es_config=es_config, - chroma_settings=chroma_settings, + log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name ) - - @staticmethod - def langchain_default_concept(embeddings: Any): - """ - Langchains default function layout for embeddings. - """ - - def embed_function(texts: Documents) -> Embeddings: - return embeddings.embed_documents(texts) - - return embed_function - - @staticmethod - def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None): - if not isinstance(embedding_function, EmbeddingFunctions): - raise ValueError( - f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501 - ) - - if embedding_function == EmbeddingFunctions.OPENAI: - from langchain.embeddings import OpenAIEmbeddings - - if model: - embeddings = OpenAIEmbeddings(model=model) - else: - if deployment_name: - embeddings = OpenAIEmbeddings(deployment=deployment_name) - else: - embeddings = OpenAIEmbeddings() - return CustomAppConfig.langchain_default_concept(embeddings) - - elif embedding_function == EmbeddingFunctions.HUGGING_FACE: - from langchain.embeddings import HuggingFaceEmbeddings - - embeddings = HuggingFaceEmbeddings(model_name=model) - return CustomAppConfig.langchain_default_concept(embeddings) - - elif embedding_function == EmbeddingFunctions.VERTEX_AI: - from langchain.embeddings import VertexAIEmbeddings - - embeddings = VertexAIEmbeddings(model_name=model) - return CustomAppConfig.langchain_default_concept(embeddings) - - elif embedding_function == EmbeddingFunctions.GPT4ALL: - # Note: We could use langchains GPT4ALL embedding, but it's not available in all versions. - from chromadb.utils import embedding_functions - - return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model) - - @staticmethod - def get_vector_dimension(embedding_function: EmbeddingFunctions): - if not isinstance(embedding_function, EmbeddingFunctions): - raise ValueError(f"Invalid option: '{embedding_function}'.") - - if embedding_function == EmbeddingFunctions.OPENAI: - return VectorDimensions.OPENAI.value - - elif embedding_function == EmbeddingFunctions.HUGGING_FACE: - return VectorDimensions.HUGGING_FACE.value - - elif embedding_function == EmbeddingFunctions.VERTEX_AI: - return VectorDimensions.VERTEX_AI.value - - elif embedding_function == EmbeddingFunctions.GPT4ALL: - return VectorDimensions.GPT4ALL.value diff --git a/embedchain/config/apps/OpenSourceAppConfig.py b/embedchain/config/apps/OpenSourceAppConfig.py index a0dd4ca4..af505632 100644 --- a/embedchain/config/apps/OpenSourceAppConfig.py +++ b/embedchain/config/apps/OpenSourceAppConfig.py @@ -1,7 +1,5 @@ from typing import Optional -from chromadb.utils import embedding_functions - from embedchain.helper_classes.json_serializable import register_deserializable from .BaseAppConfig import BaseAppConfig @@ -16,47 +14,21 @@ class OpenSourceAppConfig(BaseAppConfig): def __init__( self, log_level=None, - host=None, - port=None, id=None, - collection_name=None, collect_metrics: Optional[bool] = None, model=None, + collection_name: Optional[str] = None, ): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. :param id: Optional. ID of the app. Document metadata will have this id. - :param collection_name: Optional. Collection name for the database. - :param host: Optional. Hostname for the database server. - :param port: Optional. Port for the database server. :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. :param model: Optional. GPT4ALL uses the model to instantiate the class. So unlike `App`, it has to be provided before querying. + :param collection_name: Optional. Default collection name. + It's recommended to use app.db.set_collection_name() instead. """ self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin" - super().__init__( - log_level=log_level, - embedding_fn=OpenSourceAppConfig.default_embedding_function(), - host=host, - port=port, - id=id, - collection_name=collection_name, - collect_metrics=collect_metrics, - ) - - @staticmethod - def default_embedding_function(): - """ - Sets embedding function to default (`all-MiniLM-L6-v2`). - - :returns: The default embedding function - """ - try: - return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2") - except ValueError as e: - print(e) - raise ModuleNotFoundError( - "The open source app requires extra dependencies. Install with `pip install embedchain[opensource]`" - ) from None + super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name) diff --git a/embedchain/config/embedder/BaseEmbedderConfig.py b/embedchain/config/embedder/BaseEmbedderConfig.py new file mode 100644 index 00000000..3175de86 --- /dev/null +++ b/embedchain/config/embedder/BaseEmbedderConfig.py @@ -0,0 +1,10 @@ +from typing import Optional + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class BaseEmbedderConfig: + def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None): + self.model = model + self.deployment_name = deployment_name diff --git a/embedchain/config/embedder/__init__.py b/embedchain/config/embedder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embedchain/config/llm/__init__.py b/embedchain/config/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embedchain/config/QueryConfig.py b/embedchain/config/llm/base_llm_config.py similarity index 78% rename from embedchain/config/QueryConfig.py rename to embedchain/config/llm/base_llm_config.py index b4c29882..c4c4350b 100644 --- a/embedchain/config/QueryConfig.py +++ b/embedchain/config/llm/base_llm_config.py @@ -50,7 +50,7 @@ history_re = re.compile(r"\$\{*history\}*") @register_deserializable -class QueryConfig(BaseConfig): +class BaseLlmConfig(BaseConfig): """ Config for the `query` method. """ @@ -63,7 +63,6 @@ class QueryConfig(BaseConfig): temperature=None, max_tokens=None, top_p=None, - history=None, stream: bool = False, deployment_name=None, system_prompt: Optional[str] = None, @@ -84,7 +83,6 @@ class QueryConfig(BaseConfig): :param top_p: Optional. Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, lower values make words less diverse. - :param history: Optional. A list of strings to consider as history. :param stream: Optional. Control if response is streamed back to user :param deployment_name: t.b.a. :param system_prompt: Optional. System prompt string. @@ -97,19 +95,8 @@ class QueryConfig(BaseConfig): else: self.number_documents = number_documents - if not history: - self.history = None - else: - if len(history) == 0: - self.history = None - else: - self.history = history - if template is None: - if self.history is None: - template = DEFAULT_PROMPT_TEMPLATE - else: - template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE + template = DEFAULT_PROMPT_TEMPLATE self.temperature = temperature if temperature else 0 self.max_tokens = max_tokens if max_tokens else 1000 @@ -121,10 +108,7 @@ class QueryConfig(BaseConfig): if self.validate_template(template): self.template = template else: - if self.history is None: - raise ValueError("`template` should have `query` and `context` keys") - else: - raise ValueError("`template` should have `query`, `context` and `history` keys") + raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).") if not isinstance(stream, bool): raise ValueError("`stream` should be bool") @@ -138,11 +122,13 @@ class QueryConfig(BaseConfig): :param template: the template to validate :return: Boolean, valid (true) or invalid (false) """ - if self.history is None: - return re.search(query_re, template.template) and re.search(context_re, template.template) - else: - return ( - re.search(query_re, template.template) - and re.search(context_re, template.template) - and re.search(history_re, template.template) - ) + return re.search(query_re, template.template) and re.search(context_re, template.template) + + def _validate_template_history(self, template: Template): + """ + validate the history template for history + + :param template: the template to validate + :return: Boolean, valid (true) or invalid (false) + """ + return re.search(history_re, template.template) diff --git a/embedchain/config/vectordbs/BaseVectorDbConfig.py b/embedchain/config/vectordbs/BaseVectorDbConfig.py new file mode 100644 index 00000000..025ee584 --- /dev/null +++ b/embedchain/config/vectordbs/BaseVectorDbConfig.py @@ -0,0 +1,17 @@ +from typing import Optional + +from embedchain.config.BaseConfig import BaseConfig + + +class BaseVectorDbConfig(BaseConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + host: Optional[str] = None, + port: Optional[str] = None, + ): + self.collection_name = collection_name or "embedchain_store" + self.dir = dir or "db" + self.host = host + self.port = port diff --git a/embedchain/config/vectordbs/ChromaDbConfig.py b/embedchain/config/vectordbs/ChromaDbConfig.py new file mode 100644 index 00000000..5b728402 --- /dev/null +++ b/embedchain/config/vectordbs/ChromaDbConfig.py @@ -0,0 +1,21 @@ +from typing import Optional + +from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class ChromaDbConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + host: Optional[str] = None, + port: Optional[str] = None, + chroma_settings: Optional[dict] = None, + ): + """ + :param chroma_settings: Optional. Chroma settings for connection. + """ + self.chroma_settings = chroma_settings + super().__init__(collection_name=collection_name, dir=dir, host=host, port=port) diff --git a/embedchain/config/vectordbs/ElasticsearchDBConfig.py b/embedchain/config/vectordbs/ElasticsearchDBConfig.py index 691bb778..41775a0b 100644 --- a/embedchain/config/vectordbs/ElasticsearchDBConfig.py +++ b/embedchain/config/vectordbs/ElasticsearchDBConfig.py @@ -1,17 +1,25 @@ -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union -from embedchain.config.BaseConfig import BaseConfig +from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig from embedchain.helper_classes.json_serializable import register_deserializable @register_deserializable -class ElasticsearchDBConfig(BaseConfig): - """ - Config to initialize an elasticsearch client. - :param es_url. elasticsearch url or list of nodes url to be used for connection - :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. - """ - - def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): +class ElasticsearchDBConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + es_url: Union[str, List[str]] = None, + **ES_EXTRA_PARAMS: Dict[str, any], + ): + """ + Config to initialize an elasticsearch client. + :param es_url. elasticsearch url or list of nodes url to be used for connection + :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. + """ + # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): self.ES_URL = es_url self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS + + super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 9f30b9c5..38f12d73 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -11,30 +11,37 @@ from typing import Dict, Optional import requests from dotenv import load_dotenv from langchain.docstore.document import Document -from langchain.memory import ConversationBufferMemory from tenacity import retry, stop_after_attempt, wait_fixed from embedchain.chunkers.base_chunker import BaseChunker -from embedchain.config import AddConfig, ChatConfig, QueryConfig +from embedchain.config import AddConfig, BaseLlmConfig from embedchain.config.apps.BaseAppConfig import BaseAppConfig -from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE from embedchain.data_formatter import DataFormatter +from embedchain.embedder.base_embedder import BaseEmbedder from embedchain.helper_classes.json_serializable import JSONSerializable +from embedchain.llm.base_llm import BaseLlm from embedchain.loaders.base_loader import BaseLoader from embedchain.models.data_type import DataType from embedchain.utils import detect_datatype +from embedchain.vectordb.base_vector_db import BaseVectorDB load_dotenv() ABS_PATH = os.getcwd() -DB_DIR = os.path.join(ABS_PATH, "db") HOME_DIR = str(Path.home()) CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain") CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") class EmbedChain(JSONSerializable): - def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None): + def __init__( + self, + config: BaseAppConfig, + llm: BaseLlm, + db: BaseVectorDB = None, + embedder: BaseEmbedder = None, + system_prompt: Optional[str] = None, + ): """ Initializes the EmbedChain instance, sets up a vector DB client and creates a collection. @@ -44,17 +51,40 @@ class EmbedChain(JSONSerializable): """ self.config = config - self.system_prompt = system_prompt - self.collection = self.config.db._get_or_create_collection(self.config.collection_name) - self.db = self.config.db + + # Add subclasses + ## Llm + self.llm = llm + ## Database + # Database has support for config assignment for backwards compatibility + if db is None and (not hasattr(self.config, "db") or self.config.db is None): + raise ValueError("App requires Database.") + self.db = db or self.config.db + ## Embedder + if embedder is None: + raise ValueError("App requires Embedder.") + self.embedder = embedder + + # Initialize database + self.db._set_embedder(self.embedder) + self.db._initialize() + # Set collection name from app config for backwards compatibility. + if config.collection_name: + self.db.set_collection_name(config.collection_name) + + # Add variables that are "shortcuts" + if system_prompt: + self.llm.config.system_prompt = system_prompt + + # Attributes that aren't subclass related. self.user_asks = [] - self.is_docs_site_instance = False - self.online = False - self.memory = ConversationBufferMemory() # Send anonymous telemetry self.s_id = self.config.id if self.config.id else str(uuid.uuid4()) self.u_id = self._load_or_generate_user_id() + # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event. + # if (self.config.collect_metrics): + # raise ConnectionRefusedError("Collection of metrics should not be allowed.") thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",)) thread_telemetry.start() @@ -227,10 +257,10 @@ class EmbedChain(JSONSerializable): metadatas = new_metadatas # Count before, to calculate a delta in the end. - chunks_before_addition = self.count() + chunks_before_addition = self.db.count() self.db.add(documents=documents, metadatas=metadatas, ids=ids) - count_new_chunks = self.count() - chunks_before_addition + count_new_chunks = self.db.count() - chunks_before_addition print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks @@ -244,13 +274,7 @@ class EmbedChain(JSONSerializable): ) ] - def get_llm_model_answer(self): - """ - Usually implemented by child class - """ - raise NotImplementedError - - def retrieve_from_database(self, input_query, config: QueryConfig, where=None): + def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None): """ Queries the vector database based on the given input query. Gets relevant doc based on the query @@ -260,11 +284,12 @@ class EmbedChain(JSONSerializable): :param where: Optional. A dictionary of key-value pairs to filter the database results. :return: The content of the document that matched your query. """ + query_config = config or self.llm.config if where is not None: where = where - elif config is not None and config.where is not None: - where = config.where + elif query_config is not None and query_config.where is not None: + where = query_config.where else: where = {} @@ -273,64 +298,21 @@ class EmbedChain(JSONSerializable): contents = self.db.query( input_query=input_query, - n_results=config.number_documents, + n_results=query_config.number_documents, where=where, ) return contents - def _append_search_and_context(self, context, web_search_result): - return f"{context}\nWeb Search Result: {web_search_result}" - - def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs): - """ - Generates a prompt based on the given query and context, ready to be - passed to an LLM - - :param input_query: The query to use. - :param contexts: List of similar documents to the query used as context. - :param config: Optional. The `QueryConfig` instance to use as - configuration options. - :return: The prompt - """ - context_string = (" | ").join(contexts) - web_search_result = kwargs.get("web_search_result", "") - if web_search_result: - context_string = self._append_search_and_context(context_string, web_search_result) - if not config.history: - prompt = config.template.substitute(context=context_string, query=input_query) - else: - prompt = config.template.substitute(context=context_string, query=input_query, history=config.history) - return prompt - - def get_answer_from_llm(self, prompt, config: ChatConfig): - """ - Gets an answer based on the given query and context by passing it - to an LLM. - - :param query: The query to use. - :param context: Similar documents to the query used as context. - :return: The answer. - """ - - return self.get_llm_model_answer(prompt, config) - - def access_search_and_get_results(self, input_query): - from langchain.tools import DuckDuckGoSearchRun - - search = DuckDuckGoSearchRun() - logging.info(f"Access search to get answers for {input_query}") - return search.run(input_query) - - def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None): + def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None): """ Queries the vector database based on the given input query. Gets relevant doc based on the query and then passes it to an LLM as context to get the answer. :param input_query: The query to use. - :param config: Optional. The `QueryConfig` instance to use as - configuration options. + :param config: Optional. The `LlmConfig` instance to use as configuration options. + This is used for one method call. To persistently use a config, declare it during app init. :param dry_run: Optional. A dry run does everything except send the resulting prompt to the LLM. The purpose is to test the prompt, not the response. You can use it to test your prompt, including the context provided @@ -340,41 +322,16 @@ class EmbedChain(JSONSerializable): :param where: Optional. A dictionary of key-value pairs to filter the database results. :return: The answer to the query. """ - if config is None: - config = QueryConfig() - if self.is_docs_site_instance: - config.template = DOCS_SITE_PROMPT_TEMPLATE - config.number_documents = 5 - k = {} - if self.online: - k["web_search_result"] = self.access_search_and_get_results(input_query) - contexts = self.retrieve_from_database(input_query, config, where) - prompt = self.generate_prompt(input_query, contexts, config, **k) - logging.info(f"Prompt: {prompt}") - - if dry_run: - return prompt - - answer = self.get_answer_from_llm(prompt, config) + contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) + answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) # Send anonymous telemetry thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("query",)) thread_telemetry.start() - if isinstance(answer, str): - logging.info(f"Answer: {answer}") - return answer - else: - return self._stream_query_response(answer) + return answer - def _stream_query_response(self, answer): - streamed_answer = "" - for chunk in answer: - streamed_answer = streamed_answer + chunk - yield chunk - logging.info(f"Answer: {streamed_answer}") - - def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None): + def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None): """ Queries the vector database on the given input query. Gets relevant doc based on the query and then passes it to an @@ -382,8 +339,8 @@ class EmbedChain(JSONSerializable): Maintains the whole conversation in memory. :param input_query: The query to use. - :param config: Optional. The `ChatConfig` instance to use as - configuration options. + :param config: Optional. The `LlmConfig` instance to use as configuration options. + This is used for one method call. To persistently use a config, declare it during app init. :param dry_run: Optional. A dry run does everything except send the resulting prompt to the LLM. The purpose is to test the prompt, not the response. You can use it to test your prompt, including the context provided @@ -393,50 +350,14 @@ class EmbedChain(JSONSerializable): :param where: Optional. A dictionary of key-value pairs to filter the database results. :return: The answer to the query. """ - if config is None: - config = ChatConfig() - if self.is_docs_site_instance: - config.template = DOCS_SITE_PROMPT_TEMPLATE - config.number_documents = 5 - k = {} - if self.online: - k["web_search_result"] = self.access_search_and_get_results(input_query) - contexts = self.retrieve_from_database(input_query, config, where) - - chat_history = self.memory.load_memory_variables({})["history"] - - if chat_history: - config.set_history(chat_history) - - prompt = self.generate_prompt(input_query, contexts, config, **k) - logging.info(f"Prompt: {prompt}") - - if dry_run: - return prompt - - answer = self.get_answer_from_llm(prompt, config) - - self.memory.chat_memory.add_user_message(input_query) + contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) + answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) # Send anonymous telemetry thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",)) thread_telemetry.start() - if isinstance(answer, str): - self.memory.chat_memory.add_ai_message(answer) - logging.info(f"Answer: {answer}") - return answer - else: - # this is a streamed response and needs to be handled differently. - return self._stream_chat_response(answer) - - def _stream_chat_response(self, answer): - streamed_answer = "" - for chunk in answer: - streamed_answer = streamed_answer + chunk - yield chunk - self.memory.chat_memory.add_ai_message(streamed_answer) - logging.info(f"Answer: {streamed_answer}") + return answer def set_collection(self, collection_name): """ @@ -444,34 +365,36 @@ class EmbedChain(JSONSerializable): :param collection_name: The name of the collection to use. """ - self.collection = self.config.db._get_or_create_collection(collection_name) + self.db.set_collection_name(collection_name) + # Create the collection if it does not exist + self.db._get_or_create_collection(collection_name) + # TODO: Check whether it is necessary to assign to the `self.collection` attribute, + # since the main purpose is the creation. def count(self) -> int: """ Count the number of embeddings. + DEPRECATED IN FAVOR OF `db.count()` + :return: The number of embeddings. """ + logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.") return self.db.count() def reset(self): """ Resets the database. Deletes all embeddings irreversibly. `App` does not have to be reinitialized after using this method. + + DEPRECATED IN FAVOR OF `db.reset()` """ # Send anonymous telemetry thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",)) thread_telemetry.start() - collection_name = self.collection.name + logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.") self.db.reset() - self.collection = self.config.db._get_or_create_collection(collection_name) - # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset. - # A downside of this implementation is, if you have two instances, - # the other instance will not get the updated `self.collection` attribute. - # A better way would be to create the collection if it is called again after being reset. - # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't. - # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do. @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None): diff --git a/embedchain/embedder/__init__.py b/embedchain/embedder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embedchain/embedder/base_embedder.py b/embedchain/embedder/base_embedder.py new file mode 100644 index 00000000..28614351 --- /dev/null +++ b/embedchain/embedder/base_embedder.py @@ -0,0 +1,45 @@ +from typing import Any, Callable, Optional + +from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig + +try: + from chromadb.api.types import Documents, Embeddings +except RuntimeError: + from embedchain.utils import use_pysqlite3 + + use_pysqlite3() + from chromadb.api.types import Documents, Embeddings + + +class BaseEmbedder: + """ + Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers. + + Embedding functions and vector dimensions are set based on the child class you choose. + To manually overwrite you can use this classes `set_...` methods. + """ + + def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError): + if config is None: + self.config = BaseEmbedderConfig() + else: + self.config = config + + def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]): + if not hasattr(embedding_fn, "__call__"): + raise ValueError("Embedding function is not a function") + self.embedding_fn = embedding_fn + + def set_vector_dimension(self, vector_dimension: int): + self.vector_dimension = vector_dimension + + @staticmethod + def _langchain_default_concept(embeddings: Any): + """ + Langchains default function layout for embeddings. + """ + + def embed_function(texts: Documents) -> Embeddings: + return embeddings.embed_documents(texts) + + return embed_function diff --git a/embedchain/embedder/gpt4all_embedder.py b/embedchain/embedder/gpt4all_embedder.py new file mode 100644 index 00000000..9b06393a --- /dev/null +++ b/embedchain/embedder/gpt4all_embedder.py @@ -0,0 +1,21 @@ +from typing import Optional + +from chromadb.utils import embedding_functions + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.models import EmbeddingFunctions + + +class GPT4AllEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + # Note: We could use langchains GPT4ALL embedding, but it's not available in all versions. + super().__init__(config=config) + if self.config.model is None: + self.config.model = "all-MiniLM-L6-v2" + + embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = EmbeddingFunctions.GPT4ALL.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/embedder/huggingface_embedder.py b/embedchain/embedder/huggingface_embedder.py new file mode 100644 index 00000000..9565ad54 --- /dev/null +++ b/embedchain/embedder/huggingface_embedder.py @@ -0,0 +1,19 @@ +from typing import Optional + +from langchain.embeddings import HuggingFaceEmbeddings + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.models import EmbeddingFunctions + + +class HuggingFaceEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config=config) + + embeddings = HuggingFaceEmbeddings(model_name=self.config.model) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = EmbeddingFunctions.HUGGING_FACE.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/embedder/openai_embedder.py b/embedchain/embedder/openai_embedder.py new file mode 100644 index 00000000..b174c6a6 --- /dev/null +++ b/embedchain/embedder/openai_embedder.py @@ -0,0 +1,40 @@ +import os +from typing import Optional + +from langchain.embeddings import OpenAIEmbeddings + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.models import EmbeddingFunctions + +try: + from chromadb.utils import embedding_functions +except RuntimeError: + from embedchain.utils import use_pysqlite3 + + use_pysqlite3() + from chromadb.utils import embedding_functions + + +class OpenAiEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config=config) + if self.config.model is None: + self.config.model = "text-embedding-ada-002" + + if self.config.deployment_name: + embeddings = OpenAIEmbeddings(deployment=self.config.deployment_name) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + else: + if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None: + raise ValueError( + "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided" + ) # noqa:E501 + embedding_fn = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), + organization_id=os.getenv("OPENAI_ORGANIZATION"), + model_name=self.config.model, + ) + + self.set_embedding_fn(embedding_fn=embedding_fn) + self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value) diff --git a/embedchain/embedder/vertexai_embedder.py b/embedchain/embedder/vertexai_embedder.py new file mode 100644 index 00000000..891de5bb --- /dev/null +++ b/embedchain/embedder/vertexai_embedder.py @@ -0,0 +1,19 @@ +from typing import Optional + +from langchain.embeddings import VertexAIEmbeddings + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.models import EmbeddingFunctions + + +class VertexAiEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config=config) + + embeddings = VertexAIEmbeddings(model_name=config.model) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = EmbeddingFunctions.VERTEX_AI.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/llm/__init__.py b/embedchain/llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embedchain/llm/antrophic_llm.py b/embedchain/llm/antrophic_llm.py new file mode 100644 index 00000000..b996e2be --- /dev/null +++ b/embedchain/llm/antrophic_llm.py @@ -0,0 +1,29 @@ +import logging +from typing import Optional + +from embedchain.config import BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class AntrophicLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + return AntrophicLlm._get_athrophic_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str: + from langchain.chat_models import ChatAnthropic + + chat = ChatAnthropic(temperature=config.temperature, model=config.model) + + if config.max_tokens and config.max_tokens != 1000: + logging.warning("Config option `max_tokens` is not supported by this model.") + + messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) + + return chat(messages).content diff --git a/embedchain/llm/azure_openai_llm.py b/embedchain/llm/azure_openai_llm.py new file mode 100644 index 00000000..4ced9ca1 --- /dev/null +++ b/embedchain/llm/azure_openai_llm.py @@ -0,0 +1,39 @@ +import logging +from typing import Optional + +from embedchain.config import BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class AzureOpenAiLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + return AzureOpenAiLlm._get_azure_openai_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_azure_openai_answer(prompt: str, config: BaseLlmConfig) -> str: + from langchain.chat_models import AzureChatOpenAI + + if not config.deployment_name: + raise ValueError("Deployment name must be provided for Azure OpenAI") + + chat = AzureChatOpenAI( + deployment_name=config.deployment_name, + openai_api_version="2023-05-15", + model_name=config.model or "gpt-3.5-turbo", + temperature=config.temperature, + max_tokens=config.max_tokens, + streaming=config.stream, + ) + + if config.top_p and config.top_p != 1: + logging.warning("Config option `top_p` is not supported by this model.") + + messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) + + return chat(messages).content diff --git a/embedchain/llm/base_llm.py b/embedchain/llm/base_llm.py new file mode 100644 index 00000000..4c25173d --- /dev/null +++ b/embedchain/llm/base_llm.py @@ -0,0 +1,214 @@ +import logging +from typing import List, Optional + +from langchain.memory import ConversationBufferMemory +from langchain.schema import BaseMessage +from embedchain.helper_classes.json_serializable import JSONSerializable + +from embedchain.config import BaseLlmConfig +from embedchain.config.llm.base_llm_config import ( + DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, + DOCS_SITE_PROMPT_TEMPLATE) + + +class BaseLlm(JSONSerializable): + def __init__(self, config: Optional[BaseLlmConfig] = None): + if config is None: + self.config = BaseLlmConfig() + else: + self.config = config + + self.memory = ConversationBufferMemory() + self.is_docs_site_instance = False + self.online = False + self.history: any = None + + def get_llm_model_answer(self): + """ + Usually implemented by child class + """ + raise NotImplementedError + + def set_history(self, history: any): + self.history = history + + def update_history(self): + chat_history = self.memory.load_memory_variables({})["history"] + if chat_history: + self.set_history(chat_history) + + def generate_prompt(self, input_query, contexts, **kwargs): + """ + Generates a prompt based on the given query and context, ready to be + passed to an LLM + + :param input_query: The query to use. + :param contexts: List of similar documents to the query used as context. + :param config: Optional. The `QueryConfig` instance to use as + configuration options. + :return: The prompt + """ + context_string = (" | ").join(contexts) + web_search_result = kwargs.get("web_search_result", "") + if web_search_result: + context_string = self._append_search_and_context(context_string, web_search_result) + if not self.history: + prompt = self.config.template.substitute(context=context_string, query=input_query) + else: + # check if it's the default template without history + if ( + not self.config._validate_template_history(self.config.template) + and self.config.template.template == DEFAULT_PROMPT + ): + # swap in the template with history + prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute( + context=context_string, query=input_query, history=self.history + ) + elif not self.config._validate_template_history(self.config.template): + logging.warning("Template does not include `$history` key. History is not included in prompt.") + prompt = self.config.template.substitute(context=context_string, query=input_query) + else: + prompt = self.config.template.substitute( + context=context_string, query=input_query, history=self.history + ) + return prompt + + def _append_search_and_context(self, context, web_search_result): + return f"{context}\nWeb Search Result: {web_search_result}" + + def get_answer_from_llm(self, prompt): + """ + Gets an answer based on the given query and context by passing it + to an LLM. + + :param query: The query to use. + :param context: Similar documents to the query used as context. + :return: The answer. + """ + + return self.get_llm_model_answer(prompt) + + def access_search_and_get_results(self, input_query): + from langchain.tools import DuckDuckGoSearchRun + + search = DuckDuckGoSearchRun() + logging.info(f"Access search to get answers for {input_query}") + return search.run(input_query) + + def _stream_query_response(self, answer): + streamed_answer = "" + for chunk in answer: + streamed_answer = streamed_answer + chunk + yield chunk + logging.info(f"Answer: {streamed_answer}") + + def _stream_chat_response(self, answer): + streamed_answer = "" + for chunk in answer: + streamed_answer = streamed_answer + chunk + yield chunk + self.memory.chat_memory.add_ai_message(streamed_answer) + logging.info(f"Answer: {streamed_answer}") + + def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None): + """ + Queries the vector database based on the given input query. + Gets relevant doc based on the query and then passes it to an + LLM as context to get the answer. + + :param input_query: The query to use. + :param config: Optional. The `LlmConfig` instance to use as configuration options. + This is used for one method call. To persistently use a config, declare it during app init. + :param dry_run: Optional. A dry run does everything except send the resulting prompt to + the LLM. The purpose is to test the prompt, not the response. + You can use it to test your prompt, including the context provided + by the vector database's doc retrieval. + The only thing the dry run does not consider is the cut-off due to + the `max_tokens` parameter. + :param where: Optional. A dictionary of key-value pairs to filter the database results. + :return: The answer to the query. + """ + query_config = config or self.config + + if self.is_docs_site_instance: + query_config.template = DOCS_SITE_PROMPT_TEMPLATE + query_config.number_documents = 5 + k = {} + if self.online: + k["web_search_result"] = self.access_search_and_get_results(input_query) + prompt = self.generate_prompt(input_query, contexts, **k) + logging.info(f"Prompt: {prompt}") + + if dry_run: + return prompt + + answer = self.get_answer_from_llm(prompt) + + if isinstance(answer, str): + logging.info(f"Answer: {answer}") + return answer + else: + return self._stream_query_response(answer) + + def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None): + """ + Queries the vector database on the given input query. + Gets relevant doc based on the query and then passes it to an + LLM as context to get the answer. + + Maintains the whole conversation in memory. + :param input_query: The query to use. + :param config: Optional. The `LlmConfig` instance to use as configuration options. + This is used for one method call. To persistently use a config, declare it during app init. + :param dry_run: Optional. A dry run does everything except send the resulting prompt to + the LLM. The purpose is to test the prompt, not the response. + You can use it to test your prompt, including the context provided + by the vector database's doc retrieval. + The only thing the dry run does not consider is the cut-off due to + the `max_tokens` parameter. + :param where: Optional. A dictionary of key-value pairs to filter the database results. + :return: The answer to the query. + """ + query_config = config or self.config + + if self.is_docs_site_instance: + query_config.template = DOCS_SITE_PROMPT_TEMPLATE + query_config.number_documents = 5 + k = {} + if self.online: + k["web_search_result"] = self.access_search_and_get_results(input_query) + + self.update_history() + + prompt = self.generate_prompt(input_query, contexts, **k) + logging.info(f"Prompt: {prompt}") + + if dry_run: + return prompt + + answer = self.get_answer_from_llm(prompt) + + self.memory.chat_memory.add_user_message(input_query) + + if isinstance(answer, str): + self.memory.chat_memory.add_ai_message(answer) + logging.info(f"Answer: {answer}") + + # NOTE: Adding to history before and after. This could be seen as redundant. + # If we change it, we have to change the tests (no big deal). + self.update_history() + + return answer + else: + # this is a streamed response and needs to be handled differently. + return self._stream_chat_response(answer) + + @staticmethod + def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: + from langchain.schema import HumanMessage, SystemMessage + + messages = [] + if system_prompt: + messages.append(SystemMessage(content=system_prompt)) + messages.append(HumanMessage(content=prompt)) + return messages diff --git a/embedchain/llm/gpt4all_llm.py b/embedchain/llm/gpt4all_llm.py new file mode 100644 index 00000000..1624ae9f --- /dev/null +++ b/embedchain/llm/gpt4all_llm.py @@ -0,0 +1,47 @@ +from typing import Iterable, Optional, Union + +from embedchain.config import BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class GPT4ALLLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + if self.config.model is None: + self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin" + self.instance = GPT4ALLLlm._get_instance(self.config.model) + + def get_llm_model_answer(self, prompt): + return self._get_gpt4all_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_instance(model): + try: + from gpt4all import GPT4All + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501 + ) from None + + return GPT4All(model_name=model) + + def _get_gpt4all_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: + if config.model and config.model != self.config.model: + raise RuntimeError( + "OpenSourceApp does not support switching models at runtime. Please create a new app instance." + ) + + if config.system_prompt: + raise ValueError("OpenSourceApp does not support `system_prompt`") + + response = self.instance.generate( + prompt=prompt, + streaming=config.stream, + top_p=config.top_p, + max_tokens=config.max_tokens, + temp=config.temperature, + ) + return response diff --git a/embedchain/llm/llama2_llm.py b/embedchain/llm/llama2_llm.py new file mode 100644 index 00000000..6a2d90a6 --- /dev/null +++ b/embedchain/llm/llama2_llm.py @@ -0,0 +1,27 @@ +import os +from typing import Optional + +from langchain.llms import Replicate + +from embedchain.config import BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class Llama2Llm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + if "REPLICATE_API_TOKEN" not in os.environ: + raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + # TODO: Move the model and other inputs into config + if self.config.system_prompt: + raise ValueError("Llama2App does not support `system_prompt`") + llm = Replicate( + model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", + input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p}, + ) + return llm(prompt) diff --git a/embedchain/llm/openai_llm.py b/embedchain/llm/openai_llm.py new file mode 100644 index 00000000..320079f7 --- /dev/null +++ b/embedchain/llm/openai_llm.py @@ -0,0 +1,43 @@ +from typing import Optional + +import openai + +from embedchain.config import BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class OpenAiLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + + # NOTE: This class does not use langchain. One reason is that `top_p` is not supported. + + def get_llm_model_answer(self, prompt): + messages = [] + if self.config.system_prompt: + messages.append({"role": "system", "content": self.config.system_prompt}) + messages.append({"role": "user", "content": prompt}) + response = openai.ChatCompletion.create( + model=self.config.model or "gpt-3.5-turbo-0613", + messages=messages, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + stream=self.config.stream, + ) + + if self.config.stream: + return self._stream_llm_model_response(response) + else: + return response["choices"][0]["message"]["content"] + + def _stream_llm_model_response(self, response): + """ + This is a generator for streaming response from the OpenAI completions API + """ + for line in response: + chunk = line["choices"][0].get("delta", {}).get("content", "") + yield chunk diff --git a/embedchain/llm/vertex_ai_llm.py b/embedchain/llm/vertex_ai_llm.py new file mode 100644 index 00000000..b1d47ad6 --- /dev/null +++ b/embedchain/llm/vertex_ai_llm.py @@ -0,0 +1,29 @@ +import logging +from typing import Optional + +from embedchain.config import BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm + +from embedchain.helper_classes.json_serializable import register_deserializable + + +@register_deserializable +class VertexAiLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + return VertexAiLlm._get_athrophic_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str: + from langchain.chat_models import ChatVertexAI + + chat = ChatVertexAI(temperature=config.temperature, model=config.model) + + if config.top_p and config.top_p != 1: + logging.warning("Config option `top_p` is not supported by this model.") + + messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) + + return chat(messages).content diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index f9740c5f..1782faf9 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -1,11 +1,22 @@ +from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig +from embedchain.embedder.base_embedder import BaseEmbedder from embedchain.helper_classes.json_serializable import JSONSerializable class BaseVectorDB(JSONSerializable): """Base class for vector database.""" - def __init__(self): + def __init__(self, config: BaseVectorDbConfig): self.client = self._get_or_create_db() + self.config: BaseVectorDbConfig = config + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + + So it's can't be done in __init__ in one step. + """ + raise NotImplementedError def _get_or_create_db(self): """Get or create the database.""" @@ -14,6 +25,9 @@ class BaseVectorDB(JSONSerializable): def _get_or_create_collection(self): raise NotImplementedError + def _set_embedder(self, embedder: BaseEmbedder): + self.embedder = embedder + def get(self): raise NotImplementedError @@ -28,3 +42,6 @@ class BaseVectorDB(JSONSerializable): def reset(self): raise NotImplementedError + + def set_collection_name(self, name: str): + raise NotImplementedError diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index c019c1bb..c39c968d 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,53 +1,63 @@ import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from chromadb.errors import InvalidDimensionException from langchain.docstore.document import Document +from embedchain.config import ChromaDbConfig +from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.vectordb.base_vector_db import BaseVectorDB + try: import chromadb + from chromadb.config import Settings + from chromadb.errors import InvalidDimensionException except RuntimeError: from embedchain.utils import use_pysqlite3 use_pysqlite3() import chromadb - -from chromadb.config import Settings - -from embedchain.helper_classes.json_serializable import register_deserializable -from embedchain.vectordb.base_vector_db import BaseVectorDB + from chromadb.config import Settings + from chromadb.errors import InvalidDimensionException @register_deserializable class ChromaDB(BaseVectorDB): """Vector database using ChromaDB.""" - def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}): - self.embedding_fn = embedding_fn - - if not hasattr(embedding_fn, "__call__"): - raise ValueError("Embedding function is not a function") + def __init__(self, config: Optional[ChromaDbConfig] = None): + if config: + self.config = config + else: + self.config = ChromaDbConfig() self.settings = Settings() - for key, value in chroma_settings.items(): - if hasattr(self.settings, key): - setattr(self.settings, key, value) + if self.config.chroma_settings: + for key, value in self.config.chroma_settings.items(): + if hasattr(self.settings, key): + setattr(self.settings, key, value) - if host and port: - logging.info(f"Connecting to ChromaDB server: {host}:{port}") - self.settings.chroma_server_host = host - self.settings.chroma_server_http_port = port + if self.config.host and self.config.port: + logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}") + self.settings.chroma_server_host = self.config.host + self.settings.chroma_server_http_port = self.config.port self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI" - else: - if db_dir is None: - db_dir = "db" + if self.config.dir is None: + self.config.dir = "db" - self.settings.persist_directory = db_dir + self.settings.persist_directory = self.config.dir self.settings.is_persistent = True self.client = chromadb.Client(self.settings) - super().__init__() + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + """ + if not self.embedder: + raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") + self._get_or_create_collection(self.config.collection_name) def _get_or_create_db(self): """Get or create the database.""" @@ -55,9 +65,11 @@ class ChromaDB(BaseVectorDB): def _get_or_create_collection(self, name): """Get or create the collection.""" + if not hasattr(self, "embedder") or not self.embedder: + raise ValueError("Cannot create a Chroma database collection without an embedder.") self.collection = self.client.get_or_create_collection( name=name, - embedding_function=self.embedding_fn, + embedding_function=self.embedder.embedding_fn, ) return self.collection @@ -119,9 +131,37 @@ class ChromaDB(BaseVectorDB): contents = [result[0].page_content for result in results_formatted] return contents + def set_collection_name(self, name: str): + self.config.collection_name = name + self._get_or_create_collection(self.config.collection_name) + def count(self) -> int: + """ + Count the number of embeddings. + + :return: The number of embeddings. + """ return self.collection.count() def reset(self): + """ + Resets the database. Deletes all embeddings irreversibly. + `App` does not have to be reinitialized after using this method. + """ # Delete all data from the database - self.client.reset() + try: + self.client.reset() + except ValueError: + raise ValueError( + "For safety reasons, resetting is disabled." + 'Please enable it by including `chromadb_settings={"allow_reset": True}` in your ChromaDbConfig' + ) from None + # Recreate + self._get_or_create_collection(self.config.collection_name) + + # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset. + # A downside of this implementation is, if you have two instances, + # the other instance will not get the updated `self.collection` attribute. + # A better way would be to create the collection if it is called again after being reset. + # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't. + # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do. diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 6c75909a..7a1d93f5 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List +from typing import Any, Dict, List try: from elasticsearch import Elasticsearch @@ -10,7 +10,6 @@ except ImportError: from embedchain.config import ElasticsearchDBConfig from embedchain.helper_classes.json_serializable import register_deserializable -from embedchain.models.VectorDimensions import VectorDimensions from embedchain.vectordb.base_vector_db import BaseVectorDB @@ -18,43 +17,40 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB class ElasticsearchDB(BaseVectorDB): def __init__( self, - es_config: ElasticsearchDBConfig = None, - embedding_fn: Callable[[list[str]], list[str]] = None, - vector_dim: VectorDimensions = None, - collection_name: str = None, + config: ElasticsearchDBConfig = None, + es_config: ElasticsearchDBConfig = None, # Backwards compatibility ): """ Elasticsearch as vector database :param es_config. elasticsearch database config to be used for connection :param embedding_fn: Function to generate embedding vectors. :param vector_dim: Vector dimension generated by embedding fn - :param collection_name: Optional. Collection name for the database. """ - if not hasattr(embedding_fn, "__call__"): - raise ValueError("Embedding function is not a function") - if es_config is None: + if config is None and es_config is None: raise ValueError("ElasticsearchDBConfig is required") - if vector_dim is None: - raise ValueError("Vector Dimension is required to refer correct index and mapping") - if collection_name is None: - raise ValueError("collection name is required. It cannot be empty") - self.embedding_fn = embedding_fn + self.config = config or es_config self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS) - self.vector_dim = vector_dim - self.es_index = f"{collection_name}_{self.vector_dim}" + + # Call parent init here because embedder is needed + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + """ index_settings = { "mappings": { "properties": { "text": {"type": "text"}, - "embeddings": {"type": "dense_vector", "index": False, "dims": self.vector_dim}, + "embeddings": {"type": "dense_vector", "index": False, "dims": self.embedder.vector_dimension}, } } } - if not self.client.indices.exists(index=self.es_index): + es_index = self._get_index() + if not self.client.indices.exists(index=es_index): # create index if not exist - print("Creating index", self.es_index, index_settings) - self.client.indices.create(index=self.es_index, body=index_settings) - super().__init__() + print("Creating index", es_index, index_settings) + self.client.indices.create(index=es_index, body=index_settings) def _get_or_create_db(self): return self.client @@ -85,17 +81,17 @@ class ElasticsearchDB(BaseVectorDB): :param ids: ids of docs """ docs = [] - embeddings = self.embedding_fn(documents) + embeddings = self.config.embedding_fn(documents) for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): docs.append( { - "_index": self.es_index, + "_index": self._get_index(), "_id": id, "_source": {"text": text, "metadata": metadata, "embeddings": embeddings}, } ) bulk(self.client, docs) - self.client.indices.refresh(index=self.es_index) + self.client.indices.refresh(index=self._get_index()) return def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: @@ -105,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB): :param n_results: no of similar documents to fetch from database :param where: Optional. to filter data """ - input_query_vector = self.embedding_fn(input_query) + input_query_vector = self.config.embedding_fn(input_query) query_vector = input_query_vector[0] query = { "script_score": { @@ -120,11 +116,14 @@ class ElasticsearchDB(BaseVectorDB): app_id = where["app_id"] query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] _source = ["text"] - response = self.client.search(index=self.es_index, query=query, _source=_source, size=n_results) + response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results) docs = response["hits"]["hits"] contents = [doc["_source"]["text"] for doc in docs] return contents + def set_collection_name(self, name: str): + self.config.collection_name = name + def count(self) -> int: query = {"match_all": {}} response = self.client.count(index=self.es_index, query=query) @@ -136,3 +135,8 @@ class ElasticsearchDB(BaseVectorDB): if self.client.indices.exists(index=self.es_index): # delete index in Es self.client.indices.delete(index=self.es_index) + + def _get_index(self): + # NOTE: The method is preferred to an attribute, because if collection name changes, + # it's always up-to-date. + return f"{self.config.collection_name}_{self.config.vector_dim}" diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index df2f1f81..86ad452e 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -3,8 +3,7 @@ import unittest from unittest.mock import patch from embedchain import App -from embedchain.config import AppConfig, CustomAppConfig -from embedchain.models import EmbeddingFunctions, Providers +from embedchain.config import AppConfig, ChromaDbConfig class TestChromaDbHostsLoglevel(unittest.TestCase): @@ -13,8 +12,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): @patch("chromadb.api.models.Collection.Collection.add") @patch("chromadb.api.models.Collection.Collection.get") @patch("embedchain.embedchain.EmbedChain.retrieve_from_database") - @patch("embedchain.embedchain.EmbedChain.get_answer_from_llm") - @patch("embedchain.embedchain.EmbedChain.get_llm_model_answer") + @patch("embedchain.llm.base_llm.BaseLlm.get_answer_from_llm") + @patch("embedchain.llm.base_llm.BaseLlm.get_llm_model_answer") def test_whole_app( self, _mock_get, @@ -43,17 +42,14 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): """ Test if the `App` instance is correctly reconstructed after a reset. """ - app = App( - CustomAppConfig( - provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True} - ) - ) + config = AppConfig(log_level="DEBUG", collect_metrics=False) + app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True})) app.reset() # Make sure the client is still healthy app.db.client.heartbeat() # Make sure the collection exists, and can be added to - app.collection.add( + app.db.collection.add( embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]], metadatas=[ {"chapter": "3", "verse": "16"}, diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py index aaa6705d..36b0f045 100644 --- a/tests/helper_classes/test_json_serializable.py +++ b/tests/helper_classes/test_json_serializable.py @@ -59,7 +59,7 @@ class TestJsonSerializable(unittest.TestCase): def test_recursive(self): """Test recursiveness with the real app""" random_id = str(random.random()) - config = AppConfig(id=random_id) + config = AppConfig(id=random_id, collect_metrics=False) # config class is set under app.config. app = App(config=config) # w/o recursion it would just be @@ -67,4 +67,5 @@ class TestJsonSerializable(unittest.TestCase): new_app: App = App.deserialize(s) # The id of the new app is the same as the first one. self.assertEqual(random_id, new_app.config.id) + # We have proven that a nested class (app.config) can be serialized and deserialized just the same. # TODO: test deeper recursion diff --git a/tests/embedchain/test_chat.py b/tests/llm/test_chat.py similarity index 62% rename from tests/embedchain/test_chat.py rename to tests/llm/test_chat.py index 874f698f..cb50a895 100644 --- a/tests/embedchain/test_chat.py +++ b/tests/llm/test_chat.py @@ -1,9 +1,11 @@ + import os import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import patch, MagicMock from embedchain import App -from embedchain.config import AppConfig, ChatConfig +from embedchain.config import AppConfig, BaseLlmConfig +from embedchain.llm.base_llm import BaseLlm class TestApp(unittest.TestCase): @@ -12,7 +14,7 @@ class TestApp(unittest.TestCase): self.app = App(config=AppConfig(collect_metrics=False)) @patch.object(App, "retrieve_from_database", return_value=["Test context"]) - @patch.object(App, "get_answer_from_llm", return_value="Test answer") + @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer") def test_chat_with_memory(self, mock_get_answer, mock_retrieve): """ This test checks the functionality of the 'chat' method in the App class with respect to the chat history @@ -28,13 +30,36 @@ class TestApp(unittest.TestCase): The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and 'memory' methods. """ - app = App() + config = AppConfig(collect_metrics=False) + app = App(config=config) first_answer = app.chat("Test query 1") self.assertEqual(first_answer, "Test answer") - self.assertEqual(len(app.memory.chat_memory.messages), 2) + self.assertEqual(len(app.llm.memory.chat_memory.messages), 2) + self.assertEqual(len(app.llm.history.splitlines()), 2) second_answer = app.chat("Test query 2") self.assertEqual(second_answer, "Test answer") - self.assertEqual(len(app.memory.chat_memory.messages), 4) + self.assertEqual(len(app.llm.memory.chat_memory.messages), 4) + self.assertEqual(len(app.llm.history.splitlines()), 4) + + @patch.object(App, "retrieve_from_database", return_value=["Test context"]) + @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer") + def test_template_replacement(self, mock_get_answer, mock_retrieve): + """ + Tests that if a default template is used and it doesn't contain history, + the default template is swapped in. + + Also tests that a dry run does not change the history + """ + config = AppConfig(collect_metrics=False) + app = App(config=config) + first_answer = app.chat("Test query 1") + self.assertEqual(first_answer, "Test answer") + self.assertEqual(len(app.llm.history.splitlines()), 2) + history = app.llm.history + dry_run = app.chat("Test query 2", dry_run=True) + self.assertIn("History:", dry_run) + self.assertEqual(history, app.llm.history) + self.assertEqual(len(app.llm.history.splitlines()), 2) @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_chat_with_where_in_params(self): @@ -57,13 +82,14 @@ class TestApp(unittest.TestCase): """ with patch.object(self.app, "retrieve_from_database") as mock_retrieve: mock_retrieve.return_value = ["Test context"] - with patch.object(self.app, "get_llm_model_answer") as mock_answer: + with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" - answer = self.app.chat("Test chat", where={"attribute": "value"}) + answer = self.app.chat("Test query", where={"attribute": "value"}) self.assertEqual(answer, "Test answer") - self.assertEqual(mock_retrieve.call_args[0][0], "Test chat") - self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"}) + _args, kwargs = mock_retrieve.call_args + self.assertEqual(kwargs.get('input_query'), "Test query") + self.assertEqual(kwargs.get('where'), {"attribute": "value"}) mock_answer.assert_called_once() @patch("chromadb.api.models.Collection.Collection.add", MagicMock) @@ -85,15 +111,15 @@ class TestApp(unittest.TestCase): The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and 'get_llm_model_answer' methods. """ - with patch.object(self.app, "retrieve_from_database") as mock_retrieve: - mock_retrieve.return_value = ["Test context"] - with patch.object(self.app, "get_llm_model_answer") as mock_answer: - mock_answer.return_value = "Test answer" - chatConfig = ChatConfig(where={"attribute": "value"}) - answer = self.app.chat("Test chat", chatConfig) + with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + with patch.object(self.app.db, "query") as mock_database_query: + mock_database_query.return_value = ["Test context"] + queryConfig = BaseLlmConfig(where={"attribute": "value"}) + answer = self.app.chat("Test query", queryConfig) self.assertEqual(answer, "Test answer") - self.assertEqual(mock_retrieve.call_args[0][0], "Test chat") - self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"}) - self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig) + _args, kwargs = mock_database_query.call_args + self.assertEqual(kwargs.get('input_query'), "Test query") + self.assertEqual(kwargs.get('where'), {"attribute": "value"}) mock_answer.assert_called_once() diff --git a/tests/embedchain/test_generate_prompt.py b/tests/llm/test_generate_prompt.py similarity index 79% rename from tests/embedchain/test_generate_prompt.py rename to tests/llm/test_generate_prompt.py index bc26d2d5..13b664ec 100644 --- a/tests/embedchain/test_generate_prompt.py +++ b/tests/llm/test_generate_prompt.py @@ -2,7 +2,7 @@ import unittest from string import Template from embedchain import App -from embedchain.config import AppConfig, QueryConfig +from embedchain.config import AppConfig, BaseLlmConfig class TestGeneratePrompt(unittest.TestCase): @@ -23,10 +23,11 @@ class TestGeneratePrompt(unittest.TestCase): input_query = "Test query" contexts = ["Context 1", "Context 2", "Context 3"] template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:" - config = QueryConfig(template=Template(template)) + config = BaseLlmConfig(template=Template(template)) + self.app.llm.config = config # Execute - result = self.app.generate_prompt(input_query, contexts, config) + result = self.app.llm.generate_prompt(input_query, contexts) # Assert expected_result = ( @@ -45,10 +46,11 @@ class TestGeneratePrompt(unittest.TestCase): # Setup input_query = "Test query" contexts = ["Context 1", "Context 2", "Context 3"] - config = QueryConfig() + config = BaseLlmConfig() # Execute - result = self.app.generate_prompt(input_query, contexts, config) + self.app.llm.config = config + result = self.app.llm.generate_prompt(input_query, contexts) # Assert expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query) @@ -58,9 +60,11 @@ class TestGeneratePrompt(unittest.TestCase): """ Test the 'generate_prompt' method with QueryConfig containing a history attribute. """ - config = QueryConfig(history=["Past context 1", "Past context 2"]) + config = BaseLlmConfig() config.template = Template("Context: $context | Query: $query | History: $history") - prompt = self.app.generate_prompt("Test query", ["Test context"], config) + self.app.llm.config = config + self.app.llm.set_history(["Past context 1", "Past context 2"]) + prompt = self.app.llm.generate_prompt("Test query", ["Test context"]) expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']" self.assertEqual(prompt, expected_prompt) diff --git a/tests/embedchain/test_query.py b/tests/llm/test_query.py similarity index 69% rename from tests/embedchain/test_query.py rename to tests/llm/test_query.py index 521686f6..55bbb766 100644 --- a/tests/embedchain/test_query.py +++ b/tests/llm/test_query.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import MagicMock, patch from embedchain import App -from embedchain.config import AppConfig, QueryConfig +from embedchain.config import AppConfig, BaseLlmConfig class TestApp(unittest.TestCase): @@ -33,29 +33,35 @@ class TestApp(unittest.TestCase): """ with patch.object(self.app, "retrieve_from_database") as mock_retrieve: mock_retrieve.return_value = ["Test context"] - with patch.object(self.app, "get_llm_model_answer") as mock_answer: + with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" - answer = self.app.query("Test query") + _answer = self.app.query(input_query="Test query") - self.assertEqual(answer, "Test answer") - self.assertEqual(mock_retrieve.call_args[0][0], "Test query") - self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig) + # Ensure retrieve_from_database was called + mock_retrieve.assert_called_once() + + # Check the call arguments + args, kwargs = mock_retrieve.call_args + input_query_arg = kwargs.get("input_query") + self.assertEqual(input_query_arg, "Test query") mock_answer.assert_called_once() @patch("openai.ChatCompletion.create") def test_query_config_app_passing(self, mock_create): mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response - config = AppConfig() - chat_config = QueryConfig(system_prompt="Test system prompt") - app = App(config=config) + config = AppConfig(collect_metrics=False) + chat_config = BaseLlmConfig(system_prompt="Test system prompt") + app = App(config=config, llm_config=chat_config) - app.get_llm_model_answer("Test query", chat_config) + app.llm.get_llm_model_answer("Test query") # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument messages_arg = mock_create.call_args.kwargs["messages"] - self.assertEqual(messages_arg[0]["role"], "system") - self.assertEqual(messages_arg[0]["content"], "Test system prompt") + self.assertTrue(messages_arg[0].get("role"), "system") + self.assertEqual(messages_arg[0].get("content"), "Test system prompt") + self.assertTrue(messages_arg[1].get("role"), "user") + self.assertEqual(messages_arg[1].get("content"), "Test query") # TODO: Add tests for other config variables @@ -63,16 +69,18 @@ class TestApp(unittest.TestCase): def test_app_passing(self, mock_create): mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response - config = AppConfig() - chat_config = QueryConfig() - app = App(config=config, system_prompt="Test system prompt") + config = AppConfig(collect_metrics=False) + chat_config = BaseLlmConfig() + app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt") - app.get_llm_model_answer("Test query", chat_config) + self.assertEqual(app.llm.config.system_prompt, "Test system prompt") + + app.llm.get_llm_model_answer("Test query") # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument messages_arg = mock_create.call_args.kwargs["messages"] - self.assertEqual(messages_arg[0]["role"], "system") - self.assertEqual(messages_arg[0]["content"], "Test system prompt") + self.assertTrue(messages_arg[0].get("role"), "system") + self.assertEqual(messages_arg[0].get("content"), "Test system prompt") @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_query_with_where_in_params(self): @@ -95,13 +103,14 @@ class TestApp(unittest.TestCase): """ with patch.object(self.app, "retrieve_from_database") as mock_retrieve: mock_retrieve.return_value = ["Test context"] - with patch.object(self.app, "get_llm_model_answer") as mock_answer: + with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" answer = self.app.query("Test query", where={"attribute": "value"}) self.assertEqual(answer, "Test answer") - self.assertEqual(mock_retrieve.call_args[0][0], "Test query") - self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"}) + _args, kwargs = mock_retrieve.call_args + self.assertEqual(kwargs.get('input_query'), "Test query") + self.assertEqual(kwargs.get('where'), {"attribute": "value"}) mock_answer.assert_called_once() @patch("chromadb.api.models.Collection.Collection.add", MagicMock) @@ -123,15 +132,16 @@ class TestApp(unittest.TestCase): The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and 'get_llm_model_answer' methods. """ - with patch.object(self.app, "retrieve_from_database") as mock_retrieve: - mock_retrieve.return_value = ["Test context"] - with patch.object(self.app, "get_llm_model_answer") as mock_answer: - mock_answer.return_value = "Test answer" - queryConfig = QueryConfig(where={"attribute": "value"}) + + with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + with patch.object(self.app.db, "query") as mock_database_query: + mock_database_query.return_value = ["Test context"] + queryConfig = BaseLlmConfig(where={"attribute": "value"}) answer = self.app.query("Test query", queryConfig) self.assertEqual(answer, "Test answer") - self.assertEqual(mock_retrieve.call_args[0][0], "Test query") - self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"}) - self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig) + _args, kwargs = mock_database_query.call_args + self.assertEqual(kwargs.get('input_query'), "Test query") + self.assertEqual(kwargs.get('where'), {"attribute": "value"}) mock_answer.assert_called_once() diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index d1e95d4c..3188289b 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -3,8 +3,10 @@ import unittest from unittest.mock import patch +from chromadb.config import Settings + from embedchain import App -from embedchain.config import AppConfig, CustomAppConfig +from embedchain.config import AppConfig, ChromaDbConfig from embedchain.models import EmbeddingFunctions, Providers from embedchain.vectordb.chroma_db import ChromaDB @@ -16,8 +18,9 @@ class TestChromaDbHosts(unittest.TestCase): """ host = "test-host" port = "1234" + config = ChromaDbConfig(host=host, port=port) - db = ChromaDB(host=host, port=port, embedding_fn=len) + db = ChromaDB(config=config) settings = db.client.get_settings() self.assertEqual(settings.chroma_server_host, host) self.assertEqual(settings.chroma_server_http_port, port) @@ -31,7 +34,8 @@ class TestChromaDbHosts(unittest.TestCase): "chroma_client_auth_credentials": "admin:admin", } - db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings) + config = ChromaDbConfig(host=host, port=port, chroma_settings=chroma_auth_settings) + db = ChromaDB(config=config) settings = db.client.get_settings() self.assertEqual(settings.chroma_server_host, host) self.assertEqual(settings.chroma_server_http_port, port) @@ -44,37 +48,41 @@ class TestChromaDbHosts(unittest.TestCase): # Review this test class TestChromaDbHostsInit(unittest.TestCase): @patch("embedchain.vectordb.chroma_db.chromadb.Client") - def test_init_with_host_and_port(self, mock_client): + def test_app_init_with_host_and_port(self, mock_client): """ Test if the `App` instance is initialized with the correct host and port values. """ host = "test-host" port = "1234" - config = AppConfig(host=host, port=port, collect_metrics=False) + config = AppConfig(collect_metrics=False) + chromadb_config = ChromaDbConfig(host=host, port=port) - _app = App(config) + _app = App(config, chromadb_config=chromadb_config) - # self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host) - # self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port) + called_settings: Settings = mock_client.call_args[0][0] + + self.assertEqual(called_settings.chroma_server_host, host) + self.assertEqual(called_settings.chroma_server_http_port, port) class TestChromaDbHostsNone(unittest.TestCase): @patch("embedchain.vectordb.chroma_db.chromadb.Client") - def test_init_with_host_and_port(self, mock_client): + def test_init_with_host_and_port_none(self, mock_client): """ Test if the `App` instance is initialized without default hosts and ports. """ _app = App(config=AppConfig(collect_metrics=False)) - self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) - self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None) + called_settings: Settings = mock_client.call_args[0][0] + self.assertEqual(called_settings.chroma_server_host, None) + self.assertEqual(called_settings.chroma_server_http_port, None) class TestChromaDbHostsLoglevel(unittest.TestCase): @patch("embedchain.vectordb.chroma_db.chromadb.Client") - def test_init_with_host_and_port(self, mock_client): + def test_init_with_host_and_port_log_level(self, mock_client): """ Test if the `App` instance is initialized without a config that does not contain default hosts and ports. """ @@ -87,11 +95,10 @@ class TestChromaDbHostsLoglevel(unittest.TestCase): class TestChromaDbDuplicateHandling: - app_with_settings = App( - CustomAppConfig( - provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True} - ) - ) + chroma_settings = {"allow_reset": True} + chroma_config = ChromaDbConfig(chroma_settings=chroma_settings) + app_config = AppConfig(collection_name=False, collect_metrics=False) + app_with_settings = App(config=app_config, chromadb_config=chroma_config) def test_duplicates_throw_warning(self, caplog): """ @@ -101,8 +108,8 @@ class TestChromaDbDuplicateHandling: self.app_with_settings.reset() app = App(config=AppConfig(collect_metrics=False)) - app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) assert "Insert of existing embedding ID: 0" in caplog.text assert "Add of existing embedding ID: 0" in caplog.text @@ -117,19 +124,18 @@ class TestChromaDbDuplicateHandling: app = App(config=AppConfig(collect_metrics=False)) app.set_collection("test_collection_1") - app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.set_collection("test_collection_2") - app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) assert "Insert of existing embedding ID: 0" not in caplog.text # not assert "Add of existing embedding ID: 0" not in caplog.text # not class TestChromaDbCollection(unittest.TestCase): - app_with_settings = App( - CustomAppConfig( - provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True} - ) - ) + chroma_settings = {"allow_reset": True} + chroma_config = ChromaDbConfig(chroma_settings=chroma_settings) + app_config = AppConfig(collection_name=False, collect_metrics=False) + app_with_settings = App(config=app_config, chromadb_config=chroma_config) def test_init_with_default_collection(self): """ @@ -137,16 +143,17 @@ class TestChromaDbCollection(unittest.TestCase): """ app = App(config=AppConfig(collect_metrics=False)) - self.assertEqual(app.collection.name, "embedchain_store") + self.assertEqual(app.db.collection.name, "embedchain_store") def test_init_with_custom_collection(self): """ Test if the `App` instance is initialized with the correct custom collection name. """ - config = AppConfig(collection_name="test_collection", collect_metrics=False) - app = App(config) + config = AppConfig(collect_metrics=False) + app = App(config=config) + app.set_collection(collection_name="test_collection") - self.assertEqual(app.collection.name, "test_collection") + self.assertEqual(app.db.collection.name, "test_collection") def test_set_collection(self): """ @@ -155,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase): app = App(config=AppConfig(collect_metrics=False)) app.set_collection("test_collection") - self.assertEqual(app.collection.name, "test_collection") + self.assertEqual(app.db.collection.name, "test_collection") def test_changes_encapsulated(self): """ @@ -169,7 +176,7 @@ class TestChromaDbCollection(unittest.TestCase): # Collection should be empty when created self.assertEqual(app.count(), 0) - app.collection.add(embeddings=[0, 0, 0], ids=["0"]) + app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) # After adding, should contain one item self.assertEqual(app.count(), 1) @@ -178,7 +185,7 @@ class TestChromaDbCollection(unittest.TestCase): self.assertEqual(app.count(), 0) # Adding to new collection should not effect existing collection - app.collection.add(embeddings=[0, 0, 0], ids=["0"]) + app.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) app.set_collection("test_collection_1") # Should still be 1, not 2. self.assertEqual(app.count(), 1) @@ -192,7 +199,7 @@ class TestChromaDbCollection(unittest.TestCase): app = App(config=AppConfig(collect_metrics=False)) app.set_collection("test_collection_1") - app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) + app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) del app app = App(config=AppConfig(collect_metrics=False)) @@ -213,13 +220,13 @@ class TestChromaDbCollection(unittest.TestCase): app2 = App(AppConfig(collection_name="test_collection_2", collect_metrics=False)) # app2 has been created last, but adding to app1 will still write to collection 1. - app1.collection.add(embeddings=[0, 0, 0], ids=["0"]) - self.assertEqual(app1.count(), 1) - self.assertEqual(app2.count(), 0) + app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) + self.assertEqual(app1.db.count(), 1) + self.assertEqual(app2.db.count(), 0) # Add data - app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) - app2.collection.add(embeddings=[0, 0, 0], ids=["0"]) + app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"]) + app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"]) # Swap names and test app1.set_collection("test_collection_2") @@ -235,12 +242,14 @@ class TestChromaDbCollection(unittest.TestCase): self.app_with_settings.reset() # Create two apps - app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False)) - app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) + app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) + app1.set_collection("one_collection") + app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) + app2.set_collection("one_collection") # Add data - app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) - app2.collection.add(embeddings=[0, 0, 0], ids=["2"]) + app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"]) + app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) # Both should have the same collection self.assertEqual(app1.count(), 3) @@ -255,25 +264,20 @@ class TestChromaDbCollection(unittest.TestCase): # Create four apps. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. - app1 = App( - CustomAppConfig( - collection_name="one_collection", - id="new_app_id_1", - collect_metrics=False, - provider=Providers.OPENAI, - embedding_fn=EmbeddingFunctions.OPENAI, - chroma_settings={"allow_reset": True}, - ) - ) - app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) - app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False)) - app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False)) + app1 = App(AppConfig(id="new_app_id_1", collect_metrics=False), chromadb_config=self.chroma_config) + app1.set_collection("one_collection") + app2 = App(AppConfig(id="new_app_id_2", collect_metrics=False)) + app2.set_collection("one_collection") + app3 = App(AppConfig(id="new_app_id_1", collect_metrics=False)) + app3.set_collection("three_collection") + app4 = App(AppConfig(id="new_app_id_4", collect_metrics=False)) + app4.set_collection("four_collection") # Each one of them get data - app1.collection.add(embeddings=[0, 0, 0], ids=["1"]) - app2.collection.add(embeddings=[0, 0, 0], ids=["2"]) - app3.collection.add(embeddings=[0, 0, 0], ids=["3"]) - app4.collection.add(embeddings=[0, 0, 0], ids=["4"]) + app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"]) + app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"]) + app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"]) + app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"]) # Resetting the first one should reset them all. app1.reset() diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 4f316eae..ed75030b 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -1,7 +1,7 @@ import unittest -from unittest.mock import Mock from embedchain.config import ElasticsearchDBConfig +from embedchain.embedder.base_embedder import BaseEmbedder from embedchain.vectordb.elasticsearch_db import ElasticsearchDB @@ -10,24 +10,20 @@ class TestEsDB(unittest.TestCase): self.es_config = ElasticsearchDBConfig() self.vector_dim = 384 - def test_init_with_invalid_embedding_fn(self): - # Test if an exception is raised when an invalid embedding_fn is provided - with self.assertRaises(ValueError): - ElasticsearchDB(embedding_fn=None) - def test_init_with_invalid_es_config(self): # Test if an exception is raised when an invalid es_config is provided with self.assertRaises(ValueError): - ElasticsearchDB(embedding_fn=Mock(), es_config=None) + ElasticsearchDB(es_config=None) def test_init_with_invalid_vector_dim(self): # Test if an exception is raised when an invalid vector_dim is provided + embedder = BaseEmbedder() + embedder.set_vector_dimension(None) with self.assertRaises(ValueError): - ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None) + ElasticsearchDB(es_config=self.es_config) def test_init_with_invalid_collection_name(self): # Test if an exception is raised when an invalid collection_name is provided + self.es_config.collection_name = None with self.assertRaises(ValueError): - ElasticsearchDB( - embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None - ) + ElasticsearchDB(es_config=self.es_config)