[Mem0] Update dependencies and make the package lighter (#1708)

Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
Deshraj Yadav
2024-08-14 23:28:07 -07:00
committed by GitHub
parent e35786e567
commit a8ba7abb7d
35 changed files with 634 additions and 1594 deletions

View File

@@ -1,18 +1,26 @@
import httpx
from typing import Optional, List, Union
import threading
import litellm
try:
import litellm
except ImportError:
raise ImportError(
"litellm requires extra dependencies. Install with `pip install litellm`"
) from None
from mem0.memory.telemetry import capture_client_event
from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
class Mem0:
def __init__(
self,
config: Optional[dict] = None,
api_key: Optional[str] = None,
host: Optional[str] = None
):
self,
config: Optional[dict] = None,
api_key: Optional[str] = None,
host: Optional[str] = None,
):
if api_key:
self.mem0_client = MemoryClient(api_key, host)
else:
@@ -77,13 +85,21 @@ class Completions:
raise ValueError("One of user_id, agent_id, run_id must be provided")
if not litellm.supports_function_calling(model):
raise ValueError(f"Model '{model}' does not support function calling. Please use a model that supports function calling.")
raise ValueError(
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
)
prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
self._async_add_to_memory(
messages, user_id, agent_id, run_id, metadata, filters
)
relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit
)
prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories
)
response = litellm.completion(
model=model,
@@ -114,9 +130,9 @@ class Completions:
base_url=base_url,
api_version=api_version,
api_key=api_key,
model_list=model_list
model_list=model_list,
)
capture_client_event("mem0.chat.create", self)
return response
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
@@ -125,7 +141,9 @@ class Completions:
messages[0]["content"] = MEMORY_ANSWER_PROMPT
return messages
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def _async_add_to_memory(
self, messages, user_id, agent_id, run_id, metadata, filters
):
def add_task():
self.mem0_client.add(
messages=messages,
@@ -135,11 +153,16 @@ class Completions:
metadata=metadata,
filters=filters,
)
threading.Thread(target=add_task, daemon=True).start()
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
def _fetch_relevant_memories(
self, messages, user_id, agent_id, run_id, filters, limit
):
# Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
message_input = [
f"{message['role']}: {message['content']}" for message in messages
][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),