Fix: Changed keyword from assisstant to secretary (#2937)
This commit is contained in:
@@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.memory.utils import extract_json
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
@@ -8,10 +8,12 @@ export CARTESIA_API_KEY=your_cartesia_api_key
|
||||
"""
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
from agno.agent import Agent
|
||||
from agno.models.openai import OpenAIChat
|
||||
from agno.tools.cartesia import CartesiaTools
|
||||
from agno.utils.audio import write_audio_to_file
|
||||
|
||||
from mem0 import MemoryClient
|
||||
|
||||
memory_client = MemoryClient()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
|
||||
from agents import Agent, Runner, enable_verbose_stdout_logging, function_tool
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from mem0 import MemoryClient
|
||||
|
||||
@@ -11,7 +11,6 @@ except ImportError:
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||
|
||||
|
||||
|
||||
@@ -82,6 +82,12 @@ class AzureOpenAILLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
user_prompt = messages[-1]['content']
|
||||
|
||||
user_prompt = user_prompt.replace("assistant", "ai")
|
||||
|
||||
messages[-1]['content'] = user_prompt
|
||||
|
||||
common_params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -48,6 +48,13 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
user_prompt = messages[-1]['content']
|
||||
|
||||
user_prompt = user_prompt.replace("assistant", "ai")
|
||||
|
||||
messages[-1]['content'] = user_prompt
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import requests
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from openai import OpenAI
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
@@ -8,12 +8,31 @@ from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
try:
|
||||
import pymochow
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.model.enum import FieldType, MetricType, IndexType, TableState, ServerErrCode
|
||||
from pymochow.model.schema import Field, Schema, VectorIndex, FilteringIndex, HNSWParams, AutoBuildRowCountIncrement
|
||||
from pymochow.model.table import Partition, Row, VectorSearchConfig, VectorTopkSearchRequest, FloatVector
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import (
|
||||
FieldType,
|
||||
IndexType,
|
||||
MetricType,
|
||||
ServerErrCode,
|
||||
TableState,
|
||||
)
|
||||
from pymochow.model.schema import (
|
||||
AutoBuildRowCountIncrement,
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
Schema,
|
||||
VectorIndex,
|
||||
)
|
||||
from pymochow.model.table import (
|
||||
FloatVector,
|
||||
Partition,
|
||||
Row,
|
||||
VectorSearchConfig,
|
||||
VectorTopkSearchRequest,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
from pymongo.operations import SearchIndexModel
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymongo' library is required. Please install it using 'pip install pymongo'.")
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from redisvl.index import SearchIndex
|
||||
from redisvl.query import VectorQuery
|
||||
from redisvl.query.filter import Tag
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
from mem0.memory.utils import extract_json
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -15,9 +12,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Import your models here
|
||||
from app.database import Base
|
||||
from app.models import * # Import all your models
|
||||
# Import your models here - moved after path setup
|
||||
from app.database import Base # noqa: E402
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
||||
@@ -7,9 +7,8 @@ Create Date: 2025-04-19 00:59:56.244203
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0b53c747049a'
|
||||
|
||||
@@ -5,11 +5,11 @@ Revises: 0b53c747049a
|
||||
Create Date: 2023-06-01 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_config_table'
|
||||
down_revision = '0b53c747049a'
|
||||
|
||||
@@ -8,8 +8,6 @@ Create Date: 2025-06-04 01:59:41.637440
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'afd00efbd06b'
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# load .env file (make sure you have DATABASE_URL set)
|
||||
load_dotenv()
|
||||
|
||||
@@ -15,22 +15,22 @@ Key features:
|
||||
- Environment variable parsing for API keys
|
||||
"""
|
||||
|
||||
import logging
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models import Memory, MemoryAccessLog, MemoryState, MemoryStatusHistory
|
||||
from app.utils.db import get_user_and_app
|
||||
from app.utils.memory import get_memory_client
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.routing import APIRouter
|
||||
import contextvars
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from app.database import SessionLocal
|
||||
from app.models import Memory, MemoryState, MemoryStatusHistory, MemoryAccessLog
|
||||
from app.utils.db import get_user_and_app
|
||||
import uuid
|
||||
import datetime
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from qdrant_client import models as qdrant_models
|
||||
|
||||
# Load environment variables
|
||||
@@ -410,32 +410,32 @@ async def handle_get_message(request: Request):
|
||||
async def handle_post_message(request: Request):
|
||||
return await handle_post_message(request)
|
||||
|
||||
async def handle_post_message(request: Request):
|
||||
"""Handle POST messages for SSE"""
|
||||
try:
|
||||
body = await request.body()
|
||||
# async def handle_post_message(request: Request):
|
||||
# """Handle POST messages for SSE"""
|
||||
# try:
|
||||
# body = await request.body()
|
||||
|
||||
# Create a simple receive function that returns the body
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
# # Create a simple receive function that returns the body
|
||||
# async def receive():
|
||||
# return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
# Create a simple send function that does nothing
|
||||
async def send(message):
|
||||
return {}
|
||||
# # Create a simple send function that does nothing
|
||||
# async def send(message):
|
||||
# return {}
|
||||
|
||||
# Call handle_post_message with the correct arguments
|
||||
await sse.handle_post_message(request.scope, receive, send)
|
||||
# # Call handle_post_message with the correct arguments
|
||||
# await sse.handle_post_message(request.scope, receive, send)
|
||||
|
||||
# Return a success response
|
||||
return {"status": "ok"}
|
||||
finally:
|
||||
pass
|
||||
# Clean up context variable
|
||||
# client_name_var.reset(client_token)
|
||||
# # Return a success response
|
||||
# return {"status": "ok"}
|
||||
# finally:
|
||||
# pass
|
||||
# # Clean up context variable
|
||||
# # client_name_var.reset(client_token)
|
||||
|
||||
def setup_mcp_server(app: FastAPI):
|
||||
"""Setup MCP server with the FastAPI application"""
|
||||
mcp._mcp_server.name = f"mem0-mcp-server"
|
||||
mcp._mcp_server.name = "mem0-mcp-server"
|
||||
|
||||
# Include MCP router in the FastAPI app
|
||||
app.include_router(mcp_router)
|
||||
|
||||
@@ -1,15 +1,25 @@
|
||||
import datetime
|
||||
import enum
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import (
|
||||
Column, String, Boolean, ForeignKey, Enum, Table,
|
||||
DateTime, JSON, Integer, UUID, Index, event
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
from sqlalchemy.orm import Session
|
||||
from app.utils.categorization import get_categories_for_memory
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Table,
|
||||
event,
|
||||
)
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
|
||||
def get_current_utc_time():
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .memories import router as memories_router
|
||||
from .apps import router as apps_router
|
||||
from .stats import router as stats_router
|
||||
from .config import router as config_router
|
||||
from .memories import router as memories_router
|
||||
from .stats import router as stats_router
|
||||
|
||||
__all__ = ["memories_router", "apps_router", "stats_router", "config_router"]
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import func, desc
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import App, Memory, MemoryAccessLog, MemoryState
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
router = APIRouter(prefix="/api/v1/apps", tags=["apps"])
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Config as ConfigModel
|
||||
from app.utils.memory import reset_memory_client
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/api/v1/config", tags=["config"])
|
||||
|
||||
|
||||
@@ -1,23 +1,28 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Set
|
||||
from uuid import UUID, uuid4
|
||||
import logging
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from fastapi_pagination import Page, Params
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import or_, func
|
||||
from app.utils.memory import get_memory_client
|
||||
from datetime import UTC, datetime
|
||||
from typing import List, Optional, Set
|
||||
from uuid import UUID
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
Memory, MemoryState, MemoryAccessLog, App,
|
||||
MemoryStatusHistory, User, Category, AccessControl, Config as ConfigModel
|
||||
AccessControl,
|
||||
App,
|
||||
Category,
|
||||
Memory,
|
||||
MemoryAccessLog,
|
||||
MemoryState,
|
||||
MemoryStatusHistory,
|
||||
User,
|
||||
)
|
||||
from app.schemas import MemoryResponse, PaginatedMemoryResponse
|
||||
from app.schemas import MemoryResponse
|
||||
from app.utils.memory import get_memory_client
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi_pagination import Page, Params
|
||||
from fastapi_pagination.ext.sqlalchemy import paginate as sqlalchemy_paginate
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
router = APIRouter(prefix="/api/v1/memories", tags=["memories"])
|
||||
|
||||
@@ -412,7 +417,7 @@ async def pause_memories(
|
||||
).all()
|
||||
for memory in memories:
|
||||
update_memory_state(db, memory.id, state, user_id)
|
||||
return {"message": f"Successfully paused all memories"}
|
||||
return {"message": "Successfully paused all memories"}
|
||||
|
||||
if memory_ids:
|
||||
# Pause specific memories
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from app.database import get_db
|
||||
from app.models import App, Memory, MemoryState, User
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models import User, Memory, App, MemoryState
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/stats", tags=["stats"])
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class MemoryBase(BaseModel):
|
||||
content: str
|
||||
metadata_: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from app.utils.prompts import MEMORY_CATEGORIZATION_PROMPT
|
||||
|
||||
load_dotenv()
|
||||
openai_client = OpenAI()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models import User, App
|
||||
from typing import Tuple
|
||||
|
||||
from app.models import App, User
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def get_or_create_user(db: Session, user_id: str) -> User:
|
||||
"""Get or create a user with the given user_id"""
|
||||
|
||||
@@ -27,16 +27,15 @@ Example configuration that will be automatically adjusted:
|
||||
}
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import platform
|
||||
|
||||
from mem0 import Memory
|
||||
from app.database import SessionLocal
|
||||
from app.models import Config as ConfigModel
|
||||
|
||||
from mem0 import Memory
|
||||
|
||||
_memory_client = None
|
||||
_config_hash = None
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from app.models import App, Memory, MemoryState
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models import Memory, App, MemoryState
|
||||
|
||||
|
||||
def check_memory_access_permissions(
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import datetime
|
||||
from fastapi import FastAPI
|
||||
from app.database import engine, Base, SessionLocal
|
||||
from app.mcp_server import setup_mcp_server
|
||||
from app.routers import memories_router, apps_router, stats_router, config_router
|
||||
from fastapi_pagination import add_pagination
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.models import User, App
|
||||
from uuid import uuid4
|
||||
from app.config import USER_ID, DEFAULT_APP_ID
|
||||
|
||||
from app.config import DEFAULT_APP_ID, USER_ID
|
||||
from app.database import Base, SessionLocal, engine
|
||||
from app.mcp_server import setup_mcp_server
|
||||
from app.models import App, User
|
||||
from app.routers import apps_router, config_router, memories_router, stats_router
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_pagination import add_pagination
|
||||
|
||||
app = FastAPI(title="OpenMemory API")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch, ANY
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
175
tests/test_memory_integration.py
Normal file
175
tests/test_memory_integration.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
|
||||
def test_memory_configuration_without_env_vars():
|
||||
"""Test Memory configuration with mock config instead of environment variables"""
|
||||
|
||||
# Mock configuration without relying on environment variables
|
||||
mock_config = {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 1500,
|
||||
}
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": "test_collection",
|
||||
"path": "./test_db",
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-ada-002",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Test messages similar to the main.py file
|
||||
test_messages = [
|
||||
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
|
||||
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
|
||||
]
|
||||
|
||||
# Mock the Memory class methods to avoid actual API calls
|
||||
with patch.object(Memory, '__init__', return_value=None):
|
||||
with patch.object(Memory, 'from_config') as mock_from_config:
|
||||
with patch.object(Memory, 'add') as mock_add:
|
||||
with patch.object(Memory, 'get_all') as mock_get_all:
|
||||
|
||||
# Configure mocks
|
||||
mock_memory_instance = MagicMock()
|
||||
mock_from_config.return_value = mock_memory_instance
|
||||
|
||||
mock_add.return_value = {
|
||||
"results": [
|
||||
{"id": "1", "text": "Alex is a vegetarian"},
|
||||
{"id": "2", "text": "Alex is allergic to nuts"}
|
||||
]
|
||||
}
|
||||
|
||||
mock_get_all.return_value = [
|
||||
{"id": "1", "text": "Alex is a vegetarian", "metadata": {"category": "dietary_preferences"}},
|
||||
{"id": "2", "text": "Alex is allergic to nuts", "metadata": {"category": "allergies"}}
|
||||
]
|
||||
|
||||
# Test the workflow
|
||||
mem = Memory.from_config(config_dict=mock_config)
|
||||
assert mem is not None
|
||||
|
||||
# Test adding memories
|
||||
result = mock_add(test_messages, user_id="alice", metadata={"category": "book_recommendations"})
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 2
|
||||
|
||||
# Test retrieving memories
|
||||
all_memories = mock_get_all(user_id="alice")
|
||||
assert len(all_memories) == 2
|
||||
assert any("vegetarian" in memory["text"] for memory in all_memories)
|
||||
assert any("allergic to nuts" in memory["text"] for memory in all_memories)
|
||||
|
||||
|
||||
def test_azure_config_structure():
|
||||
"""Test that Azure configuration structure is properly formatted"""
|
||||
|
||||
# Test Azure configuration structure (without actual credentials)
|
||||
azure_config = {
|
||||
"llm": {
|
||||
"provider": "azure_openai",
|
||||
"config": {
|
||||
"model": "gpt-4",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 1500,
|
||||
"azure_kwargs": {
|
||||
"azure_deployment": "test-deployment",
|
||||
"api_version": "2023-12-01-preview",
|
||||
"azure_endpoint": "https://test.openai.azure.com/",
|
||||
"api_key": "test-key",
|
||||
}
|
||||
}
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": "azure_ai_search",
|
||||
"config": {
|
||||
"service_name": "test-service",
|
||||
"api_key": "test-key",
|
||||
"collection_name": "test-collection",
|
||||
"embedding_model_dims": 1536,
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "azure_openai",
|
||||
"config": {
|
||||
"model": "text-embedding-ada-002",
|
||||
"api_key": "test-key",
|
||||
"azure_kwargs": {
|
||||
"api_version": "2023-12-01-preview",
|
||||
"azure_deployment": "test-embedding-deployment",
|
||||
"azure_endpoint": "https://test.openai.azure.com/",
|
||||
"api_key": "test-key",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Validate configuration structure
|
||||
assert "llm" in azure_config
|
||||
assert "vector_store" in azure_config
|
||||
assert "embedder" in azure_config
|
||||
|
||||
# Validate Azure-specific configurations
|
||||
assert azure_config["llm"]["provider"] == "azure_openai"
|
||||
assert "azure_kwargs" in azure_config["llm"]["config"]
|
||||
assert "azure_deployment" in azure_config["llm"]["config"]["azure_kwargs"]
|
||||
|
||||
assert azure_config["vector_store"]["provider"] == "azure_ai_search"
|
||||
assert "service_name" in azure_config["vector_store"]["config"]
|
||||
|
||||
assert azure_config["embedder"]["provider"] == "azure_openai"
|
||||
assert "azure_kwargs" in azure_config["embedder"]["config"]
|
||||
|
||||
|
||||
def test_memory_messages_format():
|
||||
"""Test that memory messages are properly formatted"""
|
||||
|
||||
# Test message format from main.py
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."},
|
||||
{"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."}
|
||||
]
|
||||
|
||||
# Validate message structure
|
||||
assert len(messages) == 2
|
||||
assert all("role" in msg for msg in messages)
|
||||
assert all("content" in msg for msg in messages)
|
||||
|
||||
# Validate roles
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[1]["role"] == "assistant"
|
||||
|
||||
# Validate content
|
||||
assert "vegetarian" in messages[0]["content"].lower()
|
||||
assert "allergic to nuts" in messages[0]["content"].lower()
|
||||
assert "vegetarian" in messages[1]["content"].lower()
|
||||
assert "nut allergy" in messages[1]["content"].lower()
|
||||
|
||||
|
||||
def test_safe_update_prompt_constant():
|
||||
"""Test the SAFE_UPDATE_PROMPT constant from main.py"""
|
||||
|
||||
SAFE_UPDATE_PROMPT = """
|
||||
Based on the user's latest messages, what new preference can be inferred?
|
||||
Reply only in this json_object format:
|
||||
"""
|
||||
|
||||
# Validate prompt structure
|
||||
assert isinstance(SAFE_UPDATE_PROMPT, str)
|
||||
assert "user's latest messages" in SAFE_UPDATE_PROMPT
|
||||
assert "json_object format" in SAFE_UPDATE_PROMPT
|
||||
assert len(SAFE_UPDATE_PROMPT.strip()) > 0
|
||||
@@ -1,11 +1,16 @@
|
||||
from unittest.mock import Mock, patch, PropertyMock
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import ServerErrCode, TableState
|
||||
from pymochow.model.table import (
|
||||
FloatVector,
|
||||
Table,
|
||||
VectorSearchConfig,
|
||||
VectorTopkSearchRequest,
|
||||
)
|
||||
|
||||
from mem0.vector_stores.baidu import BaiduDB
|
||||
from pymochow.model.enum import TableState, ServerErrCode
|
||||
from pymochow.model.table import VectorSearchConfig, VectorTopkSearchRequest, FloatVector, Table
|
||||
from pymochow.exception import ServerError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.vector_stores.mongodb import MongoDB
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user