From 3613e2f14a02f20a556197f7072399c73f30f057 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 16 Apr 2025 13:32:33 +0530 Subject: [PATCH] Fix `user_id` functionality (#2548) --- mem0/client/main.py | 4 ++ mem0/configs/base.py | 5 +- mem0/configs/vector_stores/azure_ai_search.py | 2 +- mem0/configs/vector_stores/langchain.py | 4 +- mem0/memory/graph_memory.py | 4 +- mem0/memory/main.py | 13 ++++- mem0/memory/setup.py | 25 ++++++++++ mem0/memory/telemetry.py | 30 +++++++----- mem0/vector_stores/langchain.py | 48 ++++++++----------- 9 files changed, 86 insertions(+), 49 deletions(-) diff --git a/mem0/client/main.py b/mem0/client/main.py index 981d1ba5..fde39237 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -1,6 +1,7 @@ import logging import os import warnings +import hashlib from functools import wraps from typing import Any, Dict, List, Optional, Union @@ -83,6 +84,9 @@ class MemoryClient: if not self.api_key: raise ValueError("Mem0 API Key not provided. Please provide an API Key.") + # Create MD5 hash of API key for user_id + self.user_id = hashlib.md5(self.api_key.encode()).hexdigest() + self.client = httpx.Client( base_url=self.host, headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, diff --git a/mem0/configs/base.py b/mem0/configs/base.py index f572df85..147d2593 100644 --- a/mem0/configs/base.py +++ b/mem0/configs/base.py @@ -6,9 +6,12 @@ from pydantic import BaseModel, Field from mem0.embeddings.configs import EmbedderConfig from mem0.graphs.configs import GraphStoreConfig from mem0.llms.configs import LlmConfig -from mem0.memory.setup import mem0_dir from mem0.vector_stores.configs import VectorStoreConfig +# Set up the directory path +home_dir = os.path.expanduser("~") +mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0") + class MemoryItem(BaseModel): id: str = Field(..., description="The unique identifier for the text data") diff --git a/mem0/configs/vector_stores/azure_ai_search.py b/mem0/configs/vector_stores/azure_ai_search.py index 8618ac94..79cfe179 100644 --- a/mem0/configs/vector_stores/azure_ai_search.py +++ b/mem0/configs/vector_stores/azure_ai_search.py @@ -7,7 +7,7 @@ class AzureAISearchConfig(BaseModel): collection_name: str = Field("mem0", description="Name of the collection") service_name: str = Field(None, description="Azure AI Search service name") api_key: str = Field(None, description="API key for the Azure AI Search service") - embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") + embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector") compression_type: Optional[str] = Field( None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None" ) diff --git a/mem0/configs/vector_stores/langchain.py b/mem0/configs/vector_stores/langchain.py index 78b3533e..c4178406 100644 --- a/mem0/configs/vector_stores/langchain.py +++ b/mem0/configs/vector_stores/langchain.py @@ -7,7 +7,9 @@ class LangchainConfig(BaseModel): try: from langchain_community.vectorstores import VectorStore except ImportError: - raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.") + raise ImportError( + "The 'langchain_community' library is required. Please install it using 'pip install langchain_community'." + ) VectorStore: ClassVar[type] = VectorStore client: VectorStore = Field(description="Existing VectorStore instance") diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index ca565908..18aaf41b 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -35,9 +35,7 @@ class MemoryGraph: self.config.graph_store.config.password, ) self.embedding_model = EmbedderFactory.create( - self.config.embedder.provider, - self.config.embedder.config, - self.config.vector_store.config + self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config ) self.llm_provider = "openai_structured" diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 7fd3ebde..612ef371 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -1,3 +1,4 @@ +import os import asyncio import concurrent import hashlib @@ -18,7 +19,7 @@ from mem0.configs.prompts import ( get_update_memory_messages, ) from mem0.memory.base import MemoryBase -from mem0.memory.setup import setup_config +from mem0.memory.setup import setup_config, mem0_dir from mem0.memory.storage import SQLiteManager from mem0.memory.telemetry import capture_event from mem0.memory.utils import ( @@ -62,6 +63,16 @@ class Memory(MemoryBase): self.graph = MemoryGraph(self.config) self.enable_graph = True + self.config.vector_store.config.collection_name = "mem0_migrations" + if self.config.vector_store.provider in ["faiss", "qdrant"]: + provider_path = f"migrations_{self.config.vector_store.provider}" + self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path) + os.makedirs(self.config.vector_store.config.path, exist_ok=True) + + self._telemetry_vector_store = VectorStoreFactory.create( + self.config.vector_store.provider, self.config.vector_store.config + ) + capture_event("mem0.init", self, {"sync_type": "sync"}) @classmethod diff --git a/mem0/memory/setup.py b/mem0/memory/setup.py index a22b2e13..b4fa99ae 100644 --- a/mem0/memory/setup.py +++ b/mem0/memory/setup.py @@ -3,6 +3,7 @@ import os import uuid # Set up the directory path +VECTOR_ID = str(uuid.uuid4()) home_dir = os.path.expanduser("~") mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0") os.makedirs(mem0_dir, exist_ok=True) @@ -29,3 +30,27 @@ def get_user_id(): return user_id except Exception: return "anonymous_user" + + +def get_or_create_user_id(vector_store): + """Store user_id in vector store and return it.""" + user_id = get_user_id() + + # Try to get existing user_id from vector store + try: + existing = vector_store.get(vector_id=VECTOR_ID) + if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload: + return existing.payload["user_id"] + except: + pass + + # If we get here, we need to insert the user_id + try: + dims = getattr(vector_store, "embedding_model_dims", 1) + vector_store.insert( + vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID] + ) + except: + pass + + return user_id diff --git a/mem0/memory/telemetry.py b/mem0/memory/telemetry.py index 4efa2423..d4ad1840 100644 --- a/mem0/memory/telemetry.py +++ b/mem0/memory/telemetry.py @@ -6,9 +6,11 @@ import sys from posthog import Posthog import mem0 -from mem0.memory.setup import get_user_id, setup_config +from mem0.memory.setup import get_or_create_user_id MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True") +PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX" +HOST="https://us.i.posthog.com" if isinstance(MEM0_TELEMETRY, str): MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes") @@ -21,11 +23,11 @@ logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1) class AnonymousTelemetry: - def __init__(self, project_api_key, host): - self.posthog = Posthog(project_api_key=project_api_key, host=host) - # Call setup config to ensure that the user_id is generated - setup_config() - self.user_id = get_user_id() + def __init__(self, vector_store=None): + self.posthog = Posthog(project_api_key=PROJECT_API_KEY, host=HOST) + + self.user_id = get_or_create_user_id(vector_store) + if not MEM0_TELEMETRY: self.posthog.disabled = True @@ -50,14 +52,16 @@ class AnonymousTelemetry: self.posthog.shutdown() -# Initialize AnonymousTelemetry -telemetry = AnonymousTelemetry( - project_api_key="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX", - host="https://us.i.posthog.com", -) +client_telemetry = AnonymousTelemetry() def capture_event(event_name, memory_instance, additional_data=None): + oss_telemetry = AnonymousTelemetry( + vector_store=memory_instance._telemetry_vector_store + if hasattr(memory_instance, "_telemetry_vector_store") + else None, + ) + event_data = { "collection": memory_instance.collection_name, "vector_size": memory_instance.embedding_model.config.embedding_dims, @@ -73,7 +77,7 @@ def capture_event(event_name, memory_instance, additional_data=None): if additional_data: event_data.update(additional_data) - telemetry.capture_event(event_name, event_data) + oss_telemetry.capture_event(event_name, event_data) def capture_client_event(event_name, instance, additional_data=None): @@ -83,4 +87,4 @@ def capture_client_event(event_name, instance, additional_data=None): if additional_data: event_data.update(additional_data) - telemetry.capture_event(event_name, event_data, instance.user_email) + client_telemetry.capture_event(event_name, event_data, instance.user_email) diff --git a/mem0/vector_stores/langchain.py b/mem0/vector_stores/langchain.py index 888ed5ae..ad9a991d 100644 --- a/mem0/vector_stores/langchain.py +++ b/mem0/vector_stores/langchain.py @@ -5,7 +5,9 @@ from pydantic import BaseModel try: from langchain_community.vectorstores import VectorStore except ImportError: - raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.") + raise ImportError( + "The 'langchain_community' library is required. Please install it using 'pip install langchain_community'." + ) from mem0.vector_stores.base import VectorStoreBase @@ -15,11 +17,12 @@ class OutputData(BaseModel): score: Optional[float] # distance payload: Optional[Dict] # metadata + class Langchain(VectorStoreBase): def __init__(self, client: VectorStore, collection_name: str = "mem0"): self.client = client self.collection_name = collection_name - + def _parse_output(self, data: Dict) -> List[OutputData]: """ Parse the output data. @@ -31,17 +34,17 @@ class Langchain(VectorStoreBase): List[OutputData]: Parsed output data. """ # Check if input is a list of Document objects - if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')): + if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")): result = [] for doc in data: entry = OutputData( id=getattr(doc, "id", None), score=None, # Document objects typically don't include scores - payload=getattr(doc, "metadata", {}) + payload=getattr(doc, "metadata", {}), ) result.append(entry) return result - + # Original format handling keys = ["ids", "distances", "metadatas"] values = [] @@ -70,26 +73,20 @@ class Langchain(VectorStoreBase): self.collection_name = name return self.client - def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None): + def insert( + self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None + ): """ Insert vectors into the LangChain vectorstore. """ # Check if client has add_embeddings method if hasattr(self.client, "add_embeddings"): # Some LangChain vectorstores have a direct add_embeddings method - self.client.add_embeddings( - embeddings=vectors, - metadatas=payloads, - ids=ids - ) + self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids) else: # Fallback to add_texts method texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors) - self.client.add_texts( - texts=texts, - metadatas=payloads, - ids=ids - ) + self.client.add_texts(texts=texts, metadatas=payloads, ids=ids) def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None): """ @@ -97,16 +94,9 @@ class Langchain(VectorStoreBase): """ # For each vector, perform a similarity search if filters: - results = self.client.similarity_search_by_vector( - embedding=vectors, - k=limit, - filter=filters - ) + results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters) else: - results = self.client.similarity_search_by_vector( - embedding=vectors, - k=limit - ) + results = self.client.similarity_search_by_vector(embedding=vectors, k=limit) final_results = self._parse_output(results) return final_results @@ -133,26 +123,26 @@ class Langchain(VectorStoreBase): doc = docs[0] return self._parse_output([doc])[0] return None - + def list_cols(self): """ List all collections. """ # LangChain doesn't have collections return [self.collection_name] - + def delete_col(self): """ Delete a collection. """ self.client.delete(ids=None) - + def col_info(self): """ Get information about a collection. """ return {"name": self.collection_name} - + def list(self, filters=None, limit=None): """ List all vectors in a collection.