tools fix and formatting (#2441)

This commit is contained in:
Dev Khant
2025-03-26 11:25:03 +05:30
committed by GitHub
parent 2517ccd489
commit 2004427acd
18 changed files with 536 additions and 1151 deletions

View File

@@ -618,18 +618,16 @@ class MemoryClient:
return response.json()
@api_error_handler
def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = {
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status()
@@ -1019,20 +1017,18 @@ class AsyncMemoryClient:
return response.json()
@api_error_handler
async def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
async def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = {
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
response = await self.async_client.post("/v1/feedback/", json=data)
response.raise_for_status()
capture_client_event("async_client.feedback", self.sync_client, data)
return response.json()
return response.json()

View File

@@ -208,11 +208,12 @@ Please note to return the IDs in the output from the input IDs only and do not g
}
"""
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
if custom_update_memory_prompt is None:
global DEFAULT_UPDATE_MEMORY_PROMPT
custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT
return f"""{custom_update_memory_prompt}
Below is the current content of my memory which I have collected till now. You have to update it in the following format only:
@@ -250,4 +251,4 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust
- If there is an update, the ID key should remain the same and only the value needs to be updated.
Do not return anything except the JSON format.
"""
"""

View File

@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]:
if provider in [
"openai",
"ollama",
"huggingface",
"azure_openai",
"gemini",
"vertexai",
"together",
"lmstudio",
]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -28,5 +28,7 @@ class GoogleGenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
response = genai.embed_content(
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
)
return response["embedding"]

View File

@@ -26,8 +26,4 @@ class LMStudioEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding

View File

@@ -17,16 +17,16 @@ class OpenAIEmbedding(EmbeddingBase):
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = (
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
)
if os.environ.get("OPENAI_API_BASE"):
warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning
DeprecationWarning,
)
self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -42,4 +42,8 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding
return (
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
.data[0]
.embedding
)

View File

@@ -1,4 +1,3 @@
import json
import os
from typing import Dict, List, Optional
@@ -36,6 +35,8 @@ class AzureOpenAIStructuredLLM(LLMBase):
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str:
"""
Generate a response based on the given messages using Azure OpenAI.

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Dict, List, Optional
from mem0.configs.llms.base import BaseLlmConfig
@@ -17,12 +17,14 @@ class LLMBase(ABC):
self.config = config
@abstractmethod
def generate_response(self, messages):
def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
"""
Generate a response based on the given messages.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.

View File

@@ -84,4 +84,4 @@ class GroqLLM(LLMBase):
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools)
return self._parse_response(response, tools)

View File

@@ -10,7 +10,10 @@ class LMStudioLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
self.config.model = (
self.config.model
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
)
self.config.api_key = self.config.api_key or "lm-studio"
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
@@ -20,7 +23,7 @@ class LMStudioLLM(LLMBase):
messages: List[Dict[str, str]],
response_format: dict = {"type": "json_object"},
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto"
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using LM Studio.
@@ -39,10 +42,10 @@ class LMStudioLLM(LLMBase):
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
return response.choices[0].message.content

View File

@@ -1,3 +1,4 @@
import json
import os
import warnings
from typing import Dict, List, Optional
@@ -34,7 +35,7 @@ class OpenAILLM(LLMBase):
warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning
DeprecationWarning,
)
self.client = OpenAI(api_key=api_key, base_url=base_url)

View File

@@ -1,4 +1,3 @@
import json
import os
from typing import Dict, List, Optional
@@ -23,6 +22,8 @@ class OpenAIStructuredLLM(LLMBase):
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str:
"""
Generate a response based on the given messages using OpenAI.

View File

@@ -18,13 +18,21 @@ class XAILLM(LLMBase):
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using XAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.

View File

@@ -12,11 +12,14 @@ try:
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL, RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL)
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory

View File

@@ -8,7 +8,9 @@ try:
from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.data.dataclasses.vector import Vector
except ImportError:
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
) from None
from mem0.vector_stores.base import VectorStoreBase
@@ -34,7 +36,7 @@ class PineconeDB(VectorStoreBase):
hybrid_search: bool,
metric: str,
batch_size: int,
extra_params: Optional[Dict[str, Any]]
extra_params: Optional[Dict[str, Any]],
):
"""
Initialize the Pinecone vector store.
@@ -199,7 +201,9 @@ class PineconeDB(VectorStoreBase):
return pinecone_filter
def search(self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.

1558
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.1.76"
version = "0.1.77"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"]
exclude = [

View File

@@ -140,7 +140,7 @@ def test_search(memory_instance, version, enable_graph):
assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
query="test query", vectors=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
)
memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")