Support async client (#1980)
This commit is contained in:
@@ -38,6 +38,23 @@ const client = new MemoryClient('your-api-key');
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### 3.1 Instantiate Async Client (Python only)
|
||||
|
||||
For asynchronous operations in Python, you can use the AsyncMemoryClient:
|
||||
|
||||
```python Python
|
||||
from mem0 import AsyncMemoryClient
|
||||
|
||||
client = AsyncMemoryClient(api_key="your-api-key")
|
||||
|
||||
|
||||
async def main():
|
||||
response = await client.add("I'm travelling to SF", user_id="john")
|
||||
print(response)
|
||||
|
||||
await main()
|
||||
```
|
||||
|
||||
## 4. Memory Operations
|
||||
|
||||
Mem0 provides a simple and customizable interface for performing CRUD operations on memory.
|
||||
|
||||
@@ -2,5 +2,5 @@ import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("mem0ai")
|
||||
|
||||
from mem0.client.main import MemoryClient # noqa
|
||||
from mem0.client.main import MemoryClient, AsyncMemoryClient # noqa
|
||||
from mem0.memory.main import Memory # noqa
|
||||
|
||||
@@ -56,12 +56,12 @@ class MemoryClient:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the MemoryClient.
|
||||
|
||||
Args:
|
||||
@@ -275,9 +275,7 @@ class MemoryClient:
|
||||
params = {"org_name": self.organization, "project_name": self.project}
|
||||
entities = self.users()
|
||||
for entity in entities["results"]:
|
||||
response = self.client.delete(
|
||||
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
|
||||
)
|
||||
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
capture_client_event("client.delete_users", self)
|
||||
@@ -362,3 +360,120 @@ class MemoryClient:
|
||||
kwargs["run_id"] = kwargs.pop("session_id")
|
||||
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
|
||||
class AsyncMemoryClient:
|
||||
"""Asynchronous client for interacting with the Mem0 API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
):
|
||||
self.sync_client = MemoryClient(api_key, host, organization, project)
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.sync_client.host,
|
||||
headers=self.sync_client.client.headers,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.async_client.aclose()
|
||||
|
||||
@api_error_handler
|
||||
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
|
||||
payload = self.sync_client._prepare_payload(messages, kwargs)
|
||||
response = await self.async_client.post("/v1/memories/", json=payload)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.add", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get(self, memory_id: str) -> Dict[str, Any]:
|
||||
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.get", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
params = self.sync_client._prepare_params(kwargs)
|
||||
if version == "v1":
|
||||
response = await self.async_client.get(f"/{version}/memories/", params=params)
|
||||
elif version == "v2":
|
||||
response = await self.async_client.post(f"/{version}/memories/", json=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event(
|
||||
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
|
||||
payload = {"query": query}
|
||||
payload.update(self.sync_client._prepare_params(kwargs))
|
||||
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
|
||||
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.update", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def delete(self, memory_id: str) -> Dict[str, Any]:
|
||||
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.delete", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def delete_all(self, **kwargs) -> Dict[str, str]:
|
||||
params = self.sync_client._prepare_params(kwargs)
|
||||
response = await self.async_client.delete("/v1/memories/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
|
||||
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.history", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def users(self) -> Dict[str, Any]:
|
||||
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
|
||||
response = await self.async_client.get("/v1/entities/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.users", self.sync_client)
|
||||
return response.json()
|
||||
|
||||
@api_error_handler
|
||||
async def delete_users(self) -> Dict[str, str]:
|
||||
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
|
||||
entities = await self.users()
|
||||
for entity in entities["results"]:
|
||||
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
|
||||
response.raise_for_status()
|
||||
capture_client_event("async_client.delete_users", self.sync_client)
|
||||
return {"message": "All users, agents, and sessions deleted."}
|
||||
|
||||
@api_error_handler
|
||||
async def reset(self) -> Dict[str, str]:
|
||||
await self.delete_users()
|
||||
capture_client_event("async_client.reset", self.sync_client)
|
||||
return {"message": "Client reset successful. All users and memories deleted."}
|
||||
|
||||
async def chat(self):
|
||||
raise NotImplementedError("Chat is not implemented yet")
|
||||
|
||||
@@ -73,4 +73,6 @@ class AzureConfig(BaseModel):
|
||||
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
|
||||
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
|
||||
api_version: str = Field(description="The version of the Azure API being used.", default=None)
|
||||
default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None)
|
||||
default_headers: Optional[Dict[str, str]] = Field(
|
||||
description="Headers to include in requests to the Azure API.", default=None
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
@@ -9,7 +10,7 @@ from mem0.embeddings.base import EmbeddingBase
|
||||
class GoogleGenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
self.config.model = self.config.model or "models/text-embedding-004"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 768
|
||||
|
||||
@@ -27,4 +28,4 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
response = genai.embed_content(model=self.config.model, content=text)
|
||||
return response['embedding']
|
||||
return response["embedding"]
|
||||
|
||||
@@ -14,7 +14,7 @@ class HuggingFaceEmbedding(EmbeddingBase):
|
||||
|
||||
self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
|
||||
|
||||
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
|
||||
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
@@ -26,4 +26,4 @@ class HuggingFaceEmbedding(EmbeddingBase):
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
return self.model.encode(text, convert_to_numpy = True).tolist()
|
||||
return self.model.encode(text, convert_to_numpy=True).tolist()
|
||||
|
||||
@@ -6,7 +6,9 @@ try:
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.types import content_types
|
||||
except ImportError:
|
||||
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")
|
||||
raise ImportError(
|
||||
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
@@ -44,8 +46,8 @@ class GeminiLLM(LLMBase):
|
||||
if fn := part.function_call:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": fn.name,
|
||||
"arguments": {key:val for key, val in fn.args.items()},
|
||||
"name": fn.name,
|
||||
"arguments": {key: val for key, val in fn.args.items()},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -53,7 +55,7 @@ class GeminiLLM(LLMBase):
|
||||
else:
|
||||
return response.candidates[0].content.parts[0].text
|
||||
|
||||
def _reformat_messages(self, messages : List[Dict[str, str]]):
|
||||
def _reformat_messages(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Reformat messages for Gemini.
|
||||
|
||||
@@ -71,9 +73,8 @@ class GeminiLLM(LLMBase):
|
||||
|
||||
else:
|
||||
content = message["content"]
|
||||
|
||||
new_messages.append({"parts": content,
|
||||
"role": "model" if message["role"] == "model" else "user"})
|
||||
|
||||
new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})
|
||||
|
||||
return new_messages
|
||||
|
||||
@@ -89,24 +90,24 @@ class GeminiLLM(LLMBase):
|
||||
"""
|
||||
|
||||
def remove_additional_properties(data):
|
||||
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {
|
||||
key: remove_additional_properties(value)
|
||||
for key, value in data.items()
|
||||
if not (key == "additionalProperties")
|
||||
}
|
||||
return filtered_dict
|
||||
else:
|
||||
return data
|
||||
|
||||
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {
|
||||
key: remove_additional_properties(value)
|
||||
for key, value in data.items()
|
||||
if not (key == "additionalProperties")
|
||||
}
|
||||
return filtered_dict
|
||||
else:
|
||||
return data
|
||||
|
||||
new_tools = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
func = tool['function'].copy()
|
||||
new_tools.append({"function_declarations":[remove_additional_properties(func)]})
|
||||
|
||||
func = tool["function"].copy()
|
||||
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||
|
||||
return new_tools
|
||||
else:
|
||||
return None
|
||||
@@ -142,13 +143,21 @@ class GeminiLLM(LLMBase):
|
||||
params["response_schema"] = list[response_format]
|
||||
if tool_choice:
|
||||
tool_config = content_types.to_tool_config(
|
||||
{"function_calling_config":
|
||||
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
|
||||
})
|
||||
{
|
||||
"function_calling_config": {
|
||||
"mode": tool_choice,
|
||||
"allowed_function_names": [tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response = self.client.generate_content(contents = self._reformat_messages(messages),
|
||||
tools = self._reformat_tools(tools),
|
||||
generation_config = genai.GenerationConfig(**params),
|
||||
tool_config = tool_config)
|
||||
response = self.client.generate_content(
|
||||
contents=self._reformat_messages(messages),
|
||||
tools=self._reformat_tools(tools),
|
||||
generation_config=genai.GenerationConfig(**params),
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -18,7 +18,9 @@ class OpenAILLM(LLMBase):
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(
|
||||
api_key=os.environ.get("OPENROUTER_API_KEY"),
|
||||
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
|
||||
base_url=self.config.openrouter_base_url
|
||||
or os.getenv("OPENROUTER_API_BASE")
|
||||
or "https://openrouter.ai/api/v1",
|
||||
)
|
||||
else:
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -186,7 +185,9 @@ class Memory(MemoryBase):
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
|
||||
memory_id = self._create_memory(
|
||||
data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata
|
||||
)
|
||||
returned_memories.append(
|
||||
{
|
||||
"id": memory_id,
|
||||
@@ -195,7 +196,12 @@ class Memory(MemoryBase):
|
||||
}
|
||||
)
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
|
||||
self._update_memory(
|
||||
memory_id=resp["id"],
|
||||
data=resp["text"],
|
||||
existing_embeddings=new_message_embeddings,
|
||||
metadata=metadata,
|
||||
)
|
||||
returned_memories.append(
|
||||
{
|
||||
"id": resp["id"],
|
||||
@@ -304,10 +310,14 @@ class Memory(MemoryBase):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
|
||||
executor.submit(self.graph.get_all, filters, limit)
|
||||
if self.version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
)
|
||||
|
||||
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
|
||||
concurrent.futures.wait(
|
||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||
)
|
||||
|
||||
all_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
@@ -399,7 +409,9 @@ class Memory(MemoryBase):
|
||||
else None
|
||||
)
|
||||
|
||||
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
|
||||
concurrent.futures.wait(
|
||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||
)
|
||||
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
@@ -181,9 +181,9 @@ class Completions:
|
||||
|
||||
def _format_query_with_memories(self, messages, relevant_memories):
|
||||
# Check if self.mem0_client is an instance of Memory or MemoryClient
|
||||
|
||||
|
||||
if isinstance(self.mem0_client, mem0.memory.main.Memory):
|
||||
memories_text = "\n".join(memory["memory"] for memory in relevant_memories['results'])
|
||||
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
|
||||
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
|
||||
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
|
||||
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"
|
||||
|
||||
@@ -76,11 +76,8 @@ class MilvusDB(VectorStoreBase):
|
||||
schema = CollectionSchema(fields, enable_dynamic_field=True)
|
||||
|
||||
index = self.client.prepare_index_params(
|
||||
field_name="vectors",
|
||||
metric_type=metric_type,
|
||||
index_type="AUTOINDEX",
|
||||
index_name="vector_index"
|
||||
)
|
||||
field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
|
||||
)
|
||||
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
|
||||
|
||||
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
|
||||
|
||||
Reference in New Issue
Block a user