From 4aae2b5cca20abc0eac88b8e1616afbf66475c65 Mon Sep 17 00:00:00 2001 From: dbcontributions Date: Mon, 12 Aug 2024 15:08:44 +0530 Subject: [PATCH] added on_disk param to qdrant configs (#1677) --- mem0/configs/vector_stores/qdrant.py | 1 + mem0/vector_stores/qdrant.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mem0/configs/vector_stores/qdrant.py b/mem0/configs/vector_stores/qdrant.py index 6c261f4f..6c40f108 100644 --- a/mem0/configs/vector_stores/qdrant.py +++ b/mem0/configs/vector_stores/qdrant.py @@ -14,6 +14,7 @@ class QdrantConfig(BaseModel): path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") url: Optional[str] = Field(None, description="Full URL for Qdrant server") api_key: Optional[str] = Field(None, description="API key for Qdrant server") + on_disk: Optional[bool]= Field(False, description="Enables persistant storage") @model_validator(mode="before") def check_host_port_or_path(cls, values): diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index d4a32adf..396fc776 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -28,6 +28,7 @@ class Qdrant(VectorStoreBase): path, url, api_key, + on_disk ): """ Initialize the Qdrant vector store. @@ -39,6 +40,7 @@ class Qdrant(VectorStoreBase): path (str, optional): Path for local Qdrant database. url (str, optional): Full URL for Qdrant server. api_key (str, optional): API key for Qdrant server. + on_disk (bool, optional): Enables persistant storage. """ if client: self.client = client @@ -51,17 +53,17 @@ class Qdrant(VectorStoreBase): if host and port: params["host"] = host params["port"] = port - if not params: params["path"] = path - if os.path.exists(path) and os.path.isdir(path): - shutil.rmtree(path) + if not on_disk: + if os.path.exists(path) and os.path.isdir(path): + shutil.rmtree(path) self.client = QdrantClient(**params) - self.create_col(collection_name, embedding_model_dims) + self.create_col(collection_name, embedding_model_dims, on_disk) - def create_col(self, name, vector_size, distance=Distance.COSINE): + def create_col(self, name, vector_size, on_disk, distance=Distance.COSINE): """ Create a new collection. @@ -79,7 +81,7 @@ class Qdrant(VectorStoreBase): self.client.create_collection( collection_name=name, - vectors_config=VectorParams(size=vector_size, distance=distance), + vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk), ) def insert(self, name, vectors, payloads=None, ids=None):