tools fix and formatting (#2441)
This commit is contained in:
@@ -618,18 +618,16 @@ class MemoryClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
|
def feedback(
|
||||||
|
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
|
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
|
||||||
|
|
||||||
feedback = feedback.upper() if feedback else None
|
feedback = feedback.upper() if feedback else None
|
||||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||||
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
||||||
|
|
||||||
data = {
|
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
|
||||||
"memory_id": memory_id,
|
|
||||||
"feedback": feedback,
|
|
||||||
"feedback_reason": feedback_reason
|
|
||||||
}
|
|
||||||
|
|
||||||
response = self.client.post("/v1/feedback/", json=data)
|
response = self.client.post("/v1/feedback/", json=data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -1019,20 +1017,18 @@ class AsyncMemoryClient:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@api_error_handler
|
@api_error_handler
|
||||||
async def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
|
async def feedback(
|
||||||
|
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
|
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
|
||||||
|
|
||||||
feedback = feedback.upper() if feedback else None
|
feedback = feedback.upper() if feedback else None
|
||||||
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
|
||||||
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
|
||||||
|
|
||||||
data = {
|
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
|
||||||
"memory_id": memory_id,
|
|
||||||
"feedback": feedback,
|
|
||||||
"feedback_reason": feedback_reason
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self.async_client.post("/v1/feedback/", json=data)
|
response = await self.async_client.post("/v1/feedback/", json=data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
capture_client_event("async_client.feedback", self.sync_client, data)
|
capture_client_event("async_client.feedback", self.sync_client, data)
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|||||||
@@ -208,11 +208,12 @@ Please note to return the IDs in the output from the input IDs only and do not g
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
|
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
|
||||||
if custom_update_memory_prompt is None:
|
if custom_update_memory_prompt is None:
|
||||||
global DEFAULT_UPDATE_MEMORY_PROMPT
|
global DEFAULT_UPDATE_MEMORY_PROMPT
|
||||||
custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT
|
custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT
|
||||||
|
|
||||||
return f"""{custom_update_memory_prompt}
|
return f"""{custom_update_memory_prompt}
|
||||||
|
|
||||||
Below is the current content of my memory which I have collected till now. You have to update it in the following format only:
|
Below is the current content of my memory which I have collected till now. You have to update it in the following format only:
|
||||||
@@ -250,4 +251,4 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust
|
|||||||
- If there is an update, the ID key should remain the same and only the value needs to be updated.
|
- If there is an update, the ID key should remain the same and only the value needs to be updated.
|
||||||
|
|
||||||
Do not return anything except the JSON format.
|
Do not return anything except the JSON format.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
|
|||||||
@field_validator("config")
|
@field_validator("config")
|
||||||
def validate_config(cls, v, values):
|
def validate_config(cls, v, values):
|
||||||
provider = values.data.get("provider")
|
provider = values.data.get("provider")
|
||||||
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]:
|
if provider in [
|
||||||
|
"openai",
|
||||||
|
"ollama",
|
||||||
|
"huggingface",
|
||||||
|
"azure_openai",
|
||||||
|
"gemini",
|
||||||
|
"vertexai",
|
||||||
|
"together",
|
||||||
|
"lmstudio",
|
||||||
|
]:
|
||||||
return v
|
return v
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding provider: {provider}")
|
raise ValueError(f"Unsupported embedding provider: {provider}")
|
||||||
|
|||||||
@@ -28,5 +28,7 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
|||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
|
response = genai.embed_content(
|
||||||
|
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
|
||||||
|
)
|
||||||
return response["embedding"]
|
return response["embedding"]
|
||||||
|
|||||||
@@ -26,8 +26,4 @@ class LMStudioEmbedding(EmbeddingBase):
|
|||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return (
|
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||||
self.client.embeddings.create(input=[text], model=self.config.model)
|
|
||||||
.data[0]
|
|
||||||
.embedding
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -17,16 +17,16 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
|
|
||||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
base_url = (
|
base_url = (
|
||||||
self.config.openai_base_url
|
self.config.openai_base_url
|
||||||
or os.getenv("OPENAI_API_BASE")
|
or os.getenv("OPENAI_API_BASE")
|
||||||
or os.getenv("OPENAI_BASE_URL")
|
or os.getenv("OPENAI_BASE_URL")
|
||||||
or "https://api.openai.com/v1"
|
or "https://api.openai.com/v1"
|
||||||
)
|
)
|
||||||
if os.environ.get("OPENAI_API_BASE"):
|
if os.environ.get("OPENAI_API_BASE"):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
|
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
|
||||||
"Please use 'OPENAI_BASE_URL' instead.",
|
"Please use 'OPENAI_BASE_URL' instead.",
|
||||||
DeprecationWarning
|
DeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
@@ -42,4 +42,8 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding
|
return (
|
||||||
|
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
|
||||||
|
.data[0]
|
||||||
|
.embedding
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@@ -36,6 +35,8 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
|||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format: Optional[str] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages using Azure OpenAI.
|
Generate a response based on the given messages using Azure OpenAI.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
@@ -17,12 +17,14 @@ class LLMBase(ABC):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_response(self, messages):
|
def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages.
|
Generate a response based on the given messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (list): List of message dicts containing 'role' and 'content'.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response.
|
str: The generated response.
|
||||||
|
|||||||
@@ -84,4 +84,4 @@ class GroqLLM(LLMBase):
|
|||||||
params["tool_choice"] = tool_choice
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return self._parse_response(response, tools)
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ class LMStudioLLM(LLMBase):
|
|||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
|
self.config.model = (
|
||||||
|
self.config.model
|
||||||
|
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
|
||||||
|
)
|
||||||
self.config.api_key = self.config.api_key or "lm-studio"
|
self.config.api_key = self.config.api_key or "lm-studio"
|
||||||
|
|
||||||
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
||||||
@@ -20,7 +23,7 @@ class LMStudioLLM(LLMBase):
|
|||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: dict = {"type": "json_object"},
|
response_format: dict = {"type": "json_object"},
|
||||||
tools: Optional[List[Dict]] = None,
|
tools: Optional[List[Dict]] = None,
|
||||||
tool_choice: str = "auto"
|
tool_choice: str = "auto",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages using LM Studio.
|
Generate a response based on the given messages using LM Studio.
|
||||||
@@ -39,10 +42,10 @@ class LMStudioLLM(LLMBase):
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_tokens,
|
||||||
"top_p": self.config.top_p
|
"top_p": self.config.top_p,
|
||||||
}
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
@@ -34,7 +35,7 @@ class OpenAILLM(LLMBase):
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
|
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
|
||||||
"Please use 'OPENAI_BASE_URL' instead.",
|
"Please use 'OPENAI_BASE_URL' instead.",
|
||||||
DeprecationWarning
|
DeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@@ -23,6 +22,8 @@ class OpenAIStructuredLLM(LLMBase):
|
|||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format: Optional[str] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages using OpenAI.
|
Generate a response based on the given messages using OpenAI.
|
||||||
|
|||||||
@@ -18,13 +18,21 @@ class XAILLM(LLMBase):
|
|||||||
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
|
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
|
def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
response_format=None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages using XAI.
|
Generate a response based on the given messages using XAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (list): List of message dicts containing 'role' and 'content'.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response.
|
str: The generated response.
|
||||||
|
|||||||
@@ -12,11 +12,14 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
||||||
|
|
||||||
from mem0.graphs.tools import (DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
from mem0.graphs.tools import (
|
||||||
DELETE_MEMORY_TOOL_GRAPH,
|
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
||||||
EXTRACT_ENTITIES_STRUCT_TOOL,
|
DELETE_MEMORY_TOOL_GRAPH,
|
||||||
EXTRACT_ENTITIES_TOOL, RELATIONS_STRUCT_TOOL,
|
EXTRACT_ENTITIES_STRUCT_TOOL,
|
||||||
RELATIONS_TOOL)
|
EXTRACT_ENTITIES_TOOL,
|
||||||
|
RELATIONS_STRUCT_TOOL,
|
||||||
|
RELATIONS_TOOL,
|
||||||
|
)
|
||||||
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
|
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
|
||||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ try:
|
|||||||
from pinecone import Pinecone, PodSpec, ServerlessSpec
|
from pinecone import Pinecone, PodSpec, ServerlessSpec
|
||||||
from pinecone.data.dataclasses.vector import Vector
|
from pinecone.data.dataclasses.vector import Vector
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None
|
raise ImportError(
|
||||||
|
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
|
||||||
|
) from None
|
||||||
|
|
||||||
from mem0.vector_stores.base import VectorStoreBase
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ class PineconeDB(VectorStoreBase):
|
|||||||
hybrid_search: bool,
|
hybrid_search: bool,
|
||||||
metric: str,
|
metric: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
extra_params: Optional[Dict[str, Any]]
|
extra_params: Optional[Dict[str, Any]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Pinecone vector store.
|
Initialize the Pinecone vector store.
|
||||||
@@ -199,7 +201,9 @@ class PineconeDB(VectorStoreBase):
|
|||||||
|
|
||||||
return pinecone_filter
|
return pinecone_filter
|
||||||
|
|
||||||
def search(self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
def search(
|
||||||
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors.
|
Search for similar vectors.
|
||||||
|
|
||||||
|
|||||||
1558
poetry.lock
generated
1558
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.76"
|
version = "0.1.77"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ def test_search(memory_instance, version, enable_graph):
|
|||||||
assert result["results"][0]["score"] == 0.9
|
assert result["results"][0]["score"] == 0.9
|
||||||
|
|
||||||
memory_instance.vector_store.search.assert_called_once_with(
|
memory_instance.vector_store.search.assert_called_once_with(
|
||||||
query=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
|
query="test query", vectors=[0.1, 0.2, 0.3], limit=100, filters={"user_id": "test_user"}
|
||||||
)
|
)
|
||||||
memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")
|
memory_instance.embedding_model.embed.assert_called_once_with("test query", "search")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user