tools fix and formatting (#2441)

This commit is contained in:
Dev Khant
2025-03-26 11:25:03 +05:30
committed by GitHub
parent 2517ccd489
commit 2004427acd
18 changed files with 536 additions and 1151 deletions

View File

@@ -618,18 +618,16 @@ class MemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]: def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"} VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None') raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = { data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
response = self.client.post("/v1/feedback/", json=data) response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()
@@ -1019,18 +1017,16 @@ class AsyncMemoryClient:
return response.json() return response.json()
@api_error_handler @api_error_handler
async def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]: async def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"} VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES: if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None') raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = { data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
response = await self.async_client.post("/v1/feedback/", json=data) response = await self.async_client.post("/v1/feedback/", json=data)
response.raise_for_status() response.raise_for_status()

View File

@@ -208,6 +208,7 @@ Please note to return the IDs in the output from the input IDs only and do not g
} }
""" """
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None): def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
if custom_update_memory_prompt is None: if custom_update_memory_prompt is None:
global DEFAULT_UPDATE_MEMORY_PROMPT global DEFAULT_UPDATE_MEMORY_PROMPT

View File

@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):
provider = values.data.get("provider") provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]: if provider in [
"openai",
"ollama",
"huggingface",
"azure_openai",
"gemini",
"vertexai",
"together",
"lmstudio",
]:
return v return v
else: else:
raise ValueError(f"Unsupported embedding provider: {provider}") raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -28,5 +28,7 @@ class GoogleGenAIEmbedding(EmbeddingBase):
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims) response = genai.embed_content(
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
)
return response["embedding"] return response["embedding"]

View File

@@ -26,8 +26,4 @@ class LMStudioEmbedding(EmbeddingBase):
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)

View File

@@ -17,16 +17,16 @@ class OpenAIEmbedding(EmbeddingBase):
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = ( base_url = (
self.config.openai_base_url self.config.openai_base_url
or os.getenv("OPENAI_API_BASE") or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_BASE_URL")
or "https://api.openai.com/v1" or "https://api.openai.com/v1"
) )
if os.environ.get("OPENAI_API_BASE"): if os.environ.get("OPENAI_API_BASE"):
warnings.warn( warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. " "The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.", "Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning DeprecationWarning,
) )
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -42,4 +42,8 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding return (
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
.data[0]
.embedding
)

View File

@@ -1,4 +1,3 @@
import json
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -36,6 +35,8 @@ class AzureOpenAIStructuredLLM(LLMBase):
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format: Optional[str] = None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str: ) -> str:
""" """
Generate a response based on the given messages using Azure OpenAI. Generate a response based on the given messages using Azure OpenAI.

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Dict, List, Optional
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
@@ -17,12 +17,14 @@ class LLMBase(ABC):
self.config = config self.config = config
@abstractmethod @abstractmethod
def generate_response(self, messages): def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
""" """
Generate a response based on the given messages. Generate a response based on the given messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response.

View File

@@ -10,7 +10,10 @@ class LMStudioLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf" self.config.model = (
self.config.model
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
)
self.config.api_key = self.config.api_key or "lm-studio" self.config.api_key = self.config.api_key or "lm-studio"
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key) self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
@@ -20,7 +23,7 @@ class LMStudioLLM(LLMBase):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format: dict = {"type": "json_object"}, response_format: dict = {"type": "json_object"},
tools: Optional[List[Dict]] = None, tools: Optional[List[Dict]] = None,
tool_choice: str = "auto" tool_choice: str = "auto",
): ):
""" """
Generate a response based on the given messages using LM Studio. Generate a response based on the given messages using LM Studio.
@@ -39,7 +42,7 @@ class LMStudioLLM(LLMBase):
"messages": messages, "messages": messages,
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format

View File

@@ -1,3 +1,4 @@
import json
import os import os
import warnings import warnings
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -34,7 +35,7 @@ class OpenAILLM(LLMBase):
warnings.warn( warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. " "The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.", "Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning DeprecationWarning,
) )
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)

View File

@@ -1,4 +1,3 @@
import json
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -23,6 +22,8 @@ class OpenAIStructuredLLM(LLMBase):
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format: Optional[str] = None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str: ) -> str:
""" """
Generate a response based on the given messages using OpenAI. Generate a response based on the given messages using OpenAI.

View File

@@ -18,13 +18,21 @@ class XAILLM(LLMBase):
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1" base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def generate_response(self, messages: List[Dict[str, str]], response_format=None): def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using XAI. Generate a response based on the given messages using XAI.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response.

View File

@@ -12,11 +12,14 @@ try:
except ImportError: except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25") raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (DELETE_MEMORY_STRUCT_TOOL_GRAPH, from mem0.graphs.tools import (
DELETE_MEMORY_TOOL_GRAPH, DELETE_MEMORY_STRUCT_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL, DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_TOOL, RELATIONS_STRUCT_TOOL, EXTRACT_ENTITIES_STRUCT_TOOL,
RELATIONS_TOOL) EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory from mem0.utils.factory import EmbedderFactory, LlmFactory

View File

@@ -8,7 +8,9 @@ try:
from pinecone import Pinecone, PodSpec, ServerlessSpec from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.data.dataclasses.vector import Vector from pinecone.data.dataclasses.vector import Vector
except ImportError: except ImportError:
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
) from None
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
@@ -34,7 +36,7 @@ class PineconeDB(VectorStoreBase):
hybrid_search: bool, hybrid_search: bool,
metric: str, metric: str,
batch_size: int, batch_size: int,
extra_params: Optional[Dict[str, Any]] extra_params: Optional[Dict[str, Any]],
): ):
""" """
Initialize the Pinecone vector store. Initialize the Pinecone vector store.
@@ -199,7 +201,9 @@ class PineconeDB(VectorStoreBase):
return pinecone_filter return pinecone_filter
def search(self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
""" """
Search for similar vectors. Search for similar vectors.

1558
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.76" version = "0.1.77"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [

View File

@@ -140,7 +140,7 @@ def test_search(memory_instance, version, enable_graph):
assert result["results"][0]["score"] == 0.9 assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with( memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"} query="test query", vectors=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
) )
memory_instance.embedding_model.embed.assert_called_once_with("test query", "search") memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")