diff --git a/docs/components/vectordb.mdx b/docs/components/vectordb.mdx index 5d5fede3..288887c6 100644 --- a/docs/components/vectordb.mdx +++ b/docs/components/vectordb.mdx @@ -9,6 +9,7 @@ Mem0 includes built-in support for various popular databases. Memory can utilize + @@ -22,6 +23,7 @@ To use Qdrant you can do like this: import os from mem0 import Memory +os.environ["OPENAI_API_KEY"] = "sk-xx" config = { "vector_store": { @@ -48,6 +50,7 @@ To use ChromaDB you can do like this: import os from mem0 import Memory +os.environ["OPENAI_API_KEY"] = "sk-xx" config = { "vector_store": { @@ -63,6 +66,34 @@ m = Memory.from_config(config) m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) ``` +## pgvector + +[pgvector](https://github.com/pgvector/pgvector) is open-source vector similarity search for Postgres. After connecting with postgres run `CREATE EXTENSION IF NOT EXISTS vector;` to create the vector extension. + +Here's how to use it: + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "pgvector", + "config": { + "user": "test", + "password": "123", + "host": "127.0.0.1", + "port": "5432", + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + ## Common issues ### Using model with different dimensions diff --git a/mem0/configs/vector_stores/pgvector.py b/mem0/configs/vector_stores/pgvector.py new file mode 100644 index 00000000..41134dfe --- /dev/null +++ b/mem0/configs/vector_stores/pgvector.py @@ -0,0 +1,34 @@ +from typing import Optional, Dict, Any + +from pydantic import BaseModel, Field, model_validator + +class PGVectorConfig(BaseModel): + + dbname: str = Field("postgres", description="Default name for the database") + collection_name: str = Field("mem0", description="Default name for the collection") + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") + user: Optional[str] = Field(None, description="Database user") + password: Optional[str] = Field(None, description="Database password") + host: Optional[str] = Field(None, description="Database host. Default is localhost") + port: Optional[int] = Field(None, description="Database port. Default is 1536") + + @model_validator(mode="before") + def check_auth_and_connection(cls, values): + user, password = values.get("user"), values.get("password") + host, port = values.get("host"), values.get("port") + if not user and not password: + raise ValueError("Both 'user' and 'password' must be provided.") + if not host and not port: + raise ValueError("Both 'host' and 'port' must be provided.") + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") + return values + \ No newline at end of file diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 81bbb65f..142f6b7a 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -94,7 +94,6 @@ class Memory(MemoryBase): ] ) existing_memories = self.vector_store.search( - name=self.collection_name, query=embeddings, limit=5, filters=filters, @@ -169,7 +168,7 @@ class Memory(MemoryBase): dict: Retrieved memory. """ capture_event("mem0.get", self, {"memory_id": memory_id}) - memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id) + memory = self.vector_store.get(vector_id=memory_id) if not memory: return None @@ -210,9 +209,7 @@ class Memory(MemoryBase): filters["run_id"] = run_id capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit}) - memories = self.vector_store.list( - name=self.collection_name, filters=filters, limit=limit - ) + memories = self.vector_store.list(filters=filters, limit=limit) excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} return [ @@ -258,9 +255,7 @@ class Memory(MemoryBase): capture_event("mem0.search", self, {"filters": len(filters), "limit": limit}) embeddings = self.embedding_model.embed(query) - memories = self.vector_store.search( - name=self.collection_name, query=embeddings, limit=limit, filters=filters - ) + memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters) excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} @@ -330,7 +325,7 @@ class Memory(MemoryBase): ) capture_event("mem0.delete_all", self, {"filters": len(filters)}) - memories = self.vector_store.list(name=self.collection_name, filters=filters)[0] + memories = self.vector_store.list(filters=filters)[0] for memory in memories: self._delete_memory_tool(memory.id) return {'message': 'Memories deleted successfully!'} @@ -358,7 +353,6 @@ class Memory(MemoryBase): metadata["created_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat() self.vector_store.insert( - name=self.collection_name, vectors=[embeddings], ids=[memory_id], payloads=[metadata], @@ -367,9 +361,7 @@ class Memory(MemoryBase): return memory_id def _update_memory_tool(self, memory_id, data, metadata=None): - existing_memory = self.vector_store.get( - name=self.collection_name, vector_id=memory_id - ) + existing_memory = self.vector_store.get(vector_id=memory_id) prev_value = existing_memory.payload.get("data") new_metadata = metadata or {} @@ -387,7 +379,6 @@ class Memory(MemoryBase): embeddings = self.embedding_model.embed(data) self.vector_store.update( - name=self.collection_name, vector_id=memory_id, vector=embeddings, payload=new_metadata, @@ -397,18 +388,16 @@ class Memory(MemoryBase): def _delete_memory_tool(self, memory_id): logging.info(f"Deleting memory with {memory_id=}") - existing_memory = self.vector_store.get( - name=self.collection_name, vector_id=memory_id - ) + existing_memory = self.vector_store.get(vector_id=memory_id) prev_value = existing_memory.payload["data"] - self.vector_store.delete(name=self.collection_name, vector_id=memory_id) + self.vector_store.delete(vector_id=memory_id) self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1) def reset(self): """ Reset the memory store. """ - self.vector_store.delete_col(name=self.collection_name) + self.vector_store.delete_col() self.db.reset() capture_event("mem0.reset", self) diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 1569c1fa..d3397567 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -52,6 +52,7 @@ class VectorStoreFactory: provider_to_class = { "qdrant": "mem0.vector_stores.qdrant.Qdrant", "chroma": "mem0.vector_stores.chroma.ChromaDB", + "pgvector": "mem0.vector_stores.pgvector.PGVector" } @classmethod diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 1399f08b..6a2402b8 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -54,6 +54,7 @@ class ChromaDB(VectorStoreBase): self.client = chromadb.Client(self.settings) + self.collection_name = collection_name self.collection = self.create_col(collection_name) def _parse_output(self, data): @@ -109,12 +110,11 @@ class ChromaDB(VectorStoreBase): ) return collection - def insert(self, name, vectors, payloads=None, ids=None): + def insert(self, vectors, payloads=None, ids=None): """ Insert vectors into a collection. Args: - name (str): Name of the collection. vectors (list): List of vectors to insert. payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. ids (list, optional): List of IDs corresponding to vectors. Defaults to None. @@ -122,12 +122,11 @@ class ChromaDB(VectorStoreBase): self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads) - def search(self, name, query, limit=5, filters=None): + def search(self, query, limit=5, filters=None): """ Search for similar vectors. Args: - name (str): Name of the collection. query (list): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (dict, optional): Filters to apply to the search. Defaults to None. @@ -139,23 +138,21 @@ class ChromaDB(VectorStoreBase): final_results = self._parse_output(results) return final_results - def delete(self, name, vector_id): + def delete(self, vector_id): """ Delete a vector by ID. Args: - name (str): Name of the collection. vector_id (int): ID of the vector to delete. """ self.collection.delete(ids=vector_id) - def update(self, name, vector_id, vector=None, payload=None): + def update(self, vector_id, vector=None, payload=None): """ Update a vector and its payload. Args: - name (str): Name of the collection. vector_id (int): ID of the vector to update. vector (list, optional): Updated vector. Defaults to None. payload (dict, optional): Updated payload. Defaults to None. @@ -163,12 +160,11 @@ class ChromaDB(VectorStoreBase): self.collection.update(ids=vector_id, embeddings=vector, metadatas=payload) - def get(self, name, vector_id): + def get(self, vector_id): """ Retrieve a vector by ID. Args: - name (str): Name of the collection. vector_id (int): ID of the vector to retrieve. Returns: @@ -186,33 +182,24 @@ class ChromaDB(VectorStoreBase): """ return self.client.list_collections() - def delete_col(self, name): - """ - Delete a collection. + def delete_col(self): + """ Delete a collection. """ + self.client.delete_collection(name=self.collection_name) - Args: - name (str): Name of the collection to delete. - """ - self.client.delete_collection(name=name) - - def col_info(self, name): + def col_info(self): """ Get information about a collection. - Args: - name (str): Name of the collection. - Returns: dict: Collection information. """ - return self.client.get_collection(name=name) + return self.client.get_collection(name=self.collection_name) - def list(self, name, filters=None, limit=100): + def list(self, filters=None, limit=100): """ List all vectors in a collection. Args: - name (str): Name of the collection. filters (dict, optional): Filters to apply to the list. limit (int, optional): Number of vectors to return. Defaults to 100. diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 28e4e1c4..6308afb3 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -13,7 +13,8 @@ class VectorStoreConfig(BaseModel): _provider_configs: Dict[str, str] = { "qdrant": "QdrantConfig", - "chroma": "ChromaDbConfig" + "chroma": "ChromaDbConfig", + "pgvector": "PGVectorConfig" } @model_validator(mode="after") diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py new file mode 100644 index 00000000..66ffa926 --- /dev/null +++ b/mem0/vector_stores/pgvector.py @@ -0,0 +1,241 @@ +import json +from typing import Optional, List, Dict, Any +from pydantic import BaseModel + +try: + import psycopg2 + from psycopg2.extras import execute_values +except ImportError: + raise ImportError("PGVector requires extra dependencies. Install with `pip install psycopg2`") from None + + +from mem0.vector_stores.base import VectorStoreBase + +class OutputData(BaseModel): + id: Optional[str] + score: Optional[float] + payload: Optional[dict] + + +class PGVector(VectorStoreBase): + def __init__( + self, + dbname, + collection_name, + embedding_model_dims, + user, + password, + host, + port + ): + """ + Initialize the PGVector database. + + Args: + dbname (str): Database name + collection_name (str): Collection name + embedding_model_dims (int): Dimension of the embedding vector + user (str): Database user + password (str): Database password + host (str, optional): Database host + port (int, optional): Database port + """ + self.collection_name = collection_name + + self.conn = psycopg2.connect( + dbname=dbname, + user=user, + password=password, + host=host, + port=port + ) + self.cur = self.conn.cursor() + + collections = self.list_cols() + if collection_name not in collections: + self.create_col(embedding_model_dims) + + def create_col(self, embedding_model_dims): + """ + Create a new collection (table in PostgreSQL). + + Args: + name (str): Name of the collection. + embedding_model_dims (int, optional): Dimension of the embedding vector. + """ + self.cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.collection_name} ( + id UUID PRIMARY KEY, + vector vector({embedding_model_dims}), + payload JSONB + ); + """) + self.conn.commit() + + def insert(self, vectors, payloads = None, ids = None): + """ + Insert vectors into a collection. + + Args: + vectors (List[List[float]]): List of vectors to insert. + payloads (List[Dict], optional): List of payloads corresponding to vectors. + ids (List[str], optional): List of IDs corresponding to vectors. + """ + json_payloads = [json.dumps(payload) for payload in payloads] + + data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)] + execute_values(self.cur, f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s", data) + self.conn.commit() + + def search(self, query, limit = 5, filters = None): + """ + Search for similar vectors. + + Args: + query (List[float]): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: Search results. + """ + filter_conditions = [] + filter_params = [] + + if filters: + for k, v in filters.items(): + filter_conditions.append(f"payload->>%s = %s") + filter_params.extend([k, str(v)]) + + filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" + + self.cur.execute(f""" + SELECT id, vector <-> %s::vector AS distance, payload + FROM {self.collection_name} + {filter_clause} + ORDER BY distance + LIMIT %s + """, (query, *filter_params, limit)) + + results = self.cur.fetchall() + return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results] + + def delete(self, vector_id): + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)) + self.conn.commit() + + def update(self, vector_id, vector = None, payload = None): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + if vector: + self.cur.execute(f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s", (vector, vector_id)) + if payload: + self.cur.execute(f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s", (psycopg2.extras.Json(payload), vector_id)) + self.conn.commit() + + def get(self, vector_id) -> OutputData: + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector. + """ + self.cur.execute(f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s", (vector_id,)) + result = self.cur.fetchone() + if not result: + return None + return OutputData(id=str(result[0]), score=None, payload=result[2]) + + def list_cols(self) -> List[str]: + """ + List all collections. + + Returns: + List[str]: List of collection names. + """ + self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") + return [row[0] for row in self.cur.fetchall()] + + def delete_col(self): + """ Delete a collection. """ + self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}") + self.conn.commit() + + def col_info(self): + """ + Get information about a collection. + + Returns: + Dict[str, Any]: Collection information. + """ + self.cur.execute(f""" + SELECT + table_name, + (SELECT COUNT(*) FROM {self.collection_name}) as row_count, + (SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size + FROM information_schema.tables + WHERE table_schema = 'public' AND table_name = %s + """, (self.collection_name,)) + result = self.cur.fetchone() + return { + "name": result[0], + "count": result[1], + "size": result[2] + } + + def list(self, filters = None, limit = 100): + """ + List all vectors in a collection. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + filter_conditions = [] + filter_params = [] + + if filters: + for k, v in filters.items(): + filter_conditions.append(f"payload->>%s = %s") + filter_params.extend([k, str(v)]) + + filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" + + query = f""" + SELECT id, vector, payload + FROM {self.collection_name} + {filter_clause} + LIMIT %s + """ + + self.cur.execute(query, (*filter_params, limit)) + + results = self.cur.fetchall() + return [[OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]] + + def __del__(self): + """ + Close the database connection when the object is deleted. + """ + if hasattr(self, 'cur'): + self.cur.close() + if hasattr(self, 'conn'): + self.conn.close() \ No newline at end of file diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index 396fc776..f82a602e 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -61,9 +61,10 @@ class Qdrant(VectorStoreBase): self.client = QdrantClient(**params) - self.create_col(collection_name, embedding_model_dims, on_disk) + self.collection_name = collection_name + self.create_col(embedding_model_dims, on_disk) - def create_col(self, name, vector_size, on_disk, distance=Distance.COSINE): + def create_col(self, vector_size, on_disk, distance=Distance.COSINE): """ Create a new collection. @@ -75,21 +76,20 @@ class Qdrant(VectorStoreBase): # Skip creating collection if already exists response = self.list_cols() for collection in response.collections: - if collection.name == name: - logging.debug(f"Collection {name} already exists. Skipping creation.") + if collection.name == self.collection_name: + logging.debug(f"Collection {self.collection_name} already exists. Skipping creation.") return self.client.create_collection( - collection_name=name, + collection_name=self.collection_name, vectors_config=VectorParams(size=vector_size, distance=distance, on_disk=on_disk), ) - def insert(self, name, vectors, payloads=None, ids=None): + def insert(self, vectors, payloads=None, ids=None): """ Insert vectors into a collection. Args: - name (str): Name of the collection. vectors (list): List of vectors to insert. payloads (list, optional): List of payloads corresponding to vectors. Defaults to None. ids (list, optional): List of IDs corresponding to vectors. Defaults to None. @@ -102,7 +102,7 @@ class Qdrant(VectorStoreBase): ) for idx, vector in enumerate(vectors) ] - self.client.upsert(collection_name=name, points=points) + self.client.upsert(collection_name=self.collection_name, points=points) def _create_filter(self, filters): """ @@ -128,12 +128,11 @@ class Qdrant(VectorStoreBase): ) return Filter(must=conditions) if conditions else None - def search(self, name, query, limit=5, filters=None): + def search(self, query, limit=5, filters=None): """ Search for similar vectors. Args: - name (str): Name of the collection. query (list): Query vector. limit (int, optional): Number of results to return. Defaults to 5. filters (dict, optional): Filters to apply to the search. Defaults to None. @@ -143,54 +142,51 @@ class Qdrant(VectorStoreBase): """ query_filter = self._create_filter(filters) if filters else None hits = self.client.search( - collection_name=name, + collection_name=self.collection_name, query_vector=query, query_filter=query_filter, limit=limit, ) return hits - def delete(self, name, vector_id): + def delete(self, vector_id): """ Delete a vector by ID. Args: - name (str): Name of the collection. vector_id (int): ID of the vector to delete. """ self.client.delete( - collection_name=name, + collection_name=self.collection_name, points_selector=PointIdsList( points=[vector_id], ), ) - def update(self, name, vector_id, vector=None, payload=None): + def update(self, vector_id, vector=None, payload=None): """ Update a vector and its payload. Args: - name (str): Name of the collection. vector_id (int): ID of the vector to update. vector (list, optional): Updated vector. Defaults to None. payload (dict, optional): Updated payload. Defaults to None. """ point = PointStruct(id=vector_id, vector=vector, payload=payload) - self.client.upsert(collection_name=name, points=[point]) + self.client.upsert(collection_name=self.collection_name, points=[point]) - def get(self, name, vector_id): + def get(self, vector_id): """ Retrieve a vector by ID. Args: - name (str): Name of the collection. vector_id (int): ID of the vector to retrieve. Returns: dict: Retrieved vector. """ result = self.client.retrieve( - collection_name=name, ids=[vector_id], with_payload=True + collection_name=self.collection_name, ids=[vector_id], with_payload=True ) return result[0] if result else None @@ -203,33 +199,24 @@ class Qdrant(VectorStoreBase): """ return self.client.get_collections() - def delete_col(self, name): - """ - Delete a collection. + def delete_col(self): + """ Delete a collection. """ + self.client.delete_collection(collection_name=self.collection_name) - Args: - name (str): Name of the collection to delete. - """ - self.client.delete_collection(collection_name=name) - - def col_info(self, name): + def col_info(self): """ Get information about a collection. - Args: - name (str): Name of the collection. - Returns: dict: Collection information. """ - return self.client.get_collection(collection_name=name) + return self.client.get_collection(collection_name=self.collection_name) - def list(self, name, filters=None, limit=100): + def list(self, filters=None, limit=100): """ List all vectors in a collection. Args: - name (str): Name of the collection. limit (int, optional): Number of vectors to return. Defaults to 100. Returns: @@ -237,7 +224,7 @@ class Qdrant(VectorStoreBase): """ query_filter = self._create_filter(filters) if filters else None result = self.client.scroll( - collection_name=name, + collection_name=self.collection_name, scroll_filter=query_filter, limit=limit, with_payload=True,