feat: collection name everywhere (#310)
Co-authored-by: cachho <admin@ch-webdev.com>
This commit is contained in:
@@ -16,16 +16,22 @@ class AppConfig(BaseAppConfig):
|
||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=None, host=None, port=None, id=None):
|
||||
def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
"""
|
||||
super().__init__(
|
||||
log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id
|
||||
log_level=log_level,
|
||||
embedding_fn=AppConfig.default_embedding_function(),
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -8,19 +8,21 @@ class BaseAppConfig(BaseConfig):
|
||||
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None):
|
||||
def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param embedding_fn: Embedding function to use.
|
||||
:param db: Optional. (Vector) database instance to use for embeddings.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
|
||||
self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
|
||||
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||
self.id = id
|
||||
return
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ class CustomAppConfig(BaseAppConfig):
|
||||
host=None,
|
||||
port=None,
|
||||
id=None,
|
||||
collection_name=None,
|
||||
provider: Providers = None,
|
||||
model=None,
|
||||
open_source_app_config=None,
|
||||
deployment_name=None,
|
||||
):
|
||||
@@ -35,9 +35,10 @@ class CustomAppConfig(BaseAppConfig):
|
||||
:param embedding_fn: Optional. Embedding function to use.
|
||||
:param embedding_fn_model: Optional. Model name to use for embedding function.
|
||||
:param db: Optional. (Vector) database to use for embeddings.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param provider: Optional. (Providers): LLM Provider to use.
|
||||
:param open_source_app_config: Optional. Config instance needed for open source apps.
|
||||
"""
|
||||
@@ -58,6 +59,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -8,11 +8,12 @@ class OpenSourceAppConfig(BaseAppConfig):
|
||||
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
||||
"""
|
||||
|
||||
def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
|
||||
def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||
:param collection_name: Optional. Collection name for the database.
|
||||
:param host: Optional. Hostname for the database server.
|
||||
:param port: Optional. Port for the database server.
|
||||
:param model: Optional. GPT4ALL uses the model to instantiate the class.
|
||||
@@ -26,6 +27,7 @@ class OpenSourceAppConfig(BaseAppConfig):
|
||||
host=host,
|
||||
port=port,
|
||||
id=id,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -32,7 +32,7 @@ class EmbedChain:
|
||||
|
||||
self.config = config
|
||||
self.db_client = self.config.db.client
|
||||
self.collection = self.config.db.collection
|
||||
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
|
||||
self.user_asks = []
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
@@ -325,6 +325,14 @@ class EmbedChain:
|
||||
memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
"""
|
||||
Set the collection to use.
|
||||
|
||||
:param collection_name: The name of the collection to use.
|
||||
"""
|
||||
self.collection = self.config.db._get_or_create_collection(collection_name)
|
||||
|
||||
def count(self):
|
||||
"""
|
||||
Count the number of embeddings.
|
||||
|
||||
@@ -3,7 +3,6 @@ class BaseVectorDB:
|
||||
|
||||
def __init__(self):
|
||||
self.client = self._get_or_create_db()
|
||||
self.collection = self._get_or_create_collection()
|
||||
|
||||
def _get_or_create_db(self):
|
||||
"""Get or create the database."""
|
||||
|
||||
@@ -39,9 +39,9 @@ class ChromaDB(BaseVectorDB):
|
||||
"""Get or create the database."""
|
||||
return self.client
|
||||
|
||||
def _get_or_create_collection(self):
|
||||
def _get_or_create_collection(self, name):
|
||||
"""Get or create the collection."""
|
||||
return self.client.get_or_create_collection(
|
||||
"embedchain_store",
|
||||
name=name,
|
||||
embedding_function=self.embedding_fn,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user