diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 8ec1add1..ff236a7b 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -18,15 +18,20 @@ class ChromaDB(BaseVectorDB): if host and port: logging.info(f"Connecting to ChromaDB server: {host}:{port}") self.settings = Settings(chroma_server_host=host, chroma_server_http_port=port) + self.client = chromadb.HttpClient(self.settings) else: if db_dir is None: db_dir = "db" - self.settings = Settings(persist_directory=db_dir, anonymized_telemetry=False, allow_reset=True) + self.settings = Settings(anonymized_telemetry=False, allow_reset=True) + self.client = chromadb.PersistentClient( + path=db_dir, + settings=self.settings, + ) super().__init__() def _get_or_create_db(self): """Get or create the database.""" - return chromadb.Client(self.settings) + return self.client def _get_or_create_collection(self): """Get or create the collection."""