tools fix and formatting (#2441)

This commit is contained in:
Dev Khant
2025-03-26 11:25:03 +05:30
committed by GitHub
parent 2517ccd489
commit 2004427acd
18 changed files with 536 additions and 1151 deletions

View File

@@ -618,18 +618,16 @@ class MemoryClient:
return response.json()
@api_error_handler
def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = {
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
response = self.client.post("/v1/feedback/", json=data)
response.raise_for_status()
@@ -1019,18 +1017,16 @@ class AsyncMemoryClient:
return response.json()
@api_error_handler
async def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
async def feedback(
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
) -> Dict[str, str]:
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
feedback = feedback.upper() if feedback else None
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
data = {
"memory_id": memory_id,
"feedback": feedback,
"feedback_reason": feedback_reason
}
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
response = await self.async_client.post("/v1/feedback/", json=data)
response.raise_for_status()

View File

@@ -208,6 +208,7 @@ Please note to return the IDs in the output from the input IDs only and do not g
}
"""
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
if custom_update_memory_prompt is None:
global DEFAULT_UPDATE_MEMORY_PROMPT

View File

@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]:
if provider in [
"openai",
"ollama",
"huggingface",
"azure_openai",
"gemini",
"vertexai",
"together",
"lmstudio",
]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -28,5 +28,7 @@ class GoogleGenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
response = genai.embed_content(
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
)
return response["embedding"]

View File

@@ -26,8 +26,4 @@ class LMStudioEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding

View File

@@ -26,7 +26,7 @@ class OpenAIEmbedding(EmbeddingBase):
warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning
DeprecationWarning,
)
self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -42,4 +42,8 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding
return (
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
.data[0]
.embedding
)

View File

@@ -1,4 +1,3 @@
import json
import os
from typing import Dict, List, Optional
@@ -36,6 +35,8 @@ class AzureOpenAIStructuredLLM(LLMBase):
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str:
"""
Generate a response based on the given messages using Azure OpenAI.

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Dict, List, Optional
from mem0.configs.llms.base import BaseLlmConfig
@@ -17,12 +17,14 @@ class LLMBase(ABC):
self.config = config
@abstractmethod
def generate_response(self, messages):
def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
"""
Generate a response based on the given messages.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.

View File

@@ -10,7 +10,10 @@ class LMStudioLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
self.config.model = (
self.config.model
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
)
self.config.api_key = self.config.api_key or "lm-studio"
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
@@ -20,7 +23,7 @@ class LMStudioLLM(LLMBase):
messages: List[Dict[str, str]],
response_format: dict = {"type": "json_object"},
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto"
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using LM Studio.
@@ -39,7 +42,7 @@ class LMStudioLLM(LLMBase):
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -1,3 +1,4 @@
import json
import os
import warnings
from typing import Dict, List, Optional
@@ -34,7 +35,7 @@ class OpenAILLM(LLMBase):
warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning
DeprecationWarning,
)
self.client = OpenAI(api_key=api_key, base_url=base_url)

View File

@@ -1,4 +1,3 @@
import json
import os
from typing import Dict, List, Optional
@@ -23,6 +22,8 @@ class OpenAIStructuredLLM(LLMBase):
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str:
"""
Generate a response based on the given messages using OpenAI.

View File

@@ -18,13 +18,21 @@ class XAILLM(LLMBase):
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using XAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.

View File

@@ -12,11 +12,14 @@ try:
except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (DELETE_MEMORY_STRUCT_TOOL_GRAPH,
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL, RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL)
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory

View File

@@ -8,7 +8,9 @@ try:
from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.data.dataclasses.vector import Vector
except ImportError:
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
) from None
from mem0.vector_stores.base import VectorStoreBase
@@ -34,7 +36,7 @@ class PineconeDB(VectorStoreBase):
hybrid_search: bool,
metric: str,
batch_size: int,
extra_params: Optional[Dict[str, Any]]
extra_params: Optional[Dict[str, Any]],
):
"""
Initialize the Pinecone vector store.
@@ -199,7 +201,9 @@ class PineconeDB(VectorStoreBase):
return pinecone_filter
def search(self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.

1558
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.1.76"
version = "0.1.77"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"]
exclude = [

View File

@@ -140,7 +140,7 @@ def test_search(memory_instance, version, enable_graph):
assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
query="test query", vectors=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
)
memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")