Feature (OpenMemory): Add support for LLM and Embedding Providers in OpenMemory (#2794)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from .memories import router as memories_router
|
||||
from .apps import router as apps_router
|
||||
from .stats import router as stats_router
|
||||
from .config import router as config_router
|
||||
|
||||
__all__ = ["memories_router", "apps_router", "stats_router"]
|
||||
__all__ = ["memories_router", "apps_router", "stats_router", "config_router"]
|
||||
240
openmemory/api/app/routers/config.py
Normal file
240
openmemory/api/app/routers/config.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models import Config as ConfigModel
|
||||
from app.utils.memory import reset_memory_client
|
||||
|
||||
router = APIRouter(prefix="/api/v1/config", tags=["config"])
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
model: str = Field(..., description="LLM model name")
|
||||
temperature: float = Field(..., description="Temperature setting for the model")
|
||||
max_tokens: int = Field(..., description="Maximum tokens to generate")
|
||||
api_key: Optional[str] = Field(None, description="API key or 'env:API_KEY' to use environment variable")
|
||||
ollama_base_url: Optional[str] = Field(None, description="Base URL for Ollama server (e.g., http://host.docker.internal:11434)")
|
||||
|
||||
class LLMProvider(BaseModel):
|
||||
provider: str = Field(..., description="LLM provider name")
|
||||
config: LLMConfig
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
model: str = Field(..., description="Embedder model name")
|
||||
api_key: Optional[str] = Field(None, description="API key or 'env:API_KEY' to use environment variable")
|
||||
ollama_base_url: Optional[str] = Field(None, description="Base URL for Ollama server (e.g., http://host.docker.internal:11434)")
|
||||
|
||||
class EmbedderProvider(BaseModel):
|
||||
provider: str = Field(..., description="Embedder provider name")
|
||||
config: EmbedderConfig
|
||||
|
||||
class OpenMemoryConfig(BaseModel):
|
||||
custom_instructions: Optional[str] = Field(None, description="Custom instructions for memory management and fact extraction")
|
||||
|
||||
class Mem0Config(BaseModel):
|
||||
llm: Optional[LLMProvider] = None
|
||||
embedder: Optional[EmbedderProvider] = None
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
openmemory: Optional[OpenMemoryConfig] = None
|
||||
mem0: Mem0Config
|
||||
|
||||
def get_default_configuration():
|
||||
"""Get the default configuration with sensible defaults for LLM and embedder."""
|
||||
return {
|
||||
"openmemory": {
|
||||
"custom_instructions": None
|
||||
},
|
||||
"mem0": {
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "gpt-4o-mini",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 2000,
|
||||
"api_key": "env:OPENAI_API_KEY"
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "env:OPENAI_API_KEY"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get_config_from_db(db: Session, key: str = "main"):
|
||||
"""Get configuration from database."""
|
||||
config = db.query(ConfigModel).filter(ConfigModel.key == key).first()
|
||||
|
||||
if not config:
|
||||
# Create default config with proper provider configurations
|
||||
default_config = get_default_configuration()
|
||||
db_config = ConfigModel(key=key, value=default_config)
|
||||
db.add(db_config)
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
return default_config
|
||||
|
||||
# Ensure the config has all required sections with defaults
|
||||
config_value = config.value
|
||||
default_config = get_default_configuration()
|
||||
|
||||
# Merge with defaults to ensure all required fields exist
|
||||
if "openmemory" not in config_value:
|
||||
config_value["openmemory"] = default_config["openmemory"]
|
||||
|
||||
if "mem0" not in config_value:
|
||||
config_value["mem0"] = default_config["mem0"]
|
||||
else:
|
||||
# Ensure LLM config exists with defaults
|
||||
if "llm" not in config_value["mem0"] or config_value["mem0"]["llm"] is None:
|
||||
config_value["mem0"]["llm"] = default_config["mem0"]["llm"]
|
||||
|
||||
# Ensure embedder config exists with defaults
|
||||
if "embedder" not in config_value["mem0"] or config_value["mem0"]["embedder"] is None:
|
||||
config_value["mem0"]["embedder"] = default_config["mem0"]["embedder"]
|
||||
|
||||
# Save the updated config back to database if it was modified
|
||||
if config_value != config.value:
|
||||
config.value = config_value
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
return config_value
|
||||
|
||||
def save_config_to_db(db: Session, config: Dict[str, Any], key: str = "main"):
|
||||
"""Save configuration to database."""
|
||||
db_config = db.query(ConfigModel).filter(ConfigModel.key == key).first()
|
||||
|
||||
if db_config:
|
||||
db_config.value = config
|
||||
db_config.updated_at = None # Will trigger the onupdate to set current time
|
||||
else:
|
||||
db_config = ConfigModel(key=key, value=config)
|
||||
db.add(db_config)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
return db_config.value
|
||||
|
||||
@router.get("/", response_model=ConfigSchema)
|
||||
async def get_configuration(db: Session = Depends(get_db)):
|
||||
"""Get the current configuration."""
|
||||
config = get_config_from_db(db)
|
||||
return config
|
||||
|
||||
@router.put("/", response_model=ConfigSchema)
|
||||
async def update_configuration(config: ConfigSchema, db: Session = Depends(get_db)):
|
||||
"""Update the configuration."""
|
||||
current_config = get_config_from_db(db)
|
||||
|
||||
# Convert to dict for processing
|
||||
updated_config = current_config.copy()
|
||||
|
||||
# Update openmemory settings if provided
|
||||
if config.openmemory is not None:
|
||||
if "openmemory" not in updated_config:
|
||||
updated_config["openmemory"] = {}
|
||||
updated_config["openmemory"].update(config.openmemory.dict(exclude_none=True))
|
||||
|
||||
# Update mem0 settings
|
||||
updated_config["mem0"] = config.mem0.dict(exclude_none=True)
|
||||
|
||||
# Save the configuration to database
|
||||
save_config_to_db(db, updated_config)
|
||||
reset_memory_client()
|
||||
return updated_config
|
||||
|
||||
@router.post("/reset", response_model=ConfigSchema)
|
||||
async def reset_configuration(db: Session = Depends(get_db)):
|
||||
"""Reset the configuration to default values."""
|
||||
try:
|
||||
# Get the default configuration with proper provider setups
|
||||
default_config = get_default_configuration()
|
||||
|
||||
# Save it as the current configuration in the database
|
||||
save_config_to_db(db, default_config)
|
||||
reset_memory_client()
|
||||
return default_config
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to reset configuration: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/mem0/llm", response_model=LLMProvider)
|
||||
async def get_llm_configuration(db: Session = Depends(get_db)):
|
||||
"""Get only the LLM configuration."""
|
||||
config = get_config_from_db(db)
|
||||
llm_config = config.get("mem0", {}).get("llm", {})
|
||||
return llm_config
|
||||
|
||||
@router.put("/mem0/llm", response_model=LLMProvider)
|
||||
async def update_llm_configuration(llm_config: LLMProvider, db: Session = Depends(get_db)):
|
||||
"""Update only the LLM configuration."""
|
||||
current_config = get_config_from_db(db)
|
||||
|
||||
# Ensure mem0 key exists
|
||||
if "mem0" not in current_config:
|
||||
current_config["mem0"] = {}
|
||||
|
||||
# Update the LLM configuration
|
||||
current_config["mem0"]["llm"] = llm_config.dict(exclude_none=True)
|
||||
|
||||
# Save the configuration to database
|
||||
save_config_to_db(db, current_config)
|
||||
reset_memory_client()
|
||||
return current_config["mem0"]["llm"]
|
||||
|
||||
@router.get("/mem0/embedder", response_model=EmbedderProvider)
|
||||
async def get_embedder_configuration(db: Session = Depends(get_db)):
|
||||
"""Get only the Embedder configuration."""
|
||||
config = get_config_from_db(db)
|
||||
embedder_config = config.get("mem0", {}).get("embedder", {})
|
||||
return embedder_config
|
||||
|
||||
@router.put("/mem0/embedder", response_model=EmbedderProvider)
|
||||
async def update_embedder_configuration(embedder_config: EmbedderProvider, db: Session = Depends(get_db)):
|
||||
"""Update only the Embedder configuration."""
|
||||
current_config = get_config_from_db(db)
|
||||
|
||||
# Ensure mem0 key exists
|
||||
if "mem0" not in current_config:
|
||||
current_config["mem0"] = {}
|
||||
|
||||
# Update the Embedder configuration
|
||||
current_config["mem0"]["embedder"] = embedder_config.dict(exclude_none=True)
|
||||
|
||||
# Save the configuration to database
|
||||
save_config_to_db(db, current_config)
|
||||
reset_memory_client()
|
||||
return current_config["mem0"]["embedder"]
|
||||
|
||||
@router.get("/openmemory", response_model=OpenMemoryConfig)
|
||||
async def get_openmemory_configuration(db: Session = Depends(get_db)):
|
||||
"""Get only the OpenMemory configuration."""
|
||||
config = get_config_from_db(db)
|
||||
openmemory_config = config.get("openmemory", {})
|
||||
return openmemory_config
|
||||
|
||||
@router.put("/openmemory", response_model=OpenMemoryConfig)
|
||||
async def update_openmemory_configuration(openmemory_config: OpenMemoryConfig, db: Session = Depends(get_db)):
|
||||
"""Update only the OpenMemory configuration."""
|
||||
current_config = get_config_from_db(db)
|
||||
|
||||
# Ensure openmemory key exists
|
||||
if "openmemory" not in current_config:
|
||||
current_config["openmemory"] = {}
|
||||
|
||||
# Update the OpenMemory configuration
|
||||
current_config["openmemory"].update(openmemory_config.dict(exclude_none=True))
|
||||
|
||||
# Save the configuration to database
|
||||
save_config_to_db(db, current_config)
|
||||
reset_memory_client()
|
||||
return current_config["openmemory"]
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime, UTC
|
||||
from typing import List, Optional, Set
|
||||
from uuid import UUID, uuid4
|
||||
import logging
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from fastapi_pagination import Page, Params
|
||||
@@ -13,13 +14,11 @@ from app.utils.memory import get_memory_client
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
Memory, MemoryState, MemoryAccessLog, App,
|
||||
MemoryStatusHistory, User, Category, AccessControl
|
||||
MemoryStatusHistory, User, Category, AccessControl, Config as ConfigModel
|
||||
)
|
||||
from app.schemas import MemoryResponse, PaginatedMemoryResponse
|
||||
from app.utils.permissions import check_memory_access_permissions
|
||||
|
||||
memory_client = get_memory_client()
|
||||
|
||||
router = APIRouter(prefix="/api/v1/memories", tags=["memories"])
|
||||
|
||||
|
||||
@@ -227,100 +226,79 @@ async def create_memory(
|
||||
# Log what we're about to do
|
||||
logging.info(f"Creating memory for user_id: {request.user_id} with app: {request.app}")
|
||||
|
||||
# Save to Qdrant via memory_client
|
||||
qdrant_response = memory_client.add(
|
||||
request.text,
|
||||
user_id=request.user_id, # Use string user_id to match search
|
||||
metadata={
|
||||
"source_app": "openmemory",
|
||||
"mcp_client": request.app,
|
||||
}
|
||||
)
|
||||
|
||||
# Log the response for debugging
|
||||
logging.info(f"Qdrant response: {qdrant_response}")
|
||||
|
||||
# Process Qdrant response
|
||||
if isinstance(qdrant_response, dict) and 'results' in qdrant_response:
|
||||
for result in qdrant_response['results']:
|
||||
if result['event'] == 'ADD':
|
||||
# Get the Qdrant-generated ID
|
||||
memory_id = UUID(result['id'])
|
||||
|
||||
# Check if memory already exists
|
||||
existing_memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
|
||||
if existing_memory:
|
||||
# Update existing memory
|
||||
existing_memory.state = MemoryState.active
|
||||
existing_memory.content = result['memory']
|
||||
memory = existing_memory
|
||||
else:
|
||||
# Create memory with the EXACT SAME ID from Qdrant
|
||||
memory = Memory(
|
||||
id=memory_id, # Use the same ID that Qdrant generated
|
||||
user_id=user.id,
|
||||
app_id=app_obj.id,
|
||||
content=result['memory'],
|
||||
metadata_=request.metadata,
|
||||
state=MemoryState.active
|
||||
)
|
||||
db.add(memory)
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.deleted if existing_memory else MemoryState.deleted,
|
||||
new_state=MemoryState.active
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
db.commit()
|
||||
db.refresh(memory)
|
||||
return memory
|
||||
|
||||
# Fallback to traditional DB-only approach if Qdrant integration fails
|
||||
# Generate a random UUID for the memory
|
||||
memory_id = uuid4()
|
||||
memory = Memory(
|
||||
id=memory_id,
|
||||
user_id=user.id,
|
||||
app_id=app_obj.id,
|
||||
content=request.text,
|
||||
metadata_=request.metadata
|
||||
)
|
||||
db.add(memory)
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.deleted,
|
||||
new_state=MemoryState.active
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
db.commit()
|
||||
db.refresh(memory)
|
||||
|
||||
# Attempt to add to Qdrant with the same ID we just created
|
||||
# Try to get memory client safely
|
||||
try:
|
||||
# Try to add with our specific ID
|
||||
memory_client.add(
|
||||
memory_client = get_memory_client()
|
||||
if not memory_client:
|
||||
raise Exception("Memory client is not available")
|
||||
except Exception as client_error:
|
||||
logging.warning(f"Memory client unavailable: {client_error}. Creating memory in database only.")
|
||||
# Return a json response with the error
|
||||
return {
|
||||
"error": str(client_error)
|
||||
}
|
||||
|
||||
# Try to save to Qdrant via memory_client
|
||||
try:
|
||||
qdrant_response = memory_client.add(
|
||||
request.text,
|
||||
memory_id=str(memory_id), # Specify the ID
|
||||
user_id=request.user_id,
|
||||
user_id=request.user_id, # Use string user_id to match search
|
||||
metadata={
|
||||
"source_app": "openmemory",
|
||||
"mcp_client": request.app,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to add to Qdrant in fallback path: {e}")
|
||||
# Continue anyway, the DB record is created
|
||||
|
||||
return memory
|
||||
|
||||
# Log the response for debugging
|
||||
logging.info(f"Qdrant response: {qdrant_response}")
|
||||
|
||||
# Process Qdrant response
|
||||
if isinstance(qdrant_response, dict) and 'results' in qdrant_response:
|
||||
for result in qdrant_response['results']:
|
||||
if result['event'] == 'ADD':
|
||||
# Get the Qdrant-generated ID
|
||||
memory_id = UUID(result['id'])
|
||||
|
||||
# Check if memory already exists
|
||||
existing_memory = db.query(Memory).filter(Memory.id == memory_id).first()
|
||||
|
||||
if existing_memory:
|
||||
# Update existing memory
|
||||
existing_memory.state = MemoryState.active
|
||||
existing_memory.content = result['memory']
|
||||
memory = existing_memory
|
||||
else:
|
||||
# Create memory with the EXACT SAME ID from Qdrant
|
||||
memory = Memory(
|
||||
id=memory_id, # Use the same ID that Qdrant generated
|
||||
user_id=user.id,
|
||||
app_id=app_obj.id,
|
||||
content=result['memory'],
|
||||
metadata_=request.metadata,
|
||||
state=MemoryState.active
|
||||
)
|
||||
db.add(memory)
|
||||
|
||||
# Create history entry
|
||||
history = MemoryStatusHistory(
|
||||
memory_id=memory_id,
|
||||
changed_by=user.id,
|
||||
old_state=MemoryState.deleted if existing_memory else MemoryState.deleted,
|
||||
new_state=MemoryState.active
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
db.commit()
|
||||
db.refresh(memory)
|
||||
return memory
|
||||
except Exception as qdrant_error:
|
||||
logging.warning(f"Qdrant operation failed: {qdrant_error}.")
|
||||
# Return a json response with the error
|
||||
return {
|
||||
"error": str(qdrant_error)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
# Get memory by ID
|
||||
|
||||
Reference in New Issue
Block a user