Upgrade the chromadb version to 0.4.8 and open its settings configuration. (#517)
This commit is contained in:
@@ -24,6 +24,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
|||||||
db_type: VectorDatabases = None,
|
db_type: VectorDatabases = None,
|
||||||
vector_dim: VectorDimensions = None,
|
vector_dim: VectorDimensions = None,
|
||||||
es_config: ElasticsearchDBConfig = None,
|
es_config: ElasticsearchDBConfig = None,
|
||||||
|
chroma_settings: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
@@ -38,6 +39,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
|||||||
:param db_type: Optional. type of Vector database to use
|
:param db_type: Optional. type of Vector database to use
|
||||||
:param vector_dim: Vector dimension generated by embedding fn
|
:param vector_dim: Vector dimension generated by embedding fn
|
||||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||||
|
:param chroma_settings: Optional. Chroma settings for connection.
|
||||||
"""
|
"""
|
||||||
self._setup_logging(log_level)
|
self._setup_logging(log_level)
|
||||||
self.collection_name = collection_name if collection_name else "embedchain_store"
|
self.collection_name = collection_name if collection_name else "embedchain_store"
|
||||||
@@ -50,13 +52,14 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
|||||||
vector_dim=vector_dim,
|
vector_dim=vector_dim,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
es_config=es_config,
|
es_config=es_config,
|
||||||
|
chroma_settings=chroma_settings,
|
||||||
)
|
)
|
||||||
self.id = id
|
self.id = id
|
||||||
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config):
|
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
|
||||||
"""
|
"""
|
||||||
Get db based on db_type, db with default database (`ChromaDb`)
|
Get db based on db_type, db with default database (`ChromaDb`)
|
||||||
:param Optional. (Vector) database to use for embeddings.
|
:param Optional. (Vector) database to use for embeddings.
|
||||||
@@ -85,7 +88,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
|||||||
|
|
||||||
from embedchain.vectordb.chroma_db import ChromaDB
|
from embedchain.vectordb.chroma_db import ChromaDB
|
||||||
|
|
||||||
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port)
|
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
|
||||||
|
|
||||||
def _setup_logging(self, debug_level):
|
def _setup_logging(self, debug_level):
|
||||||
level = logging.WARNING # Default level
|
level = logging.WARNING # Default level
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
collect_metrics: Optional[bool] = None,
|
collect_metrics: Optional[bool] = None,
|
||||||
db_type: VectorDatabases = None,
|
db_type: VectorDatabases = None,
|
||||||
es_config: ElasticsearchDBConfig = None,
|
es_config: ElasticsearchDBConfig = None,
|
||||||
|
chroma_settings: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
@@ -51,6 +52,7 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
|
||||||
:param db_type: Optional. type of Vector database to use.
|
:param db_type: Optional. type of Vector database to use.
|
||||||
:param es_config: Optional. elasticsearch database config to be used for connection
|
:param es_config: Optional. elasticsearch database config to be used for connection
|
||||||
|
:param chroma_settings: Optional. Chroma settings for connection.
|
||||||
"""
|
"""
|
||||||
if provider:
|
if provider:
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
@@ -73,6 +75,7 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
db_type=db_type,
|
db_type=db_type,
|
||||||
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
|
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
|
||||||
es_config=es_config,
|
es_config=es_config,
|
||||||
|
chroma_settings=chroma_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -22,23 +22,31 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
|
|||||||
class ChromaDB(BaseVectorDB):
|
class ChromaDB(BaseVectorDB):
|
||||||
"""Vector database using ChromaDB."""
|
"""Vector database using ChromaDB."""
|
||||||
|
|
||||||
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
|
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
|
||||||
self.embedding_fn = embedding_fn
|
self.embedding_fn = embedding_fn
|
||||||
|
|
||||||
if not hasattr(embedding_fn, "__call__"):
|
if not hasattr(embedding_fn, "__call__"):
|
||||||
raise ValueError("Embedding function is not a function")
|
raise ValueError("Embedding function is not a function")
|
||||||
|
|
||||||
|
self.settings = Settings()
|
||||||
|
for key, value in chroma_settings.items():
|
||||||
|
if hasattr(self.settings, key):
|
||||||
|
setattr(self.settings, key, value)
|
||||||
|
|
||||||
if host and port:
|
if host and port:
|
||||||
logging.info(f"Connecting to ChromaDB server: {host}:{port}")
|
logging.info(f"Connecting to ChromaDB server: {host}:{port}")
|
||||||
self.client = chromadb.HttpClient(host=host, port=port)
|
self.settings.chroma_server_host = host
|
||||||
|
self.settings.chroma_server_http_port = port
|
||||||
|
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if db_dir is None:
|
if db_dir is None:
|
||||||
db_dir = "db"
|
db_dir = "db"
|
||||||
self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
|
|
||||||
self.client = chromadb.PersistentClient(
|
self.settings.persist_directory = db_dir
|
||||||
path=db_dir,
|
self.settings.is_persistent = True
|
||||||
settings=self.settings,
|
|
||||||
)
|
self.client = chromadb.Client(self.settings)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _get_or_create_db(self):
|
def _get_or_create_db(self):
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ langchain = "^0.0.279"
|
|||||||
requests = "^2.31.0"
|
requests = "^2.31.0"
|
||||||
openai = "^0.27.5"
|
openai = "^0.27.5"
|
||||||
tiktoken = "^0.4.0"
|
tiktoken = "^0.4.0"
|
||||||
chromadb ="^0.4.2"
|
chromadb ="^0.4.8"
|
||||||
youtube-transcript-api = "^0.6.1"
|
youtube-transcript-api = "^0.6.1"
|
||||||
beautifulsoup4 = "^4.12.2"
|
beautifulsoup4 = "^4.12.2"
|
||||||
pypdf = "^3.11.0"
|
pypdf = "^3.11.0"
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig, CustomAppConfig
|
||||||
|
from embedchain.models import EmbeddingFunctions, Providers
|
||||||
|
|
||||||
|
|
||||||
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||||
@@ -42,7 +43,11 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Test if the `App` instance is correctly reconstructed after a reset.
|
Test if the `App` instance is correctly reconstructed after a reset.
|
||||||
"""
|
"""
|
||||||
app = App()
|
app = App(
|
||||||
|
CustomAppConfig(
|
||||||
|
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||||
|
)
|
||||||
|
)
|
||||||
app.reset()
|
app.reset()
|
||||||
|
|
||||||
# Make sure the client is still healthy
|
# Make sure the client is still healthy
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig, CustomAppConfig
|
||||||
|
from embedchain.models import EmbeddingFunctions, Providers
|
||||||
from embedchain.vectordb.chroma_db import ChromaDB
|
from embedchain.vectordb.chroma_db import ChromaDB
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +22,24 @@ class TestChromaDbHosts(unittest.TestCase):
|
|||||||
self.assertEqual(settings.chroma_server_host, host)
|
self.assertEqual(settings.chroma_server_host, host)
|
||||||
self.assertEqual(settings.chroma_server_http_port, port)
|
self.assertEqual(settings.chroma_server_http_port, port)
|
||||||
|
|
||||||
|
def test_init_with_basic_auth(self):
|
||||||
|
host = "test-host"
|
||||||
|
port = "1234"
|
||||||
|
|
||||||
|
chroma_auth_settings = {
|
||||||
|
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
|
||||||
|
"chroma_client_auth_credentials": "admin:admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
|
||||||
|
settings = db.client.get_settings()
|
||||||
|
self.assertEqual(settings.chroma_server_host, host)
|
||||||
|
self.assertEqual(settings.chroma_server_http_port, port)
|
||||||
|
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
|
||||||
|
self.assertEqual(
|
||||||
|
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Review this test
|
# Review this test
|
||||||
class TestChromaDbHostsInit(unittest.TestCase):
|
class TestChromaDbHostsInit(unittest.TestCase):
|
||||||
@@ -68,12 +87,18 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestChromaDbDuplicateHandling:
|
class TestChromaDbDuplicateHandling:
|
||||||
|
app_with_settings = App(
|
||||||
|
CustomAppConfig(
|
||||||
|
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_duplicates_throw_warning(self, caplog):
|
def test_duplicates_throw_warning(self, caplog):
|
||||||
"""
|
"""
|
||||||
Test that add duplicates throws an error.
|
Test that add duplicates throws an error.
|
||||||
"""
|
"""
|
||||||
# Start with a clean app
|
# Start with a clean app
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
app = App(config=AppConfig(collect_metrics=False))
|
app = App(config=AppConfig(collect_metrics=False))
|
||||||
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||||
@@ -88,7 +113,7 @@ class TestChromaDbDuplicateHandling:
|
|||||||
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
|
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
|
||||||
|
|
||||||
# Start with a clean app
|
# Start with a clean app
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
app = App(config=AppConfig(collect_metrics=False))
|
app = App(config=AppConfig(collect_metrics=False))
|
||||||
app.set_collection("test_collection_1")
|
app.set_collection("test_collection_1")
|
||||||
@@ -100,6 +125,12 @@ class TestChromaDbDuplicateHandling:
|
|||||||
|
|
||||||
|
|
||||||
class TestChromaDbCollection(unittest.TestCase):
|
class TestChromaDbCollection(unittest.TestCase):
|
||||||
|
app_with_settings = App(
|
||||||
|
CustomAppConfig(
|
||||||
|
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_init_with_default_collection(self):
|
def test_init_with_default_collection(self):
|
||||||
"""
|
"""
|
||||||
Test if the `App` instance is initialized with the correct default collection name.
|
Test if the `App` instance is initialized with the correct default collection name.
|
||||||
@@ -131,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
|||||||
Test that changes to one collection do not affect the other collection
|
Test that changes to one collection do not affect the other collection
|
||||||
"""
|
"""
|
||||||
# Start with a clean app
|
# Start with a clean app
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
app = App(config=AppConfig(collect_metrics=False))
|
app = App(config=AppConfig(collect_metrics=False))
|
||||||
app.set_collection("test_collection_1")
|
app.set_collection("test_collection_1")
|
||||||
@@ -157,7 +188,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
|||||||
Test that a collection can be picked up later.
|
Test that a collection can be picked up later.
|
||||||
"""
|
"""
|
||||||
# Start with a clean app
|
# Start with a clean app
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
app = App(config=AppConfig(collect_metrics=False))
|
app = App(config=AppConfig(collect_metrics=False))
|
||||||
app.set_collection("test_collection_1")
|
app.set_collection("test_collection_1")
|
||||||
@@ -175,7 +206,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
|||||||
the other app.
|
the other app.
|
||||||
"""
|
"""
|
||||||
# Start clean
|
# Start clean
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
# Create two apps
|
# Create two apps
|
||||||
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
|
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
|
||||||
@@ -201,7 +232,7 @@ class TestChromaDbCollection(unittest.TestCase):
|
|||||||
Different ids should still share collections.
|
Different ids should still share collections.
|
||||||
"""
|
"""
|
||||||
# Start clean
|
# Start clean
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
# Create two apps
|
# Create two apps
|
||||||
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
|
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
|
||||||
@@ -220,11 +251,20 @@ class TestChromaDbCollection(unittest.TestCase):
|
|||||||
Resetting should hit all collections and ids.
|
Resetting should hit all collections and ids.
|
||||||
"""
|
"""
|
||||||
# Start clean
|
# Start clean
|
||||||
App().reset()
|
self.app_with_settings.reset()
|
||||||
|
|
||||||
# Create four apps.
|
# Create four apps.
|
||||||
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
|
||||||
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
|
app1 = App(
|
||||||
|
CustomAppConfig(
|
||||||
|
collection_name="one_collection",
|
||||||
|
id="new_app_id_1",
|
||||||
|
collect_metrics=False,
|
||||||
|
provider=Providers.OPENAI,
|
||||||
|
embedding_fn=EmbeddingFunctions.OPENAI,
|
||||||
|
chroma_settings={"allow_reset": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
|
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
|
||||||
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
|
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
|
||||||
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
|
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
|
||||||
|
|||||||
Reference in New Issue
Block a user