Fix user_id functionality (#2548)

This commit is contained in:
Dev Khant
2025-04-16 13:32:33 +05:30
committed by GitHub
parent 541030d69c
commit 3613e2f14a
9 changed files with 86 additions and 49 deletions

View File

@@ -1,6 +1,7 @@
import logging import logging
import os import os
import warnings import warnings
import hashlib
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@@ -83,6 +84,9 @@ class MemoryClient:
if not self.api_key: if not self.api_key:
raise ValueError("Mem0 API Key not provided. Please provide an API Key.") raise ValueError("Mem0 API Key not provided. Please provide an API Key.")
# Create MD5 hash of API key for user_id
self.user_id = hashlib.md5(self.api_key.encode()).hexdigest()
self.client = httpx.Client( self.client = httpx.Client(
base_url=self.host, base_url=self.host,
headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id}, headers={"Authorization": f"Token {self.api_key}", "Mem0-User-ID": self.user_id},

View File

@@ -6,9 +6,12 @@ from pydantic import BaseModel, Field
from mem0.embeddings.configs import EmbedderConfig from mem0.embeddings.configs import EmbedderConfig
from mem0.graphs.configs import GraphStoreConfig from mem0.graphs.configs import GraphStoreConfig
from mem0.llms.configs import LlmConfig from mem0.llms.configs import LlmConfig
from mem0.memory.setup import mem0_dir
from mem0.vector_stores.configs import VectorStoreConfig from mem0.vector_stores.configs import VectorStoreConfig
# Set up the directory path
home_dir = os.path.expanduser("~")
mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0")
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data") id: str = Field(..., description="The unique identifier for the text data")

View File

@@ -7,7 +7,7 @@ class AzureAISearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the collection") collection_name: str = Field("mem0", description="Name of the collection")
service_name: str = Field(None, description="Azure AI Search service name") service_name: str = Field(None, description="Azure AI Search service name")
api_key: str = Field(None, description="API key for the Azure AI Search service") api_key: str = Field(None, description="API key for the Azure AI Search service")
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
compression_type: Optional[str] = Field( compression_type: Optional[str] = Field(
None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None" None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
) )

View File

@@ -7,7 +7,9 @@ class LangchainConfig(BaseModel):
try: try:
from langchain_community.vectorstores import VectorStore from langchain_community.vectorstores import VectorStore
except ImportError: except ImportError:
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.") raise ImportError(
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
)
VectorStore: ClassVar[type] = VectorStore VectorStore: ClassVar[type] = VectorStore
client: VectorStore = Field(description="Existing VectorStore instance") client: VectorStore = Field(description="Existing VectorStore instance")

View File

@@ -35,9 +35,7 @@ class MemoryGraph:
self.config.graph_store.config.password, self.config.graph_store.config.password,
) )
self.embedding_model = EmbedderFactory.create( self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
self.config.embedder.config,
self.config.vector_store.config
) )
self.llm_provider = "openai_structured" self.llm_provider = "openai_structured"

View File

@@ -1,3 +1,4 @@
import os
import asyncio import asyncio
import concurrent import concurrent
import hashlib import hashlib
@@ -18,7 +19,7 @@ from mem0.configs.prompts import (
get_update_memory_messages, get_update_memory_messages,
) )
from mem0.memory.base import MemoryBase from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config from mem0.memory.setup import setup_config, mem0_dir
from mem0.memory.storage import SQLiteManager from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event from mem0.memory.telemetry import capture_event
from mem0.memory.utils import ( from mem0.memory.utils import (
@@ -62,6 +63,16 @@ class Memory(MemoryBase):
self.graph = MemoryGraph(self.config) self.graph = MemoryGraph(self.config)
self.enable_graph = True self.enable_graph = True
self.config.vector_store.config.collection_name = "mem0_migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
self._telemetry_vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("mem0.init", self, {"sync_type": "sync"}) capture_event("mem0.init", self, {"sync_type": "sync"})
@classmethod @classmethod

View File

@@ -3,6 +3,7 @@ import os
import uuid import uuid
# Set up the directory path # Set up the directory path
VECTOR_ID = str(uuid.uuid4())
home_dir = os.path.expanduser("~") home_dir = os.path.expanduser("~")
mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0") mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0")
os.makedirs(mem0_dir, exist_ok=True) os.makedirs(mem0_dir, exist_ok=True)
@@ -29,3 +30,27 @@ def get_user_id():
return user_id return user_id
except Exception: except Exception:
return "anonymous_user" return "anonymous_user"
def get_or_create_user_id(vector_store):
"""Store user_id in vector store and return it."""
user_id = get_user_id()
# Try to get existing user_id from vector store
try:
existing = vector_store.get(vector_id=VECTOR_ID)
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
return existing.payload["user_id"]
except:
pass
# If we get here, we need to insert the user_id
try:
dims = getattr(vector_store, "embedding_model_dims", 1)
vector_store.insert(
vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID]
)
except:
pass
return user_id

View File

@@ -6,9 +6,11 @@ import sys
from posthog import Posthog from posthog import Posthog
import mem0 import mem0
from mem0.memory.setup import get_user_id, setup_config from mem0.memory.setup import get_or_create_user_id
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True") MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
PROJECT_API_KEY="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
HOST="https://us.i.posthog.com"
if isinstance(MEM0_TELEMETRY, str): if isinstance(MEM0_TELEMETRY, str):
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes") MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
@@ -21,11 +23,11 @@ logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
class AnonymousTelemetry: class AnonymousTelemetry:
def __init__(self, project_api_key, host): def __init__(self, vector_store=None):
self.posthog = Posthog(project_api_key=project_api_key, host=host) self.posthog = Posthog(project_api_key=PROJECT_API_KEY, host=HOST)
# Call setup config to ensure that the user_id is generated
setup_config() self.user_id = get_or_create_user_id(vector_store)
self.user_id = get_user_id()
if not MEM0_TELEMETRY: if not MEM0_TELEMETRY:
self.posthog.disabled = True self.posthog.disabled = True
@@ -50,14 +52,16 @@ class AnonymousTelemetry:
self.posthog.shutdown() self.posthog.shutdown()
# Initialize AnonymousTelemetry client_telemetry = AnonymousTelemetry()
telemetry = AnonymousTelemetry(
project_api_key="phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX",
host="https://us.i.posthog.com",
)
def capture_event(event_name, memory_instance, additional_data=None): def capture_event(event_name, memory_instance, additional_data=None):
oss_telemetry = AnonymousTelemetry(
vector_store=memory_instance._telemetry_vector_store
if hasattr(memory_instance, "_telemetry_vector_store")
else None,
)
event_data = { event_data = {
"collection": memory_instance.collection_name, "collection": memory_instance.collection_name,
"vector_size": memory_instance.embedding_model.config.embedding_dims, "vector_size": memory_instance.embedding_model.config.embedding_dims,
@@ -73,7 +77,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
if additional_data: if additional_data:
event_data.update(additional_data) event_data.update(additional_data)
telemetry.capture_event(event_name, event_data) oss_telemetry.capture_event(event_name, event_data)
def capture_client_event(event_name, instance, additional_data=None): def capture_client_event(event_name, instance, additional_data=None):
@@ -83,4 +87,4 @@ def capture_client_event(event_name, instance, additional_data=None):
if additional_data: if additional_data:
event_data.update(additional_data) event_data.update(additional_data)
telemetry.capture_event(event_name, event_data, instance.user_email) client_telemetry.capture_event(event_name, event_data, instance.user_email)

View File

@@ -5,7 +5,9 @@ from pydantic import BaseModel
try: try:
from langchain_community.vectorstores import VectorStore from langchain_community.vectorstores import VectorStore
except ImportError: except ImportError:
raise ImportError("The 'langchain_community' library is required. Please install it using 'pip install langchain_community'.") raise ImportError(
"The 'langchain_community' library is required. Please install it using 'pip install langchain_community'."
)
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
@@ -15,11 +17,12 @@ class OutputData(BaseModel):
score: Optional[float] # distance score: Optional[float] # distance
payload: Optional[Dict] # metadata payload: Optional[Dict] # metadata
class Langchain(VectorStoreBase): class Langchain(VectorStoreBase):
def __init__(self, client: VectorStore, collection_name: str = "mem0"): def __init__(self, client: VectorStore, collection_name: str = "mem0"):
self.client = client self.client = client
self.collection_name = collection_name self.collection_name = collection_name
def _parse_output(self, data: Dict) -> List[OutputData]: def _parse_output(self, data: Dict) -> List[OutputData]:
""" """
Parse the output data. Parse the output data.
@@ -31,17 +34,17 @@ class Langchain(VectorStoreBase):
List[OutputData]: Parsed output data. List[OutputData]: Parsed output data.
""" """
# Check if input is a list of Document objects # Check if input is a list of Document objects
if isinstance(data, list) and all(hasattr(doc, 'metadata') for doc in data if hasattr(doc, '__dict__')): if isinstance(data, list) and all(hasattr(doc, "metadata") for doc in data if hasattr(doc, "__dict__")):
result = [] result = []
for doc in data: for doc in data:
entry = OutputData( entry = OutputData(
id=getattr(doc, "id", None), id=getattr(doc, "id", None),
score=None, # Document objects typically don't include scores score=None, # Document objects typically don't include scores
payload=getattr(doc, "metadata", {}) payload=getattr(doc, "metadata", {}),
) )
result.append(entry) result.append(entry)
return result return result
# Original format handling # Original format handling
keys = ["ids", "distances", "metadatas"] keys = ["ids", "distances", "metadatas"]
values = [] values = []
@@ -70,26 +73,20 @@ class Langchain(VectorStoreBase):
self.collection_name = name self.collection_name = name
return self.client return self.client
def insert(self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None): def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
):
""" """
Insert vectors into the LangChain vectorstore. Insert vectors into the LangChain vectorstore.
""" """
# Check if client has add_embeddings method # Check if client has add_embeddings method
if hasattr(self.client, "add_embeddings"): if hasattr(self.client, "add_embeddings"):
# Some LangChain vectorstores have a direct add_embeddings method # Some LangChain vectorstores have a direct add_embeddings method
self.client.add_embeddings( self.client.add_embeddings(embeddings=vectors, metadatas=payloads, ids=ids)
embeddings=vectors,
metadatas=payloads,
ids=ids
)
else: else:
# Fallback to add_texts method # Fallback to add_texts method
texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors) texts = [payload.get("data", "") for payload in payloads] if payloads else [""] * len(vectors)
self.client.add_texts( self.client.add_texts(texts=texts, metadatas=payloads, ids=ids)
texts=texts,
metadatas=payloads,
ids=ids
)
def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None): def search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None):
""" """
@@ -97,16 +94,9 @@ class Langchain(VectorStoreBase):
""" """
# For each vector, perform a similarity search # For each vector, perform a similarity search
if filters: if filters:
results = self.client.similarity_search_by_vector( results = self.client.similarity_search_by_vector(embedding=vectors, k=limit, filter=filters)
embedding=vectors,
k=limit,
filter=filters
)
else: else:
results = self.client.similarity_search_by_vector( results = self.client.similarity_search_by_vector(embedding=vectors, k=limit)
embedding=vectors,
k=limit
)
final_results = self._parse_output(results) final_results = self._parse_output(results)
return final_results return final_results
@@ -133,26 +123,26 @@ class Langchain(VectorStoreBase):
doc = docs[0] doc = docs[0]
return self._parse_output([doc])[0] return self._parse_output([doc])[0]
return None return None
def list_cols(self): def list_cols(self):
""" """
List all collections. List all collections.
""" """
# LangChain doesn't have collections # LangChain doesn't have collections
return [self.collection_name] return [self.collection_name]
def delete_col(self): def delete_col(self):
""" """
Delete a collection. Delete a collection.
""" """
self.client.delete(ids=None) self.client.delete(ids=None)
def col_info(self): def col_info(self):
""" """
Get information about a collection. Get information about a collection.
""" """
return {"name": self.collection_name} return {"name": self.collection_name}
def list(self, filters=None, limit=None): def list(self, filters=None, limit=None):
""" """
List all vectors in a collection. List all vectors in a collection.