refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -8,3 +8,5 @@ from embedchain.apps.Llama2App import Llama2App # noqa: F401
from embedchain.apps.OpenSourceApp import OpenSourceApp # noqa: F401
from embedchain.apps.PersonApp import (PersonApp, # noqa: F401
PersonOpenSourceApp)
from embedchain.vectordb.chroma_db import ChromaDB # noqa: F401
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB # noqa: F401

View File

@@ -1,10 +1,12 @@
from typing import Optional
import openai
from embedchain.config import AppConfig, ChatConfig
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChromaDbConfig)
from embedchain.embedchain import EmbedChain
from embedchain.embedder.openai_embedder import OpenAiEmbedder
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.openai_llm import OpenAiLlm
from embedchain.vectordb.chroma_db import ChromaDB
@register_deserializable
@@ -18,7 +20,13 @@ class App(EmbedChain):
dry_run(query): test your prompt without consuming tokens.
"""
def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
def __init__(
self,
config: AppConfig = None,
llm_config: BaseLlmConfig = None,
chromadb_config: Optional[ChromaDbConfig] = None,
system_prompt: Optional[str] = None,
):
"""
:param config: AppConfig instance to load as configuration. Optional.
:param system_prompt: System prompt string. Optional.
@@ -26,38 +34,8 @@ class App(EmbedChain):
if config is None:
config = AppConfig()
super().__init__(config, system_prompt)
llm = OpenAiLlm(config=llm_config)
embedder = OpenAiEmbedder(config=BaseEmbedderConfig(model="text-embedding-ada-002"))
database = ChromaDB(config=chromadb_config)
def get_llm_model_answer(self, prompt, config: ChatConfig):
messages = []
system_prompt = (
self.system_prompt
if self.system_prompt is not None
else config.system_prompt
if config.system_prompt is not None
else None
)
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = openai.ChatCompletion.create(
model=config.model or "gpt-3.5-turbo-0613",
messages=messages,
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stream=config.stream,
)
if config.stream:
return self._stream_llm_model_response(response)
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk
super().__init__(config, llm, db=database, embedder=embedder, system_prompt=system_prompt)

View File

@@ -1,12 +1,11 @@
import logging
from typing import List, Optional
from typing import Optional
from langchain.schema import BaseMessage
from embedchain.config import ChatConfig, CustomAppConfig
from embedchain.config import CustomAppConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models import Providers
from embedchain.llm.base_llm import BaseLlm
from embedchain.vectordb.base_vector_db import BaseVectorDB
@register_deserializable
@@ -20,143 +19,49 @@ class CustomApp(EmbedChain):
dry_run(query): test your prompt without consuming tokens.
"""
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
def __init__(
self,
config: CustomAppConfig = None,
llm: BaseLlm = None,
db: BaseVectorDB = None,
embedder: BaseEmbedder = None,
system_prompt: Optional[str] = None,
):
"""
:param config: Optional. `CustomAppConfig` instance to load as configuration.
:raises ValueError: Config must be provided for custom app
:param system_prompt: Optional. System prompt string.
"""
# Config is not required, it has a default
if config is None:
raise ValueError("Config must be provided for custom app")
config = CustomAppConfig()
self.provider = config.provider
if llm is None:
raise ValueError("LLM must be provided for custom app. Please import from `embedchain.llm`.")
if db is None:
raise ValueError("Database must be provided for custom app. Please import from `embedchain.vectordb`.")
if embedder is None:
raise ValueError("Embedder must be provided for custom app. Please import from `embedchain.embedder`.")
if config.provider == Providers.GPT4ALL:
from embedchain import OpenSourceApp
# Because these models run locally, they should have an instance running when the custom app is created
self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
super().__init__(config, system_prompt)
def set_llm_model(self, provider: Providers):
self.provider = provider
if provider == Providers.GPT4ALL:
raise ValueError(
"GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
if not isinstance(config, CustomAppConfig):
raise TypeError(
"Config is not a `CustomAppConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(llm, BaseLlm):
raise TypeError(
"LLM is not a `BaseLlm` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(db, BaseVectorDB):
raise TypeError(
"Database is not a `BaseVectorDB` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(embedder, BaseEmbedder):
raise TypeError(
"Embedder is not a `BaseEmbedder` instance. "
"Please make sure the type is right and that you are passing an instance."
)
def get_llm_model_answer(self, prompt, config: ChatConfig):
# TODO: Quitting the streaming response here for now.
# Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
if config.stream:
raise NotImplementedError(
"Streaming responses have not been implemented for this model yet. Please disable."
)
if config.system_prompt is None and self.system_prompt is not None:
config.system_prompt = self.system_prompt
try:
if self.provider == Providers.OPENAI:
return CustomApp._get_openai_answer(prompt, config)
if self.provider == Providers.ANTHROPHIC:
return CustomApp._get_athrophic_answer(prompt, config)
if self.provider == Providers.VERTEX_AI:
return CustomApp._get_vertex_answer(prompt, config)
if self.provider == Providers.GPT4ALL:
return self.open_source_app._get_gpt4all_answer(prompt, config)
if self.provider == Providers.AZURE_OPENAI:
return CustomApp._get_azure_openai_answer(prompt, config)
except ImportError as e:
raise ModuleNotFoundError(e.msg) from None
@staticmethod
def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import ChatOpenAI
chat = ChatOpenAI(
temperature=config.temperature,
model=config.model or "gpt-3.5-turbo",
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import ChatAnthropic
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
if config.max_tokens and config.max_tokens != 1000:
logging.warning("Config option `max_tokens` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import ChatVertexAI
chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")
chat = AzureChatOpenAI(
deployment_name=config.deployment_name,
openai_api_version="2023-05-15",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
from langchain.schema import HumanMessage, SystemMessage
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=prompt))
return messages
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk
super().__init__(config=config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)

View File

@@ -1,13 +1,15 @@
import os
from typing import Optional
from langchain.llms import Replicate
from embedchain.config import AppConfig, ChatConfig
from embedchain.embedchain import EmbedChain
from embedchain.apps.CustomApp import CustomApp
from embedchain.config import CustomAppConfig
from embedchain.embedder.openai_embedder import OpenAiEmbedder
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.llama2_llm import Llama2Llm
from embedchain.vectordb.chroma_db import ChromaDB
class Llama2App(EmbedChain):
@register_deserializable
class Llama2App(CustomApp):
"""
The EmbedChain Llama2App class.
Has two functions: add and query.
@@ -16,25 +18,15 @@ class Llama2App(EmbedChain):
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
"""
:param config: AppConfig instance to load as configuration. Optional.
:param config: CustomAppConfig instance to load as configuration. Optional.
:param system_prompt: System prompt string. Optional.
"""
if "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
if config is None:
config = AppConfig()
config = CustomAppConfig()
super().__init__(config, system_prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig = None):
# TODO: Move the model and other inputs into config
if self.system_prompt or config.system_prompt:
raise ValueError("Llama2App does not support `system_prompt`")
llm = Replicate(
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
input={"temperature": 0.75, "max_length": 500, "top_p": 1},
super().__init__(
config=config, llm=Llama2Llm(), db=ChromaDB(), embedder=OpenAiEmbedder(), system_prompt=system_prompt
)
return llm(prompt)

View File

@@ -1,9 +1,13 @@
import logging
from typing import Iterable, Optional, Union
from typing import Optional
from embedchain.config import ChatConfig, OpenSourceAppConfig
from embedchain.config import (BaseEmbedderConfig, BaseLlmConfig,
ChromaDbConfig, OpenSourceAppConfig)
from embedchain.embedchain import EmbedChain
from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.gpt4all_llm import GPT4ALLLlm
from embedchain.vectordb.chroma_db import ChromaDB
gpt4all_model = None
@@ -20,7 +24,12 @@ class OpenSourceApp(EmbedChain):
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
def __init__(
self,
config: OpenSourceAppConfig = None,
chromadb_config: Optional[ChromaDbConfig] = None,
system_prompt: Optional[str] = None,
):
"""
:param config: OpenSourceAppConfig instance to load as configuration. Optional.
`ef` defaults to open source.
@@ -30,42 +39,19 @@ class OpenSourceApp(EmbedChain):
if not config:
config = OpenSourceAppConfig()
if not isinstance(config, OpenSourceAppConfig):
raise ValueError(
"OpenSourceApp needs a OpenSourceAppConfig passed to it. "
"You can import it with `from embedchain.config import OpenSourceAppConfig`"
)
if not config.model:
raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
self.instance = OpenSourceApp._get_instance(config.model)
logging.info("Successfully loaded open source embedding model.")
super().__init__(config, system_prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig):
return self._get_gpt4all_answer(prompt=prompt, config=config)
llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin"))
embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2"))
database = ChromaDB(config=chromadb_config)
@staticmethod
def _get_instance(model):
try:
from gpt4all import GPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
) from None
return GPT4All(model)
def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
if config.model and config.model != self.config.model:
raise RuntimeError(
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
)
if self.system_prompt or config.system_prompt:
raise ValueError("OpenSourceApp does not support `system_prompt`")
response = self.instance.generate(
prompt=prompt,
streaming=config.stream,
top_p=config.top_p,
max_tokens=config.max_tokens,
temp=config.temperature,
)
return response
super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt)

View File

@@ -2,9 +2,10 @@ from string import Template
from embedchain.apps.App import App
from embedchain.apps.OpenSourceApp import OpenSourceApp
from embedchain.config import ChatConfig, QueryConfig
from embedchain.config import BaseLlmConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
from embedchain.config.llm.base_llm_config import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper_classes.json_serializable import register_deserializable
@@ -23,7 +24,7 @@ class EmbedChainPersonApp:
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
super().__init__(config)
def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
def add_person_template_to_config(self, default_prompt: str, config: BaseLlmConfig = None):
"""
This method checks if the config object contains a prompt template
if yes it adds the person prompt to it and return the updated config
@@ -44,7 +45,7 @@ class EmbedChainPersonApp:
config.template = template
else:
# if no config is present at all, initialize the config with person prompt and default template
config = QueryConfig(
config = BaseLlmConfig(
template=template,
)
@@ -58,11 +59,11 @@ class PersonApp(EmbedChainPersonApp, App):
Extends functionality from EmbedChainPersonApp and App
"""
def query(self, input_query, config: QueryConfig = None, dry_run=False):
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
return super().query(input_query, config, dry_run, where=None)
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
return super().chat(input_query, config, dry_run, where)
@@ -74,10 +75,10 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
Extends functionality from EmbedChainPersonApp and OpenSourceApp
"""
def query(self, input_query, config: QueryConfig = None, dry_run=False):
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
return super().query(input_query, config, dry_run)
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
return super().chat(input_query, config, dry_run)

View File

@@ -1,26 +1,25 @@
from embedchain import CustomApp
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
from embedchain.embedder.openai_embedder import OpenAiEmbedder
from embedchain.helper_classes.json_serializable import (
JSONSerializable, register_deserializable)
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.llm.openai_llm import OpenAiLlm
from embedchain.vectordb.chroma_db import ChromaDB
@register_deserializable
class BaseBot(JSONSerializable):
def __init__(self, app_config=None):
if app_config is None:
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
self.app_config = app_config
self.app = CustomApp(config=self.app_config)
def __init__(self):
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder())
def add(self, data, config: AddConfig = None):
"""Add data to the bot"""
config = config if config else AddConfig()
self.app.add(data, config=config)
def query(self, query, config: QueryConfig = None):
def query(self, query, config: LlmConfig = None):
"""Query bot"""
config = config if config else QueryConfig()
config = config
return self.app.query(query, config=config)
def start(self):

View File

@@ -6,6 +6,8 @@ import discord
from discord import app_commands
from discord.ext import commands
from embedchain.helper_classes.json_serializable import register_deserializable
from .base import BaseBot
intents = discord.Intents.default()
@@ -17,6 +19,7 @@ tree = app_commands.CommandTree(client)
# https://discord.com/api/oauth2/authorize?client_id={DISCORD_CLIENT_ID}&permissions=2048&scope=bot
@register_deserializable
class DiscordBot(BaseBot):
def __init__(self, *args, **kwargs):
BaseBot.__init__(self, *args, **kwargs)

View File

@@ -5,7 +5,6 @@ from typing import List, Optional
from fastapi_poe import PoeBot, run
from embedchain.config import QueryConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from .base import BaseBot
@@ -46,7 +45,6 @@ class PoeBot(BaseBot, PoeBot):
)
except Exception as e:
logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
logging.warning(history)
answer = self.handle_message(last_message, history)
yield self.text_event(answer)
@@ -69,8 +67,8 @@ class PoeBot(BaseBot, PoeBot):
def ask_bot(self, message, history: List[str]):
try:
config = QueryConfig(history=history)
response = self.query(message, config)
self.app.llm.set_history(history=history)
response = self.query(message)
except Exception:
logging.exception(f"Failed to query {message}.")
response = "An error occurred. Please try again!"

View File

@@ -1,92 +0,0 @@
from string import Template
from typing import Optional
from embedchain.config.QueryConfig import QueryConfig
from embedchain.helper_classes.json_serializable import register_deserializable
DEFAULT_PROMPT = """
You are a chatbot having a conversation with a human. You are given chat
history and context.
You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
$context
History: $history
Query: $query
Helpful Answer:
""" # noqa:E501
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
@register_deserializable
class ChatConfig(QueryConfig):
"""
Config for the `chat` method, inherits from `QueryConfig`.
"""
def __init__(
self,
number_documents=None,
template: Template = None,
model=None,
temperature=None,
max_tokens=None,
top_p=None,
stream: bool = False,
deployment_name=None,
system_prompt: Optional[str] = None,
where=None,
):
"""
Initializes the ChatConfig instance.
:param number_documents: Number of documents to pull from the database as
context.
:param template: Optional. The `Template` instance to use as a template for
prompt.
:param model: Optional. Controls the OpenAI model used.
:param temperature: Optional. Controls the randomness of the model's output.
Higher values (closer to 1) make output more random,lower values make it more
deterministic.
:param max_tokens: Optional. Controls how many tokens are generated.
:param top_p: Optional. Controls the diversity of words.Higher values
(closer to 1) make word selection more diverse, lower values make words less
diverse.
:param stream: Optional. Control if response is streamed back to the user
:param deployment_name: t.b.a.
:param system_prompt: Optional. System prompt string.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:raises ValueError: If the template is not valid as template should contain
$context and $query and $history
"""
if template is None:
template = DEFAULT_PROMPT_TEMPLATE
# History is set as 0 to ensure that there is always a history, that way,
# there don't have to be two templates. Having two templates would make it
# complicated because the history is not user controlled.
super().__init__(
number_documents=number_documents,
template=template,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
history=[0],
stream=stream,
deployment_name=deployment_name,
system_prompt=system_prompt,
where=where,
)
def set_history(self, history):
"""
Chat history is not user provided and not set at initialization time
:param history: (string) history to set
"""
self.history = history
return

View File

@@ -1,9 +1,13 @@
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
from .apps.AppConfig import AppConfig # noqa: F401
from .apps.CustomAppConfig import CustomAppConfig # noqa: F401
from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
from .BaseConfig import BaseConfig # noqa: F401
from .ChatConfig import ChatConfig # noqa: F401
from .QueryConfig import QueryConfig # noqa: F401
from .vectordbs.ElasticsearchDBConfig import \
ElasticsearchDBConfig # noqa: F401
# flake8: noqa: F401
from .AddConfig import AddConfig, ChunkerConfig
from .apps.AppConfig import AppConfig
from .apps.CustomAppConfig import CustomAppConfig
from .apps.OpenSourceAppConfig import OpenSourceAppConfig
from .BaseConfig import BaseConfig
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig as EmbedderConfig
from .llm.base_llm_config import BaseLlmConfig
from .llm.base_llm_config import BaseLlmConfig as LlmConfig
from .vectordbs.ChromaDbConfig import ChromaDbConfig
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig

View File

@@ -1,14 +1,5 @@
import os
from typing import Optional
try:
from chromadb.utils import embedding_functions
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.utils import embedding_functions
from embedchain.helper_classes.json_serializable import register_deserializable
from .BaseAppConfig import BaseAppConfig
@@ -23,44 +14,14 @@ class AppConfig(BaseAppConfig):
def __init__(
self,
log_level=None,
host=None,
port=None,
id=None,
collection_name=None,
collect_metrics: Optional[bool] = None,
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
"""
super().__init__(
log_level=log_level,
embedding_fn=AppConfig.default_embedding_function(),
host=host,
port=port,
id=id,
collection_name=collection_name,
collect_metrics=collect_metrics,
)
@staticmethod
def default_embedding_function():
"""
Sets embedding function to default (`text-embedding-ada-002`).
:raises ValueError: If the template is not valid as template should contain
$context and $query
:returns: The default embedding function for the app class.
"""
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
return embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002",
)
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)

View File

@@ -1,9 +1,9 @@
import logging
from typing import Optional
from embedchain.config.BaseConfig import BaseConfig
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.models import VectorDatabases, VectorDimensions
from embedchain.vectordb.base_vector_db import BaseVectorDB
class BaseAppConfig(BaseConfig, JSONSerializable):
@@ -14,81 +14,38 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
def __init__(
self,
log_level=None,
embedding_fn=None,
db=None,
host=None,
port=None,
db: Optional[BaseVectorDB] = None,
id=None,
collection_name=None,
collect_metrics: bool = True,
db_type: VectorDatabases = None,
vector_dim: VectorDimensions = None,
es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param embedding_fn: Embedding function to use.
:param db: Optional. (Vector) database instance to use for embeddings.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db).
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param db_type: Optional. type of Vector database to use
:param vector_dim: Vector dimension generated by embedding fn
:param db_type: Optional. Initializes a default vector database of the given type.
Using the `db` argument is preferred.
:param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
:param collection_name: Optional. Default collection name.
It's recommended to use app.set_collection_name() instead.
"""
self._setup_logging(log_level)
self.collection_name = collection_name if collection_name else "embedchain_store"
self.db = BaseAppConfig.get_db(
db=db,
embedding_fn=embedding_fn,
host=host,
port=port,
db_type=db_type,
vector_dim=vector_dim,
collection_name=self.collection_name,
es_config=es_config,
chroma_settings=chroma_settings,
)
self.id = id
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
return
self.collection_name = collection_name
@staticmethod
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
"""
Get db based on db_type, db with default database (`ChromaDb`)
:param Optional. (Vector) database to use for embeddings.
:param embedding_fn: Embedding function to use in database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param db_type: Optional. db type to use. Supported values (`es`, `chroma`)
:param vector_dim: Vector dimension generated by embedding fn
:param collection_name: Optional. Collection name for the database.
:param es_config: Optional. elasticsearch database config to be used for connection
:raises ValueError: BaseAppConfig knows no default embedding function.
:returns: database instance
"""
if db:
return db
if embedding_fn is None:
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
if db_type == VectorDatabases.ELASTICSEARCH:
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
return ElasticsearchDB(
embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config
self._db = db
logging.warning(
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
"Such as `app(config=config, db=db)`."
)
from embedchain.vectordb.chroma_db import ChromaDB
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
if collection_name:
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
return
def _setup_logging(self, debug_level):
level = logging.WARNING # Default level

View File

@@ -1,12 +1,8 @@
from typing import Any, Optional
from typing import Optional
from chromadb.api.types import Documents, Embeddings
from dotenv import load_dotenv
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
VectorDimensions)
from .BaseAppConfig import BaseAppConfig
@@ -22,123 +18,23 @@ class CustomAppConfig(BaseAppConfig):
def __init__(
self,
log_level=None,
embedding_fn: EmbeddingFunctions = None,
embedding_fn_model=None,
db=None,
host=None,
port=None,
id=None,
collection_name=None,
provider: Providers = None,
open_source_app_config=None,
deployment_name=None,
collect_metrics: Optional[bool] = None,
db_type: VectorDatabases = None,
es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param embedding_fn: Optional. Embedding function to use.
:param embedding_fn_model: Optional. Model name to use for embedding function.
:param db: Optional. (Vector) database to use for embeddings.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param provider: Optional. (Providers): LLM Provider to use.
:param open_source_app_config: Optional. Config instance needed for open source apps.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param db_type: Optional. type of Vector database to use.
:param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
:param collection_name: Optional. Default collection name.
It's recommended to use app.set_collection_name() instead.
"""
if provider:
self.provider = provider
else:
raise ValueError("CustomApp must have a provider assigned.")
self.open_source_app_config = open_source_app_config
super().__init__(
log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function(
embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name
),
db=db,
host=host,
port=port,
id=id,
collection_name=collection_name,
collect_metrics=collect_metrics,
db_type=db_type,
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
es_config=es_config,
chroma_settings=chroma_settings,
log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name
)
@staticmethod
def langchain_default_concept(embeddings: Any):
"""
Langchains default function layout for embeddings.
"""
def embed_function(texts: Documents) -> Embeddings:
return embeddings.embed_documents(texts)
return embed_function
@staticmethod
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
)
if embedding_function == EmbeddingFunctions.OPENAI:
from langchain.embeddings import OpenAIEmbeddings
if model:
embeddings = OpenAIEmbeddings(model=model)
else:
if deployment_name:
embeddings = OpenAIEmbeddings(deployment=deployment_name)
else:
embeddings = OpenAIEmbeddings()
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=model)
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
from langchain.embeddings import VertexAIEmbeddings
embeddings = VertexAIEmbeddings(model_name=model)
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.GPT4ALL:
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
from chromadb.utils import embedding_functions
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model)
@staticmethod
def get_vector_dimension(embedding_function: EmbeddingFunctions):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(f"Invalid option: '{embedding_function}'.")
if embedding_function == EmbeddingFunctions.OPENAI:
return VectorDimensions.OPENAI.value
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
return VectorDimensions.HUGGING_FACE.value
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
return VectorDimensions.VERTEX_AI.value
elif embedding_function == EmbeddingFunctions.GPT4ALL:
return VectorDimensions.GPT4ALL.value

View File

@@ -1,7 +1,5 @@
from typing import Optional
from chromadb.utils import embedding_functions
from embedchain.helper_classes.json_serializable import register_deserializable
from .BaseAppConfig import BaseAppConfig
@@ -16,47 +14,21 @@ class OpenSourceAppConfig(BaseAppConfig):
def __init__(
self,
log_level=None,
host=None,
port=None,
id=None,
collection_name=None,
collect_metrics: Optional[bool] = None,
model=None,
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param model: Optional. GPT4ALL uses the model to instantiate the class.
So unlike `App`, it has to be provided before querying.
:param collection_name: Optional. Default collection name.
It's recommended to use app.db.set_collection_name() instead.
"""
self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"
super().__init__(
log_level=log_level,
embedding_fn=OpenSourceAppConfig.default_embedding_function(),
host=host,
port=port,
id=id,
collection_name=collection_name,
collect_metrics=collect_metrics,
)
@staticmethod
def default_embedding_function():
"""
Sets embedding function to default (`all-MiniLM-L6-v2`).
:returns: The default embedding function
"""
try:
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
except ValueError as e:
print(e)
raise ModuleNotFoundError(
"The open source app requires extra dependencies. Install with `pip install embedchain[opensource]`"
) from None
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)

View File

@@ -0,0 +1,10 @@
from typing import Optional
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class BaseEmbedderConfig:
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
self.model = model
self.deployment_name = deployment_name

View File

View File

View File

@@ -50,7 +50,7 @@ history_re = re.compile(r"\$\{*history\}*")
@register_deserializable
class QueryConfig(BaseConfig):
class BaseLlmConfig(BaseConfig):
"""
Config for the `query` method.
"""
@@ -63,7 +63,6 @@ class QueryConfig(BaseConfig):
temperature=None,
max_tokens=None,
top_p=None,
history=None,
stream: bool = False,
deployment_name=None,
system_prompt: Optional[str] = None,
@@ -84,7 +83,6 @@ class QueryConfig(BaseConfig):
:param top_p: Optional. Controls the diversity of words. Higher values
(closer to 1) make word selection more diverse, lower values make words less
diverse.
:param history: Optional. A list of strings to consider as history.
:param stream: Optional. Control if response is streamed back to user
:param deployment_name: t.b.a.
:param system_prompt: Optional. System prompt string.
@@ -97,19 +95,8 @@ class QueryConfig(BaseConfig):
else:
self.number_documents = number_documents
if not history:
self.history = None
else:
if len(history) == 0:
self.history = None
else:
self.history = history
if template is None:
if self.history is None:
template = DEFAULT_PROMPT_TEMPLATE
else:
template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
template = DEFAULT_PROMPT_TEMPLATE
self.temperature = temperature if temperature else 0
self.max_tokens = max_tokens if max_tokens else 1000
@@ -121,10 +108,7 @@ class QueryConfig(BaseConfig):
if self.validate_template(template):
self.template = template
else:
if self.history is None:
raise ValueError("`template` should have `query` and `context` keys")
else:
raise ValueError("`template` should have `query`, `context` and `history` keys")
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
@@ -138,11 +122,13 @@ class QueryConfig(BaseConfig):
:param template: the template to validate
:return: Boolean, valid (true) or invalid (false)
"""
if self.history is None:
return re.search(query_re, template.template) and re.search(context_re, template.template)
else:
return (
re.search(query_re, template.template)
and re.search(context_re, template.template)
and re.search(history_re, template.template)
)
return re.search(query_re, template.template) and re.search(context_re, template.template)
def _validate_template_history(self, template: Template):
"""
validate the history template for history
:param template: the template to validate
:return: Boolean, valid (true) or invalid (false)
"""
return re.search(history_re, template.template)

View File

@@ -0,0 +1,17 @@
from typing import Optional
from embedchain.config.BaseConfig import BaseConfig
class BaseVectorDbConfig(BaseConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
):
self.collection_name = collection_name or "embedchain_store"
self.dir = dir or "db"
self.host = host
self.port = port

View File

@@ -0,0 +1,21 @@
from typing import Optional
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class ChromaDbConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
chroma_settings: Optional[dict] = None,
):
"""
:param chroma_settings: Optional. Chroma settings for connection.
"""
self.chroma_settings = chroma_settings
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -1,17 +1,25 @@
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
from embedchain.config.BaseConfig import BaseConfig
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseConfig):
"""
Config to initialize an elasticsearch client.
:param es_url. elasticsearch url or list of nodes url to be used for connection
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
"""
def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, List[str]] = None,
**ES_EXTRA_PARAMS: Dict[str, any],
):
"""
Config to initialize an elasticsearch client.
:param es_url. elasticsearch url or list of nodes url to be used for connection
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
"""
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -11,30 +11,37 @@ from typing import Dict, Optional
import requests
from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.memory import ConversationBufferMemory
from tenacity import retry, stop_after_attempt, wait_fixed
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, ChatConfig, QueryConfig
from embedchain.config import AddConfig, BaseLlmConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.llm.base_llm import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
from embedchain.vectordb.base_vector_db import BaseVectorDB
load_dotenv()
ABS_PATH = os.getcwd()
DB_DIR = os.path.join(ABS_PATH, "db")
HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
class EmbedChain(JSONSerializable):
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
def __init__(
self,
config: BaseAppConfig,
llm: BaseLlm,
db: BaseVectorDB = None,
embedder: BaseEmbedder = None,
system_prompt: Optional[str] = None,
):
"""
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
@@ -44,17 +51,40 @@ class EmbedChain(JSONSerializable):
"""
self.config = config
self.system_prompt = system_prompt
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
self.db = self.config.db
# Add subclasses
## Llm
self.llm = llm
## Database
# Database has support for config assignment for backwards compatibility
if db is None and (not hasattr(self.config, "db") or self.config.db is None):
raise ValueError("App requires Database.")
self.db = db or self.config.db
## Embedder
if embedder is None:
raise ValueError("App requires Embedder.")
self.embedder = embedder
# Initialize database
self.db._set_embedder(self.embedder)
self.db._initialize()
# Set collection name from app config for backwards compatibility.
if config.collection_name:
self.db.set_collection_name(config.collection_name)
# Add variables that are "shortcuts"
if system_prompt:
self.llm.config.system_prompt = system_prompt
# Attributes that aren't subclass related.
self.user_asks = []
self.is_docs_site_instance = False
self.online = False
self.memory = ConversationBufferMemory()
# Send anonymous telemetry
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id()
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# if (self.config.collect_metrics):
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
thread_telemetry.start()
@@ -227,10 +257,10 @@ class EmbedChain(JSONSerializable):
metadatas = new_metadatas
# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()
chunks_before_addition = self.db.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
@@ -244,13 +274,7 @@ class EmbedChain(JSONSerializable):
)
]
def get_llm_model_answer(self):
"""
Usually implemented by child class
"""
raise NotImplementedError
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
@@ -260,11 +284,12 @@ class EmbedChain(JSONSerializable):
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The content of the document that matched your query.
"""
query_config = config or self.llm.config
if where is not None:
where = where
elif config is not None and config.where is not None:
where = config.where
elif query_config is not None and query_config.where is not None:
where = query_config.where
else:
where = {}
@@ -273,64 +298,21 @@ class EmbedChain(JSONSerializable):
contents = self.db.query(
input_query=input_query,
n_results=config.number_documents,
n_results=query_config.number_documents,
where=where,
)
return contents
def _append_search_and_context(self, context, web_search_result):
return f"{context}\nWeb Search Result: {web_search_result}"
def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
:param input_query: The query to use.
:param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:return: The prompt
"""
context_string = (" | ").join(contexts)
web_search_result = kwargs.get("web_search_result", "")
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
if not config.history:
prompt = config.template.substitute(context=context_string, query=input_query)
else:
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
return prompt
def get_answer_from_llm(self, prompt, config: ChatConfig):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
:param query: The query to use.
:param context: Similar documents to the query used as context.
:return: The answer.
"""
return self.get_llm_model_answer(prompt, config)
def access_search_and_get_results(self, input_query):
from langchain.tools import DuckDuckGoSearchRun
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
@@ -340,41 +322,16 @@ class EmbedChain(JSONSerializable):
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
config = QueryConfig()
if self.is_docs_site_instance:
config.template = DOCS_SITE_PROMPT_TEMPLATE
config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config, where)
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt, config)
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("query",))
thread_telemetry.start()
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
return answer
def _stream_query_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -382,8 +339,8 @@ class EmbedChain(JSONSerializable):
Maintains the whole conversation in memory.
:param input_query: The query to use.
:param config: Optional. The `ChatConfig` instance to use as
configuration options.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
@@ -393,50 +350,14 @@ class EmbedChain(JSONSerializable):
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
if config is None:
config = ChatConfig()
if self.is_docs_site_instance:
config.template = DOCS_SITE_PROMPT_TEMPLATE
config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
contexts = self.retrieve_from_database(input_query, config, where)
chat_history = self.memory.load_memory_variables({})["history"]
if chat_history:
config.set_history(chat_history)
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt, config)
self.memory.chat_memory.add_user_message(input_query)
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",))
thread_telemetry.start()
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
def _stream_chat_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
return answer
def set_collection(self, collection_name):
"""
@@ -444,34 +365,36 @@ class EmbedChain(JSONSerializable):
:param collection_name: The name of the collection to use.
"""
self.collection = self.config.db._get_or_create_collection(collection_name)
self.db.set_collection_name(collection_name)
# Create the collection if it does not exist
self.db._get_or_create_collection(collection_name)
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
# since the main purpose is the creation.
def count(self) -> int:
"""
Count the number of embeddings.
DEPRECATED IN FAVOR OF `db.count()`
:return: The number of embeddings.
"""
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
return self.db.count()
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
DEPRECATED IN FAVOR OF `db.reset()`
"""
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
thread_telemetry.start()
collection_name = self.collection.name
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
self.db.reset()
self.collection = self.config.db._get_or_create_collection(collection_name)
# Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
# A downside of this implementation is, if you have two instances,
# the other instance will not get the updated `self.collection` attribute.
# A better way would be to create the collection if it is called again after being reset.
# That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
# That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):

View File

View File

@@ -0,0 +1,45 @@
from typing import Any, Callable, Optional
from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig
try:
from chromadb.api.types import Documents, Embeddings
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.api.types import Documents, Embeddings
class BaseEmbedder:
"""
Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
Embedding functions and vector dimensions are set based on the child class you choose.
To manually overwrite you can use this classes `set_...` methods.
"""
def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
if config is None:
self.config = BaseEmbedderConfig()
else:
self.config = config
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")
self.embedding_fn = embedding_fn
def set_vector_dimension(self, vector_dimension: int):
self.vector_dimension = vector_dimension
@staticmethod
def _langchain_default_concept(embeddings: Any):
"""
Langchains default function layout for embeddings.
"""
def embed_function(texts: Documents) -> Embeddings:
return embeddings.embed_documents(texts)
return embed_function

View File

@@ -0,0 +1,21 @@
from typing import Optional
from chromadb.utils import embedding_functions
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
super().__init__(config=config)
if self.config.model is None:
self.config.model = "all-MiniLM-L6-v2"
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.GPT4ALL.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langchain.embeddings import HuggingFaceEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
class HuggingFaceEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.HUGGING_FACE.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,40 @@
import os
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
try:
from chromadb.utils import embedding_functions
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.utils import embedding_functions
class OpenAiEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "text-embedding-ada-002"
if self.config.deployment_name:
embeddings = OpenAIEmbeddings(deployment=self.config.deployment_name)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
else:
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError(
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
) # noqa:E501
embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name=self.config.model,
)
self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value)

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langchain.embeddings import VertexAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
class VertexAiEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
embeddings = VertexAIEmbeddings(model_name=config.model)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.VERTEX_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

View File

@@ -0,0 +1,29 @@
import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class AntrophicLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return AntrophicLlm._get_athrophic_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import ChatAnthropic
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
if config.max_tokens and config.max_tokens != 1000:
logging.warning("Config option `max_tokens` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content

View File

@@ -0,0 +1,39 @@
import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class AzureOpenAiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return AzureOpenAiLlm._get_azure_openai_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_azure_openai_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")
chat = AzureChatOpenAI(
deployment_name=config.deployment_name,
openai_api_version="2023-05-15",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content

214
embedchain/llm/base_llm.py Normal file
View File

@@ -0,0 +1,214 @@
import logging
from typing import List, Optional
from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base_llm_config import (
DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
class BaseLlm(JSONSerializable):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if config is None:
self.config = BaseLlmConfig()
else:
self.config = config
self.memory = ConversationBufferMemory()
self.is_docs_site_instance = False
self.online = False
self.history: any = None
def get_llm_model_answer(self):
"""
Usually implemented by child class
"""
raise NotImplementedError
def set_history(self, history: any):
self.history = history
def update_history(self):
chat_history = self.memory.load_memory_variables({})["history"]
if chat_history:
self.set_history(chat_history)
def generate_prompt(self, input_query, contexts, **kwargs):
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
:param input_query: The query to use.
:param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:return: The prompt
"""
context_string = (" | ").join(contexts)
web_search_result = kwargs.get("web_search_result", "")
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
if not self.history:
prompt = self.config.template.substitute(context=context_string, query=input_query)
else:
# check if it's the default template without history
if (
not self.config._validate_template_history(self.config.template)
and self.config.template.template == DEFAULT_PROMPT
):
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
context=context_string, query=input_query, history=self.history
)
elif not self.config._validate_template_history(self.config.template):
logging.warning("Template does not include `$history` key. History is not included in prompt.")
prompt = self.config.template.substitute(context=context_string, query=input_query)
else:
prompt = self.config.template.substitute(
context=context_string, query=input_query, history=self.history
)
return prompt
def _append_search_and_context(self, context, web_search_result):
return f"{context}\nWeb Search Result: {web_search_result}"
def get_answer_from_llm(self, prompt):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
:param query: The query to use.
:param context: Similar documents to the query used as context.
:return: The answer.
"""
return self.get_llm_model_answer(prompt)
def access_search_and_get_results(self, input_query):
from langchain.tools import DuckDuckGoSearchRun
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def _stream_query_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def _stream_chat_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
query_config = config or self.config
if self.is_docs_site_instance:
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
query_config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
Maintains the whole conversation in memory.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
query_config = config or self.config
if self.is_docs_site_instance:
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
query_config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
self.update_history()
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
self.memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
# NOTE: Adding to history before and after. This could be seen as redundant.
# If we change it, we have to change the tests (no big deal).
self.update_history()
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
from langchain.schema import HumanMessage, SystemMessage
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=prompt))
return messages

View File

@@ -0,0 +1,47 @@
from typing import Iterable, Optional, Union
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class GPT4ALLLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
self.instance = GPT4ALLLlm._get_instance(self.config.model)
def get_llm_model_answer(self, prompt):
return self._get_gpt4all_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_instance(model):
try:
from gpt4all import GPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
) from None
return GPT4All(model_name=model)
def _get_gpt4all_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
if config.model and config.model != self.config.model:
raise RuntimeError(
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
)
if config.system_prompt:
raise ValueError("OpenSourceApp does not support `system_prompt`")
response = self.instance.generate(
prompt=prompt,
streaming=config.stream,
top_p=config.top_p,
max_tokens=config.max_tokens,
temp=config.temperature,
)
return response

View File

@@ -0,0 +1,27 @@
import os
from typing import Optional
from langchain.llms import Replicate
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class Llama2Llm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
# TODO: Move the model and other inputs into config
if self.config.system_prompt:
raise ValueError("Llama2App does not support `system_prompt`")
llm = Replicate(
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p},
)
return llm(prompt)

View File

@@ -0,0 +1,43 @@
from typing import Optional
import openai
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class OpenAiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
# NOTE: This class does not use langchain. One reason is that `top_p` is not supported.
def get_llm_model_answer(self, prompt):
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": prompt})
response = openai.ChatCompletion.create(
model=self.config.model or "gpt-3.5-turbo-0613",
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
stream=self.config.stream,
)
if self.config.stream:
return self._stream_llm_model_response(response)
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk

View File

@@ -0,0 +1,29 @@
import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class VertexAiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return VertexAiLlm._get_athrophic_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import ChatVertexAI
chat = ChatVertexAI(temperature=config.temperature, model=config.model)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content

View File

@@ -1,11 +1,22 @@
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseVectorDB(JSONSerializable):
"""Base class for vector database."""
def __init__(self):
def __init__(self, config: BaseVectorDbConfig):
self.client = self._get_or_create_db()
self.config: BaseVectorDbConfig = config
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
So it's can't be done in __init__ in one step.
"""
raise NotImplementedError
def _get_or_create_db(self):
"""Get or create the database."""
@@ -14,6 +25,9 @@ class BaseVectorDB(JSONSerializable):
def _get_or_create_collection(self):
raise NotImplementedError
def _set_embedder(self, embedder: BaseEmbedder):
self.embedder = embedder
def get(self):
raise NotImplementedError
@@ -28,3 +42,6 @@ class BaseVectorDB(JSONSerializable):
def reset(self):
raise NotImplementedError
def set_collection_name(self, name: str):
raise NotImplementedError

View File

@@ -1,53 +1,63 @@
import logging
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from chromadb.errors import InvalidDimensionException
from langchain.docstore.document import Document
from embedchain.config import ChromaDbConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.vectordb.base_vector_db import BaseVectorDB
try:
import chromadb
from chromadb.config import Settings
from chromadb.errors import InvalidDimensionException
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
import chromadb
from chromadb.config import Settings
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.vectordb.base_vector_db import BaseVectorDB
from chromadb.config import Settings
from chromadb.errors import InvalidDimensionException
@register_deserializable
class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB."""
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
self.embedding_fn = embedding_fn
if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")
def __init__(self, config: Optional[ChromaDbConfig] = None):
if config:
self.config = config
else:
self.config = ChromaDbConfig()
self.settings = Settings()
for key, value in chroma_settings.items():
if hasattr(self.settings, key):
setattr(self.settings, key, value)
if self.config.chroma_settings:
for key, value in self.config.chroma_settings.items():
if hasattr(self.settings, key):
setattr(self.settings, key, value)
if host and port:
logging.info(f"Connecting to ChromaDB server: {host}:{port}")
self.settings.chroma_server_host = host
self.settings.chroma_server_http_port = port
if self.config.host and self.config.port:
logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
self.settings.chroma_server_host = self.config.host
self.settings.chroma_server_http_port = self.config.port
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
else:
if db_dir is None:
db_dir = "db"
if self.config.dir is None:
self.config.dir = "db"
self.settings.persist_directory = db_dir
self.settings.persist_directory = self.config.dir
self.settings.is_persistent = True
self.client = chromadb.Client(self.settings)
super().__init__()
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
if not self.embedder:
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self._get_or_create_collection(self.config.collection_name)
def _get_or_create_db(self):
"""Get or create the database."""
@@ -55,9 +65,11 @@ class ChromaDB(BaseVectorDB):
def _get_or_create_collection(self, name):
"""Get or create the collection."""
if not hasattr(self, "embedder") or not self.embedder:
raise ValueError("Cannot create a Chroma database collection without an embedder.")
self.collection = self.client.get_or_create_collection(
name=name,
embedding_function=self.embedding_fn,
embedding_function=self.embedder.embedding_fn,
)
return self.collection
@@ -119,9 +131,37 @@ class ChromaDB(BaseVectorDB):
contents = [result[0].page_content for result in results_formatted]
return contents
def set_collection_name(self, name: str):
self.config.collection_name = name
self._get_or_create_collection(self.config.collection_name)
def count(self) -> int:
"""
Count the number of embeddings.
:return: The number of embeddings.
"""
return self.collection.count()
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
"""
# Delete all data from the database
self.client.reset()
try:
self.client.reset()
except ValueError:
raise ValueError(
"For safety reasons, resetting is disabled."
'Please enable it by including `chromadb_settings={"allow_reset": True}` in your ChromaDbConfig'
) from None
# Recreate
self._get_or_create_collection(self.config.collection_name)
# Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
# A downside of this implementation is, if you have two instances,
# the other instance will not get the updated `self.collection` attribute.
# A better way would be to create the collection if it is called again after being reset.
# That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
# That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List
from typing import Any, Dict, List
try:
from elasticsearch import Elasticsearch
@@ -10,7 +10,6 @@ except ImportError:
from embedchain.config import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models.VectorDimensions import VectorDimensions
from embedchain.vectordb.base_vector_db import BaseVectorDB
@@ -18,43 +17,40 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
class ElasticsearchDB(BaseVectorDB):
def __init__(
self,
es_config: ElasticsearchDBConfig = None,
embedding_fn: Callable[[list[str]], list[str]] = None,
vector_dim: VectorDimensions = None,
collection_name: str = None,
config: ElasticsearchDBConfig = None,
es_config: ElasticsearchDBConfig = None, # Backwards compatibility
):
"""
Elasticsearch as vector database
:param es_config. elasticsearch database config to be used for connection
:param embedding_fn: Function to generate embedding vectors.
:param vector_dim: Vector dimension generated by embedding fn
:param collection_name: Optional. Collection name for the database.
"""
if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")
if es_config is None:
if config is None and es_config is None:
raise ValueError("ElasticsearchDBConfig is required")
if vector_dim is None:
raise ValueError("Vector Dimension is required to refer correct index and mapping")
if collection_name is None:
raise ValueError("collection name is required. It cannot be empty")
self.embedding_fn = embedding_fn
self.config = config or es_config
self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS)
self.vector_dim = vector_dim
self.es_index = f"{collection_name}_{self.vector_dim}"
# Call parent init here because embedder is needed
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
index_settings = {
"mappings": {
"properties": {
"text": {"type": "text"},
"embeddings": {"type": "dense_vector", "index": False, "dims": self.vector_dim},
"embeddings": {"type": "dense_vector", "index": False, "dims": self.embedder.vector_dimension},
}
}
}
if not self.client.indices.exists(index=self.es_index):
es_index = self._get_index()
if not self.client.indices.exists(index=es_index):
# create index if not exist
print("Creating index", self.es_index, index_settings)
self.client.indices.create(index=self.es_index, body=index_settings)
super().__init__()
print("Creating index", es_index, index_settings)
self.client.indices.create(index=es_index, body=index_settings)
def _get_or_create_db(self):
return self.client
@@ -85,17 +81,17 @@ class ElasticsearchDB(BaseVectorDB):
:param ids: ids of docs
"""
docs = []
embeddings = self.embedding_fn(documents)
embeddings = self.config.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
"_index": self.es_index,
"_index": self._get_index(),
"_id": id,
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
}
)
bulk(self.client, docs)
self.client.indices.refresh(index=self.es_index)
self.client.indices.refresh(index=self._get_index())
return
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
@@ -105,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB):
:param n_results: no of similar documents to fetch from database
:param where: Optional. to filter data
"""
input_query_vector = self.embedding_fn(input_query)
input_query_vector = self.config.embedding_fn(input_query)
query_vector = input_query_vector[0]
query = {
"script_score": {
@@ -120,11 +116,14 @@ class ElasticsearchDB(BaseVectorDB):
app_id = where["app_id"]
query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}]
_source = ["text"]
response = self.client.search(index=self.es_index, query=query, _source=_source, size=n_results)
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
docs = response["hits"]["hits"]
contents = [doc["_source"]["text"] for doc in docs]
return contents
def set_collection_name(self, name: str):
self.config.collection_name = name
def count(self) -> int:
query = {"match_all": {}}
response = self.client.count(index=self.es_index, query=query)
@@ -136,3 +135,8 @@ class ElasticsearchDB(BaseVectorDB):
if self.client.indices.exists(index=self.es_index):
# delete index in Es
self.client.indices.delete(index=self.es_index)
def _get_index(self):
# NOTE: The method is preferred to an attribute, because if collection name changes,
# it's always up-to-date.
return f"{self.config.collection_name}_{self.config.vector_dim}"