Fix: Changed keyword from assisstant to secretary (#2937)

This commit is contained in:
Varun Mohanta
2025-07-08 10:57:25 +05:30
committed by GitHub
parent 6866e56d7a
commit aae5989e78
35 changed files with 351 additions and 122 deletions

View File

@@ -4,6 +4,7 @@ from collections import defaultdict
import numpy as np import numpy as np
from openai import OpenAI from openai import OpenAI
from mem0.memory.utils import extract_json from mem0.memory.utils import extract_json
client = OpenAI() client = OpenAI()

View File

@@ -8,10 +8,12 @@ export CARTESIA_API_KEY=your_cartesia_api_key
""" """
from textwrap import dedent from textwrap import dedent
from agno.agent import Agent from agno.agent import Agent
from agno.models.openai import OpenAIChat from agno.models.openai import OpenAIChat
from agno.tools.cartesia import CartesiaTools from agno.tools.cartesia import CartesiaTools
from agno.utils.audio import write_audio_to_file from agno.utils.audio import write_audio_to_file
from mem0 import MemoryClient from mem0 import MemoryClient
memory_client = MemoryClient() memory_client = MemoryClient()

View File

@@ -1,4 +1,4 @@
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging from agents import Agent, Runner, enable_verbose_stdout_logging, function_tool
from dotenv import load_dotenv from dotenv import load_dotenv
from mem0 import MemoryClient from mem0 import MemoryClient

View File

@@ -11,7 +11,6 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"] PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]

View File

@@ -82,6 +82,12 @@ class AzureOpenAILLM(LLMBase):
str: The generated response. str: The generated response.
""" """
user_prompt = messages[-1]['content']
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
common_params = { common_params = {
"model": self.config.model, "model": self.config.model,
"messages": messages, "messages": messages,

View File

@@ -48,6 +48,13 @@ class AzureOpenAIStructuredLLM(LLMBase):
Returns: Returns:
str: The generated response. str: The generated response.
""" """
user_prompt = messages[-1]['content']
user_prompt = user_prompt.replace("assistant", "ai")
messages[-1]['content'] = user_prompt
params = { params = {
"model": self.config.model, "model": self.config.model,
"messages": messages, "messages": messages,

View File

@@ -1,6 +1,8 @@
import os import os
import requests
from typing import Dict, List, Optional from typing import Dict, List, Optional
import requests
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase

View File

@@ -1,8 +1,8 @@
import json import json
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
from openai import OpenAI
from openai import OpenAI
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase

View File

@@ -1,5 +1,5 @@
import re
import hashlib import hashlib
import re
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT

View File

@@ -5,8 +5,8 @@ from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
try: try:
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AzureKeyCredential

View File

@@ -8,12 +8,31 @@ from mem0.vector_stores.base import VectorStoreBase
try: try:
import pymochow import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials from pymochow.auth.bce_credentials import BceCredentials
from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode from pymochow.configuration import Configuration
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector
from pymochow.exception import ServerError from pymochow.exception import ServerError
from pymochow.model.enum import (
FieldType,
IndexType,
MetricType,
ServerErrCode,
TableState,
)
from pymochow.model.schema import (
AutoBuildRowCountIncrement,
Field,
FilteringIndex,
HNSWParams,
Schema,
VectorIndex,
)
from pymochow.model.table import (
FloatVector,
Partition,
Row,
VectorSearchConfig,
VectorTopkSearchRequest,
)
except ImportError: except ImportError:
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.") raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")

View File

@@ -1,12 +1,12 @@
import logging import logging
from typing import List, Optional, Dict, Any from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
try: try:
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError
from pymongo.operations import SearchIndexModel
except ImportError: except ImportError:
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.") raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")

View File

@@ -11,8 +11,8 @@ from redisvl.index import SearchIndex
from redisvl.query import VectorQuery from redisvl.query import VectorQuery
from redisvl.query.filter import Tag from redisvl.query.filter import Tag
from mem0.vector_stores.base import VectorStoreBase
from mem0.memory.utils import extract_json from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,13 +1,10 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
import os import os
import sys import sys
from logging.config import fileConfig
from alembic import context
from dotenv import load_dotenv from dotenv import load_dotenv
from sqlalchemy import engine_from_config, pool
# Add the parent directory to the Python path # Add the parent directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -15,9 +12,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
# Import your models here # Import your models here - moved after path setup
from app.database import Base from app.database import Base # noqa: E402
from app.models import * # Import all your models
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View File

@@ -7,9 +7,8 @@ Create Date: 2025-04-19 00:59:56.244203
""" """
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '0b53c747049a' revision: str = '0b53c747049a'

View File

@@ -5,11 +5,11 @@ Revises: 0b53c747049a
Create Date: 2023-06-01 10:00:00.000000 Create Date: 2023-06-01 10:00:00.000000
""" """
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import uuid import uuid
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'add_config_table' revision = 'add_config_table'
down_revision = '0b53c747049a' down_revision = '0b53c747049a'

View File

@@ -8,8 +8,6 @@ Create Date: 2025-06-04 01:59:41.637440
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = 'afd00efbd06b' revision: str = 'afd00efbd06b'

View File

@@ -1,7 +1,8 @@
import os import os
from dotenv import load_dotenv
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from dotenv import load_dotenv
# load .env file (make sure you have DATABASE_URL set) # load .env file (make sure you have DATABASE_URL set)
load_dotenv() load_dotenv()

View File

@@ -15,22 +15,22 @@ Key features:
- Environment variable parsing for API keys - Environment variable parsing for API keys
""" """
import logging import contextvars
import datetime
import json import json
from mcp.server.fastmcp import FastMCP import logging
from mcp.server.sse import SseServerTransport import uuid
from app.database import SessionLocal
from app.models import Memory, MemoryAccessLog, MemoryState, MemoryStatusHistory
from app.utils.db import get_user_and_app
from app.utils.memory import get_memory_client from app.utils.memory import get_memory_client
from app.utils.permissions import check_memory_access_permissions
from dotenv import load_dotenv
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
import contextvars from mcp.server.fastmcp import FastMCP
import os from mcp.server.sse import SseServerTransport
from dotenv import load_dotenv
from app.database import SessionLocal
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
from app.utils.db import get_user_and_app
import uuid
import datetime
from app.utils.permissions import check_memory_access_permissions
from qdrant_client import models as qdrant_models from qdrant_client import models as qdrant_models
# Load environment variables # Load environment variables
@@ -410,32 +410,32 @@ async def handle_get_message(request: Request):
async def handle_post_message(request: Request): async def handle_post_message(request: Request):
return await handle_post_message(request) return await handle_post_message(request)
async def handle_post_message(request: Request): # async def handle_post_message(request: Request):
"""Handle POST messages for SSE""" # """Handle POST messages for SSE"""
try: # try:
body = await request.body() # body = await request.body()
# Create a simple receive function that returns the body # # Create a simple receive function that returns the body
async def receive(): # async def receive():
return {"type": "http.request", "body": body, "more_body": False} # return {"type": "http.request", "body": body, "more_body": False}
# Create a simple send function that does nothing # # Create a simple send function that does nothing
async def send(message): # async def send(message):
return {} # return {}
# Call handle_post_message with the correct arguments # # Call handle_post_message with the correct arguments
await sse.handle_post_message(request.scope, receive, send) # await sse.handle_post_message(request.scope, receive, send)
# Return a success response # # Return a success response
return {"status": "ok"} # return {"status": "ok"}
finally: # finally:
pass # pass
# Clean up context variable # # Clean up context variable
# client_name_var.reset(client_token) # # client_name_var.reset(client_token)
def setup_mcp_server(app: FastAPI): def setup_mcp_server(app: FastAPI):
"""Setup MCP server with the FastAPI application""" """Setup MCP server with the FastAPI application"""
mcp._mcp_server.name = f"mem0-mcp-server" mcp._mcp_server.name = "mem0-mcp-server"
# Include MCP router in the FastAPI app # Include MCP router in the FastAPI app
app.include_router(mcp_router) app.include_router(mcp_router)

View File

@@ -1,15 +1,25 @@
import datetime
import enum import enum
import uuid import uuid
import datetime
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import (
Column, String, Boolean, ForeignKey, Enum, Table,
DateTime, JSON, Integer, UUID, Index, event
)
from sqlalchemy.orm import relationship
from app.database import Base from app.database import Base
from sqlalchemy.orm import Session
from app.utils.categorization import get_categories_for_memory from app.utils.categorization import get_categories_for_memory
from sqlalchemy import (
JSON,
UUID,
Boolean,
Column,
DateTime,
Enum,
ForeignKey,
Index,
Integer,
String,
Table,
event,
)
from sqlalchemy.orm import Session, relationship
def get_current_utc_time(): def get_current_utc_time():

View File

@@ -1,6 +1,6 @@
from .memories import router as memories_router
from .apps import router as apps_router from .apps import router as apps_router
from .stats import router as stats_router
from .config import router as config_router from .config import router as config_router
from .memories import router as memories_router
from .stats import router as stats_router
__all__ = ["memories_router", "apps_router", "stats_router", "config_router"] __all__ = ["memories_router", "apps_router", "stats_router", "config_router"]

View File

@@ -1,11 +1,11 @@
from typing import Optional from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import func, desc
from app.database import get_db from app.database import get_db
from app.models import App, Memory, MemoryAccessLog, MemoryState from app.models import App, Memory, MemoryAccessLog, MemoryState
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import desc, func
from sqlalchemy.orm import Session, joinedload
router = APIRouter(prefix="/api/v1/apps", tags=["apps"]) router = APIRouter(prefix="/api/v1/apps", tags=["apps"])

View File

@@ -1,12 +1,11 @@
import os from typing import Any, Dict, Optional
import json
from typing import Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.database import get_db from app.database import get_db
from app.models import Config as ConfigModel from app.models import Config as ConfigModel
from app.utils.memory import reset_memory_client from app.utils.memory import reset_memory_client
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
router = APIRouter(prefix="/api/v1/config", tags=["config"]) router = APIRouter(prefix="/api/v1/config", tags=["config"])

View File

@@ -1,23 +1,28 @@
from datetime import datetime, UTC
from typing import List, Optional, Set
from uuid import UUID, uuid4
import logging import logging
import os from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional, Set
from sqlalchemy.orm import Session, joinedload from uuid import UUID
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
from pydantic import BaseModel
from sqlalchemy import or_, func
from app.utils.memory import get_memory_client
from app.database import get_db from app.database import get_db
from app.models import ( from app.models import (
Memory, MemoryState, MemoryAccessLog, App, AccessControl,
MemoryStatusHistory, User, Category, AccessControl, Config as ConfigModel App,
Category,
Memory,
MemoryAccessLog,
MemoryState,
MemoryStatusHistory,
User,
) )
from app.schemas import MemoryResponse, PaginatedMemoryResponse from app.schemas import MemoryResponse
from app.utils.memory import get_memory_client
from app.utils.permissions import check_memory_access_permissions from app.utils.permissions import check_memory_access_permissions
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
router = APIRouter(prefix="/api/v1/memories", tags=["memories"]) router = APIRouter(prefix="/api/v1/memories", tags=["memories"])
@@ -412,7 +417,7 @@ async def pause_memories(
).all() ).all()
for memory in memories: for memory in memories:
update_memory_state(db, memory.id, state, user_id) update_memory_state(db, memory.id, state, user_id)
return {"message": f"Successfully paused all memories"} return {"message": "Successfully paused all memories"}
if memory_ids: if memory_ids:
# Pause specific memories # Pause specific memories

View File

@@ -1,8 +1,7 @@
from app.database import get_db
from app.models import App, Memory, MemoryState, User
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.database import get_db
from app.models import User, Memory, App, MemoryState
router = APIRouter(prefix="/api/v1/stats", tags=["stats"]) router = APIRouter(prefix="/api/v1/stats", tags=["stats"])

View File

@@ -1,8 +1,10 @@
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import List, Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
class MemoryBase(BaseModel): class MemoryBase(BaseModel):
content: str content: str
metadata_: Optional[dict] = Field(default_factory=dict) metadata_: Optional[dict] = Field(default_factory=dict)

View File

@@ -1,11 +1,11 @@
import logging import logging
from typing import List from typing import List
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
from dotenv import load_dotenv from dotenv import load_dotenv
from openai import OpenAI from openai import OpenAI
from pydantic import BaseModel from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
load_dotenv() load_dotenv()
openai_client = OpenAI() openai_client = OpenAI()

View File

@@ -1,7 +1,8 @@
from sqlalchemy.orm import Session
from app.models import User, App
from typing import Tuple from typing import Tuple
from app.models import App, User
from sqlalchemy.orm import Session
def get_or_create_user(db: Session, user_id: str) -> User: def get_or_create_user(db: Session, user_id: str) -> User:
"""Get or create a user with the given user_id""" """Get or create a user with the given user_id"""

View File

@@ -27,16 +27,15 @@ Example configuration that will be automatically adjusted:
} }
""" """
import os
import json
import hashlib import hashlib
import json
import os
import socket import socket
import platform
from mem0 import Memory
from app.database import SessionLocal from app.database import SessionLocal
from app.models import Config as ConfigModel from app.models import Config as ConfigModel
from mem0 import Memory
_memory_client = None _memory_client = None
_config_hash = None _config_hash = None

View File

@@ -1,7 +1,8 @@
from typing import Optional from typing import Optional
from uuid import UUID from uuid import UUID
from app.models import App, Memory, MemoryState
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.models import Memory, App, MemoryState
def check_memory_access_permissions( def check_memory_access_permissions(

View File

@@ -1,13 +1,14 @@
import datetime import datetime
from fastapi import FastAPI
from app.database import engine, Base, SessionLocal
from app.mcp_server import setup_mcp_server
from app.routers import memories_router, apps_router, stats_router, config_router
from fastapi_pagination import add_pagination
from fastapi.middleware.cors import CORSMiddleware
from app.models import User, App
from uuid import uuid4 from uuid import uuid4
from app.config import USER_ID, DEFAULT_APP_ID
from app.config import DEFAULT_APP_ID, USER_ID
from app.database import Base, SessionLocal, engine
from app.mcp_server import setup_mcp_server
from app.models import App, User
from app.routers import apps_router, config_router, memories_router, stats_router
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
app = FastAPI(title="OpenMemory API") app = FastAPI(title="OpenMemory API")

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch, ANY from unittest.mock import ANY, patch
import pytest import pytest

View File

@@ -0,0 +1,175 @@
from unittest.mock import MagicMock, patch
from mem0.memory.main import Memory
def test_memory_configuration_without_env_vars():
"""Test Memory configuration with mock config instead of environment variables"""
# Mock configuration without relying on environment variables
mock_config = {
"llm": {
"provider": "openai",
"config": {
"model": "gpt-4",
"temperature": 0.1,
"max_tokens": 1500,
}
},
"vector_store": {
"provider": "chroma",
"config": {
"collection_name": "test_collection",
"path": "./test_db",
}
},
"embedder": {
"provider": "openai",
"config": {
"model": "text-embedding-ada-002",
}
}
}
# Test messages similar to the main.py file
test_messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
]
# Mock the Memory class methods to avoid actual API calls
with patch.object(Memory, '__init__', return_value=None):
with patch.object(Memory, 'from_config') as mock_from_config:
with patch.object(Memory, 'add') as mock_add:
with patch.object(Memory, 'get_all') as mock_get_all:
# Configure mocks
mock_memory_instance = MagicMock()
mock_from_config.return_value = mock_memory_instance
mock_add.return_value = {
"results": [
{"id": "1", "text": "Alex is a vegetarian"},
{"id": "2", "text": "Alex is allergic to nuts"}
]
}
mock_get_all.return_value = [
{"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}},
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}}
]
# Test the workflow
mem = Memory.from_config(config_dict=mock_config)
assert mem is not None
# Test adding memories
result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"})
assert "results" in result
assert len(result["results"]) == 2
# Test retrieving memories
all_memories = mock_get_all(user_id="alice")
assert len(all_memories) == 2
assert any("vegetarian" in memory["text"] for memory in all_memories)
assert any("allergic to nuts" in memory["text"] for memory in all_memories)
def test_azure_config_structure():
"""Test that Azure configuration structure is properly formatted"""
# Test Azure configuration structure (without actual credentials)
azure_config = {
"llm": {
"provider": "azure_openai",
"config": {
"model": "gpt-4",
"temperature": 0.1,
"max_tokens": 1500,
"azure_kwargs": {
"azure_deployment": "test-deployment",
"api_version": "2023-12-01-preview",
"azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key",
}
}
},
"vector_store": {
"provider": "azure_ai_search",
"config": {
"service_name": "test-service",
"api_key": "test-key",
"collection_name": "test-collection",
"embedding_model_dims": 1536,
}
},
"embedder": {
"provider": "azure_openai",
"config": {
"model": "text-embedding-ada-002",
"api_key": "test-key",
"azure_kwargs": {
"api_version": "2023-12-01-preview",
"azure_deployment": "test-embedding-deployment",
"azure_endpoint": "https://test.openai.azure.com/",
"api_key": "test-key",
}
}
}
}
# Validate configuration structure
assert "llm" in azure_config
assert "vector_store" in azure_config
assert "embedder" in azure_config
# Validate Azure-specific configurations
assert azure_config["llm"]["provider"] == "azure_openai"
assert "azure_kwargs" in azure_config["llm"]["config"]
assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"]
assert azure_config["vector_store"]["provider"] == "azure_ai_search"
assert "service_name" in azure_config["vector_store"]["config"]
assert azure_config["embedder"]["provider"] == "azure_openai"
assert "azure_kwargs" in azure_config["embedder"]["config"]
def test_memory_messages_format():
"""Test that memory messages are properly formatted"""
# Test message format from main.py
messages = [
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
]
# Validate message structure
assert len(messages) == 2
assert all("role" in msg for msg in messages)
assert all("content" in msg for msg in messages)
# Validate roles
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
# Validate content
assert "vegetarian" in messages[0]["content"].lower()
assert "allergic to nuts" in messages[0]["content"].lower()
assert "vegetarian" in messages[1]["content"].lower()
assert "nut allergy" in messages[1]["content"].lower()
def test_safe_update_prompt_constant():
"""Test the SAFE_UPDATE_PROMPT constant from main.py"""
SAFE_UPDATE_PROMPT = """
Based on the user's latest messages, what new preference can be inferred?
Reply only in this json_object format:
"""
# Validate prompt structure
assert isinstance(SAFE_UPDATE_PROMPT, str)
assert "user's latest messages" in SAFE_UPDATE_PROMPT
assert "json_object format" in SAFE_UPDATE_PROMPT
assert len(SAFE_UPDATE_PROMPT.strip()) > 0

View File

@@ -1,11 +1,16 @@
from unittest.mock import Mock, patch, PropertyMock from unittest.mock import Mock, PropertyMock, patch
import pytest import pytest
from pymochow.exception import ServerError
from pymochow.model.enum import ServerErrCode, TableState
from pymochow.model.table import (
FloatVector,
Table,
VectorSearchConfig,
VectorTopkSearchRequest,
)
from mem0.vector_stores.baidu import BaiduDB from mem0.vector_stores.baidu import BaiduDB
from pymochow.model.enum import TableState, ServerErrCode
from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
from pymochow.exception import ServerError
@pytest.fixture @pytest.fixture

View File

@@ -1,5 +1,7 @@
import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from mem0.vector_stores.mongodb import MongoDB from mem0.vector_stores.mongodb import MongoDB