Upgrade the chromadb version to 0.4.8 and open its settings configuration. (#517)

This commit is contained in:
wangJm
2023-09-04 14:31:08 +08:00
committed by GitHub
parent 433c4157e0
commit eecdbc5e06
6 changed files with 80 additions and 21 deletions

View File

@@ -24,6 +24,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
db_type: VectorDatabases = None, db_type: VectorDatabases = None,
vector_dim: VectorDimensions = None, vector_dim: VectorDimensions = None,
es_config: ElasticsearchDBConfig = None, es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
): ):
""" """
:param log_level: Optional. (String) Debug level :param log_level: Optional. (String) Debug level
@@ -38,6 +39,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
:param db_type: Optional. type of Vector database to use :param db_type: Optional. type of Vector database to use
:param vector_dim: Vector dimension generated by embedding fn :param vector_dim: Vector dimension generated by embedding fn
:param es_config: Optional. elasticsearch database config to be used for connection :param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
""" """
self._setup_logging(log_level) self._setup_logging(log_level)
self.collection_name = collection_name if collection_name else "embedchain_store" self.collection_name = collection_name if collection_name else "embedchain_store"
@@ -50,13 +52,14 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
vector_dim=vector_dim, vector_dim=vector_dim,
collection_name=self.collection_name, collection_name=self.collection_name,
es_config=es_config, es_config=es_config,
chroma_settings=chroma_settings,
) )
self.id = id self.id = id
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
return return
@staticmethod @staticmethod
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config): def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
""" """
Get db based on db_type, db with default database (`ChromaDb`) Get db based on db_type, db with default database (`ChromaDb`)
:param Optional. (Vector) database to use for embeddings. :param Optional. (Vector) database to use for embeddings.
@@ -85,7 +88,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
from embedchain.vectordb.chroma_db import ChromaDB from embedchain.vectordb.chroma_db import ChromaDB
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port) return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
def _setup_logging(self, debug_level): def _setup_logging(self, debug_level):
level = logging.WARNING # Default level level = logging.WARNING # Default level

View File

@@ -35,6 +35,7 @@ class CustomAppConfig(BaseAppConfig):
collect_metrics: Optional[bool] = None, collect_metrics: Optional[bool] = None,
db_type: VectorDatabases = None, db_type: VectorDatabases = None,
es_config: ElasticsearchDBConfig = None, es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
): ):
""" """
:param log_level: Optional. (String) Debug level :param log_level: Optional. (String) Debug level
@@ -51,6 +52,7 @@ class CustomAppConfig(BaseAppConfig):
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain. :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param db_type: Optional. type of Vector database to use. :param db_type: Optional. type of Vector database to use.
:param es_config: Optional. elasticsearch database config to be used for connection :param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
""" """
if provider: if provider:
self.provider = provider self.provider = provider
@@ -73,6 +75,7 @@ class CustomAppConfig(BaseAppConfig):
db_type=db_type, db_type=db_type,
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn), vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
es_config=es_config, es_config=es_config,
chroma_settings=chroma_settings,
) )
@staticmethod @staticmethod

View File

@@ -22,23 +22,31 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
class ChromaDB(BaseVectorDB): class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB.""" """Vector database using ChromaDB."""
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None): def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
self.embedding_fn = embedding_fn self.embedding_fn = embedding_fn
if not hasattr(embedding_fn, "__call__"): if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function") raise ValueError("Embedding function is not a function")
self.settings = Settings()
for key, value in chroma_settings.items():
if hasattr(self.settings, key):
setattr(self.settings, key, value)
if host and port: if host and port:
logging.info(f"Connecting to ChromaDB server: {host}:{port}") logging.info(f"Connecting to ChromaDB server: {host}:{port}")
self.client = chromadb.HttpClient(host=host, port=port) self.settings.chroma_server_host = host
self.settings.chroma_server_http_port = port
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
else: else:
if db_dir is None: if db_dir is None:
db_dir = "db" db_dir = "db"
self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
self.client = chromadb.PersistentClient( self.settings.persist_directory = db_dir
path=db_dir, self.settings.is_persistent = True
settings=self.settings,
) self.client = chromadb.Client(self.settings)
super().__init__() super().__init__()
def _get_or_create_db(self): def _get_or_create_db(self):

View File

@@ -86,7 +86,7 @@ langchain = "^0.0.279"
requests = "^2.31.0" requests = "^2.31.0"
openai = "^0.27.5" openai = "^0.27.5"
tiktoken = "^0.4.0" tiktoken = "^0.4.0"
chromadb ="^0.4.2" chromadb ="^0.4.8"
youtube-transcript-api = "^0.6.1" youtube-transcript-api = "^0.6.1"
beautifulsoup4 = "^4.12.2" beautifulsoup4 = "^4.12.2"
pypdf = "^3.11.0" pypdf = "^3.11.0"

View File

@@ -3,7 +3,8 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig, CustomAppConfig
from embedchain.models import EmbeddingFunctions, Providers
class TestChromaDbHostsLoglevel(unittest.TestCase): class TestChromaDbHostsLoglevel(unittest.TestCase):
@@ -42,7 +43,11 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
""" """
Test if the `App` instance is correctly reconstructed after a reset. Test if the `App` instance is correctly reconstructed after a reset.
""" """
app = App() app = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
app.reset() app.reset()
# Make sure the client is still healthy # Make sure the client is still healthy

View File

@@ -4,7 +4,8 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig, CustomAppConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.vectordb.chroma_db import ChromaDB from embedchain.vectordb.chroma_db import ChromaDB
@@ -21,6 +22,24 @@ class TestChromaDbHosts(unittest.TestCase):
self.assertEqual(settings.chroma_server_host, host) self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port) self.assertEqual(settings.chroma_server_http_port, port)
def test_init_with_basic_auth(self):
host = "test-host"
port = "1234"
chroma_auth_settings = {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
}
db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
)
# Review this test # Review this test
class TestChromaDbHostsInit(unittest.TestCase): class TestChromaDbHostsInit(unittest.TestCase):
@@ -68,12 +87,18 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
class TestChromaDbDuplicateHandling: class TestChromaDbDuplicateHandling:
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
def test_duplicates_throw_warning(self, caplog): def test_duplicates_throw_warning(self, caplog):
""" """
Test that add duplicates throws an error. Test that add duplicates throws an error.
""" """
# Start with a clean app # Start with a clean app
App().reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
@@ -88,7 +113,7 @@ class TestChromaDbDuplicateHandling:
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog. # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
# Start with a clean app # Start with a clean app
App().reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection("test_collection_1")
@@ -100,6 +125,12 @@ class TestChromaDbDuplicateHandling:
class TestChromaDbCollection(unittest.TestCase): class TestChromaDbCollection(unittest.TestCase):
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
def test_init_with_default_collection(self): def test_init_with_default_collection(self):
""" """
Test if the `App` instance is initialized with the correct default collection name. Test if the `App` instance is initialized with the correct default collection name.
@@ -131,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
Test that changes to one collection do not affect the other collection Test that changes to one collection do not affect the other collection
""" """
# Start with a clean app # Start with a clean app
App().reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection("test_collection_1")
@@ -157,7 +188,7 @@ class TestChromaDbCollection(unittest.TestCase):
Test that a collection can be picked up later. Test that a collection can be picked up later.
""" """
# Start with a clean app # Start with a clean app
App().reset() self.app_with_settings.reset()
app = App(config=AppConfig(collect_metrics=False)) app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1") app.set_collection("test_collection_1")
@@ -175,7 +206,7 @@ class TestChromaDbCollection(unittest.TestCase):
the other app. the other app.
""" """
# Start clean # Start clean
App().reset() self.app_with_settings.reset()
# Create two apps # Create two apps
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False)) app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
@@ -201,7 +232,7 @@ class TestChromaDbCollection(unittest.TestCase):
Different ids should still share collections. Different ids should still share collections.
""" """
# Start clean # Start clean
App().reset() self.app_with_settings.reset()
# Create two apps # Create two apps
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False)) app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
@@ -220,11 +251,20 @@ class TestChromaDbCollection(unittest.TestCase):
Resetting should hit all collections and ids. Resetting should hit all collections and ids.
""" """
# Start clean # Start clean
App().reset() self.app_with_settings.reset()
# Create four apps. # Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last. # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False)) app1 = App(
CustomAppConfig(
collection_name="one_collection",
id="new_app_id_1",
collect_metrics=False,
provider=Providers.OPENAI,
embedding_fn=EmbeddingFunctions.OPENAI,
chroma_settings={"allow_reset": True},
)
)
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False)) app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False)) app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False)) app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))