[Mem0] Update dependencies and make the package lighter (#1708)
Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
@@ -1,18 +1,26 @@
|
||||
import httpx
|
||||
from typing import Optional, List, Union
|
||||
import threading
|
||||
import litellm
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"litellm requires extra dependencies. Install with `pip install litellm`"
|
||||
) from None
|
||||
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
from mem0 import Memory, MemoryClient
|
||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||
|
||||
|
||||
class Mem0:
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[dict] = None,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
config: Optional[dict] = None,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
):
|
||||
if api_key:
|
||||
self.mem0_client = MemoryClient(api_key, host)
|
||||
else:
|
||||
@@ -77,13 +85,21 @@ class Completions:
|
||||
raise ValueError("One of user_id, agent_id, run_id must be provided")
|
||||
|
||||
if not litellm.supports_function_calling(model):
|
||||
raise ValueError(f"Model '{model}' does not support function calling. Please use a model that supports function calling.")
|
||||
raise ValueError(
|
||||
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
|
||||
)
|
||||
|
||||
prepared_messages = self._prepare_messages(messages)
|
||||
if prepared_messages[-1]["role"] == "user":
|
||||
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
|
||||
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
|
||||
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
|
||||
self._async_add_to_memory(
|
||||
messages, user_id, agent_id, run_id, metadata, filters
|
||||
)
|
||||
relevant_memories = self._fetch_relevant_memories(
|
||||
messages, user_id, agent_id, run_id, filters, limit
|
||||
)
|
||||
prepared_messages[-1]["content"] = self._format_query_with_memories(
|
||||
messages, relevant_memories
|
||||
)
|
||||
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
@@ -114,9 +130,9 @@ class Completions:
|
||||
base_url=base_url,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
model_list=model_list
|
||||
model_list=model_list,
|
||||
)
|
||||
|
||||
capture_client_event("mem0.chat.create", self)
|
||||
return response
|
||||
|
||||
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
|
||||
@@ -125,7 +141,9 @@ class Completions:
|
||||
messages[0]["content"] = MEMORY_ANSWER_PROMPT
|
||||
return messages
|
||||
|
||||
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
|
||||
def _async_add_to_memory(
|
||||
self, messages, user_id, agent_id, run_id, metadata, filters
|
||||
):
|
||||
def add_task():
|
||||
self.mem0_client.add(
|
||||
messages=messages,
|
||||
@@ -135,11 +153,16 @@ class Completions:
|
||||
metadata=metadata,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
threading.Thread(target=add_task, daemon=True).start()
|
||||
|
||||
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
|
||||
def _fetch_relevant_memories(
|
||||
self, messages, user_id, agent_id, run_id, filters, limit
|
||||
):
|
||||
# Currently, only pass the last 6 messages to the search API to prevent long query
|
||||
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
|
||||
message_input = [
|
||||
f"{message['role']}: {message['content']}" for message in messages
|
||||
][-6:]
|
||||
# TODO: Make it better by summarizing the past conversation
|
||||
return self.mem0_client.search(
|
||||
query="\n".join(message_input),
|
||||
|
||||
Reference in New Issue
Block a user