From d590e4423b516992173a5a7ec0b16df152ab8783 Mon Sep 17 00:00:00 2001 From: Candido Sales Gomes Date: Thu, 20 Jul 2023 17:17:53 -0400 Subject: [PATCH] update: chroma v0.4.0 (#330) --- embedchain/vectordb/chroma_db.py | 15 ++++----------- pyproject.toml | 4 ++-- tests/vectordb/test_chroma_db.py | 8 ++++---- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 62fe9d59..8ec1add1 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,6 +1,7 @@ import logging import chromadb +from chromadb.config import Settings from embedchain.vectordb.base_vector_db import BaseVectorDB @@ -16,24 +17,16 @@ class ChromaDB(BaseVectorDB): if host and port: logging.info(f"Connecting to ChromaDB server: {host}:{port}") - self.client_settings = chromadb.config.Settings( - chroma_api_impl="rest", - chroma_server_host=host, - chroma_server_http_port=port, - ) + self.settings = Settings(chroma_server_host=host, chroma_server_http_port=port) else: if db_dir is None: db_dir = "db" - self.client_settings = chromadb.config.Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=db_dir, - anonymized_telemetry=False, - ) + self.settings = Settings(persist_directory=db_dir, anonymized_telemetry=False, allow_reset=True) super().__init__() def _get_or_create_db(self): """Get or create the database.""" - return chromadb.Client(self.client_settings) + return chromadb.Client(self.settings) def _get_or_create_collection(self): """Get or create the collection.""" diff --git a/pyproject.toml b/pyproject.toml index 39206aa4..618b4357 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,10 +82,10 @@ color = true [tool.poetry.dependencies] python = ">=3.9,<3.9.7 || >3.9.7,<4.0" python-dotenv = "^1.0.0" -langchain = "^0.0.205" +langchain = "^0.0.237" requests = "^2.31.0" openai = "^0.27.5" -chromadb ="^0.3.26" +chromadb ="^0.4.2" youtube-transcript-api = "^0.6.1" beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index ca68bb03..de009b9b 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -6,6 +6,7 @@ from unittest.mock import patch from embedchain import App from embedchain.config import AppConfig from embedchain.vectordb.chroma_db import ChromaDB, chromadb +from chromadb.config import Settings class TestChromaDbHosts(unittest.TestCase): @@ -19,10 +20,9 @@ class TestChromaDbHosts(unittest.TestCase): with patch.object(chromadb, "Client") as mock_client: _db = ChromaDB(host=host, port=port, embedding_fn=len) - expected_settings = chromadb.config.Settings( - chroma_api_impl="rest", - chroma_server_host=host, - chroma_server_http_port=port, + expected_settings = Settings( + chroma_server_host="test-host", + chroma_server_http_port="1234", ) mock_client.assert_called_once_with(expected_settings)