refactor: classes and configs (#528)
This commit is contained in:
@@ -8,3 +8,5 @@ from embedchain.apps.Llama2App import Llama2App # noqa: F401
|
||||
from embedchain.apps.OpenSourceApp import OpenSourceApp # noqa: F401
|
||||
from embedchain.apps.PersonApp import (PersonApp, # noqa: F401
|
||||
PersonOpenSourceApp)
|
||||
from embedchain.vectordb.chroma_db import ChromaDB # noqa: F401
|
||||
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB # noqa: F401
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChromaDbConfig)
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.openai_embedder import OpenAiEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.llm.openai_llm import OpenAiLlm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
@@ -18,7 +20,13 @@ class App(EmbedChain):
|
||||
dry_run(query): test your prompt without consuming tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig = None,
|
||||
llm_config: BaseLlmConfig = None,
|
||||
chromadb_config: Optional[ChromaDbConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param config: AppConfig instance to load as configuration. Optional.
|
||||
:param system_prompt: System prompt string. Optional.
|
||||
@@ -26,38 +34,8 @@ class App(EmbedChain):
|
||||
if config is None:
|
||||
config = AppConfig()
|
||||
|
||||
super().__init__(config, system_prompt)
|
||||
llm = OpenAiLlm(config=llm_config)
|
||||
embedder = OpenAiEmbedder(config=BaseEmbedderConfig(model="text-embedding-ada-002"))
|
||||
database = ChromaDB(config=chromadb_config)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
messages = []
|
||||
system_prompt = (
|
||||
self.system_prompt
|
||||
if self.system_prompt is not None
|
||||
else config.system_prompt
|
||||
if config.system_prompt is not None
|
||||
else None
|
||||
)
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config.model or "gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
stream=config.stream,
|
||||
)
|
||||
|
||||
if config.stream:
|
||||
return self._stream_llm_model_response(response)
|
||||
else:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
super().__init__(config, llm, db=database, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from embedchain.config import ChatConfig, CustomAppConfig
|
||||
from embedchain.config import CustomAppConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models import Providers
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
@@ -20,143 +19,49 @@ class CustomApp(EmbedChain):
|
||||
dry_run(query): test your prompt without consuming tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomAppConfig = None,
|
||||
llm: BaseLlm = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param config: Optional. `CustomAppConfig` instance to load as configuration.
|
||||
:raises ValueError: Config must be provided for custom app
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
"""
|
||||
# Config is not required, it has a default
|
||||
if config is None:
|
||||
raise ValueError("Config must be provided for custom app")
|
||||
config = CustomAppConfig()
|
||||
|
||||
self.provider = config.provider
|
||||
if llm is None:
|
||||
raise ValueError("LLM must be provided for custom app. Please import from `embedchain.llm`.")
|
||||
if db is None:
|
||||
raise ValueError("Database must be provided for custom app. Please import from `embedchain.vectordb`.")
|
||||
if embedder is None:
|
||||
raise ValueError("Embedder must be provided for custom app. Please import from `embedchain.embedder`.")
|
||||
|
||||
if config.provider == Providers.GPT4ALL:
|
||||
from embedchain import OpenSourceApp
|
||||
|
||||
# Because these models run locally, they should have an instance running when the custom app is created
|
||||
self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
|
||||
|
||||
super().__init__(config, system_prompt)
|
||||
|
||||
def set_llm_model(self, provider: Providers):
|
||||
self.provider = provider
|
||||
if provider == Providers.GPT4ALL:
|
||||
raise ValueError(
|
||||
"GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
|
||||
if not isinstance(config, CustomAppConfig):
|
||||
raise TypeError(
|
||||
"Config is not a `CustomAppConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(llm, BaseLlm):
|
||||
raise TypeError(
|
||||
"LLM is not a `BaseLlm` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(db, BaseVectorDB):
|
||||
raise TypeError(
|
||||
"Database is not a `BaseVectorDB` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(embedder, BaseEmbedder):
|
||||
raise TypeError(
|
||||
"Embedder is not a `BaseEmbedder` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
# TODO: Quitting the streaming response here for now.
|
||||
# Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
|
||||
if config.stream:
|
||||
raise NotImplementedError(
|
||||
"Streaming responses have not been implemented for this model yet. Please disable."
|
||||
)
|
||||
|
||||
if config.system_prompt is None and self.system_prompt is not None:
|
||||
config.system_prompt = self.system_prompt
|
||||
|
||||
try:
|
||||
if self.provider == Providers.OPENAI:
|
||||
return CustomApp._get_openai_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.ANTHROPHIC:
|
||||
return CustomApp._get_athrophic_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.VERTEX_AI:
|
||||
return CustomApp._get_vertex_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.GPT4ALL:
|
||||
return self.open_source_app._get_gpt4all_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.AZURE_OPENAI:
|
||||
return CustomApp._get_azure_openai_answer(prompt, config)
|
||||
|
||||
except ImportError as e:
|
||||
raise ModuleNotFoundError(e.msg) from None
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
chat = ChatOpenAI(
|
||||
temperature=config.temperature,
|
||||
model=config.model or "gpt-3.5-turbo",
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
|
||||
chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
if not config.deployment_name:
|
||||
raise ValueError("Deployment name must be provided for Azure OpenAI")
|
||||
|
||||
chat = AzureChatOpenAI(
|
||||
deployment_name=config.deployment_name,
|
||||
openai_api_version="2023-05-15",
|
||||
model_name=config.model or "gpt-3.5-turbo",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
return messages
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
super().__init__(config=config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms import Replicate
|
||||
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.apps.CustomApp import CustomApp
|
||||
from embedchain.config import CustomAppConfig
|
||||
from embedchain.embedder.openai_embedder import OpenAiEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.llm.llama2_llm import Llama2Llm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
|
||||
class Llama2App(EmbedChain):
|
||||
@register_deserializable
|
||||
class Llama2App(CustomApp):
|
||||
"""
|
||||
The EmbedChain Llama2App class.
|
||||
Has two functions: add and query.
|
||||
@@ -16,25 +18,15 @@ class Llama2App(EmbedChain):
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
|
||||
"""
|
||||
:param config: AppConfig instance to load as configuration. Optional.
|
||||
:param config: CustomAppConfig instance to load as configuration. Optional.
|
||||
:param system_prompt: System prompt string. Optional.
|
||||
"""
|
||||
if "REPLICATE_API_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
|
||||
|
||||
if config is None:
|
||||
config = AppConfig()
|
||||
config = CustomAppConfig()
|
||||
|
||||
super().__init__(config, system_prompt)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig = None):
|
||||
# TODO: Move the model and other inputs into config
|
||||
if self.system_prompt or config.system_prompt:
|
||||
raise ValueError("Llama2App does not support `system_prompt`")
|
||||
llm = Replicate(
|
||||
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
|
||||
input={"temperature": 0.75, "max_length": 500, "top_p": 1},
|
||||
super().__init__(
|
||||
config=config, llm=Llama2Llm(), db=ChromaDB(), embedder=OpenAiEmbedder(), system_prompt=system_prompt
|
||||
)
|
||||
return llm(prompt)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import logging
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
||||
from embedchain.config import (BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChromaDbConfig, OpenSourceAppConfig)
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.llm.gpt4all_llm import GPT4ALLLlm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
gpt4all_model = None
|
||||
|
||||
@@ -20,7 +24,12 @@ class OpenSourceApp(EmbedChain):
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenSourceAppConfig = None,
|
||||
chromadb_config: Optional[ChromaDbConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param config: OpenSourceAppConfig instance to load as configuration. Optional.
|
||||
`ef` defaults to open source.
|
||||
@@ -30,42 +39,19 @@ class OpenSourceApp(EmbedChain):
|
||||
if not config:
|
||||
config = OpenSourceAppConfig()
|
||||
|
||||
if not isinstance(config, OpenSourceAppConfig):
|
||||
raise ValueError(
|
||||
"OpenSourceApp needs a OpenSourceAppConfig passed to it. "
|
||||
"You can import it with `from embedchain.config import OpenSourceAppConfig`"
|
||||
)
|
||||
|
||||
if not config.model:
|
||||
raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
|
||||
|
||||
self.instance = OpenSourceApp._get_instance(config.model)
|
||||
|
||||
logging.info("Successfully loaded open source embedding model.")
|
||||
super().__init__(config, system_prompt)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
return self._get_gpt4all_answer(prompt=prompt, config=config)
|
||||
llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin"))
|
||||
embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2"))
|
||||
database = ChromaDB(config=chromadb_config)
|
||||
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
|
||||
) from None
|
||||
|
||||
return GPT4All(model)
|
||||
|
||||
def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
|
||||
if config.model and config.model != self.config.model:
|
||||
raise RuntimeError(
|
||||
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
|
||||
)
|
||||
|
||||
if self.system_prompt or config.system_prompt:
|
||||
raise ValueError("OpenSourceApp does not support `system_prompt`")
|
||||
|
||||
response = self.instance.generate(
|
||||
prompt=prompt,
|
||||
streaming=config.stream,
|
||||
top_p=config.top_p,
|
||||
max_tokens=config.max_tokens,
|
||||
temp=config.temperature,
|
||||
)
|
||||
return response
|
||||
super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@@ -2,9 +2,10 @@ from string import Template
|
||||
|
||||
from embedchain.apps.App import App
|
||||
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||
from embedchain.config import ChatConfig, QueryConfig
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||
from embedchain.config.llm.base_llm_config import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -23,7 +24,7 @@ class EmbedChainPersonApp:
|
||||
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
||||
super().__init__(config)
|
||||
|
||||
def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
|
||||
def add_person_template_to_config(self, default_prompt: str, config: BaseLlmConfig = None):
|
||||
"""
|
||||
This method checks if the config object contains a prompt template
|
||||
if yes it adds the person prompt to it and return the updated config
|
||||
@@ -44,7 +45,7 @@ class EmbedChainPersonApp:
|
||||
config.template = template
|
||||
else:
|
||||
# if no config is present at all, initialize the config with person prompt and default template
|
||||
config = QueryConfig(
|
||||
config = BaseLlmConfig(
|
||||
template=template,
|
||||
)
|
||||
|
||||
@@ -58,11 +59,11 @@ class PersonApp(EmbedChainPersonApp, App):
|
||||
Extends functionality from EmbedChainPersonApp and App
|
||||
"""
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
|
||||
return super().query(input_query, config, dry_run, where=None)
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||
return super().chat(input_query, config, dry_run, where)
|
||||
|
||||
@@ -74,10 +75,10 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||
Extends functionality from EmbedChainPersonApp and OpenSourceApp
|
||||
"""
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
|
||||
return super().query(input_query, config, dry_run)
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||
return super().chat(input_query, config, dry_run)
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
from embedchain import CustomApp
|
||||
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
|
||||
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
|
||||
from embedchain.embedder.openai_embedder import OpenAiEmbedder
|
||||
from embedchain.helper_classes.json_serializable import (
|
||||
JSONSerializable, register_deserializable)
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.llm.openai_llm import OpenAiLlm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self, app_config=None):
|
||||
if app_config is None:
|
||||
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
|
||||
self.app_config = app_config
|
||||
self.app = CustomApp(config=self.app_config)
|
||||
def __init__(self):
|
||||
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder())
|
||||
|
||||
def add(self, data, config: AddConfig = None):
|
||||
"""Add data to the bot"""
|
||||
config = config if config else AddConfig()
|
||||
self.app.add(data, config=config)
|
||||
|
||||
def query(self, query, config: QueryConfig = None):
|
||||
def query(self, query, config: LlmConfig = None):
|
||||
"""Query bot"""
|
||||
config = config if config else QueryConfig()
|
||||
config = config
|
||||
return self.app.query(query, config=config)
|
||||
|
||||
def start(self):
|
||||
|
||||
@@ -6,6 +6,8 @@ import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
intents = discord.Intents.default()
|
||||
@@ -17,6 +19,7 @@ tree = app_commands.CommandTree(client)
|
||||
# https://discord.com/api/oauth2/authorize?client_id={DISCORD_CLIENT_ID}&permissions=2048&scope=bot
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscordBot(BaseBot):
|
||||
def __init__(self, *args, **kwargs):
|
||||
BaseBot.__init__(self, *args, **kwargs)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List, Optional
|
||||
|
||||
from fastapi_poe import PoeBot, run
|
||||
|
||||
from embedchain.config import QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
@@ -46,7 +45,6 @@ class PoeBot(BaseBot, PoeBot):
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
|
||||
logging.warning(history)
|
||||
answer = self.handle_message(last_message, history)
|
||||
yield self.text_event(answer)
|
||||
|
||||
@@ -69,8 +67,8 @@ class PoeBot(BaseBot, PoeBot):
|
||||
|
||||
def ask_bot(self, message, history: List[str]):
|
||||
try:
|
||||
config = QueryConfig(history=history)
|
||||
response = self.query(message, config)
|
||||
self.app.llm.set_history(history=history)
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
from string import Template
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.QueryConfig import QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
You are a chatbot having a conversation with a human. You are given chat
|
||||
history and context.
|
||||
You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
|
||||
|
||||
$context
|
||||
|
||||
History: $history
|
||||
|
||||
Query: $query
|
||||
|
||||
Helpful Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChatConfig(QueryConfig):
|
||||
"""
|
||||
Config for the `chat` method, inherits from `QueryConfig`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_documents=None,
|
||||
template: Template = None,
|
||||
model=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where=None,
|
||||
):
|
||||
"""
|
||||
Initializes the ChatConfig instance.
|
||||
|
||||
:param number_documents: Number of documents to pull from the database as
|
||||
context.
|
||||
:param template: Optional. The `Template` instance to use as a template for
|
||||
prompt.
|
||||
:param model: Optional. Controls the OpenAI model used.
|
||||
:param temperature: Optional. Controls the randomness of the model's output.
|
||||
Higher values (closer to 1) make output more random,lower values make it more
|
||||
deterministic.
|
||||
:param max_tokens: Optional. Controls how many tokens are generated.
|
||||
:param top_p: Optional. Controls the diversity of words.Higher values
|
||||
(closer to 1) make word selection more diverse, lower values make words less
|
||||
diverse.
|
||||
:param stream: Optional. Control if response is streamed back to the user
|
||||
:param deployment_name: t.b.a.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:raises ValueError: If the template is not valid as template should contain
|
||||
$context and $query and $history
|
||||
"""
|
||||
if template is None:
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
# History is set as 0 to ensure that there is always a history, that way,
|
||||
# there don't have to be two templates. Having two templates would make it
|
||||
# complicated because the history is not user controlled.
|
||||
super().__init__(
|
||||
number_documents=number_documents,
|
||||
template=template,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
history=[0],
|
||||
stream=stream,
|
||||
deployment_name=deployment_name,
|
||||
system_prompt=system_prompt,
|
||||
where=where,
|
||||
)
|
||||
|
||||
def set_history(self, history):
|
||||
"""
|
||||
Chat history is not user provided and not set at initialization time
|
||||
|
||||
:param history: (string) history to set
|
||||
"""
|
||||
self.history = history
|
||||
return
|
||||
@@ -1,9 +1,13 @@
|
||||
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
|
||||
from .apps.AppConfig import AppConfig # noqa: F401
|
||||
from .apps.CustomAppConfig import CustomAppConfig # noqa: F401
|
||||
from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
|
||||
from .BaseConfig import BaseConfig # noqa: F401
|
||||
from .ChatConfig import ChatConfig # noqa: F401
|
||||
from .QueryConfig import QueryConfig # noqa: F401
|
||||
from .vectordbs.ElasticsearchDBConfig import \
|
||||
ElasticsearchDBConfig # noqa: F401
|
||||
# flake8: noqa: F401
|
||||
|
||||
from .AddConfig import AddConfig, ChunkerConfig
|
||||
from .apps.AppConfig import AppConfig
|
||||
from .apps.CustomAppConfig import CustomAppConfig
|
||||
from .apps.OpenSourceAppConfig import OpenSourceAppConfig
|
||||
from .BaseConfig import BaseConfig
|
||||
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig
|
||||
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig as EmbedderConfig
|
||||
from .llm.base_llm_config import BaseLlmConfig
|
||||
from .llm.base_llm_config import BaseLlmConfig as LlmConfig
|
||||
from .vectordbs.ChromaDbConfig import ChromaDbConfig
|
||||
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
@@ -23,44 +14,14 @@ class AppConfig(BaseAppConfig):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
collect_metrics: Optional[bool] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
"""
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=AppConfig.default_embedding_function(),
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
collect_metrics=collect_metrics,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_embedding_function():
|
||||
"""
|
||||
Sets embedding function to default (`text-embedding-ada-002`).
|
||||
|
||||
:raises ValueError: If the template is not valid as template should contain
|
||||
$context and $query
|
||||
:returns: The default embedding function for the app class.
|
||||
"""
|
||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.models import VectorDatabases, VectorDimensions
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
@@ -14,81 +14,38 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
embedding_fn=None,
|
||||
db=None,
|
||||
host=None,
|
||||
port=None,
|
||||
db: Optional[BaseVectorDB] = None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
collect_metrics: bool = True,
|
||||
db_type: VectorDatabases = None,
|
||||
vector_dim: VectorDimensions = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
chroma_settings: dict = {},
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param embedding_fn: Embedding function to use.
|
||||
:param db: Optional. (Vector) database instance to use for embeddings.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db).
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param db_type: Optional. type of Vector database to use
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param db_type: Optional. Initializes a default vector database of the given type.
|
||||
Using the `db` argument is preferred.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
:param collection_name: Optional. Default collection name.
|
||||
It's recommended to use app.set_collection_name() instead.
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||
self.db = BaseAppConfig.get_db(
|
||||
db=db,
|
||||
embedding_fn=embedding_fn,
|
||||
host=host,
|
||||
port=port,
|
||||
db_type=db_type,
|
||||
vector_dim=vector_dim,
|
||||
collection_name=self.collection_name,
|
||||
es_config=es_config,
|
||||
chroma_settings=chroma_settings,
|
||||
)
|
||||
self.id = id
|
||||
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
||||
return
|
||||
self.collection_name = collection_name
|
||||
|
||||
@staticmethod
|
||||
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
|
||||
"""
|
||||
Get db based on db_type, db with default database (`ChromaDb`)
|
||||
:param Optional. (Vector) database to use for embeddings.
|
||||
:param embedding_fn: Embedding function to use in database.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param db_type: Optional. db type to use. Supported values (`es`, `chroma`)
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:raises ValueError: BaseAppConfig knows no default embedding function.
|
||||
:returns: database instance
|
||||
"""
|
||||
if db:
|
||||
return db
|
||||
|
||||
if embedding_fn is None:
|
||||
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
|
||||
|
||||
if db_type == VectorDatabases.ELASTICSEARCH:
|
||||
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
|
||||
|
||||
return ElasticsearchDB(
|
||||
embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config
|
||||
self._db = db
|
||||
logging.warning(
|
||||
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
|
||||
"Such as `app(config=config, db=db)`."
|
||||
)
|
||||
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
|
||||
if collection_name:
|
||||
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
||||
return
|
||||
|
||||
def _setup_logging(self, debug_level):
|
||||
level = logging.WARNING # Default level
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
|
||||
VectorDimensions)
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
|
||||
@@ -22,123 +18,23 @@ class CustomAppConfig(BaseAppConfig):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
embedding_fn: EmbeddingFunctions = None,
|
||||
embedding_fn_model=None,
|
||||
db=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
provider: Providers = None,
|
||||
open_source_app_config=None,
|
||||
deployment_name=None,
|
||||
collect_metrics: Optional[bool] = None,
|
||||
db_type: VectorDatabases = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
chroma_settings: dict = {},
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param embedding_fn: Optional. Embedding function to use.
|
||||
:param embedding_fn_model: Optional. Model name to use for embedding function.
|
||||
:param db: Optional. (Vector) database to use for embeddings.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param provider: Optional. (Providers): LLM Provider to use.
|
||||
:param open_source_app_config: Optional. Config instance needed for open source apps.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param db_type: Optional. type of Vector database to use.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
:param collection_name: Optional. Default collection name.
|
||||
It's recommended to use app.set_collection_name() instead.
|
||||
"""
|
||||
if provider:
|
||||
self.provider = provider
|
||||
else:
|
||||
raise ValueError("CustomApp must have a provider assigned.")
|
||||
|
||||
self.open_source_app_config = open_source_app_config
|
||||
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=CustomAppConfig.embedding_function(
|
||||
embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name
|
||||
),
|
||||
db=db,
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
collect_metrics=collect_metrics,
|
||||
db_type=db_type,
|
||||
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
|
||||
es_config=es_config,
|
||||
chroma_settings=chroma_settings,
|
||||
log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def langchain_default_concept(embeddings: Any):
|
||||
"""
|
||||
Langchains default function layout for embeddings.
|
||||
"""
|
||||
|
||||
def embed_function(texts: Documents) -> Embeddings:
|
||||
return embeddings.embed_documents(texts)
|
||||
|
||||
return embed_function
|
||||
|
||||
@staticmethod
|
||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(
|
||||
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
|
||||
)
|
||||
|
||||
if embedding_function == EmbeddingFunctions.OPENAI:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
if model:
|
||||
embeddings = OpenAIEmbeddings(model=model)
|
||||
else:
|
||||
if deployment_name:
|
||||
embeddings = OpenAIEmbeddings(deployment=deployment_name)
|
||||
else:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=model)
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
|
||||
from langchain.embeddings import VertexAIEmbeddings
|
||||
|
||||
embeddings = VertexAIEmbeddings(model_name=model)
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.GPT4ALL:
|
||||
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model)
|
||||
|
||||
@staticmethod
|
||||
def get_vector_dimension(embedding_function: EmbeddingFunctions):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(f"Invalid option: '{embedding_function}'.")
|
||||
|
||||
if embedding_function == EmbeddingFunctions.OPENAI:
|
||||
return VectorDimensions.OPENAI.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
return VectorDimensions.HUGGING_FACE.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
|
||||
return VectorDimensions.VERTEX_AI.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.GPT4ALL:
|
||||
return VectorDimensions.GPT4ALL.value
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
@@ -16,47 +14,21 @@ class OpenSourceAppConfig(BaseAppConfig):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
collect_metrics: Optional[bool] = None,
|
||||
model=None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param model: Optional. GPT4ALL uses the model to instantiate the class.
|
||||
So unlike `App`, it has to be provided before querying.
|
||||
:param collection_name: Optional. Default collection name.
|
||||
It's recommended to use app.db.set_collection_name() instead.
|
||||
"""
|
||||
self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"
|
||||
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=OpenSourceAppConfig.default_embedding_function(),
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
collect_metrics=collect_metrics,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_embedding_function():
|
||||
"""
|
||||
Sets embedding function to default (`all-MiniLM-L6-v2`).
|
||||
|
||||
:returns: The default embedding function
|
||||
"""
|
||||
try:
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
raise ModuleNotFoundError(
|
||||
"The open source app requires extra dependencies. Install with `pip install embedchain[opensource]`"
|
||||
) from None
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
|
||||
|
||||
10
embedchain/config/embedder/BaseEmbedderConfig.py
Normal file
10
embedchain/config/embedder/BaseEmbedderConfig.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseEmbedderConfig:
|
||||
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
0
embedchain/config/embedder/__init__.py
Normal file
0
embedchain/config/embedder/__init__.py
Normal file
0
embedchain/config/llm/__init__.py
Normal file
0
embedchain/config/llm/__init__.py
Normal file
@@ -50,7 +50,7 @@ history_re = re.compile(r"\$\{*history\}*")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QueryConfig(BaseConfig):
|
||||
class BaseLlmConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `query` method.
|
||||
"""
|
||||
@@ -63,7 +63,6 @@ class QueryConfig(BaseConfig):
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
history=None,
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
system_prompt: Optional[str] = None,
|
||||
@@ -84,7 +83,6 @@ class QueryConfig(BaseConfig):
|
||||
:param top_p: Optional. Controls the diversity of words. Higher values
|
||||
(closer to 1) make word selection more diverse, lower values make words less
|
||||
diverse.
|
||||
:param history: Optional. A list of strings to consider as history.
|
||||
:param stream: Optional. Control if response is streamed back to user
|
||||
:param deployment_name: t.b.a.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
@@ -97,19 +95,8 @@ class QueryConfig(BaseConfig):
|
||||
else:
|
||||
self.number_documents = number_documents
|
||||
|
||||
if not history:
|
||||
self.history = None
|
||||
else:
|
||||
if len(history) == 0:
|
||||
self.history = None
|
||||
else:
|
||||
self.history = history
|
||||
|
||||
if template is None:
|
||||
if self.history is None:
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
self.temperature = temperature if temperature else 0
|
||||
self.max_tokens = max_tokens if max_tokens else 1000
|
||||
@@ -121,10 +108,7 @@ class QueryConfig(BaseConfig):
|
||||
if self.validate_template(template):
|
||||
self.template = template
|
||||
else:
|
||||
if self.history is None:
|
||||
raise ValueError("`template` should have `query` and `context` keys")
|
||||
else:
|
||||
raise ValueError("`template` should have `query`, `context` and `history` keys")
|
||||
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
|
||||
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
@@ -138,11 +122,13 @@ class QueryConfig(BaseConfig):
|
||||
:param template: the template to validate
|
||||
:return: Boolean, valid (true) or invalid (false)
|
||||
"""
|
||||
if self.history is None:
|
||||
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||
else:
|
||||
return (
|
||||
re.search(query_re, template.template)
|
||||
and re.search(context_re, template.template)
|
||||
and re.search(history_re, template.template)
|
||||
)
|
||||
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||
|
||||
def _validate_template_history(self, template: Template):
|
||||
"""
|
||||
validate the history template for history
|
||||
|
||||
:param template: the template to validate
|
||||
:return: Boolean, valid (true) or invalid (false)
|
||||
"""
|
||||
return re.search(history_re, template.template)
|
||||
17
embedchain/config/vectordbs/BaseVectorDbConfig.py
Normal file
17
embedchain/config/vectordbs/BaseVectorDbConfig.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
|
||||
class BaseVectorDbConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name or "embedchain_store"
|
||||
self.dir = dir or "db"
|
||||
self.host = host
|
||||
self.port = port
|
||||
21
embedchain/config/vectordbs/ChromaDbConfig.py
Normal file
21
embedchain/config/vectordbs/ChromaDbConfig.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChromaDbConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
chroma_settings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
"""
|
||||
self.chroma_settings = chroma_settings
|
||||
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
||||
@@ -1,17 +1,25 @@
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDBConfig(BaseConfig):
|
||||
"""
|
||||
Config to initialize an elasticsearch client.
|
||||
:param es_url. elasticsearch url or list of nodes url to be used for connection
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
"""
|
||||
|
||||
def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, List[str]] = None,
|
||||
**ES_EXTRA_PARAMS: Dict[str, any],
|
||||
):
|
||||
"""
|
||||
Config to initialize an elasticsearch client.
|
||||
:param es_url. elasticsearch url or list of nodes url to be used for connection
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
"""
|
||||
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||
self.ES_URL = es_url
|
||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -11,30 +11,37 @@ from typing import Dict, Optional
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, ChatConfig, QueryConfig
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
load_dotenv()
|
||||
|
||||
ABS_PATH = os.getcwd()
|
||||
DB_DIR = os.path.join(ABS_PATH, "db")
|
||||
HOME_DIR = str(Path.home())
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseAppConfig,
|
||||
llm: BaseLlm,
|
||||
db: BaseVectorDB = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
creates a collection.
|
||||
@@ -44,17 +51,40 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
self.system_prompt = system_prompt
|
||||
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
|
||||
self.db = self.config.db
|
||||
|
||||
# Add subclasses
|
||||
## Llm
|
||||
self.llm = llm
|
||||
## Database
|
||||
# Database has support for config assignment for backwards compatibility
|
||||
if db is None and (not hasattr(self.config, "db") or self.config.db is None):
|
||||
raise ValueError("App requires Database.")
|
||||
self.db = db or self.config.db
|
||||
## Embedder
|
||||
if embedder is None:
|
||||
raise ValueError("App requires Embedder.")
|
||||
self.embedder = embedder
|
||||
|
||||
# Initialize database
|
||||
self.db._set_embedder(self.embedder)
|
||||
self.db._initialize()
|
||||
# Set collection name from app config for backwards compatibility.
|
||||
if config.collection_name:
|
||||
self.db.set_collection_name(config.collection_name)
|
||||
|
||||
# Add variables that are "shortcuts"
|
||||
if system_prompt:
|
||||
self.llm.config.system_prompt = system_prompt
|
||||
|
||||
# Attributes that aren't subclass related.
|
||||
self.user_asks = []
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.memory = ConversationBufferMemory()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
|
||||
self.u_id = self._load_or_generate_user_id()
|
||||
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
|
||||
# if (self.config.collect_metrics):
|
||||
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
|
||||
thread_telemetry.start()
|
||||
|
||||
@@ -227,10 +257,10 @@ class EmbedChain(JSONSerializable):
|
||||
metadatas = new_metadatas
|
||||
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.count()
|
||||
chunks_before_addition = self.db.count()
|
||||
|
||||
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
count_new_chunks = self.count() - chunks_before_addition
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
@@ -244,13 +274,7 @@ class EmbedChain(JSONSerializable):
|
||||
)
|
||||
]
|
||||
|
||||
def get_llm_model_answer(self):
|
||||
"""
|
||||
Usually implemented by child class
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
|
||||
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
@@ -260,11 +284,12 @@ class EmbedChain(JSONSerializable):
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The content of the document that matched your query.
|
||||
"""
|
||||
query_config = config or self.llm.config
|
||||
|
||||
if where is not None:
|
||||
where = where
|
||||
elif config is not None and config.where is not None:
|
||||
where = config.where
|
||||
elif query_config is not None and query_config.where is not None:
|
||||
where = query_config.where
|
||||
else:
|
||||
where = {}
|
||||
|
||||
@@ -273,64 +298,21 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
contents = self.db.query(
|
||||
input_query=input_query,
|
||||
n_results=config.number_documents,
|
||||
n_results=query_config.number_documents,
|
||||
where=where,
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
def _append_search_and_context(self, context, web_search_result):
|
||||
return f"{context}\nWeb Search Result: {web_search_result}"
|
||||
|
||||
def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param contexts: List of similar documents to the query used as context.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:return: The prompt
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
if not config.history:
|
||||
prompt = config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
|
||||
return prompt
|
||||
|
||||
def get_answer_from_llm(self, prompt, config: ChatConfig):
|
||||
"""
|
||||
Gets an answer based on the given query and context by passing it
|
||||
to an LLM.
|
||||
|
||||
:param query: The query to use.
|
||||
:param context: Similar documents to the query used as context.
|
||||
:return: The answer.
|
||||
"""
|
||||
|
||||
return self.get_llm_model_answer(prompt, config)
|
||||
|
||||
def access_search_and_get_results(self, input_query):
|
||||
from langchain.tools import DuckDuckGoSearchRun
|
||||
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
@@ -340,41 +322,16 @@ class EmbedChain(JSONSerializable):
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
if self.is_docs_site_instance:
|
||||
config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config, where)
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
||||
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
||||
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("query",))
|
||||
thread_telemetry.start()
|
||||
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
return self._stream_query_response(answer)
|
||||
return answer
|
||||
|
||||
def _stream_query_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -382,8 +339,8 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
Maintains the whole conversation in memory.
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `ChatConfig` instance to use as
|
||||
configuration options.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
@@ -393,50 +350,14 @@ class EmbedChain(JSONSerializable):
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
config = ChatConfig()
|
||||
if self.is_docs_site_instance:
|
||||
config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config, where)
|
||||
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
|
||||
if chat_history:
|
||||
config.set_history(chat_history)
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
|
||||
self.memory.chat_memory.add_user_message(input_query)
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
||||
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
||||
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",))
|
||||
thread_telemetry.start()
|
||||
|
||||
if isinstance(answer, str):
|
||||
self.memory.chat_memory.add_ai_message(answer)
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
return self._stream_chat_response(answer)
|
||||
|
||||
def _stream_chat_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
return answer
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
"""
|
||||
@@ -444,34 +365,36 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
:param collection_name: The name of the collection to use.
|
||||
"""
|
||||
self.collection = self.config.db._get_or_create_collection(collection_name)
|
||||
self.db.set_collection_name(collection_name)
|
||||
# Create the collection if it does not exist
|
||||
self.db._get_or_create_collection(collection_name)
|
||||
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
|
||||
# since the main purpose is the creation.
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count the number of embeddings.
|
||||
|
||||
DEPRECATED IN FAVOR OF `db.count()`
|
||||
|
||||
:return: The number of embeddings.
|
||||
"""
|
||||
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
|
||||
return self.db.count()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
|
||||
DEPRECATED IN FAVOR OF `db.reset()`
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
||||
thread_telemetry.start()
|
||||
|
||||
collection_name = self.collection.name
|
||||
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
|
||||
self.db.reset()
|
||||
self.collection = self.config.db._get_or_create_collection(collection_name)
|
||||
# Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
|
||||
# A downside of this implementation is, if you have two instances,
|
||||
# the other instance will not get the updated `self.collection` attribute.
|
||||
# A better way would be to create the collection if it is called again after being reset.
|
||||
# That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
|
||||
# That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
||||
|
||||
0
embedchain/embedder/__init__.py
Normal file
0
embedchain/embedder/__init__.py
Normal file
45
embedchain/embedder/base_embedder.py
Normal file
45
embedchain/embedder/base_embedder.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig
|
||||
|
||||
try:
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
|
||||
class BaseEmbedder:
|
||||
"""
|
||||
Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
|
||||
|
||||
Embedding functions and vector dimensions are set based on the child class you choose.
|
||||
To manually overwrite you can use this classes `set_...` methods.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
|
||||
if config is None:
|
||||
self.config = BaseEmbedderConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||
if not hasattr(embedding_fn, "__call__"):
|
||||
raise ValueError("Embedding function is not a function")
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
def set_vector_dimension(self, vector_dimension: int):
|
||||
self.vector_dimension = vector_dimension
|
||||
|
||||
@staticmethod
|
||||
def _langchain_default_concept(embeddings: Any):
|
||||
"""
|
||||
Langchains default function layout for embeddings.
|
||||
"""
|
||||
|
||||
def embed_function(texts: Documents) -> Embeddings:
|
||||
return embeddings.embed_documents(texts)
|
||||
|
||||
return embed_function
|
||||
21
embedchain/embedder/gpt4all_embedder.py
Normal file
21
embedchain/embedder/gpt4all_embedder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
|
||||
class GPT4AllEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "all-MiniLM-L6-v2"
|
||||
|
||||
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = EmbeddingFunctions.GPT4ALL.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
19
embedchain/embedder/huggingface_embedder.py
Normal file
19
embedchain/embedder/huggingface_embedder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
|
||||
class HuggingFaceEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = EmbeddingFunctions.HUGGING_FACE.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
40
embedchain/embedder/openai_embedder.py
Normal file
40
embedchain/embedder/openai_embedder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
|
||||
class OpenAiEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
if self.config.deployment_name:
|
||||
embeddings = OpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
else:
|
||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
|
||||
) # noqa:E501
|
||||
embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name=self.config.model,
|
||||
)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value)
|
||||
19
embedchain/embedder/vertexai_embedder.py
Normal file
19
embedchain/embedder/vertexai_embedder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import VertexAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
|
||||
class VertexAiEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
embeddings = VertexAIEmbeddings(model_name=config.model)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = EmbeddingFunctions.VERTEX_AI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
0
embedchain/llm/__init__.py
Normal file
0
embedchain/llm/__init__.py
Normal file
29
embedchain/llm/antrophic_llm.py
Normal file
29
embedchain/llm/antrophic_llm.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AntrophicLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AntrophicLlm._get_athrophic_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
39
embedchain/llm/azure_openai_llm.py
Normal file
39
embedchain/llm/azure_openai_llm.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AzureOpenAiLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AzureOpenAiLlm._get_azure_openai_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_azure_openai_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
if not config.deployment_name:
|
||||
raise ValueError("Deployment name must be provided for Azure OpenAI")
|
||||
|
||||
chat = AzureChatOpenAI(
|
||||
deployment_name=config.deployment_name,
|
||||
openai_api_version="2023-05-15",
|
||||
model_name=config.model or "gpt-3.5-turbo",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
214
embedchain/llm/base_llm.py
Normal file
214
embedchain/llm/base_llm.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.schema import BaseMessage
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base_llm_config import (
|
||||
DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
|
||||
|
||||
class BaseLlm(JSONSerializable):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if config is None:
|
||||
self.config = BaseLlmConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.memory = ConversationBufferMemory()
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.history: any = None
|
||||
|
||||
def get_llm_model_answer(self):
|
||||
"""
|
||||
Usually implemented by child class
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_history(self, history: any):
|
||||
self.history = history
|
||||
|
||||
def update_history(self):
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
if chat_history:
|
||||
self.set_history(chat_history)
|
||||
|
||||
def generate_prompt(self, input_query, contexts, **kwargs):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param contexts: List of similar documents to the query used as context.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:return: The prompt
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
if not self.history:
|
||||
prompt = self.config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
# check if it's the default template without history
|
||||
if (
|
||||
not self.config._validate_template_history(self.config.template)
|
||||
and self.config.template.template == DEFAULT_PROMPT
|
||||
):
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
context=context_string, query=input_query, history=self.history
|
||||
)
|
||||
elif not self.config._validate_template_history(self.config.template):
|
||||
logging.warning("Template does not include `$history` key. History is not included in prompt.")
|
||||
prompt = self.config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
prompt = self.config.template.substitute(
|
||||
context=context_string, query=input_query, history=self.history
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _append_search_and_context(self, context, web_search_result):
|
||||
return f"{context}\nWeb Search Result: {web_search_result}"
|
||||
|
||||
def get_answer_from_llm(self, prompt):
|
||||
"""
|
||||
Gets an answer based on the given query and context by passing it
|
||||
to an LLM.
|
||||
|
||||
:param query: The query to use.
|
||||
:param context: Similar documents to the query used as context.
|
||||
:return: The answer.
|
||||
"""
|
||||
|
||||
return self.get_llm_model_answer(prompt)
|
||||
|
||||
def access_search_and_get_results(self, input_query):
|
||||
from langchain.tools import DuckDuckGoSearchRun
|
||||
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def _stream_query_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def _stream_chat_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
query_config = config or self.config
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
query_config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
return self._stream_query_response(answer)
|
||||
|
||||
def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
Maintains the whole conversation in memory.
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
query_config = config or self.config
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
query_config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
|
||||
self.update_history()
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
self.memory.chat_memory.add_user_message(input_query)
|
||||
|
||||
if isinstance(answer, str):
|
||||
self.memory.chat_memory.add_ai_message(answer)
|
||||
logging.info(f"Answer: {answer}")
|
||||
|
||||
# NOTE: Adding to history before and after. This could be seen as redundant.
|
||||
# If we change it, we have to change the tests (no big deal).
|
||||
self.update_history()
|
||||
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
return self._stream_chat_response(answer)
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
return messages
|
||||
47
embedchain/llm/gpt4all_llm.py
Normal file
47
embedchain/llm/gpt4all_llm.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GPT4ALLLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
|
||||
self.instance = GPT4ALLLlm._get_instance(self.config.model)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return self._get_gpt4all_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
|
||||
) from None
|
||||
|
||||
return GPT4All(model_name=model)
|
||||
|
||||
def _get_gpt4all_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
||||
if config.model and config.model != self.config.model:
|
||||
raise RuntimeError(
|
||||
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
|
||||
)
|
||||
|
||||
if config.system_prompt:
|
||||
raise ValueError("OpenSourceApp does not support `system_prompt`")
|
||||
|
||||
response = self.instance.generate(
|
||||
prompt=prompt,
|
||||
streaming=config.stream,
|
||||
top_p=config.top_p,
|
||||
max_tokens=config.max_tokens,
|
||||
temp=config.temperature,
|
||||
)
|
||||
return response
|
||||
27
embedchain/llm/llama2_llm.py
Normal file
27
embedchain/llm/llama2_llm.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms import Replicate
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Llama2Llm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "REPLICATE_API_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
# TODO: Move the model and other inputs into config
|
||||
if self.config.system_prompt:
|
||||
raise ValueError("Llama2App does not support `system_prompt`")
|
||||
llm = Replicate(
|
||||
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
|
||||
input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p},
|
||||
)
|
||||
return llm(prompt)
|
||||
43
embedchain/llm/openai_llm.py
Normal file
43
embedchain/llm/openai_llm.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenAiLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
# NOTE: This class does not use langchain. One reason is that `top_p` is not supported.
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.config.model or "gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=self.config.max_tokens,
|
||||
top_p=self.config.top_p,
|
||||
stream=self.config.stream,
|
||||
)
|
||||
|
||||
if self.config.stream:
|
||||
return self._stream_llm_model_response(response)
|
||||
else:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
29
embedchain/llm/vertex_ai_llm.py
Normal file
29
embedchain/llm/vertex_ai_llm.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class VertexAiLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return VertexAiLlm._get_athrophic_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
|
||||
chat = ChatVertexAI(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
@@ -1,11 +1,22 @@
|
||||
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseVectorDB(JSONSerializable):
|
||||
"""Base class for vector database."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: BaseVectorDbConfig):
|
||||
self.client = self._get_or_create_db()
|
||||
self.config: BaseVectorDbConfig = config
|
||||
|
||||
def _initialize(self):
|
||||
"""
|
||||
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||
|
||||
So it's can't be done in __init__ in one step.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_or_create_db(self):
|
||||
"""Get or create the database."""
|
||||
@@ -14,6 +25,9 @@ class BaseVectorDB(JSONSerializable):
|
||||
def _get_or_create_collection(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _set_embedder(self, embedder: BaseEmbedder):
|
||||
self.embedder = embedder
|
||||
|
||||
def get(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -28,3 +42,6 @@ class BaseVectorDB(JSONSerializable):
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,53 +1,63 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from chromadb.errors import InvalidDimensionException
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.errors import InvalidDimensionException
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
import chromadb
|
||||
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
from chromadb.config import Settings
|
||||
from chromadb.errors import InvalidDimensionException
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChromaDB(BaseVectorDB):
|
||||
"""Vector database using ChromaDB."""
|
||||
|
||||
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
if not hasattr(embedding_fn, "__call__"):
|
||||
raise ValueError("Embedding function is not a function")
|
||||
def __init__(self, config: Optional[ChromaDbConfig] = None):
|
||||
if config:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = ChromaDbConfig()
|
||||
|
||||
self.settings = Settings()
|
||||
for key, value in chroma_settings.items():
|
||||
if hasattr(self.settings, key):
|
||||
setattr(self.settings, key, value)
|
||||
if self.config.chroma_settings:
|
||||
for key, value in self.config.chroma_settings.items():
|
||||
if hasattr(self.settings, key):
|
||||
setattr(self.settings, key, value)
|
||||
|
||||
if host and port:
|
||||
logging.info(f"Connecting to ChromaDB server: {host}:{port}")
|
||||
self.settings.chroma_server_host = host
|
||||
self.settings.chroma_server_http_port = port
|
||||
if self.config.host and self.config.port:
|
||||
logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
|
||||
self.settings.chroma_server_host = self.config.host
|
||||
self.settings.chroma_server_http_port = self.config.port
|
||||
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||
|
||||
else:
|
||||
if db_dir is None:
|
||||
db_dir = "db"
|
||||
if self.config.dir is None:
|
||||
self.config.dir = "db"
|
||||
|
||||
self.settings.persist_directory = db_dir
|
||||
self.settings.persist_directory = self.config.dir
|
||||
self.settings.is_persistent = True
|
||||
|
||||
self.client = chromadb.Client(self.settings)
|
||||
super().__init__()
|
||||
super().__init__(config=self.config)
|
||||
|
||||
def _initialize(self):
|
||||
"""
|
||||
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||
"""
|
||||
if not self.embedder:
|
||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||
self._get_or_create_collection(self.config.collection_name)
|
||||
|
||||
def _get_or_create_db(self):
|
||||
"""Get or create the database."""
|
||||
@@ -55,9 +65,11 @@ class ChromaDB(BaseVectorDB):
|
||||
|
||||
def _get_or_create_collection(self, name):
|
||||
"""Get or create the collection."""
|
||||
if not hasattr(self, "embedder") or not self.embedder:
|
||||
raise ValueError("Cannot create a Chroma database collection without an embedder.")
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=name,
|
||||
embedding_function=self.embedding_fn,
|
||||
embedding_function=self.embedder.embedding_fn,
|
||||
)
|
||||
return self.collection
|
||||
|
||||
@@ -119,9 +131,37 @@ class ChromaDB(BaseVectorDB):
|
||||
contents = [result[0].page_content for result in results_formatted]
|
||||
return contents
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
self.config.collection_name = name
|
||||
self._get_or_create_collection(self.config.collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count the number of embeddings.
|
||||
|
||||
:return: The number of embeddings.
|
||||
"""
|
||||
return self.collection.count()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
"""
|
||||
# Delete all data from the database
|
||||
self.client.reset()
|
||||
try:
|
||||
self.client.reset()
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"For safety reasons, resetting is disabled."
|
||||
'Please enable it by including `chromadb_settings={"allow_reset": True}` in your ChromaDbConfig'
|
||||
) from None
|
||||
# Recreate
|
||||
self._get_or_create_collection(self.config.collection_name)
|
||||
|
||||
# Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
|
||||
# A downside of this implementation is, if you have two instances,
|
||||
# the other instance will not get the updated `self.collection` attribute.
|
||||
# A better way would be to create the collection if it is called again after being reset.
|
||||
# That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
|
||||
# That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -10,7 +10,6 @@ except ImportError:
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models.VectorDimensions import VectorDimensions
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
@@ -18,43 +17,40 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
class ElasticsearchDB(BaseVectorDB):
|
||||
def __init__(
|
||||
self,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
embedding_fn: Callable[[list[str]], list[str]] = None,
|
||||
vector_dim: VectorDimensions = None,
|
||||
collection_name: str = None,
|
||||
config: ElasticsearchDBConfig = None,
|
||||
es_config: ElasticsearchDBConfig = None, # Backwards compatibility
|
||||
):
|
||||
"""
|
||||
Elasticsearch as vector database
|
||||
:param es_config. elasticsearch database config to be used for connection
|
||||
:param embedding_fn: Function to generate embedding vectors.
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
"""
|
||||
if not hasattr(embedding_fn, "__call__"):
|
||||
raise ValueError("Embedding function is not a function")
|
||||
if es_config is None:
|
||||
if config is None and es_config is None:
|
||||
raise ValueError("ElasticsearchDBConfig is required")
|
||||
if vector_dim is None:
|
||||
raise ValueError("Vector Dimension is required to refer correct index and mapping")
|
||||
if collection_name is None:
|
||||
raise ValueError("collection name is required. It cannot be empty")
|
||||
self.embedding_fn = embedding_fn
|
||||
self.config = config or es_config
|
||||
self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS)
|
||||
self.vector_dim = vector_dim
|
||||
self.es_index = f"{collection_name}_{self.vector_dim}"
|
||||
|
||||
# Call parent init here because embedder is needed
|
||||
super().__init__(config=self.config)
|
||||
|
||||
def _initialize(self):
|
||||
"""
|
||||
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||
"""
|
||||
index_settings = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"embeddings": {"type": "dense_vector", "index": False, "dims": self.vector_dim},
|
||||
"embeddings": {"type": "dense_vector", "index": False, "dims": self.embedder.vector_dimension},
|
||||
}
|
||||
}
|
||||
}
|
||||
if not self.client.indices.exists(index=self.es_index):
|
||||
es_index = self._get_index()
|
||||
if not self.client.indices.exists(index=es_index):
|
||||
# create index if not exist
|
||||
print("Creating index", self.es_index, index_settings)
|
||||
self.client.indices.create(index=self.es_index, body=index_settings)
|
||||
super().__init__()
|
||||
print("Creating index", es_index, index_settings)
|
||||
self.client.indices.create(index=es_index, body=index_settings)
|
||||
|
||||
def _get_or_create_db(self):
|
||||
return self.client
|
||||
@@ -85,17 +81,17 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
:param ids: ids of docs
|
||||
"""
|
||||
docs = []
|
||||
embeddings = self.embedding_fn(documents)
|
||||
embeddings = self.config.embedding_fn(documents)
|
||||
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
||||
docs.append(
|
||||
{
|
||||
"_index": self.es_index,
|
||||
"_index": self._get_index(),
|
||||
"_id": id,
|
||||
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
|
||||
}
|
||||
)
|
||||
bulk(self.client, docs)
|
||||
self.client.indices.refresh(index=self.es_index)
|
||||
self.client.indices.refresh(index=self._get_index())
|
||||
return
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
|
||||
@@ -105,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:param where: Optional. to filter data
|
||||
"""
|
||||
input_query_vector = self.embedding_fn(input_query)
|
||||
input_query_vector = self.config.embedding_fn(input_query)
|
||||
query_vector = input_query_vector[0]
|
||||
query = {
|
||||
"script_score": {
|
||||
@@ -120,11 +116,14 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
app_id = where["app_id"]
|
||||
query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}]
|
||||
_source = ["text"]
|
||||
response = self.client.search(index=self.es_index, query=query, _source=_source, size=n_results)
|
||||
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
|
||||
docs = response["hits"]["hits"]
|
||||
contents = [doc["_source"]["text"] for doc in docs]
|
||||
return contents
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
self.config.collection_name = name
|
||||
|
||||
def count(self) -> int:
|
||||
query = {"match_all": {}}
|
||||
response = self.client.count(index=self.es_index, query=query)
|
||||
@@ -136,3 +135,8 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
if self.client.indices.exists(index=self.es_index):
|
||||
# delete index in Es
|
||||
self.client.indices.delete(index=self.es_index)
|
||||
|
||||
def _get_index(self):
|
||||
# NOTE: The method is preferred to an attribute, because if collection name changes,
|
||||
# it's always up-to-date.
|
||||
return f"{self.config.collection_name}_{self.config.vector_dim}"
|
||||
|
||||
Reference in New Issue
Block a user