From aae5989e78a6188b3b047c104d960c9ad0927e75 Mon Sep 17 00:00:00 2001 From: Varun Mohanta <140868411+V-Silpin@users.noreply.github.com> Date: Tue, 8 Jul 2025 10:57:25 +0530 Subject: [PATCH] Fix: Changed keyword from assisstant to secretary (#2937) --- evaluation/metrics/llm_judge.py | 1 + .../misc/diet_assistant_voice_cartesia.py | 2 + examples/misc/test.py | 2 +- mem0/llms/aws_bedrock.py | 1 - mem0/llms/azure_openai.py | 6 + mem0/llms/azure_openai_structured.py | 7 + mem0/llms/sarvam.py | 4 +- mem0/llms/vllm.py | 2 +- mem0/memory/utils.py | 2 +- mem0/vector_stores/azure_ai_search.py | 2 +- mem0/vector_stores/baidu.py | 27 ++- mem0/vector_stores/mongodb.py | 4 +- mem0/vector_stores/redis.py | 2 +- openmemory/api/alembic/env.py | 16 +- .../0b53c747049a_initial_migration.py | 3 +- .../api/alembic/versions/add_config_table.py | 6 +- ...0efbd06b_add_unique_user_id_constraints.py | 2 - openmemory/api/app/database.py | 3 +- openmemory/api/app/mcp_server.py | 62 +++---- openmemory/api/app/models.py | 24 ++- openmemory/api/app/routers/__init__.py | 6 +- openmemory/api/app/routers/apps.py | 6 +- openmemory/api/app/routers/config.py | 11 +- openmemory/api/app/routers/memories.py | 35 ++-- openmemory/api/app/routers/stats.py | 5 +- openmemory/api/app/schemas.py | 4 +- openmemory/api/app/utils/categorization.py | 2 +- openmemory/api/app/utils/db.py | 5 +- openmemory/api/app/utils/memory.py | 7 +- openmemory/api/app/utils/permissions.py | 3 +- openmemory/api/main.py | 17 +- tests/embeddings/test_gemini_emeddings.py | 2 +- tests/test_memory_integration.py | 175 ++++++++++++++++++ tests/vector_stores/test_baidu.py | 13 +- tests/vector_stores/test_mongodb.py | 4 +- 35 files changed, 351 insertions(+), 122 deletions(-) create mode 100644 tests/test_memory_integration.py diff --git a/evaluation/metrics/llm_judge.py b/evaluation/metrics/llm_judge.py index 20acc4af..55c946a0 100644 --- a/evaluation/metrics/llm_judge.py +++ b/evaluation/metrics/llm_judge.py @@ -4,6 +4,7 @@ from collections import defaultdict import numpy as np from openai import OpenAI + from mem0.memory.utils import extract_json client = OpenAI() diff --git a/examples/misc/diet_assistant_voice_cartesia.py b/examples/misc/diet_assistant_voice_cartesia.py index 62b2bf35..2fb2f6c1 100644 --- a/examples/misc/diet_assistant_voice_cartesia.py +++ b/examples/misc/diet_assistant_voice_cartesia.py @@ -8,10 +8,12 @@ export CARTESIA_API_KEY=your_cartesia_api_key """ from textwrap import dedent + from agno.agent import Agent from agno.models.openai import OpenAIChat from agno.tools.cartesia import CartesiaTools from agno.utils.audio import write_audio_to_file + from mem0 import MemoryClient memory_client = MemoryClient() diff --git a/examples/misc/test.py b/examples/misc/test.py index 06c73a9d..ec21f6ac 100644 --- a/examples/misc/test.py +++ b/examples/misc/test.py @@ -1,4 +1,4 @@ -from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging +from agents import Agent, Runner, enable_verbose_stdout_logging, function_tool from dotenv import load_dotenv from mem0 import MemoryClient diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index b170f298..8266beef 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -11,7 +11,6 @@ except ImportError: from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase - PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"] diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index 1dcb5f3f..9a04a804 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -82,6 +82,12 @@ class AzureOpenAILLM(LLMBase): str: The generated response. """ + user_prompt = messages[-1]['content'] + + user_prompt = user_prompt.replace("assistant", "ai") + + messages[-1]['content'] = user_prompt + common_params = { "model": self.config.model, "messages": messages, diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py index 1a746fb2..a9361fc5 100644 --- a/mem0/llms/azure_openai_structured.py +++ b/mem0/llms/azure_openai_structured.py @@ -48,6 +48,13 @@ class AzureOpenAIStructuredLLM(LLMBase): Returns: str: The generated response. """ + + user_prompt = messages[-1]['content'] + + user_prompt = user_prompt.replace("assistant", "ai") + + messages[-1]['content'] = user_prompt + params = { "model": self.config.model, "messages": messages, diff --git a/mem0/llms/sarvam.py b/mem0/llms/sarvam.py index b3389c94..6ef836ed 100644 --- a/mem0/llms/sarvam.py +++ b/mem0/llms/sarvam.py @@ -1,6 +1,8 @@ import os -import requests from typing import Dict, List, Optional + +import requests + from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase diff --git a/mem0/llms/vllm.py b/mem0/llms/vllm.py index 52abf427..6aa13add 100644 --- a/mem0/llms/vllm.py +++ b/mem0/llms/vllm.py @@ -1,8 +1,8 @@ import json import os from typing import Dict, List, Optional -from openai import OpenAI +from openai import OpenAI from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py index 9018d546..00a0a36b 100644 --- a/mem0/memory/utils.py +++ b/mem0/memory/utils.py @@ -1,5 +1,5 @@ -import re import hashlib +import re from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 7e06cff7..b6ebde37 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -5,8 +5,8 @@ from typing import List, Optional from pydantic import BaseModel -from mem0.vector_stores.base import VectorStoreBase from mem0.memory.utils import extract_json +from mem0.vector_stores.base import VectorStoreBase try: from azure.core.credentials import AzureKeyCredential diff --git a/mem0/vector_stores/baidu.py b/mem0/vector_stores/baidu.py index 0a3ed139..2c211abe 100644 --- a/mem0/vector_stores/baidu.py +++ b/mem0/vector_stores/baidu.py @@ -8,12 +8,31 @@ from mem0.vector_stores.base import VectorStoreBase try: import pymochow - from pymochow.configuration import Configuration from pymochow.auth.bce_credentials import BceCredentials - from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode - from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement - from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector + from pymochow.configuration import Configuration from pymochow.exception import ServerError + from pymochow.model.enum import ( + FieldType, + IndexType, + MetricType, + ServerErrCode, + TableState, + ) + from pymochow.model.schema import ( + AutoBuildRowCountIncrement, + Field, + FilteringIndex, + HNSWParams, + Schema, + VectorIndex, + ) + from pymochow.model.table import ( + FloatVector, + Partition, + Row, + VectorSearchConfig, + VectorTopkSearchRequest, + ) except ImportError: raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.") diff --git a/mem0/vector_stores/mongodb.py b/mem0/vector_stores/mongodb.py index 5b212eed..ef82c6e6 100644 --- a/mem0/vector_stores/mongodb.py +++ b/mem0/vector_stores/mongodb.py @@ -1,12 +1,12 @@ import logging -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional from pydantic import BaseModel try: from pymongo import MongoClient - from pymongo.operations import SearchIndexModel from pymongo.errors import PyMongoError + from pymongo.operations import SearchIndexModel except ImportError: raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.") diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py index 8e85055f..7fb1ada9 100644 --- a/mem0/vector_stores/redis.py +++ b/mem0/vector_stores/redis.py @@ -11,8 +11,8 @@ from redisvl.index import SearchIndex from redisvl.query import VectorQuery from redisvl.query.filter import Tag -from mem0.vector_stores.base import VectorStoreBase from mem0.memory.utils import extract_json +from mem0.vector_stores.base import VectorStoreBase logger = logging.getLogger(__name__) diff --git a/openmemory/api/alembic/env.py b/openmemory/api/alembic/env.py index b4295b69..278cc65f 100644 --- a/openmemory/api/alembic/env.py +++ b/openmemory/api/alembic/env.py @@ -1,13 +1,10 @@ -from logging.config import fileConfig - -from sqlalchemy import engine_from_config -from sqlalchemy import pool - -from alembic import context - import os import sys +from logging.config import fileConfig + +from alembic import context from dotenv import load_dotenv +from sqlalchemy import engine_from_config, pool # Add the parent directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -15,9 +12,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Load environment variables load_dotenv() -# Import your models here -from app.database import Base -from app.models import * # Import all your models +# Import your models here - moved after path setup +from app.database import Base # noqa: E402 # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/openmemory/api/alembic/versions/0b53c747049a_initial_migration.py b/openmemory/api/alembic/versions/0b53c747049a_initial_migration.py index fb834314..6bbfbcca 100644 --- a/openmemory/api/alembic/versions/0b53c747049a_initial_migration.py +++ b/openmemory/api/alembic/versions/0b53c747049a_initial_migration.py @@ -7,9 +7,8 @@ Create Date: 2025-04-19 00:59:56.244203 """ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision: str = '0b53c747049a' diff --git a/openmemory/api/alembic/versions/add_config_table.py b/openmemory/api/alembic/versions/add_config_table.py index cc7c8268..b53488f9 100644 --- a/openmemory/api/alembic/versions/add_config_table.py +++ b/openmemory/api/alembic/versions/add_config_table.py @@ -5,11 +5,11 @@ Revises: 0b53c747049a Create Date: 2023-06-01 10:00:00.000000 """ -from alembic import op -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql import uuid +import sqlalchemy as sa +from alembic import op + # revision identifiers, used by Alembic. revision = 'add_config_table' down_revision = '0b53c747049a' diff --git a/openmemory/api/alembic/versions/afd00efbd06b_add_unique_user_id_constraints.py b/openmemory/api/alembic/versions/afd00efbd06b_add_unique_user_id_constraints.py index a685f4e2..bec325c3 100644 --- a/openmemory/api/alembic/versions/afd00efbd06b_add_unique_user_id_constraints.py +++ b/openmemory/api/alembic/versions/afd00efbd06b_add_unique_user_id_constraints.py @@ -8,8 +8,6 @@ Create Date: 2025-06-04 01:59:41.637440 from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = 'afd00efbd06b' diff --git a/openmemory/api/app/database.py b/openmemory/api/app/database.py index 6404ce70..4ab4eaaa 100644 --- a/openmemory/api/app/database.py +++ b/openmemory/api/app/database.py @@ -1,7 +1,8 @@ import os + +from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.orm import declarative_base, sessionmaker -from dotenv import load_dotenv # load .env file (make sure you have DATABASE_URL set) load_dotenv() diff --git a/openmemory/api/app/mcp_server.py b/openmemory/api/app/mcp_server.py index 16bcf6c5..f4759628 100644 --- a/openmemory/api/app/mcp_server.py +++ b/openmemory/api/app/mcp_server.py @@ -15,22 +15,22 @@ Key features: - Environment variable parsing for API keys """ -import logging +import contextvars +import datetime import json -from mcp.server.fastmcp import FastMCP -from mcp.server.sse import SseServerTransport +import logging +import uuid + +from app.database import SessionLocal +from app.models import Memory, MemoryAccessLog, MemoryState, MemoryStatusHistory +from app.utils.db import get_user_and_app from app.utils.memory import get_memory_client +from app.utils.permissions import check_memory_access_permissions +from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.routing import APIRouter -import contextvars -import os -from dotenv import load_dotenv -from app.database import SessionLocal -from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog -from app.utils.db import get_user_and_app -import uuid -import datetime -from app.utils.permissions import check_memory_access_permissions +from mcp.server.fastmcp import FastMCP +from mcp.server.sse import SseServerTransport from qdrant_client import models as qdrant_models # Load environment variables @@ -410,32 +410,32 @@ async def handle_get_message(request: Request): async def handle_post_message(request: Request): return await handle_post_message(request) -async def handle_post_message(request: Request): - """Handle POST messages for SSE""" - try: - body = await request.body() +# async def handle_post_message(request: Request): +# """Handle POST messages for SSE""" +# try: +# body = await request.body() - # Create a simple receive function that returns the body - async def receive(): - return {"type": "http.request", "body": body, "more_body": False} +# # Create a simple receive function that returns the body +# async def receive(): +# return {"type": "http.request", "body": body, "more_body": False} - # Create a simple send function that does nothing - async def send(message): - return {} +# # Create a simple send function that does nothing +# async def send(message): +# return {} - # Call handle_post_message with the correct arguments - await sse.handle_post_message(request.scope, receive, send) +# # Call handle_post_message with the correct arguments +# await sse.handle_post_message(request.scope, receive, send) - # Return a success response - return {"status": "ok"} - finally: - pass - # Clean up context variable - # client_name_var.reset(client_token) +# # Return a success response +# return {"status": "ok"} +# finally: +# pass +# # Clean up context variable +# # client_name_var.reset(client_token) def setup_mcp_server(app: FastAPI): """Setup MCP server with the FastAPI application""" - mcp._mcp_server.name = f"mem0-mcp-server" + mcp._mcp_server.name = "mem0-mcp-server" # Include MCP router in the FastAPI app app.include_router(mcp_router) diff --git a/openmemory/api/app/models.py b/openmemory/api/app/models.py index 7bc3546d..66541013 100644 --- a/openmemory/api/app/models.py +++ b/openmemory/api/app/models.py @@ -1,15 +1,25 @@ +import datetime import enum import uuid -import datetime + import sqlalchemy as sa -from sqlalchemy import ( - Column, String, Boolean, ForeignKey, Enum, Table, - DateTime, JSON, Integer, UUID, Index, event -) -from sqlalchemy.orm import relationship from app.database import Base -from sqlalchemy.orm import Session from app.utils.categorization import get_categories_for_memory +from sqlalchemy import ( + JSON, + UUID, + Boolean, + Column, + DateTime, + Enum, + ForeignKey, + Index, + Integer, + String, + Table, + event, +) +from sqlalchemy.orm import Session, relationship def get_current_utc_time(): diff --git a/openmemory/api/app/routers/__init__.py b/openmemory/api/app/routers/__init__.py index 4210577c..519e7edd 100644 --- a/openmemory/api/app/routers/__init__.py +++ b/openmemory/api/app/routers/__init__.py @@ -1,6 +1,6 @@ -from .memories import router as memories_router from .apps import router as apps_router -from .stats import router as stats_router from .config import router as config_router +from .memories import router as memories_router +from .stats import router as stats_router -__all__ = ["memories_router", "apps_router", "stats_router", "config_router"] \ No newline at end of file +__all__ = ["memories_router", "apps_router", "stats_router", "config_router"] diff --git a/openmemory/api/app/routers/apps.py b/openmemory/api/app/routers/apps.py index 36584f92..97f0dc89 100644 --- a/openmemory/api/app/routers/apps.py +++ b/openmemory/api/app/routers/apps.py @@ -1,11 +1,11 @@ from typing import Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import func, desc from app.database import get_db from app.models import App, Memory, MemoryAccessLog, MemoryState +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import desc, func +from sqlalchemy.orm import Session, joinedload router = APIRouter(prefix="/api/v1/apps", tags=["apps"]) diff --git a/openmemory/api/app/routers/config.py b/openmemory/api/app/routers/config.py index cab9630a..7eaae4bf 100644 --- a/openmemory/api/app/routers/config.py +++ b/openmemory/api/app/routers/config.py @@ -1,12 +1,11 @@ -import os -import json -from typing import Dict, Any, Optional -from fastapi import APIRouter, HTTPException, Depends -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from typing import Any, Dict, Optional + from app.database import get_db from app.models import Config as ConfigModel from app.utils.memory import reset_memory_client +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session router = APIRouter(prefix="/api/v1/config", tags=["config"]) diff --git a/openmemory/api/app/routers/memories.py b/openmemory/api/app/routers/memories.py index ff7d3c4b..f73aa702 100644 --- a/openmemory/api/app/routers/memories.py +++ b/openmemory/api/app/routers/memories.py @@ -1,23 +1,28 @@ -from datetime import datetime, UTC -from typing import List, Optional, Set -from uuid import UUID, uuid4 import logging -import os -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.orm import Session, joinedload -from fastapi_pagination import Page, Params -from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate -from pydantic import BaseModel -from sqlalchemy import or_, func -from app.utils.memory import get_memory_client +from datetime import UTC, datetime +from typing import List, Optional, Set +from uuid import UUID from app.database import get_db from app.models import ( - Memory, MemoryState, MemoryAccessLog, App, - MemoryStatusHistory, User, Category, AccessControl, Config as ConfigModel + AccessControl, + App, + Category, + Memory, + MemoryAccessLog, + MemoryState, + MemoryStatusHistory, + User, ) -from app.schemas import MemoryResponse, PaginatedMemoryResponse +from app.schemas import MemoryResponse +from app.utils.memory import get_memory_client from app.utils.permissions import check_memory_access_permissions +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi_pagination import Page, Params +from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate +from pydantic import BaseModel +from sqlalchemy import func +from sqlalchemy.orm import Session, joinedload router = APIRouter(prefix="/api/v1/memories", tags=["memories"]) @@ -412,7 +417,7 @@ async def pause_memories( ).all() for memory in memories: update_memory_state(db, memory.id, state, user_id) - return {"message": f"Successfully paused all memories"} + return {"message": "Successfully paused all memories"} if memory_ids: # Pause specific memories diff --git a/openmemory/api/app/routers/stats.py b/openmemory/api/app/routers/stats.py index 047721f1..c609d372 100644 --- a/openmemory/api/app/routers/stats.py +++ b/openmemory/api/app/routers/stats.py @@ -1,8 +1,7 @@ +from app.database import get_db +from app.models import App, Memory, MemoryState, User from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from app.database import get_db -from app.models import User, Memory, App, MemoryState - router = APIRouter(prefix="/api/v1/stats", tags=["stats"]) diff --git a/openmemory/api/app/schemas.py b/openmemory/api/app/schemas.py index 35a512f6..f5462e7f 100644 --- a/openmemory/api/app/schemas.py +++ b/openmemory/api/app/schemas.py @@ -1,8 +1,10 @@ from datetime import datetime -from typing import Optional, List +from typing import List, Optional from uuid import UUID + from pydantic import BaseModel, Field, validator + class MemoryBase(BaseModel): content: str metadata_: Optional[dict] = Field(default_factory=dict) diff --git a/openmemory/api/app/utils/categorization.py b/openmemory/api/app/utils/categorization.py index d3554691..e20c4005 100644 --- a/openmemory/api/app/utils/categorization.py +++ b/openmemory/api/app/utils/categorization.py @@ -1,11 +1,11 @@ import logging from typing import List +from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT from dotenv import load_dotenv from openai import OpenAI from pydantic import BaseModel from tenacity import retry, stop_after_attempt, wait_exponential -from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT load_dotenv() openai_client = OpenAI() diff --git a/openmemory/api/app/utils/db.py b/openmemory/api/app/utils/db.py index bf2fcf9c..50a90f6a 100644 --- a/openmemory/api/app/utils/db.py +++ b/openmemory/api/app/utils/db.py @@ -1,7 +1,8 @@ -from sqlalchemy.orm import Session -from app.models import User, App from typing import Tuple +from app.models import App, User +from sqlalchemy.orm import Session + def get_or_create_user(db: Session, user_id: str) -> User: """Get or create a user with the given user_id""" diff --git a/openmemory/api/app/utils/memory.py b/openmemory/api/app/utils/memory.py index a7319b94..921e4dff 100644 --- a/openmemory/api/app/utils/memory.py +++ b/openmemory/api/app/utils/memory.py @@ -27,16 +27,15 @@ Example configuration that will be automatically adjusted: } """ -import os -import json import hashlib +import json +import os import socket -import platform -from mem0 import Memory from app.database import SessionLocal from app.models import Config as ConfigModel +from mem0 import Memory _memory_client = None _config_hash = None diff --git a/openmemory/api/app/utils/permissions.py b/openmemory/api/app/utils/permissions.py index 4b8a04ed..060caf96 100644 --- a/openmemory/api/app/utils/permissions.py +++ b/openmemory/api/app/utils/permissions.py @@ -1,7 +1,8 @@ from typing import Optional from uuid import UUID + +from app.models import App, Memory, MemoryState from sqlalchemy.orm import Session -from app.models import Memory, App, MemoryState def check_memory_access_permissions( diff --git a/openmemory/api/main.py b/openmemory/api/main.py index 049d9e51..923ba8cf 100644 --- a/openmemory/api/main.py +++ b/openmemory/api/main.py @@ -1,13 +1,14 @@ import datetime -from fastapi import FastAPI -from app.database import engine, Base, SessionLocal -from app.mcp_server import setup_mcp_server -from app.routers import memories_router, apps_router, stats_router, config_router -from fastapi_pagination import add_pagination -from fastapi.middleware.cors import CORSMiddleware -from app.models import User, App from uuid import uuid4 -from app.config import USER_ID, DEFAULT_APP_ID + +from app.config import DEFAULT_APP_ID, USER_ID +from app.database import Base, SessionLocal, engine +from app.mcp_server import setup_mcp_server +from app.models import App, User +from app.routers import apps_router, config_router, memories_router, stats_router +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi_pagination import add_pagination app = FastAPI(title="OpenMemory API") diff --git a/tests/embeddings/test_gemini_emeddings.py b/tests/embeddings/test_gemini_emeddings.py index e3254702..dff834df 100644 --- a/tests/embeddings/test_gemini_emeddings.py +++ b/tests/embeddings/test_gemini_emeddings.py @@ -1,4 +1,4 @@ -from unittest.mock import patch, ANY +from unittest.mock import ANY, patch import pytest diff --git a/tests/test_memory_integration.py b/tests/test_memory_integration.py new file mode 100644 index 00000000..899eb76d --- /dev/null +++ b/tests/test_memory_integration.py @@ -0,0 +1,175 @@ +from unittest.mock import MagicMock, patch + +from mem0.memory.main import Memory + + +def test_memory_configuration_without_env_vars(): + """Test Memory configuration with mock config instead of environment variables""" + + # Mock configuration without relying on environment variables + mock_config = { + "llm": { + "provider": "openai", + "config": { + "model": "gpt-4", + "temperature": 0.1, + "max_tokens": 1500, + } + }, + "vector_store": { + "provider": "chroma", + "config": { + "collection_name": "test_collection", + "path": "./test_db", + } + }, + "embedder": { + "provider": "openai", + "config": { + "model": "text-embedding-ada-002", + } + } + } + + # Test messages similar to the main.py file + test_messages = [ + {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, + {"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} + ] + + # Mock the Memory class methods to avoid actual API calls + with patch.object(Memory, '__init__', return_value=None): + with patch.object(Memory, 'from_config') as mock_from_config: + with patch.object(Memory, 'add') as mock_add: + with patch.object(Memory, 'get_all') as mock_get_all: + + # Configure mocks + mock_memory_instance = MagicMock() + mock_from_config.return_value = mock_memory_instance + + mock_add.return_value = { + "results": [ + {"id": "1", "text": "Alex is a vegetarian"}, + {"id": "2", "text": "Alex is allergic to nuts"} + ] + } + + mock_get_all.return_value = [ + {"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}}, + {"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}} + ] + + # Test the workflow + mem = Memory.from_config(config_dict=mock_config) + assert mem is not None + + # Test adding memories + result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"}) + assert "results" in result + assert len(result["results"]) == 2 + + # Test retrieving memories + all_memories = mock_get_all(user_id="alice") + assert len(all_memories) == 2 + assert any("vegetarian" in memory["text"] for memory in all_memories) + assert any("allergic to nuts" in memory["text"] for memory in all_memories) + + +def test_azure_config_structure(): + """Test that Azure configuration structure is properly formatted""" + + # Test Azure configuration structure (without actual credentials) + azure_config = { + "llm": { + "provider": "azure_openai", + "config": { + "model": "gpt-4", + "temperature": 0.1, + "max_tokens": 1500, + "azure_kwargs": { + "azure_deployment": "test-deployment", + "api_version": "2023-12-01-preview", + "azure_endpoint": "https://test.openai.azure.com/", + "api_key": "test-key", + } + } + }, + "vector_store": { + "provider": "azure_ai_search", + "config": { + "service_name": "test-service", + "api_key": "test-key", + "collection_name": "test-collection", + "embedding_model_dims": 1536, + } + }, + "embedder": { + "provider": "azure_openai", + "config": { + "model": "text-embedding-ada-002", + "api_key": "test-key", + "azure_kwargs": { + "api_version": "2023-12-01-preview", + "azure_deployment": "test-embedding-deployment", + "azure_endpoint": "https://test.openai.azure.com/", + "api_key": "test-key", + } + } + } + } + + # Validate configuration structure + assert "llm" in azure_config + assert "vector_store" in azure_config + assert "embedder" in azure_config + + # Validate Azure-specific configurations + assert azure_config["llm"]["provider"] == "azure_openai" + assert "azure_kwargs" in azure_config["llm"]["config"] + assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"] + + assert azure_config["vector_store"]["provider"] == "azure_ai_search" + assert "service_name" in azure_config["vector_store"]["config"] + + assert azure_config["embedder"]["provider"] == "azure_openai" + assert "azure_kwargs" in azure_config["embedder"]["config"] + + +def test_memory_messages_format(): + """Test that memory messages are properly formatted""" + + # Test message format from main.py + messages = [ + {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, + {"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} + ] + + # Validate message structure + assert len(messages) == 2 + assert all("role" in msg for msg in messages) + assert all("content" in msg for msg in messages) + + # Validate roles + assert messages[0]["role"] == "user" + assert messages[1]["role"] == "assistant" + + # Validate content + assert "vegetarian" in messages[0]["content"].lower() + assert "allergic to nuts" in messages[0]["content"].lower() + assert "vegetarian" in messages[1]["content"].lower() + assert "nut allergy" in messages[1]["content"].lower() + + +def test_safe_update_prompt_constant(): + """Test the SAFE_UPDATE_PROMPT constant from main.py""" + + SAFE_UPDATE_PROMPT = """ +Based on the user's latest messages, what new preference can be inferred? +Reply only in this json_object format: +""" + + # Validate prompt structure + assert isinstance(SAFE_UPDATE_PROMPT, str) + assert "user's latest messages" in SAFE_UPDATE_PROMPT + assert "json_object format" in SAFE_UPDATE_PROMPT + assert len(SAFE_UPDATE_PROMPT.strip()) > 0 diff --git a/tests/vector_stores/test_baidu.py b/tests/vector_stores/test_baidu.py index 987c298b..981c7790 100644 --- a/tests/vector_stores/test_baidu.py +++ b/tests/vector_stores/test_baidu.py @@ -1,11 +1,16 @@ -from unittest.mock import Mock, patch, PropertyMock +from unittest.mock import Mock, PropertyMock, patch import pytest +from pymochow.exception import ServerError +from pymochow.model.enum import ServerErrCode, TableState +from pymochow.model.table import ( + FloatVector, + Table, + VectorSearchConfig, + VectorTopkSearchRequest, +) from mem0.vector_stores.baidu import BaiduDB -from pymochow.model.enum import TableState, ServerErrCode -from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table -from pymochow.exception import ServerError @pytest.fixture diff --git a/tests/vector_stores/test_mongodb.py b/tests/vector_stores/test_mongodb.py index 812abfc2..abb299e0 100644 --- a/tests/vector_stores/test_mongodb.py +++ b/tests/vector_stores/test_mongodb.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import MagicMock, patch + +import pytest + from mem0.vector_stores.mongodb import MongoDB