Store user_id in vectordb (#2466)

This commit is contained in:
Dev Khant
2025-04-11 13:37:34 +05:30
committed by GitHub
parent 19d7beef43
commit 15a3e20371
6 changed files with 57 additions and 20 deletions

View File

@@ -5,17 +5,14 @@ from functools import wraps
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import httpx import httpx
import hashlib
from mem0.memory.setup import get_user_id, setup_config
from mem0.memory.telemetry import capture_client_event from mem0.memory.telemetry import capture_client_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
warnings.filterwarnings("default", category=DeprecationWarning) warnings.filterwarnings("default", category=DeprecationWarning)
# Setup user config
setup_config()
class APIError(Exception): class APIError(Exception):
"""Exception raised for errors in the API.""" """Exception raised for errors in the API."""
@@ -78,11 +75,13 @@ class MemoryClient:
self.host = host or "https://api.mem0.ai" self.host = host or "https://api.mem0.ai"
self.org_id = org_id self.org_id = org_id
self.project_id = project_id self.project_id = project_id
self.user_id = get_user_id()
if not self.api_key: if not self.api_key:
raise ValueError("Mem0 API Key not provided. Please provide an API Key.") raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
# Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
self.client = httpx.Client( self.client = httpx.Client(
base_url=self.host, base_url=self.host,
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},

View File

@@ -6,9 +6,12 @@ from pydantic import BaseModel, Field
from mem0.embeddings.configs import EmbedderConfig from mem0.embeddings.configs import EmbedderConfig
from mem0.graphs.configs import GraphStoreConfig from mem0.graphs.configs import GraphStoreConfig
from mem0.llms.configs import LlmConfig from mem0.llms.configs import LlmConfig
from mem0.memory.setup import mem0_dir
from mem0.vector_stores.configs import VectorStoreConfig from mem0.vector_stores.configs import VectorStoreConfig
# Set up the directory path
home_dir = os.path.expanduser("~")
mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0")
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data") id: str = Field(..., description="The unique identifier for the text data")

View File

@@ -7,7 +7,7 @@ class AzureAISearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the collection") collection_name: str = Field("mem0", description="Name of the collection")
service_name: str = Field(None, description="Azure AI Search service name") service_name: str = Field(None, description="Azure AI Search service name")
api_key: str = Field(None, description="API key for the Azure AI Search service") api_key: str = Field(None, description="API key for the Azure AI Search service")
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
compression_type: Optional[str] = Field( compression_type: Optional[str] = Field(
None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None" None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
) )

View File

@@ -3,6 +3,7 @@ import concurrent
import hashlib import hashlib
import json import json
import logging import logging
import os
import uuid import uuid
import warnings import warnings
from datetime import datetime from datetime import datetime
@@ -18,7 +19,7 @@ from mem0.configs.prompts import (
get_update_memory_messages, get_update_memory_messages,
) )
from mem0.memory.base import MemoryBase from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config from mem0.memory.setup import mem0_dir, setup_config
from mem0.memory.storage import SQLiteManager from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event from mem0.memory.telemetry import capture_event
from mem0.memory.utils import ( from mem0.memory.utils import (
@@ -62,6 +63,15 @@ class Memory(MemoryBase):
self.graph = MemoryGraph(self.config) self.graph = MemoryGraph(self.config)
self.enable_graph = True self.enable_graph = True
self.config.vector_store.config.collection_name = "mem0_migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
self._telemetry_vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("mem0.init", self) capture_event("mem0.init", self)
@classmethod @classmethod

View File

@@ -3,6 +3,7 @@ import os
import uuid import uuid
# Set up the directory path # Set up the directory path
VECTOR_ID = str(uuid.uuid4())
home_dir = os.path.expanduser("~") home_dir = os.path.expanduser("~")
mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0") mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0")
os.makedirs(mem0_dir, exist_ok=True) os.makedirs(mem0_dir, exist_ok=True)
@@ -29,3 +30,29 @@ def get_user_id():
return user_id return user_id
except Exception: except Exception:
return "anonymous_user" return "anonymous_user"
def get_or_create_user_id(vector_store):
"""Store user_id in vector store and return it."""
user_id = get_user_id()
# Try to get existing user_id from vector store
try:
existing = vector_store.get(vector_id=VECTOR_ID)
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
return existing.payload["user_id"]
except:
pass
# If we get here, we need to insert the user_id
try:
dims = getattr(vector_store, "embedding_model_dims", 1)
vector_store.insert(
vectors=[[0.0] * dims],
payloads=[{"user_id": user_id, "type": "user_identity"}],
ids=[VECTOR_ID]
)
except:
pass
return user_id

View File

@@ -6,7 +6,7 @@ import sys
from posthog import Posthog from posthog import Posthog
import mem0 import mem0
from mem0.memory.setup import get_user_id, setup_config from mem0.memory.setup import get_or_create_user_id
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True") MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
@@ -21,11 +21,11 @@ logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
class AnonymousTelemetry: class AnonymousTelemetry:
def __init__(self, project_api_key, host): def __init__(self, project_api_key, host, vector_store=None):
self.posthog = Posthog(project_api_key=project_api_key, host=host) self.posthog = Posthog(project_api_key=project_api_key, host=host)
# Call setup config to ensure that the user_id is generated
setup_config() self.user_id = get_or_create_user_id(vector_store)
self.user_id = get_user_id()
if not MEM0_TELEMETRY: if not MEM0_TELEMETRY:
self.posthog.disabled = True self.posthog.disabled = True
@@ -50,14 +50,12 @@ class AnonymousTelemetry:
self.posthog.shutdown() self.posthog.shutdown()
# Initialize AnonymousTelemetry
telemetry = AnonymousTelemetry(
project_api_key="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX",
host="https://us.i.posthog.com",
)
def capture_event(event_name, memory_instance, additional_data=None): def capture_event(event_name, memory_instance, additional_data=None):
global telemetry
# For OSS, we use the telemetry vector store to store the user_id
telemetry = AnonymousTelemetry(project_api_key="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX", host="https://us.i.posthog.com", vector_store=memory_instance._telemetry_vector_store if hasattr(memory_instance, "_telemetry_vector_store") else None)
event_data = { event_data = {
"collection": memory_instance.collection_name, "collection": memory_instance.collection_name,
"vector_size": memory_instance.embedding_model.config.embedding_dims, "vector_size": memory_instance.embedding_model.config.embedding_dims,