refactor: classes and configs (#528)
This commit is contained in:
@@ -1,92 +0,0 @@
|
||||
from string import Template
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.QueryConfig import QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
You are a chatbot having a conversation with a human. You are given chat
|
||||
history and context.
|
||||
You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
|
||||
|
||||
$context
|
||||
|
||||
History: $history
|
||||
|
||||
Query: $query
|
||||
|
||||
Helpful Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChatConfig(QueryConfig):
|
||||
"""
|
||||
Config for the `chat` method, inherits from `QueryConfig`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_documents=None,
|
||||
template: Template = None,
|
||||
model=None,
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where=None,
|
||||
):
|
||||
"""
|
||||
Initializes the ChatConfig instance.
|
||||
|
||||
:param number_documents: Number of documents to pull from the database as
|
||||
context.
|
||||
:param template: Optional. The `Template` instance to use as a template for
|
||||
prompt.
|
||||
:param model: Optional. Controls the OpenAI model used.
|
||||
:param temperature: Optional. Controls the randomness of the model's output.
|
||||
Higher values (closer to 1) make output more random,lower values make it more
|
||||
deterministic.
|
||||
:param max_tokens: Optional. Controls how many tokens are generated.
|
||||
:param top_p: Optional. Controls the diversity of words.Higher values
|
||||
(closer to 1) make word selection more diverse, lower values make words less
|
||||
diverse.
|
||||
:param stream: Optional. Control if response is streamed back to the user
|
||||
:param deployment_name: t.b.a.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:raises ValueError: If the template is not valid as template should contain
|
||||
$context and $query and $history
|
||||
"""
|
||||
if template is None:
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
# History is set as 0 to ensure that there is always a history, that way,
|
||||
# there don't have to be two templates. Having two templates would make it
|
||||
# complicated because the history is not user controlled.
|
||||
super().__init__(
|
||||
number_documents=number_documents,
|
||||
template=template,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
history=[0],
|
||||
stream=stream,
|
||||
deployment_name=deployment_name,
|
||||
system_prompt=system_prompt,
|
||||
where=where,
|
||||
)
|
||||
|
||||
def set_history(self, history):
|
||||
"""
|
||||
Chat history is not user provided and not set at initialization time
|
||||
|
||||
:param history: (string) history to set
|
||||
"""
|
||||
self.history = history
|
||||
return
|
||||
@@ -1,9 +1,13 @@
|
||||
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
|
||||
from .apps.AppConfig import AppConfig # noqa: F401
|
||||
from .apps.CustomAppConfig import CustomAppConfig # noqa: F401
|
||||
from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401
|
||||
from .BaseConfig import BaseConfig # noqa: F401
|
||||
from .ChatConfig import ChatConfig # noqa: F401
|
||||
from .QueryConfig import QueryConfig # noqa: F401
|
||||
from .vectordbs.ElasticsearchDBConfig import \
|
||||
ElasticsearchDBConfig # noqa: F401
|
||||
# flake8: noqa: F401
|
||||
|
||||
from .AddConfig import AddConfig, ChunkerConfig
|
||||
from .apps.AppConfig import AppConfig
|
||||
from .apps.CustomAppConfig import CustomAppConfig
|
||||
from .apps.OpenSourceAppConfig import OpenSourceAppConfig
|
||||
from .BaseConfig import BaseConfig
|
||||
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig
|
||||
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig as EmbedderConfig
|
||||
from .llm.base_llm_config import BaseLlmConfig
|
||||
from .llm.base_llm_config import BaseLlmConfig as LlmConfig
|
||||
from .vectordbs.ChromaDbConfig import ChromaDbConfig
|
||||
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
@@ -23,44 +14,14 @@ class AppConfig(BaseAppConfig):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
collect_metrics: Optional[bool] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
"""
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=AppConfig.default_embedding_function(),
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
collect_metrics=collect_metrics,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_embedding_function():
|
||||
"""
|
||||
Sets embedding function to default (`text-embedding-ada-002`).
|
||||
|
||||
:raises ValueError: If the template is not valid as template should contain
|
||||
$context and $query
|
||||
:returns: The default embedding function for the app class.
|
||||
"""
|
||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided") # noqa:E501
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.models import VectorDatabases, VectorDimensions
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
@@ -14,81 +14,38 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
embedding_fn=None,
|
||||
db=None,
|
||||
host=None,
|
||||
port=None,
|
||||
db: Optional[BaseVectorDB] = None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
collect_metrics: bool = True,
|
||||
db_type: VectorDatabases = None,
|
||||
vector_dim: VectorDimensions = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
chroma_settings: dict = {},
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param embedding_fn: Embedding function to use.
|
||||
:param db: Optional. (Vector) database instance to use for embeddings.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param db: Optional. (Vector) database instance to use for embeddings. Deprecated in favor of app(..., db).
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param db_type: Optional. type of Vector database to use
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param db_type: Optional. Initializes a default vector database of the given type.
|
||||
Using the `db` argument is preferred.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
:param collection_name: Optional. Default collection name.
|
||||
It's recommended to use app.set_collection_name() instead.
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||
self.db = BaseAppConfig.get_db(
|
||||
db=db,
|
||||
embedding_fn=embedding_fn,
|
||||
host=host,
|
||||
port=port,
|
||||
db_type=db_type,
|
||||
vector_dim=vector_dim,
|
||||
collection_name=self.collection_name,
|
||||
es_config=es_config,
|
||||
chroma_settings=chroma_settings,
|
||||
)
|
||||
self.id = id
|
||||
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
||||
return
|
||||
self.collection_name = collection_name
|
||||
|
||||
@staticmethod
|
||||
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
|
||||
"""
|
||||
Get db based on db_type, db with default database (`ChromaDb`)
|
||||
:param Optional. (Vector) database to use for embeddings.
|
||||
:param embedding_fn: Embedding function to use in database.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param db_type: Optional. db type to use. Supported values (`es`, `chroma`)
|
||||
:param vector_dim: Vector dimension generated by embedding fn
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:raises ValueError: BaseAppConfig knows no default embedding function.
|
||||
:returns: database instance
|
||||
"""
|
||||
if db:
|
||||
return db
|
||||
|
||||
if embedding_fn is None:
|
||||
raise ValueError("ChromaDb cannot be instantiated without an embedding function")
|
||||
|
||||
if db_type == VectorDatabases.ELASTICSEARCH:
|
||||
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB
|
||||
|
||||
return ElasticsearchDB(
|
||||
embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config
|
||||
self._db = db
|
||||
logging.warning(
|
||||
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
|
||||
"Such as `app(config=config, db=db)`."
|
||||
)
|
||||
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
|
||||
if collection_name:
|
||||
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
||||
return
|
||||
|
||||
def _setup_logging(self, debug_level):
|
||||
level = logging.WARNING # Default level
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
|
||||
VectorDimensions)
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
|
||||
@@ -22,123 +18,23 @@ class CustomAppConfig(BaseAppConfig):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
embedding_fn: EmbeddingFunctions = None,
|
||||
embedding_fn_model=None,
|
||||
db=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
provider: Providers = None,
|
||||
open_source_app_config=None,
|
||||
deployment_name=None,
|
||||
collect_metrics: Optional[bool] = None,
|
||||
db_type: VectorDatabases = None,
|
||||
es_config: ElasticsearchDBConfig = None,
|
||||
chroma_settings: dict = {},
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param embedding_fn: Optional. Embedding function to use.
|
||||
:param embedding_fn_model: Optional. Model name to use for embedding function.
|
||||
:param db: Optional. (Vector) database to use for embeddings.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param provider: Optional. (Providers): LLM Provider to use.
|
||||
:param open_source_app_config: Optional. Config instance needed for open source apps.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param db_type: Optional. type of Vector database to use.
|
||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
:param collection_name: Optional. Default collection name.
|
||||
It's recommended to use app.set_collection_name() instead.
|
||||
"""
|
||||
if provider:
|
||||
self.provider = provider
|
||||
else:
|
||||
raise ValueError("CustomApp must have a provider assigned.")
|
||||
|
||||
self.open_source_app_config = open_source_app_config
|
||||
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=CustomAppConfig.embedding_function(
|
||||
embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name
|
||||
),
|
||||
db=db,
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
collect_metrics=collect_metrics,
|
||||
db_type=db_type,
|
||||
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
|
||||
es_config=es_config,
|
||||
chroma_settings=chroma_settings,
|
||||
log_level=log_level, db=db, id=id, collect_metrics=collect_metrics, collection_name=collection_name
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def langchain_default_concept(embeddings: Any):
|
||||
"""
|
||||
Langchains default function layout for embeddings.
|
||||
"""
|
||||
|
||||
def embed_function(texts: Documents) -> Embeddings:
|
||||
return embeddings.embed_documents(texts)
|
||||
|
||||
return embed_function
|
||||
|
||||
@staticmethod
|
||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(
|
||||
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
|
||||
)
|
||||
|
||||
if embedding_function == EmbeddingFunctions.OPENAI:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
if model:
|
||||
embeddings = OpenAIEmbeddings(model=model)
|
||||
else:
|
||||
if deployment_name:
|
||||
embeddings = OpenAIEmbeddings(deployment=deployment_name)
|
||||
else:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=model)
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
|
||||
from langchain.embeddings import VertexAIEmbeddings
|
||||
|
||||
embeddings = VertexAIEmbeddings(model_name=model)
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.GPT4ALL:
|
||||
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model)
|
||||
|
||||
@staticmethod
|
||||
def get_vector_dimension(embedding_function: EmbeddingFunctions):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(f"Invalid option: '{embedding_function}'.")
|
||||
|
||||
if embedding_function == EmbeddingFunctions.OPENAI:
|
||||
return VectorDimensions.OPENAI.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
return VectorDimensions.HUGGING_FACE.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.VERTEX_AI:
|
||||
return VectorDimensions.VERTEX_AI.value
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.GPT4ALL:
|
||||
return VectorDimensions.GPT4ALL.value
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
@@ -16,47 +14,21 @@ class OpenSourceAppConfig(BaseAppConfig):
|
||||
def __init__(
|
||||
self,
|
||||
log_level=None,
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
collect_metrics: Optional[bool] = None,
|
||||
model=None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||
:param model: Optional. GPT4ALL uses the model to instantiate the class.
|
||||
So unlike `App`, it has to be provided before querying.
|
||||
:param collection_name: Optional. Default collection name.
|
||||
It's recommended to use app.db.set_collection_name() instead.
|
||||
"""
|
||||
self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin"
|
||||
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=OpenSourceAppConfig.default_embedding_function(),
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
collect_metrics=collect_metrics,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_embedding_function():
|
||||
"""
|
||||
Sets embedding function to default (`all-MiniLM-L6-v2`).
|
||||
|
||||
:returns: The default embedding function
|
||||
"""
|
||||
try:
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
raise ModuleNotFoundError(
|
||||
"The open source app requires extra dependencies. Install with `pip install embedchain[opensource]`"
|
||||
) from None
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
|
||||
|
||||
10
embedchain/config/embedder/BaseEmbedderConfig.py
Normal file
10
embedchain/config/embedder/BaseEmbedderConfig.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseEmbedderConfig:
|
||||
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
0
embedchain/config/embedder/__init__.py
Normal file
0
embedchain/config/embedder/__init__.py
Normal file
0
embedchain/config/llm/__init__.py
Normal file
0
embedchain/config/llm/__init__.py
Normal file
@@ -50,7 +50,7 @@ history_re = re.compile(r"\$\{*history\}*")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QueryConfig(BaseConfig):
|
||||
class BaseLlmConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `query` method.
|
||||
"""
|
||||
@@ -63,7 +63,6 @@ class QueryConfig(BaseConfig):
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
history=None,
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
system_prompt: Optional[str] = None,
|
||||
@@ -84,7 +83,6 @@ class QueryConfig(BaseConfig):
|
||||
:param top_p: Optional. Controls the diversity of words. Higher values
|
||||
(closer to 1) make word selection more diverse, lower values make words less
|
||||
diverse.
|
||||
:param history: Optional. A list of strings to consider as history.
|
||||
:param stream: Optional. Control if response is streamed back to user
|
||||
:param deployment_name: t.b.a.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
@@ -97,19 +95,8 @@ class QueryConfig(BaseConfig):
|
||||
else:
|
||||
self.number_documents = number_documents
|
||||
|
||||
if not history:
|
||||
self.history = None
|
||||
else:
|
||||
if len(history) == 0:
|
||||
self.history = None
|
||||
else:
|
||||
self.history = history
|
||||
|
||||
if template is None:
|
||||
if self.history is None:
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
self.temperature = temperature if temperature else 0
|
||||
self.max_tokens = max_tokens if max_tokens else 1000
|
||||
@@ -121,10 +108,7 @@ class QueryConfig(BaseConfig):
|
||||
if self.validate_template(template):
|
||||
self.template = template
|
||||
else:
|
||||
if self.history is None:
|
||||
raise ValueError("`template` should have `query` and `context` keys")
|
||||
else:
|
||||
raise ValueError("`template` should have `query`, `context` and `history` keys")
|
||||
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
|
||||
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
@@ -138,11 +122,13 @@ class QueryConfig(BaseConfig):
|
||||
:param template: the template to validate
|
||||
:return: Boolean, valid (true) or invalid (false)
|
||||
"""
|
||||
if self.history is None:
|
||||
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||
else:
|
||||
return (
|
||||
re.search(query_re, template.template)
|
||||
and re.search(context_re, template.template)
|
||||
and re.search(history_re, template.template)
|
||||
)
|
||||
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||
|
||||
def _validate_template_history(self, template: Template):
|
||||
"""
|
||||
validate the history template for history
|
||||
|
||||
:param template: the template to validate
|
||||
:return: Boolean, valid (true) or invalid (false)
|
||||
"""
|
||||
return re.search(history_re, template.template)
|
||||
17
embedchain/config/vectordbs/BaseVectorDbConfig.py
Normal file
17
embedchain/config/vectordbs/BaseVectorDbConfig.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
|
||||
class BaseVectorDbConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name or "embedchain_store"
|
||||
self.dir = dir or "db"
|
||||
self.host = host
|
||||
self.port = port
|
||||
21
embedchain/config/vectordbs/ChromaDbConfig.py
Normal file
21
embedchain/config/vectordbs/ChromaDbConfig.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChromaDbConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
chroma_settings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
:param chroma_settings: Optional. Chroma settings for connection.
|
||||
"""
|
||||
self.chroma_settings = chroma_settings
|
||||
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
||||
@@ -1,17 +1,25 @@
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDBConfig(BaseConfig):
|
||||
"""
|
||||
Config to initialize an elasticsearch client.
|
||||
:param es_url. elasticsearch url or list of nodes url to be used for connection
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
"""
|
||||
|
||||
def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, List[str]] = None,
|
||||
**ES_EXTRA_PARAMS: Dict[str, any],
|
||||
):
|
||||
"""
|
||||
Config to initialize an elasticsearch client.
|
||||
:param es_url. elasticsearch url or list of nodes url to be used for connection
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
"""
|
||||
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||
self.ES_URL = es_url
|
||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
Reference in New Issue
Block a user