Fix: Changed keyword from assisstant to secretary (#2937)

This commit is contained in:
Varun Mohanta
2025-07-08 10:57:25 +05:30
committed by GitHub
parent 6866e56d7a
commit aae5989e78
35 changed files with 351 additions and 122 deletions

View File

@@ -4,6 +4,7 @@ from collections import defaultdict
import numpy as np
from openai import OpenAI
from mem0.memory.utils import extract_json
client = OpenAI()

View File

@@ -8,10 +8,12 @@ export CARTESIA_API_KEY=your_cartesia_api_key
"""
from textwrap import dedent
from agno.agent import Agent
from agno.models.openai import OpenAIChat
from agno.tools.cartesia import CartesiaTools
from agno.utils.audio import write_audio_to_file
from mem0 import MemoryClient
memory_client = MemoryClient()

View File

@@ -1,4 +1,4 @@
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
from agents import Agent, Runner, enable_verbose_stdout_logging, function_tool
from dotenv import load_dotenv
from mem0 import MemoryClient

View File

@@ -11,7 +11,6 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]

View File

@@ -82,6 +82,12 @@ class AzureOpenAILLM(LLMBase):
str: The generated response.
"""
user_prompt = messages[-1]['content']
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
common_params = {
"model": self.config.model,
"messages": messages,

View File

@@ -48,6 +48,13 @@ class AzureOpenAIStructuredLLM(LLMBase):
Returns:
str: The generated response.
"""
user_prompt = messages[-1]['content']
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
params = {
"model": self.config.model,
"messages": messages,

View File

@@ -1,6 +1,8 @@
import os
import requests
from typing import Dict, List, Optional
import requests
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase

View File

@@ -1,8 +1,8 @@
import json
import os
from typing import Dict, List, Optional
from openai import OpenAI
from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase

View File

@@ -1,5 +1,5 @@
import re
import hashlib
import re
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT

View File

@@ -5,8 +5,8 @@ from typing import List, Optional
from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
try:
from azure.core.credentials import AzureKeyCredential

View File

@@ -8,12 +8,31 @@ from mem0.vector_stores.base import VectorStoreBase
try:
import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector
from pymochow.configuration import Configuration
from pymochow.exception import ServerError
from pymochow.model.enum import (
FieldType,
IndexType,
MetricType,
ServerErrCode,
TableState,
)
from pymochow.model.schema import (
AutoBuildRowCountIncrement,
Field,
FilteringIndex,
HNSWParams,
Schema,
VectorIndex,
)
from pymochow.model.table import (
FloatVector,
Partition,
Row,
VectorSearchConfig,
VectorTopkSearchRequest,
)
except ImportError:
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")

View File

@@ -1,12 +1,12 @@
import logging
from typing import List, Optional, Dict, Any
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
try:
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from pymongo.operations import SearchIndexModel
except ImportError:
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")

View File

@@ -11,8 +11,8 @@ from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.query.filter import Tag
from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)

View File

@@ -1,13 +1,10 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
import os
import sys
from logging.config import fileConfig
from alembic import context
from dotenv import load_dotenv
from sqlalchemy import engine_from_config, pool
# Add the parent directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -15,9 +12,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load environment variables
load_dotenv()
# Import your models here
from app.database import Base
from app.models import * # Import all your models
# Import your models here - moved after path setup
from app.database import Base # noqa: E402
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.

View File

@@ -7,9 +7,8 @@ Create Date: 2025-04-19 00:59:56.244203
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '0b53c747049a'

View File

@@ -5,11 +5,11 @@ Revises: 0b53c747049a
Create Date: 2023-06-01 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import uuid
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'add_config_table'
down_revision = '0b53c747049a'

View File

@@ -8,8 +8,6 @@ Create Date: 2025-06-04 01:59:41.637440
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'afd00efbd06b'

View File

@@ -1,7 +1,8 @@
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from dotenv import load_dotenv
# load .env file (make sure you have DATABASE_URL set)
load_dotenv()

View File

@@ -15,22 +15,22 @@ Key features:
- Environment variable parsing for API keys
"""
import logging
import contextvars
import datetime
import json
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
import logging
import uuid
from app.database import SessionLocal
from app.models import Memory, MemoryAccessLog, MemoryState, MemoryStatusHistory
from app.utils.db import get_user_and_app
from app.utils.memory import get_memory_client
from app.utils.permissions import check_memory_access_permissions
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.routing import APIRouter
import contextvars
import os
from dotenv import load_dotenv
from app.database import SessionLocal
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
from app.utils.db import get_user_and_app
import uuid
import datetime
from app.utils.permissions import check_memory_access_permissions
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from qdrant_client import models as qdrant_models
# Load environment variables
@@ -410,32 +410,32 @@ async def handle_get_message(request: Request):
async def handle_post_message(request: Request):
return await handle_post_message(request)
async def handle_post_message(request: Request):
"""Handle POST messages for SSE"""
try:
body = await request.body()
# async def handle_post_message(request: Request):
# """Handle POST messages for SSE"""
# try:
# body = await request.body()
# Create a simple receive function that returns the body
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
# # Create a simple receive function that returns the body
# async def receive():
# return {"type": "http.request", "body": body, "more_body": False}
# Create a simple send function that does nothing
async def send(message):
return {}
# # Create a simple send function that does nothing
# async def send(message):
# return {}
# Call handle_post_message with the correct arguments
await sse.handle_post_message(request.scope, receive, send)
# # Call handle_post_message with the correct arguments
# await sse.handle_post_message(request.scope, receive, send)
# Return a success response
return {"status": "ok"}
finally:
pass
# Clean up context variable
# client_name_var.reset(client_token)
# # Return a success response
# return {"status": "ok"}
# finally:
# pass
# # Clean up context variable
# # client_name_var.reset(client_token)
def setup_mcp_server(app: FastAPI):
"""Setup MCP server with the FastAPI application"""
mcp._mcp_server.name = f"mem0-mcp-server"
mcp._mcp_server.name = "mem0-mcp-server"
# Include MCP router in the FastAPI app
app.include_router(mcp_router)

View File

@@ -1,15 +1,25 @@
import datetime
import enum
import uuid
import datetime
import sqlalchemy as sa
from sqlalchemy import (
Column, String, Boolean, ForeignKey, Enum, Table,
DateTime, JSON, Integer, UUID, Index, event
)
from sqlalchemy.orm import relationship
from app.database import Base
from sqlalchemy.orm import Session
from app.utils.categorization import get_categories_for_memory
from sqlalchemy import (
JSON,
UUID,
Boolean,
Column,
DateTime,
Enum,
ForeignKey,
Index,
Integer,
String,
Table,
event,
)
from sqlalchemy.orm import Session, relationship
def get_current_utc_time():

View File

@@ -1,6 +1,6 @@
from .memories import router as memories_router
from .apps import router as apps_router
from .stats import router as stats_router
from .config import router as config_router
from .memories import router as memories_router
from .stats import router as stats_router
__all__ = ["memories_router", "apps_router", "stats_router", "config_router"]
__all__ = ["memories_router", "apps_router", "stats_router", "config_router"]

View File

@@ -1,11 +1,11 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import func, desc
from app.database import get_db
from app.models import App, Memory, MemoryAccessLog, MemoryState
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import desc, func
from sqlalchemy.orm import Session, joinedload
router = APIRouter(prefix="/api/v1/apps", tags=["apps"])

View File

@@ -1,12 +1,11 @@
import os
import json
from typing import Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from typing import Any, Dict, Optional
from app.database import get_db
from app.models import Config as ConfigModel
from app.utils.memory import reset_memory_client
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
router = APIRouter(prefix="/api/v1/config", tags=["config"])

View File

@@ -1,23 +1,28 @@
from datetime import datetime, UTC
from typing import List, Optional, Set
from uuid import UUID, uuid4
import logging
import os
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
from pydantic import BaseModel
from sqlalchemy import or_, func
from app.utils.memory import get_memory_client
from datetime import UTC, datetime
from typing import List, Optional, Set
from uuid import UUID
from app.database import get_db
from app.models import (
Memory, MemoryState, MemoryAccessLog, App,
MemoryStatusHistory, User, Category, AccessControl, Config as ConfigModel
AccessControl,
App,
Category,
Memory,
MemoryAccessLog,
MemoryState,
MemoryStatusHistory,
User,
)
from app.schemas import MemoryResponse, PaginatedMemoryResponse
from app.schemas import MemoryResponse
from app.utils.memory import get_memory_client
from app.utils.permissions import check_memory_access_permissions
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
router = APIRouter(prefix="/api/v1/memories", tags=["memories"])
@@ -412,7 +417,7 @@ async def pause_memories(
).all()
for memory in memories:
update_memory_state(db, memory.id, state, user_id)
return {"message": f"Successfully paused all memories"}
return {"message": "Successfully paused all memories"}
if memory_ids:
# Pause specific memories

View File

@@ -1,8 +1,7 @@
from app.database import get_db
from app.models import App, Memory, MemoryState, User
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.database import get_db
from app.models import User, Memory, App, MemoryState
router = APIRouter(prefix="/api/v1/stats", tags=["stats"])

View File

@@ -1,8 +1,10 @@
from datetime import datetime
from typing import Optional, List
from typing import List, Optional
from uuid import UUID
from pydantic import BaseModel, Field, validator
class MemoryBase(BaseModel):
content: str
metadata_: Optional[dict] = Field(default_factory=dict)

View File

@@ -1,11 +1,11 @@
import logging
from typing import List
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
load_dotenv()
openai_client = OpenAI()

View File

@@ -1,7 +1,8 @@
from sqlalchemy.orm import Session
from app.models import User, App
from typing import Tuple
from app.models import App, User
from sqlalchemy.orm import Session
def get_or_create_user(db: Session, user_id: str) -> User:
"""Get or create a user with the given user_id"""

View File

@@ -27,16 +27,15 @@ Example configuration that will be automatically adjusted:
}
"""
import os
import json
import hashlib
import json
import os
import socket
import platform
from mem0 import Memory
from app.database import SessionLocal
from app.models import Config as ConfigModel
from mem0 import Memory
_memory_client = None
_config_hash = None

View File

@@ -1,7 +1,8 @@
from typing import Optional
from uuid import UUID
from app.models import App, Memory, MemoryState
from sqlalchemy.orm import Session
from app.models import Memory, App, MemoryState
def check_memory_access_permissions(

View File

@@ -1,13 +1,14 @@
import datetime
from fastapi import FastAPI
from app.database import engine, Base, SessionLocal
from app.mcp_server import setup_mcp_server
from app.routers import memories_router, apps_router, stats_router, config_router
from fastapi_pagination import add_pagination
from fastapi.middleware.cors import CORSMiddleware
from app.models import User, App
from uuid import uuid4
from app.config import USER_ID, DEFAULT_APP_ID
from app.config import DEFAULT_APP_ID, USER_ID
from app.database import Base, SessionLocal, engine
from app.mcp_server import setup_mcp_server
from app.models import App, User
from app.routers import apps_router, config_router, memories_router, stats_router
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
app = FastAPI(title="OpenMemory API")

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch, ANY
from unittest.mock import ANY, patch
import pytest

View File

@@ -0,0 +1,175 @@
from unittest.mock import MagicMock, patch
from mem0.memory.main import Memory
def test_memory_configuration_without_env_vars():
"""Test Memory configuration with mock config instead of environment variables"""
# Mock configuration without relying on environment variables
mock_config = {
"llm": {
"provider": "openai",
"config": {
"model": "gpt-4",
"temperature": 0.1,
"max_tokens": 1500,
}
},
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": "test_collection",
"path": "./test_db",
}
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-ada-002",
}
}
}
# Test messages similar to the main.py file
test_messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
]
# Mock the Memory class methods to avoid actual API calls
with patch.object(Memory, '__init__', return_value=None):
with patch.object(Memory, 'from_config') as mock_from_config:
with patch.object(Memory, 'add') as mock_add:
with patch.object(Memory, 'get_all') as mock_get_all:
# Configure mocks
mock_memory_instance = MagicMock()
mock_from_config.return_value = mock_memory_instance
mock_add.return_value = {
"results": [
{"id": "1", "text": "Alex is a vegetarian"},
{"id": "2", "text": "Alex is allergic to nuts"}
]
}
mock_get_all.return_value = [
{"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}},
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}}
]
# Test the workflow
mem = Memory.from_config(config_dict=mock_config)
assert mem is not None
# Test adding memories
result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"})
assert "results" in result
assert len(result["results"]) == 2
# Test retrieving memories
all_memories = mock_get_all(user_id="alice")
assert len(all_memories) == 2
assert any("vegetarian" in memory["text"] for memory in all_memories)
assert any("allergic to nuts" in memory["text"] for memory in all_memories)
def test_azure_config_structure():
"""Test that Azure configuration structure is properly formatted"""
# Test Azure configuration structure (without actual credentials)
azure_config = {
"llm": {
"provider": "azure_openai",
"config": {
"model": "gpt-4",
"temperature": 0.1,
"max_tokens": 1500,
"azure_kwargs": {
"azure_deployment": "test-deployment",
"api_version": "2023-12-01-preview",
"azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key",
}
}
},
"vector_store": {
"provider": "azure_ai_search",
"config": {
"service_name": "test-service",
"api_key": "test-key",
"collection_name": "test-collection",
"embedding_model_dims": 1536,
}
},
"embedder": {
"provider": "azure_openai",
"config": {
"model": "text-embedding-ada-002",
"api_key": "test-key",
"azure_kwargs": {
"api_version": "2023-12-01-preview",
"azure_deployment": "test-embedding-deployment",
"azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key",
}
}
}
}
# Validate configuration structure
assert "llm" in azure_config
assert "vector_store" in azure_config
assert "embedder" in azure_config
# Validate Azure-specific configurations
assert azure_config["llm"]["provider"] == "azure_openai"
assert "azure_kwargs" in azure_config["llm"]["config"]
assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"]
assert azure_config["vector_store"]["provider"] == "azure_ai_search"
assert "service_name" in azure_config["vector_store"]["config"]
assert azure_config["embedder"]["provider"] == "azure_openai"
assert "azure_kwargs" in azure_config["embedder"]["config"]
def test_memory_messages_format():
"""Test that memory messages are properly formatted"""
# Test message format from main.py
messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
]
# Validate message structure
assert len(messages) == 2
assert all("role" in msg for msg in messages)
assert all("content" in msg for msg in messages)
# Validate roles
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
# Validate content
assert "vegetarian" in messages[0]["content"].lower()
assert "allergic to nuts" in messages[0]["content"].lower()
assert "vegetarian" in messages[1]["content"].lower()
assert "nut allergy" in messages[1]["content"].lower()
def test_safe_update_prompt_constant():
"""Test the SAFE_UPDATE_PROMPT constant from main.py"""
SAFE_UPDATE_PROMPT = """
Based on the user's latest messages, what new preference can be inferred?
Reply only in this json_object format:
"""
# Validate prompt structure
assert isinstance(SAFE_UPDATE_PROMPT, str)
assert "user's latest messages" in SAFE_UPDATE_PROMPT
assert "json_object format" in SAFE_UPDATE_PROMPT
assert len(SAFE_UPDATE_PROMPT.strip()) > 0

View File

@@ -1,11 +1,16 @@
from unittest.mock import Mock, patch, PropertyMock
from unittest.mock import Mock, PropertyMock, patch
import pytest
from pymochow.exception import ServerError
from pymochow.model.enum import ServerErrCode, TableState
from pymochow.model.table import (
FloatVector,
Table,
VectorSearchConfig,
VectorTopkSearchRequest,
)
from mem0.vector_stores.baidu import BaiduDB
from pymochow.model.enum import TableState, ServerErrCode
from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.exception import ServerError
@pytest.fixture

View File

@@ -1,5 +1,7 @@
import pytest
from unittest.mock import MagicMock, patch
import pytest
from mem0.vector_stores.mongodb import MongoDB