refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -1,92 +0,0 @@
from string import Template
from typing import Optional
from embedchain.config.QueryConfig import QueryConfig
from embedchain.helper_classes.json_serializable import register_deserializable
DEFAULT_PROMPT = """
You are a chatbot having a conversation with a human. You are given chat
history and context.
You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
$context
History: $history
Query: $query
Helpful Answer:
""" # noqa:E501
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
@register_deserializable
class ChatConfig(QueryConfig):
"""
Config for the `chat` method, inherits from `QueryConfig`.
"""
def __init__(
self,
number_documents=None,
template: Template = None,
model=None,
temperature=None,
max_tokens=None,
top_p=None,
stream: bool = False,
deployment_name=None,
system_prompt: Optional[str] = None,
where=None,
):
"""
Initializes the ChatConfig instance.
:param number_documents: Number of documents to pull from the database as
context.
:param template: Optional. The `Template` instance to use as a template for
prompt.
:param model: Optional. Controls the OpenAI model used.
:param temperature: Optional. Controls the randomness of the model's output.
Higher values (closer to 1) make output more random,lower values make it more
deterministic.
:param max_tokens: Optional. Controls how many tokens are generated.
:param top_p: Optional. Controls the diversity of words.Higher values
(closer to 1) make word selection more diverse, lower values make words less
diverse.
:param stream: Optional. Control if response is streamed back to the user
:param deployment_name: t.b.a.
:param system_prompt: Optional. System prompt string.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:raises ValueError: If the template is not valid as template should contain
$context and $query and $history
"""
if template is None:
template = DEFAULT_PROMPT_TEMPLATE
# History is set as 0 to ensure that there is always a history, that way,
# there don't have to be two templates. Having two templates would make it
# complicated because the history is not user controlled.
super().__init__(
number_documents=number_documents,
template=template,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
history=[0],
stream=stream,
deployment_name=deployment_name,
system_prompt=system_prompt,
where=where,
)
def set_history(self, history):
"""
Chat history is not user provided and not set at initialization time
:param history: (string) history to set
"""
self.history = history
return

View File

@@ -1,9 +1,13 @@
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
from .apps.AppConfig import AppConfig # noqa: F401
from .apps.CustomAppConfig import CustomAppConfig # noqa: F401
from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
from .BaseConfig import BaseConfig # noqa: F401
from .ChatConfig import ChatConfig # noqa: F401
from .QueryConfig import QueryConfig # noqa: F401
from .vectordbs.ElasticsearchDBConfig import \
ElasticsearchDBConfig # noqa: F401
# flake8: noqa: F401
from .AddConfig import AddConfig, ChunkerConfig
from .apps.AppConfig import AppConfig
from .apps.CustomAppConfig import CustomAppConfig
from .apps.OpenSourceAppConfig import OpenSourceAppConfig
from .BaseConfig import BaseConfig
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig as EmbedderConfig
from .llm.base_llm_config import BaseLlmConfig
from .llm.base_llm_config import BaseLlmConfig as LlmConfig
from .vectordbs.ChromaDbConfig import ChromaDbConfig
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig

View File

@@ -1,14 +1,5 @@
import os
from typing import Optional
try:
from chromadb.utils import embedding_functions
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.utils import embedding_functions
from embedchain.helper_classes.json_serializable import register_deserializable
from .BaseAppConfig import BaseAppConfig
@@ -23,44 +14,14 @@ class AppConfig(BaseAppConfig):
def __init__(
self,
log_level=None,
host=None,
port=None,
id=None,
collection_name=None,
collect_metrics: Optional[bool] = None,
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
"""
super().__init__(
log_level=log_level,
embedding_fn=AppConfig.default_embedding_function(),
host=host,
port=port,
id=id,
collection_name=collection_name,
collect_metrics=collect_metrics,
)
@staticmethod
def default_embedding_function():
"""
Sets embedding function to default (`text-embedding-ada-002`).
:raises ValueError: If the template is not valid as template should contain
$context and $query
:returns: The default embedding function for the app class.
"""
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
return embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name="text-embedding-ada-002",
)
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)

View File

@@ -1,9 +1,9 @@
import logging
from typing import Optional
from embedchain.config.BaseConfig import BaseConfig
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.models import VectorDatabases, VectorDimensions
from embedchain.vectordb.base_vector_db import BaseVectorDB
class BaseAppConfig(BaseConfig, JSONSerializable):
@@ -14,81 +14,38 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
def __init__(
self,
log_level=None,
embedding_fn=None,
db=None,
host=None,
port=None,
db: Optional[BaseVectorDB] = None,
id=None,
collection_name=None,
collect_metrics: bool = True,
db_type: VectorDatabases = None,
vector_dim: VectorDimensions = None,
es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param embedding_fn: Embedding function to use.
:param db: Optional. (Vector) database instance to use for embeddings.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db).
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param db_type: Optional. type of Vector database to use
:param vector_dim: Vector dimension generated by embedding fn
:param db_type: Optional. Initializes a default vector database of the given type.
Using the `db` argument is preferred.
:param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
:param collection_name: Optional. Default collection name.
It's recommended to use app.set_collection_name() instead.
"""
self._setup_logging(log_level)
self.collection_name = collection_name if collection_name else "embedchain_store"
self.db = BaseAppConfig.get_db(
db=db,
embedding_fn=embedding_fn,
host=host,
port=port,
db_type=db_type,
vector_dim=vector_dim,
collection_name=self.collection_name,
es_config=es_config,
chroma_settings=chroma_settings,
)
self.id = id
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
return
self.collection_name = collection_name
@staticmethod
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
"""
Get db based on db_type, db with default database (`ChromaDb`)
:param Optional. (Vector) database to use for embeddings.
:param embedding_fn: Embedding function to use in database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param db_type: Optional. db type to use. Supported values (`es`, `chroma`)
:param vector_dim: Vector dimension generated by embedding fn
:param collection_name: Optional. Collection name for the database.
:param es_config: Optional. elasticsearch database config to be used for connection
:raises ValueError: BaseAppConfig knows no default embedding function.
:returns: database instance
"""
if db:
return db
if embedding_fn is None:
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
if db_type == VectorDatabases.ELASTICSEARCH:
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
return ElasticsearchDB(
embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config
self._db = db
logging.warning(
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
"Such as `app(config=config, db=db)`."
)
from embedchain.vectordb.chroma_db import ChromaDB
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
if collection_name:
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
return
def _setup_logging(self, debug_level):
level = logging.WARNING # Default level

View File

@@ -1,12 +1,8 @@
from typing import Any, Optional
from typing import Optional
from chromadb.api.types import Documents, Embeddings
from dotenv import load_dotenv
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
VectorDimensions)
from .BaseAppConfig import BaseAppConfig
@@ -22,123 +18,23 @@ class CustomAppConfig(BaseAppConfig):
def __init__(
self,
log_level=None,
embedding_fn: EmbeddingFunctions = None,
embedding_fn_model=None,
db=None,
host=None,
port=None,
id=None,
collection_name=None,
provider: Providers = None,
open_source_app_config=None,
deployment_name=None,
collect_metrics: Optional[bool] = None,
db_type: VectorDatabases = None,
es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param embedding_fn: Optional. Embedding function to use.
:param embedding_fn_model: Optional. Model name to use for embedding function.
:param db: Optional. (Vector) database to use for embeddings.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param provider: Optional. (Providers): LLM Provider to use.
:param open_source_app_config: Optional. Config instance needed for open source apps.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param db_type: Optional. type of Vector database to use.
:param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
:param collection_name: Optional. Default collection name.
It's recommended to use app.set_collection_name() instead.
"""
if provider:
self.provider = provider
else:
raise ValueError("CustomApp must have a provider assigned.")
self.open_source_app_config = open_source_app_config
super().__init__(
log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function(
embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name
),
db=db,
host=host,
port=port,
id=id,
collection_name=collection_name,
collect_metrics=collect_metrics,
db_type=db_type,
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
es_config=es_config,
chroma_settings=chroma_settings,
log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name
)
@staticmethod
def langchain_default_concept(embeddings: Any):
"""
Langchains default function layout for embeddings.
"""
def embed_function(texts: Documents) -> Embeddings:
return embeddings.embed_documents(texts)
return embed_function
@staticmethod
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
)
if embedding_function == EmbeddingFunctions.OPENAI:
from langchain.embeddings import OpenAIEmbeddings
if model:
embeddings = OpenAIEmbeddings(model=model)
else:
if deployment_name:
embeddings = OpenAIEmbeddings(deployment=deployment_name)
else:
embeddings = OpenAIEmbeddings()
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=model)
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
from langchain.embeddings import VertexAIEmbeddings
embeddings = VertexAIEmbeddings(model_name=model)
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.GPT4ALL:
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
from chromadb.utils import embedding_functions
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model)
@staticmethod
def get_vector_dimension(embedding_function: EmbeddingFunctions):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(f"Invalid option: '{embedding_function}'.")
if embedding_function == EmbeddingFunctions.OPENAI:
return VectorDimensions.OPENAI.value
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
return VectorDimensions.HUGGING_FACE.value
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
return VectorDimensions.VERTEX_AI.value
elif embedding_function == EmbeddingFunctions.GPT4ALL:
return VectorDimensions.GPT4ALL.value

View File

@@ -1,7 +1,5 @@
from typing import Optional
from chromadb.utils import embedding_functions
from embedchain.helper_classes.json_serializable import register_deserializable
from .BaseAppConfig import BaseAppConfig
@@ -16,47 +14,21 @@ class OpenSourceAppConfig(BaseAppConfig):
def __init__(
self,
log_level=None,
host=None,
port=None,
id=None,
collection_name=None,
collect_metrics: Optional[bool] = None,
model=None,
collection_name: Optional[str] = None,
):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param model: Optional. GPT4ALL uses the model to instantiate the class.
So unlike `App`, it has to be provided before querying.
:param collection_name: Optional. Default collection name.
It's recommended to use app.db.set_collection_name() instead.
"""
self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"
super().__init__(
log_level=log_level,
embedding_fn=OpenSourceAppConfig.default_embedding_function(),
host=host,
port=port,
id=id,
collection_name=collection_name,
collect_metrics=collect_metrics,
)
@staticmethod
def default_embedding_function():
"""
Sets embedding function to default (`all-MiniLM-L6-v2`).
:returns: The default embedding function
"""
try:
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
except ValueError as e:
print(e)
raise ModuleNotFoundError(
"The open source app requires extra dependencies. Install with `pip install embedchain[opensource]`"
) from None
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)

View File

@@ -0,0 +1,10 @@
from typing import Optional
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class BaseEmbedderConfig:
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
self.model = model
self.deployment_name = deployment_name

View File

View File

View File

@@ -50,7 +50,7 @@ history_re = re.compile(r"\$\{*history\}*")
@register_deserializable
class QueryConfig(BaseConfig):
class BaseLlmConfig(BaseConfig):
"""
Config for the `query` method.
"""
@@ -63,7 +63,6 @@ class QueryConfig(BaseConfig):
temperature=None,
max_tokens=None,
top_p=None,
history=None,
stream: bool = False,
deployment_name=None,
system_prompt: Optional[str] = None,
@@ -84,7 +83,6 @@ class QueryConfig(BaseConfig):
:param top_p: Optional. Controls the diversity of words. Higher values
(closer to 1) make word selection more diverse, lower values make words less
diverse.
:param history: Optional. A list of strings to consider as history.
:param stream: Optional. Control if response is streamed back to user
:param deployment_name: t.b.a.
:param system_prompt: Optional. System prompt string.
@@ -97,19 +95,8 @@ class QueryConfig(BaseConfig):
else:
self.number_documents = number_documents
if not history:
self.history = None
else:
if len(history) == 0:
self.history = None
else:
self.history = history
if template is None:
if self.history is None:
template = DEFAULT_PROMPT_TEMPLATE
else:
template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
template = DEFAULT_PROMPT_TEMPLATE
self.temperature = temperature if temperature else 0
self.max_tokens = max_tokens if max_tokens else 1000
@@ -121,10 +108,7 @@ class QueryConfig(BaseConfig):
if self.validate_template(template):
self.template = template
else:
if self.history is None:
raise ValueError("`template` should have `query` and `context` keys")
else:
raise ValueError("`template` should have `query`, `context` and `history` keys")
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
@@ -138,11 +122,13 @@ class QueryConfig(BaseConfig):
:param template: the template to validate
:return: Boolean, valid (true) or invalid (false)
"""
if self.history is None:
return re.search(query_re, template.template) and re.search(context_re, template.template)
else:
return (
re.search(query_re, template.template)
and re.search(context_re, template.template)
and re.search(history_re, template.template)
)
return re.search(query_re, template.template) and re.search(context_re, template.template)
def _validate_template_history(self, template: Template):
"""
validate the history template for history
:param template: the template to validate
:return: Boolean, valid (true) or invalid (false)
"""
return re.search(history_re, template.template)

View File

@@ -0,0 +1,17 @@
from typing import Optional
from embedchain.config.BaseConfig import BaseConfig
class BaseVectorDbConfig(BaseConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
):
self.collection_name = collection_name or "embedchain_store"
self.dir = dir or "db"
self.host = host
self.port = port

View File

@@ -0,0 +1,21 @@
from typing import Optional
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class ChromaDbConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
chroma_settings: Optional[dict] = None,
):
"""
:param chroma_settings: Optional. Chroma settings for connection.
"""
self.chroma_settings = chroma_settings
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -1,17 +1,25 @@
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
from embedchain.config.BaseConfig import BaseConfig
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseConfig):
"""
Config to initialize an elasticsearch client.
:param es_url. elasticsearch url or list of nodes url to be used for connection
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
"""
def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, List[str]] = None,
**ES_EXTRA_PARAMS: Dict[str, any],
):
"""
Config to initialize an elasticsearch client.
:param es_url. elasticsearch url or list of nodes url to be used for connection
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
"""
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
super().__init__(collection_name=collection_name, dir=dir)