tools fix and formatting (#2441)

This commit is contained in:
Dev Khant
2025-03-26 11:25:03 +05:30
committed by GitHub
parent 2517ccd489
commit 2004427acd
18 changed files with 536 additions and 1151 deletions

View File

@@ -618,18 +618,16 @@ class MemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]: def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"} VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None') raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = { data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
response = self.client.post("/v1/feedback/", json=data) response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()
@@ -1019,20 +1017,18 @@ class AsyncMemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
async def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]: async def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"} VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None') raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = { data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
response = await self.async_client.post("/v1/feedback/", json=data) response = await self.async_client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()
capture_client_event("async_client.feedback", self.sync_client, data) capture_client_event("async_client.feedback", self.sync_client, data)
return response.json() return response.json()

View File

@@ -208,11 +208,12 @@ Please note to return the IDs in the output from the input IDs only and do not g
} }
""" """
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None): def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
if custom_update_memory_prompt is None: if custom_update_memory_prompt is None:
global DEFAULT_UPDATE_MEMORY_PROMPT global DEFAULT_UPDATE_MEMORY_PROMPT
custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT
return f"""{custom_update_memory_prompt} return f"""{custom_update_memory_prompt}
Below is the current content of my memory which I have collected till now. You have to update it in the following format only: Below is the current content of my memory which I have collected till now. You have to update it in the following format only:
@@ -250,4 +251,4 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust
- If there is an update, the ID key should remain the same and only the value needs to be updated. - If there is an update, the ID key should remain the same and only the value needs to be updated.
Do not return anything except the JSON format. Do not return anything except the JSON format.
""" """

View File

@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
provider = values.data.get("provider") provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]: if provider in [
"openai",
"ollama",
"huggingface",
"azure_openai",
"gemini",
"vertexai",
"together",
"lmstudio",
]:
return v return v
else: else:
raise ValueError(f"Unsupported embedding provider: {provider}") raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -28,5 +28,7 @@ class GoogleGenAIEmbedding(EmbeddingBase):
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims) response = genai.embed_content(
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
)
return response["embedding"] return response["embedding"]

View File

@@ -26,8 +26,4 @@ class LMStudioEmbedding(EmbeddingBase):
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)

View File

@@ -17,16 +17,16 @@ class OpenAIEmbedding(EmbeddingBase):
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = ( base_url = (
self.config.openai_base_url self.config.openai_base_url
or os.getenv("OPENAI_API_BASE") or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_BASE_URL")
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
if os.environ.get("OPENAI_API_BASE"): if os.environ.get("OPENAI_API_BASE"):
warnings.warn( warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. " "The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.", "Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning DeprecationWarning,
) )
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -42,4 +42,8 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding return (
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
.data[0]
.embedding
)

View File

@@ -1,4 +1,3 @@
import json
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -36,6 +35,8 @@ class AzureOpenAIStructuredLLM(LLMBase):
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format: Optional[str] = None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str: ) -> str:
""" """
Generate a response based on the given messages using Azure OpenAI. Generate a response based on the given messages using Azure OpenAI.

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Dict, List, Optional
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -17,12 +17,14 @@ class LLMBase(ABC):
self.config = config self.config = config
@abstractmethod @abstractmethod
def generate_response(self, messages): def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
""" """
Generate a response based on the given messages. Generate a response based on the given messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response.

View File

@@ -84,4 +84,4 @@ class GroqLLM(LLMBase):
params["tool_choice"] = tool_choice params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return self._parse_response(response, tools)

View File

@@ -10,7 +10,10 @@ class LMStudioLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf" self.config.model = (
self.config.model
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
)
self.config.api_key = self.config.api_key or "lm-studio" self.config.api_key = self.config.api_key or "lm-studio"
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key) self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
@@ -20,7 +23,7 @@ class LMStudioLLM(LLMBase):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format: dict = {"type": "json_object"}, response_format: dict = {"type": "json_object"},
tools: Optional[List[Dict]] = None, tools: Optional[List[Dict]] = None,
tool_choice: str = "auto" tool_choice: str = "auto",
): ):
""" """
Generate a response based on the given messages using LM Studio. Generate a response based on the given messages using LM Studio.
@@ -39,10 +42,10 @@ class LMStudioLLM(LLMBase):
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return response.choices[0].message.content return response.choices[0].message.content

View File

@@ -1,3 +1,4 @@
import json
import os import os
import warnings import warnings
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -34,7 +35,7 @@ class OpenAILLM(LLMBase):
warnings.warn( warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. " "The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.", "Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning DeprecationWarning,
) )
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)

View File

@@ -1,4 +1,3 @@
import json
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -23,6 +22,8 @@ class OpenAIStructuredLLM(LLMBase):
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format: Optional[str] = None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str: ) -> str:
""" """
Generate a response based on the given messages using OpenAI. Generate a response based on the given messages using OpenAI.

View File

@@ -18,13 +18,21 @@ class XAILLM(LLMBase):
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1" base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def generate_response(self, messages: List[Dict[str, str]], response_format=None): def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using XAI. Generate a response based on the given messages using XAI.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response.

View File

@@ -12,11 +12,14 @@ try:
except ImportError: except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25") raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (DELETE_MEMORY_STRUCT_TOOL_GRAPH, from mem0.graphs.tools import (
DELETE_MEMORY_TOOL_GRAPH, DELETE_MEMORY_STRUCT_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL, DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_TOOL, RELATIONS_STRUCT_TOOL, EXTRACT_ENTITIES_STRUCT_TOOL,
RELATIONS_TOOL) EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory from mem0.utils.factory import EmbedderFactory, LlmFactory

View File

@@ -8,7 +8,9 @@ try:
from pinecone import Pinecone, PodSpec, ServerlessSpec from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.data.dataclasses.vector import Vector from pinecone.data.dataclasses.vector import Vector
except ImportError: except ImportError:
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
) from None
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
@@ -34,7 +36,7 @@ class PineconeDB(VectorStoreBase):
hybrid_search: bool, hybrid_search: bool,
metric: str, metric: str,
batch_size: int, batch_size: int,
extra_params: Optional[Dict[str, Any]] extra_params: Optional[Dict[str, Any]],
): ):
""" """
Initialize the Pinecone vector store. Initialize the Pinecone vector store.
@@ -199,7 +201,9 @@ class PineconeDB(VectorStoreBase):
return pinecone_filter return pinecone_filter
def search(self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.

1558
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.76" version = "0.1.77"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [

View File

@@ -140,7 +140,7 @@ def test_search(memory_instance, version, enable_graph):
assert result["results"][0]["score"] == 0.9 assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with( memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"} query="test query", vectors=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
) )
memory_instance.embedding_model.embed.assert_called_once_with("test query", "search") memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")