Support async client (#1980)

This commit is contained in:
Dev Khant
2024-10-22 12:42:55 +05:30
committed by GitHub
parent c5d298eec8
commit fbf1d8c372
11 changed files with 213 additions and 58 deletions

View File

@@ -38,6 +38,23 @@ const client = new MemoryClient('your-api-key');
</CodeGroup>
### 3.1 Instantiate Async Client (Python only)
For asynchronous operations in Python, you can use the AsyncMemoryClient:
```python Python
from mem0 import AsyncMemoryClient
client = AsyncMemoryClient(api_key="your-api-key")
async def main():
response = await client.add("I'm travelling to SF", user_id="john")
print(response)
await main()
```
## 4. Memory Operations
Mem0 provides a simple and customizable interface for performing CRUD operations on memory.

View File

@@ -2,5 +2,5 @@ import importlib.metadata
__version__ = importlib.metadata.version("mem0ai")
from mem0.client.main import MemoryClient # noqa
from mem0.client.main import MemoryClient, AsyncMemoryClient # noqa
from mem0.memory.main import Memory # noqa

View File

@@ -56,12 +56,12 @@ class MemoryClient:
"""
def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None
):
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
"""Initialize the MemoryClient.
Args:
@@ -275,9 +275,7 @@ class MemoryClient:
params = {"org_name": self.organization, "project_name": self.project}
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
)
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("client.delete_users", self)
@@ -362,3 +360,120 @@ class MemoryClient:
kwargs["run_id"] = kwargs.pop("session_id")
return {k: v for k, v in kwargs.items() if v is not None}
class AsyncMemoryClient:
"""Asynchronous client for interacting with the Mem0 API."""
def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
self.sync_client = MemoryClient(api_key, host, organization, project)
self.async_client = httpx.AsyncClient(
base_url=self.sync_client.host,
headers=self.sync_client.client.headers,
timeout=60,
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.async_client.aclose()
@api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("async_client.add", self.sync_client)
return response.json()
@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.get", self.sync_client)
return response.json()
@api_error_handler
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
params = self.sync_client._prepare_params(kwargs)
if version == "v1":
response = await self.async_client.get(f"/{version}/memories/", params=params)
elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
)
return response.json()
@api_error_handler
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
payload = {"query": query}
payload.update(self.sync_client._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
return response.json()
@api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("async_client.update", self.sync_client)
return response.json()
@api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.delete", self.sync_client)
return response.json()
@api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
return response.json()
@api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("async_client.history", self.sync_client)
return response.json()
@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
return response.json()
@api_error_handler
async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
entities = await self.users()
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_users", self.sync_client)
return {"message": "All users, agents, and sessions deleted."}
@api_error_handler
async def reset(self) -> Dict[str, str]:
await self.delete_users()
capture_client_event("async_client.reset", self.sync_client)
return {"message": "Client reset successful. All users and memories deleted."}
async def chat(self):
raise NotImplementedError("Chat is not implemented yet")

View File

@@ -73,4 +73,6 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)
default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None)
default_headers: Optional[Dict[str, str]] = Field(
description="Headers to include in requests to the Azure API.", default=None
)

View File

@@ -1,5 +1,6 @@
import os
from typing import Optional
import google.generativeai as genai
from mem0.configs.embeddings.base import BaseEmbedderConfig
@@ -27,4 +28,4 @@ class GoogleGenAIEmbedding(EmbeddingBase):
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text)
return response['embedding']
return response["embedding"]

View File

@@ -14,7 +14,7 @@ class HuggingFaceEmbedding(EmbeddingBase):
self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
def embed(self, text):
"""
@@ -26,4 +26,4 @@ class HuggingFaceEmbedding(EmbeddingBase):
Returns:
list: The embedding vector.
"""
return self.model.encode(text, convert_to_numpy = True).tolist()
return self.model.encode(text, convert_to_numpy=True).tolist()

View File

@@ -6,7 +6,9 @@ try:
from google.generativeai import GenerativeModel
from google.generativeai.types import content_types
except ImportError:
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")
raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
)
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
@@ -44,8 +46,8 @@ class GeminiLLM(LLMBase):
if fn := part.function_call:
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": {key:val for key, val in fn.args.items()},
"name": fn.name,
"arguments": {key: val for key, val in fn.args.items()},
}
)
@@ -53,7 +55,7 @@ class GeminiLLM(LLMBase):
else:
return response.candidates[0].content.parts[0].text
def _reformat_messages(self, messages : List[Dict[str, str]]):
def _reformat_messages(self, messages: List[Dict[str, str]]):
"""
Reformat messages for Gemini.
@@ -72,8 +74,7 @@ class GeminiLLM(LLMBase):
else:
content = message["content"]
new_messages.append({"parts": content,
"role": "model" if message["role"] == "model" else "user"})
new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})
return new_messages
@@ -89,23 +90,23 @@ class GeminiLLM(LLMBase):
"""
def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
new_tools = []
if tools:
for tool in tools:
func = tool['function'].copy()
new_tools.append({"function_declarations":[remove_additional_properties(func)]})
func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
return new_tools
else:
@@ -142,13 +143,21 @@ class GeminiLLM(LLMBase):
params["response_schema"] = list[response_format]
if tool_choice:
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
})
{
"function_calling_config": {
"mode": tool_choice,
"allowed_function_names": [tool["function"]["name"] for tool in tools]
if tool_choice == "any"
else None,
}
}
)
response = self.client.generate_content(contents = self._reformat_messages(messages),
tools = self._reformat_tools(tools),
generation_config = genai.GenerationConfig(**params),
tool_config = tool_config)
response = self.client.generate_content(
contents=self._reformat_messages(messages),
tools=self._reformat_tools(tools),
generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
)
return self._parse_response(response, tools)

View File

@@ -18,7 +18,9 @@ class OpenAILLM(LLMBase):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
base_url=self.config.openrouter_base_url
or os.getenv("OPENROUTER_API_BASE")
or "https://openrouter.ai/api/v1",
)
else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")

View File

@@ -4,7 +4,6 @@ import json
import logging
import uuid
import warnings
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict
@@ -186,7 +185,9 @@ class Memory(MemoryBase):
logging.info(resp)
try:
if resp["event"] == "ADD":
memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
memory_id = self._create_memory(
data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata
)
returned_memories.append(
{
"id": memory_id,
@@ -195,7 +196,12 @@ class Memory(MemoryBase):
}
)
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
self._update_memory(
memory_id=resp["id"],
data=resp["text"],
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append(
{
"id": resp["id"],
@@ -304,10 +310,14 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
executor.submit(self.graph.get_all, filters, limit)
if self.version == "v1.1" and self.enable_graph
else None
)
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -399,7 +409,9 @@ class Memory(MemoryBase):
else None
)
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None

View File

@@ -183,7 +183,7 @@ class Completions:
# Check if self.mem0_client is an instance of Memory or MemoryClient
if isinstance(self.mem0_client, mem0.memory.main.Memory):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories['results'])
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"

View File

@@ -76,11 +76,8 @@ class MilvusDB(VectorStoreBase):
schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params(
field_name="vectors",
metric_type=metric_type,
index_type="AUTOINDEX",
index_name="vector_index"
)
field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):