Add OpenAI proxy (#1503)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2024-08-02 20:14:27 +05:30
committed by GitHub
parent 51092b0b64
commit 419dc6598c
18 changed files with 637 additions and 135 deletions

View File

@@ -4,3 +4,4 @@ __version__ = importlib.metadata.version("mem0ai")
from mem0.memory.main import Memory # noqa
from mem0.client.main import MemoryClient # noqa
from mem0.proxy.main import Mem0 #noqa

View File

@@ -0,0 +1,39 @@
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from mem0.memory.setup import mem0_dir
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel):
vector_store: VectorStoreConfig = Field(
description="Configuration for the vector store",
default_factory=VectorStoreConfig,
)
llm: LlmConfig = Field(
description="Configuration for the language model",
default_factory=LlmConfig,
)
embedder: EmbedderConfig = Field(
description="Configuration for the embedding model",
default_factory=EmbedderConfig,
)
history_db_path: str = Field(
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
)

View File

View File

@@ -0,0 +1,32 @@
from abc import ABC
from typing import Optional
class BaseEmbedderConfig(ABC):
"""
Config for Embeddings.
"""
def __init__(
self,
model: Optional[str] = None,
embedding_dims: Optional[int] = None,
# Ollama specific
base_url: Optional[str] = None
):
"""
Initializes a configuration class instance for the Embeddings.
:param model: Embedding model to use, defaults to None
:type model: Optional[str], optional
:param embedding_dims: The number of dimensions in the embedding, defaults to None
:type embedding_dims: Optional[int], optional
:param base_url: Base URL for the Ollama API, defaults to None
:type base_url: Optional[str], optional
"""
self.model = model
self.embedding_dims = embedding_dims
# Ollama specific
self.base_url = base_url

View File

@@ -29,3 +29,14 @@ Constraint for deducing facts, preferences, and memories:
Deduced facts, preferences, and memories:
"""
MEMORY_ANSWER_PROMPT = """
You are an expert at answering questions based on the provided memories. Your task is to provide accurate and concise answers to the questions by leveraging the information given in the memories.
Guidelines:
- Extract relevant information from the memories based on the question.
- If no relevant information is found, make sure you don't say no information is found. Instead, accept the question and provide a general response.
- Ensure that the answers are clear, concise, and directly address the question.
Here are the details of the task:
"""

View File

@@ -15,51 +15,17 @@ from mem0.llms.utils.tools import (
)
from mem0.configs.prompts import MEMORY_DEDUCTION_PROMPT
from mem0.memory.base import MemoryBase
from mem0.memory.setup import mem0_dir, setup_config
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_update_memory_messages
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
# Setup user config
setup_config()
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel):
vector_store: VectorStoreConfig = Field(
description="Configuration for the vector store",
default_factory=VectorStoreConfig,
)
llm: LlmConfig = Field(
description="Configuration for the language model",
default_factory=LlmConfig,
)
embedder: EmbedderConfig = Field(
description="Configuration for the embedding model",
default_factory=EmbedderConfig,
)
history_db_path: str = Field(
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config

View File

@@ -17,33 +17,52 @@ class SQLiteManager:
table_exists = cursor.fetchone() is not None
if table_exists:
# Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history")
# Get the current schema of the history table
cursor.execute("PRAGMA table_info(history)")
current_schema = {row[1]: row[2] for row in cursor.fetchall()}
cursor.execute("""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
old_memory TEXT,
new_memory TEXT,
new_value TEXT,
event TEXT,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER
)
""")
# Define the expected schema
expected_schema = {
'id': 'TEXT',
'memory_id': 'TEXT',
'old_memory': 'TEXT',
'new_memory': 'TEXT',
'new_value': 'TEXT',
'event': 'TEXT',
'created_at': 'DATETIME',
'updated_at': 'DATETIME',
'is_deleted': 'INTEGER'
}
# Copy data from the old table to the new table
cursor.execute("""
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history
""")
# Check if the schemas are the same
if current_schema != expected_schema:
# Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history")
cursor.execute("DROP TABLE old_history")
cursor.execute("""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
old_memory TEXT,
new_memory TEXT,
new_value TEXT,
event TEXT,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER
)
""")
self.connection.commit()
# Copy data from the old table to the new table
cursor.execute("""
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history
""")
cursor.execute("DROP TABLE old_history")
self.connection.commit()
def _create_history_table(self):

0
mem0/proxy/__init__.py Normal file
View File

155
mem0/proxy/main.py Normal file
View File

@@ -0,0 +1,155 @@
import httpx
from typing import Optional, List, Union
import threading
import litellm
from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
class Mem0:
def __init__(
self,
config: Optional[dict] = None,
api_key: Optional[str] = None,
host: Optional[str] = None
):
if api_key:
self.mem0_client = MemoryClient(api_key, host)
else:
self.mem0_client = Memory.from_config(config) if config else Memory()
self.chat = Chat(self.mem0_client)
class Chat:
def __init__(self, mem0_client):
self.completions = Completions(mem0_client)
class Completions:
def __init__(self, mem0_client):
self.mem0_client = mem0_client
def create(
self,
model: str,
messages: List = [],
# Mem0 arguments
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[dict] = None,
filters: Optional[dict] = None,
limit: Optional[int] = 10,
# LLM arguments
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
# openai v1.0+ new params
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
deployment_id=None,
extra_headers: Optional[dict] = None,
# soon to be deprecated params by OpenAI
functions: Optional[List] = None,
function_call: Optional[str] = None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
):
if not any([user_id, agent_id, run_id]):
raise ValueError("One of user_id, agent_id, run_id must be provided")
if not litellm.supports_function_calling(model):
raise ValueError(f"Model '{model}' does not support function calling. Please use a model that supports function calling.")
prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
response = litellm.completion(
model=model,
messages=prepared_messages,
temperature=temperature,
top_p=top_p,
n=n,
timeout=timeout,
stream=stream,
stream_options=stream_options,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
response_format=response_format,
seed=seed,
tools=tools,
tool_choice=tool_choice,
logprobs=logprobs,
top_logprobs=top_logprobs,
parallel_tool_calls=parallel_tool_calls,
deployment_id=deployment_id,
extra_headers=extra_headers,
functions=functions,
function_call=function_call,
base_url=base_url,
api_version=api_version,
api_key=api_key,
model_list=model_list
)
return response
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
if not messages or messages[0]["role"] != "system":
return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
messages[0]["content"] = MEMORY_ANSWER_PROMPT
return messages
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def add_task():
self.mem0_client.add(
messages=messages,
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
metadata=metadata,
filters=filters,
)
threading.Thread(target=add_task, daemon=True).start()
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
# Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),
user_id=user_id,
agent_id=agent_id,
run_id=run_id,
filters=filters,
limit=limit,
)
def _format_query_with_memories(self, messages, relevant_memories):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"

View File

@@ -55,6 +55,8 @@ class VectorStoreFactory:
def create(cls, provider_name, config):
class_type = cls.provider_to_class.get(provider_name)
if class_type:
if not isinstance(config, dict):
config = config.model_dump()
vector_store_instance = load_class(class_type)
return vector_store_instance(**config)
else:

View File

@@ -21,20 +21,20 @@ class OutputData(BaseModel):
class ChromaDB(VectorStoreBase):
def __init__(
self,
collection_name="mem0",
client=None,
host=None,
port=None,
path=None
collection_name,
client,
host,
port,
path
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to None.
port (int, optional): Port for Qdrant server. Defaults to None.
path (str, optional): Path for local Qdrant database. Defaults to None.
client (QdrantClient, optional): Existing Qdrant client instance.
host (str, optional): Host address for Qdrant server.
port (int, optional): Port for Qdrant server.
path (str, optional): Path for local Qdrant database.
"""
if client:
self.client = client
@@ -95,7 +95,7 @@ class ChromaDB(VectorStoreBase):
Args:
name (str): Name of the collection.
embedding_fn (function): Embedding function to use.
embedding_fn (function): Embedding function to use. Defaults to None.
"""
# Skip creating collection if already exists
collections = self.list_cols()
@@ -213,7 +213,7 @@ class ChromaDB(VectorStoreBase):
Args:
name (str): Name of the collection.
filters (dict, optional): Filters to apply to the list. Defaults to None.
filters (dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:

View File

@@ -3,14 +3,23 @@ from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
def create_default_config(provider: str):
"""Create a default configuration based on the provider."""
if provider == "qdrant":
return QdrantConfig(path="/tmp/qdrant")
elif provider == "chromadb":
return ChromaDbConfig(path="/tmp/chromadb")
else:
raise ValueError(f"Unsupported vector store provider: {provider}")
class QdrantConfig(BaseModel):
collection_name: str = Field(default="mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(
default=1536, description="Dimensions of the embedding model"
)
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[str] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field(None, description="Path for local Qdrant database")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
@@ -31,18 +40,11 @@ class QdrantConfig(BaseModel):
class ChromaDbConfig(BaseModel):
collection_name: str = Field(
default="mem0", description="Default name for the collection"
)
path: Optional[str] = Field(
default=None, description="Path to the database directory"
)
host: Optional[str] = Field(
default=None, description="Database connection remote host"
)
port: Optional[str] = Field(
default=None, description="Database connection remote port"
)
collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[str] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[str] = Field(None, description="Database connection remote port")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
@@ -59,15 +61,37 @@ class VectorStoreConfig(BaseModel):
)
config: Optional[dict] = Field(
description="Configuration for the specific vector store",
default={},
default=None
)
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider == "qdrant":
return QdrantConfig(**v.model_dump())
elif provider == "chromadb":
return ChromaDbConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported vector store provider: {provider}")
if v is None:
return create_default_config(provider)
if isinstance(v, dict):
if provider == "qdrant":
return QdrantConfig(**v)
elif provider == "chromadb":
return ChromaDbConfig(**v)
return v
@model_validator(mode="after")
def ensure_config_type(cls, values):
provider = values.provider
config = values.config
if config is None:
values.config = create_default_config(provider)
elif isinstance(config, dict):
if provider == "qdrant":
values.config = QdrantConfig(**config)
elif provider == "chromadb":
values.config = ChromaDbConfig(**config)
elif not isinstance(config, (QdrantConfig, ChromaDbConfig)):
raise ValueError(f"Invalid config type for provider {provider}")
return values

View File

@@ -20,25 +20,25 @@ from mem0.vector_stores.base import VectorStoreBase
class Qdrant(VectorStoreBase):
def __init__(
self,
collection_name="mem0",
embedding_model_dims=1536,
client=None,
host="localhost",
port=6333,
path=None,
url=None,
api_key=None,
collection_name,
embedding_model_dims,
client,
host,
port,
path,
url,
api_key,
):
"""
Initialize the Qdrant vector store.
Args:
client (QdrantClient, optional): Existing Qdrant client instance. Defaults to None.
host (str, optional): Host address for Qdrant server. Defaults to "localhost".
port (int, optional): Port for Qdrant server. Defaults to 6333.
path (str, optional): Path for local Qdrant database. Defaults to None.
url (str, optional): Full URL for Qdrant server. Defaults to None.
api_key (str, optional): API key for Qdrant server. Defaults to None.
client (QdrantClient, optional): Existing Qdrant client instance.
host (str, optional): Host address for Qdrant server.
port (int, optional): Port for Qdrant server.
path (str, optional): Path for local Qdrant database.
url (str, optional): Full URL for Qdrant server.
api_key (str, optional): API key for Qdrant server.
"""
if client:
self.client = client