feat: collection name everywhere (#310)
Co-authored-by: cachho <admin@ch-webdev.com>
This commit is contained in:
@@ -20,6 +20,10 @@ from chromadb.utils import embedding_functions
|
|||||||
config = AppConfig(log_level="DEBUG")
|
config = AppConfig(log_level="DEBUG")
|
||||||
naval_chat_bot = App(config)
|
naval_chat_bot = App(config)
|
||||||
|
|
||||||
|
# Example: specify a custom collection name
|
||||||
|
config = AppConfig(collection_name="naval_chat_bot")
|
||||||
|
naval_chat_bot = App(config)
|
||||||
|
|
||||||
# Example: define your own chunker config for `youtube_video`
|
# Example: define your own chunker config for `youtube_video`
|
||||||
chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
|
||||||
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))
|
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ title: '🔍 Query configurations'
|
|||||||
|
|
||||||
## AppConfig
|
## AppConfig
|
||||||
|
|
||||||
| option | description | type | default |
|
| option | description | type | default |
|
||||||
|-------------|-----------------------|---------------------------------|------------------------|
|
|-----------|-----------------------|---------------------------------|------------------------|
|
||||||
| log_level | log level | string | WARNING |
|
| log_level | log level | string | WARNING |
|
||||||
| embedding_fn| embedding function | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
|
| embedding_fn| embedding function | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
|
||||||
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
|
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
|
||||||
|
| collection_name | initial collection name for the database | string | embedchain_store |
|
||||||
|
|
||||||
|
|
||||||
## AddConfig
|
## AddConfig
|
||||||
|
|||||||
@@ -16,16 +16,22 @@ class AppConfig(BaseAppConfig):
|
|||||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, log_level=None, host=None, port=None, id=None):
|
def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None):
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
:param host: Optional. Hostname for the database server.
|
:param host: Optional. Hostname for the database server.
|
||||||
:param port: Optional. Port for the database server.
|
:param port: Optional. Port for the database server.
|
||||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param collection_name: Optional. Collection name for the database.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id
|
log_level=log_level,
|
||||||
|
embedding_fn=AppConfig.default_embedding_function(),
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
id=id,
|
||||||
|
collection_name=collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -8,19 +8,21 @@ class BaseAppConfig(BaseConfig):
|
|||||||
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None):
|
def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None):
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
:param embedding_fn: Embedding function to use.
|
:param embedding_fn: Embedding function to use.
|
||||||
:param db: Optional. (Vector) database instance to use for embeddings.
|
:param db: Optional. (Vector) database instance to use for embeddings.
|
||||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
|
||||||
:param host: Optional. Hostname for the database server.
|
:param host: Optional. Hostname for the database server.
|
||||||
:param port: Optional. Port for the database server.
|
:param port: Optional. Port for the database server.
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param collection_name: Optional. Collection name for the database.
|
||||||
"""
|
"""
|
||||||
self._setup_logging(log_level)
|
self._setup_logging(log_level)
|
||||||
|
|
||||||
self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
|
self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
|
||||||
|
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||||
self.id = id
|
self.id = id
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
host=None,
|
host=None,
|
||||||
port=None,
|
port=None,
|
||||||
id=None,
|
id=None,
|
||||||
|
collection_name=None,
|
||||||
provider: Providers = None,
|
provider: Providers = None,
|
||||||
model=None,
|
|
||||||
open_source_app_config=None,
|
open_source_app_config=None,
|
||||||
deployment_name=None,
|
deployment_name=None,
|
||||||
):
|
):
|
||||||
@@ -35,9 +35,10 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
:param embedding_fn: Optional. Embedding function to use.
|
:param embedding_fn: Optional. Embedding function to use.
|
||||||
:param embedding_fn_model: Optional. Model name to use for embedding function.
|
:param embedding_fn_model: Optional. Model name to use for embedding function.
|
||||||
:param db: Optional. (Vector) database to use for embeddings.
|
:param db: Optional. (Vector) database to use for embeddings.
|
||||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
|
||||||
:param host: Optional. Hostname for the database server.
|
:param host: Optional. Hostname for the database server.
|
||||||
:param port: Optional. Port for the database server.
|
:param port: Optional. Port for the database server.
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param collection_name: Optional. Collection name for the database.
|
||||||
:param provider: Optional. (Providers): LLM Provider to use.
|
:param provider: Optional. (Providers): LLM Provider to use.
|
||||||
:param open_source_app_config: Optional. Config instance needed for open source apps.
|
:param open_source_app_config: Optional. Config instance needed for open source apps.
|
||||||
"""
|
"""
|
||||||
@@ -58,6 +59,7 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
id=id,
|
id=id,
|
||||||
|
collection_name=collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ class OpenSourceAppConfig(BaseAppConfig):
|
|||||||
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
|
def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None):
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
:param id: Optional. ID of the app. Document metadata will have this id.
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
|
:param collection_name: Optional. Collection name for the database.
|
||||||
:param host: Optional. Hostname for the database server.
|
:param host: Optional. Hostname for the database server.
|
||||||
:param port: Optional. Port for the database server.
|
:param port: Optional. Port for the database server.
|
||||||
:param model: Optional. GPT4ALL uses the model to instantiate the class.
|
:param model: Optional. GPT4ALL uses the model to instantiate the class.
|
||||||
@@ -26,6 +27,7 @@ class OpenSourceAppConfig(BaseAppConfig):
|
|||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
id=id,
|
id=id,
|
||||||
|
collection_name=collection_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class EmbedChain:
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.db_client = self.config.db.client
|
self.db_client = self.config.db.client
|
||||||
self.collection = self.config.db.collection
|
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
|
||||||
self.user_asks = []
|
self.user_asks = []
|
||||||
self.is_docs_site_instance = False
|
self.is_docs_site_instance = False
|
||||||
self.online = False
|
self.online = False
|
||||||
@@ -325,6 +325,14 @@ class EmbedChain:
|
|||||||
memory.chat_memory.add_ai_message(streamed_answer)
|
memory.chat_memory.add_ai_message(streamed_answer)
|
||||||
logging.info(f"Answer: {streamed_answer}")
|
logging.info(f"Answer: {streamed_answer}")
|
||||||
|
|
||||||
|
def set_collection(self, collection_name):
|
||||||
|
"""
|
||||||
|
Set the collection to use.
|
||||||
|
|
||||||
|
:param collection_name: The name of the collection to use.
|
||||||
|
"""
|
||||||
|
self.collection = self.config.db._get_or_create_collection(collection_name)
|
||||||
|
|
||||||
def count(self):
|
def count(self):
|
||||||
"""
|
"""
|
||||||
Count the number of embeddings.
|
Count the number of embeddings.
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ class BaseVectorDB:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = self._get_or_create_db()
|
self.client = self._get_or_create_db()
|
||||||
self.collection = self._get_or_create_collection()
|
|
||||||
|
|
||||||
def _get_or_create_db(self):
|
def _get_or_create_db(self):
|
||||||
"""Get or create the database."""
|
"""Get or create the database."""
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ class ChromaDB(BaseVectorDB):
|
|||||||
"""Get or create the database."""
|
"""Get or create the database."""
|
||||||
return self.client
|
return self.client
|
||||||
|
|
||||||
def _get_or_create_collection(self):
|
def _get_or_create_collection(self, name):
|
||||||
"""Get or create the collection."""
|
"""Get or create the collection."""
|
||||||
return self.client.get_or_create_collection(
|
return self.client.get_or_create_collection(
|
||||||
"embedchain_store",
|
name=name,
|
||||||
embedding_function=self.embedding_fn,
|
embedding_function=self.embedding_fn,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -72,3 +72,186 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
|
||||||
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
|
||||||
|
|
||||||
|
class TestChromaDbDuplicateHandling:
|
||||||
|
def test_duplicates_throw_warning(self, caplog):
|
||||||
|
"""
|
||||||
|
Test that add duplicates throws an error.
|
||||||
|
"""
|
||||||
|
# Start with a clean app
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
app = App()
|
||||||
|
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||||
|
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||||
|
assert "Insert of existing embedding ID: 0" in caplog.text
|
||||||
|
assert "Add of existing embedding ID: 0" in caplog.text
|
||||||
|
|
||||||
|
def test_duplicates_collections_no_warning(self, caplog):
|
||||||
|
"""
|
||||||
|
Test that different collections can have duplicates.
|
||||||
|
"""
|
||||||
|
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
|
||||||
|
|
||||||
|
# Start with a clean app
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
app = App()
|
||||||
|
app.set_collection("test_collection_1")
|
||||||
|
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||||
|
app.set_collection("test_collection_2")
|
||||||
|
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||||
|
assert "Insert of existing embedding ID: 0" not in caplog.text # not
|
||||||
|
assert "Add of existing embedding ID: 0" not in caplog.text # not
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDbCollection(unittest.TestCase):
|
||||||
|
def test_init_with_default_collection(self):
|
||||||
|
"""
|
||||||
|
Test if the `App` instance is initialized with the correct default collection name.
|
||||||
|
"""
|
||||||
|
app = App()
|
||||||
|
|
||||||
|
self.assertEqual(app.collection.name, "embedchain_store")
|
||||||
|
|
||||||
|
def test_init_with_custom_collection(self):
|
||||||
|
"""
|
||||||
|
Test if the `App` instance is initialized with the correct custom collection name.
|
||||||
|
"""
|
||||||
|
config = AppConfig(collection_name="test_collection")
|
||||||
|
app = App(config)
|
||||||
|
|
||||||
|
self.assertEqual(app.collection.name, "test_collection")
|
||||||
|
|
||||||
|
def test_set_collection(self):
|
||||||
|
"""
|
||||||
|
Test if the `App` collection is correctly switched using the `set_collection` method.
|
||||||
|
"""
|
||||||
|
app = App()
|
||||||
|
app.set_collection("test_collection")
|
||||||
|
|
||||||
|
self.assertEqual(app.collection.name, "test_collection")
|
||||||
|
|
||||||
|
def test_changes_encapsulated(self):
|
||||||
|
"""
|
||||||
|
Test that changes to one collection do not affect the other collection
|
||||||
|
"""
|
||||||
|
# Start with a clean app
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
app = App()
|
||||||
|
app.set_collection("test_collection_1")
|
||||||
|
# Collection should be empty when created
|
||||||
|
self.assertEqual(app.count(), 0)
|
||||||
|
|
||||||
|
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||||
|
# After adding, should contain one item
|
||||||
|
self.assertEqual(app.count(), 1)
|
||||||
|
|
||||||
|
app.set_collection("test_collection_2")
|
||||||
|
# New collection is empty
|
||||||
|
self.assertEqual(app.count(), 0)
|
||||||
|
|
||||||
|
# Adding to new collection should not effect existing collection
|
||||||
|
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||||
|
app.set_collection("test_collection_1")
|
||||||
|
# Should still be 1, not 2.
|
||||||
|
self.assertEqual(app.count(), 1)
|
||||||
|
|
||||||
|
def test_collections_are_persistent(self):
|
||||||
|
"""
|
||||||
|
Test that a collection can be picked up later.
|
||||||
|
"""
|
||||||
|
# Start with a clean app
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
app = App()
|
||||||
|
app.set_collection("test_collection_1")
|
||||||
|
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||||
|
del app
|
||||||
|
|
||||||
|
app = App()
|
||||||
|
app.set_collection("test_collection_1")
|
||||||
|
self.assertEqual(app.count(), 1)
|
||||||
|
|
||||||
|
def test_parallel_collections(self):
|
||||||
|
"""
|
||||||
|
Test that two apps can have different collections open in parallel.
|
||||||
|
Switching the names will allow instant access to the collection of
|
||||||
|
the other app.
|
||||||
|
"""
|
||||||
|
# Start clean
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
# Create two apps
|
||||||
|
app1 = App(AppConfig(collection_name="test_collection_1"))
|
||||||
|
app2 = App(AppConfig(collection_name="test_collection_2"))
|
||||||
|
|
||||||
|
# app2 has been created last, but adding to app1 will still write to collection 1.
|
||||||
|
app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||||
|
self.assertEqual(app1.count(), 1)
|
||||||
|
self.assertEqual(app2.count(), 0)
|
||||||
|
|
||||||
|
# Add data
|
||||||
|
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
|
||||||
|
app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||||
|
|
||||||
|
# Swap names and test
|
||||||
|
app1.set_collection('test_collection_2')
|
||||||
|
self.assertEqual(app1.count(), 1)
|
||||||
|
app2.set_collection('test_collection_1')
|
||||||
|
self.assertEqual(app2.count(), 3)
|
||||||
|
|
||||||
|
def test_ids_share_collections(self):
|
||||||
|
"""
|
||||||
|
Different ids should still share collections.
|
||||||
|
"""
|
||||||
|
# Start clean
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
# Create two apps
|
||||||
|
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
|
||||||
|
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
|
||||||
|
|
||||||
|
# Add data
|
||||||
|
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||||
|
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||||
|
|
||||||
|
# Both should have the same collection
|
||||||
|
self.assertEqual(app1.count(), 3)
|
||||||
|
self.assertEqual(app2.count(), 3)
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
"""
|
||||||
|
Resetting should hit all collections and ids.
|
||||||
|
"""
|
||||||
|
# Start clean
|
||||||
|
App().reset()
|
||||||
|
|
||||||
|
# Create four apps.
|
||||||
|
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
||||||
|
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
|
||||||
|
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
|
||||||
|
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1"))
|
||||||
|
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4"))
|
||||||
|
|
||||||
|
# Each one of them get data
|
||||||
|
app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||||
|
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||||
|
app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
|
||||||
|
app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
|
||||||
|
|
||||||
|
# Resetting the first one should reset them all.
|
||||||
|
app1.reset()
|
||||||
|
|
||||||
|
# Reinstantiate them
|
||||||
|
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
|
||||||
|
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
|
||||||
|
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3"))
|
||||||
|
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3"))
|
||||||
|
|
||||||
|
# All should be empty
|
||||||
|
self.assertEqual(app1.count(), 0)
|
||||||
|
self.assertEqual(app2.count(), 0)
|
||||||
|
self.assertEqual(app3.count(), 0)
|
||||||
|
self.assertEqual(app4.count(), 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user