Support async client (#1980)

This commit is contained in:
Dev Khant
2024-10-22 12:42:55 +05:30
committed by GitHub
parent c5d298eec8
commit fbf1d8c372
11 changed files with 213 additions and 58 deletions

View File

@@ -38,6 +38,23 @@ const client = new MemoryClient('your-api-key');
</CodeGroup> </CodeGroup>
### 3.1 Instantiate Async Client (Python only)
For asynchronous operations in Python, you can use the AsyncMemoryClient:
```python Python
from mem0 import AsyncMemoryClient
client = AsyncMemoryClient(api_key="your-api-key")
async def main():
response = await client.add("I'm travelling to SF", user_id="john")
print(response)
await main()
```
## 4. Memory Operations ## 4. Memory Operations
Mem0 provides a simple and customizable interface for performing CRUD operations on memory. Mem0 provides a simple and customizable interface for performing CRUD operations on memory.

View File

@@ -2,5 +2,5 @@ import importlib.metadata
__version__ = importlib.metadata.version("mem0ai") __version__ = importlib.metadata.version("mem0ai")
from mem0.client.main import MemoryClient # noqa from mem0.client.main import MemoryClient, AsyncMemoryClient # noqa
from mem0.memory.main import Memory # noqa from mem0.memory.main import Memory # noqa

View File

@@ -56,12 +56,12 @@ class MemoryClient:
""" """
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
host: Optional[str] = None, host: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
project: Optional[str] = None project: Optional[str] = None,
): ):
"""Initialize the MemoryClient. """Initialize the MemoryClient.
Args: Args:
@@ -275,9 +275,7 @@ class MemoryClient:
params = {"org_name": self.organization, "project_name": self.project} params = {"org_name": self.organization, "project_name": self.project}
entities = self.users() entities = self.users()
for entity in entities["results"]: for entity in entities["results"]:
response = self.client.delete( response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
)
response.raise_for_status() response.raise_for_status()
capture_client_event("client.delete_users", self) capture_client_event("client.delete_users", self)
@@ -362,3 +360,120 @@ class MemoryClient:
kwargs["run_id"] = kwargs.pop("session_id") kwargs["run_id"] = kwargs.pop("session_id")
return {k: v for k, v in kwargs.items() if v is not None} return {k: v for k, v in kwargs.items() if v is not None}
class AsyncMemoryClient:
"""Asynchronous client for interacting with the Mem0 API."""
def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
self.sync_client = MemoryClient(api_key, host, organization, project)
self.async_client = httpx.AsyncClient(
base_url=self.sync_client.host,
headers=self.sync_client.client.headers,
timeout=60,
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.async_client.aclose()
@api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("async_client.add", self.sync_client)
return response.json()
@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.get", self.sync_client)
return response.json()
@api_error_handler
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
params = self.sync_client._prepare_params(kwargs)
if version == "v1":
response = await self.async_client.get(f"/{version}/memories/", params=params)
elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
)
return response.json()
@api_error_handler
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
payload = {"query": query}
payload.update(self.sync_client._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
return response.json()
@api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("async_client.update", self.sync_client)
return response.json()
@api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.delete", self.sync_client)
return response.json()
@api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
return response.json()
@api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("async_client.history", self.sync_client)
return response.json()
@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
return response.json()
@api_error_handler
async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
entities = await self.users()
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_users", self.sync_client)
return {"message": "All users, agents, and sessions deleted."}
@api_error_handler
async def reset(self) -> Dict[str, str]:
await self.delete_users()
capture_client_event("async_client.reset", self.sync_client)
return {"message": "Client reset successful. All users and memories deleted."}
async def chat(self):
raise NotImplementedError("Chat is not implemented yet")

View File

@@ -73,4 +73,6 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None) azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None) azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None) api_version: str = Field(description="The version of the Azure API being used.", default=None)
default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None) default_headers: Optional[Dict[str, str]] = Field(
description="Headers to include in requests to the Azure API.", default=None
)

View File

@@ -1,5 +1,6 @@
import os import os
from typing import Optional from typing import Optional
import google.generativeai as genai import google.generativeai as genai
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
@@ -27,4 +28,4 @@ class GoogleGenAIEmbedding(EmbeddingBase):
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text) response = genai.embed_content(model=self.config.model, content=text)
return response['embedding'] return response["embedding"]

View File

@@ -14,7 +14,7 @@ class HuggingFaceEmbedding(EmbeddingBase):
self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs) self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
def embed(self, text): def embed(self, text):
""" """
@@ -26,4 +26,4 @@ class HuggingFaceEmbedding(EmbeddingBase):
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """
return self.model.encode(text, convert_to_numpy = True).tolist() return self.model.encode(text, convert_to_numpy=True).tolist()

View File

@@ -6,7 +6,9 @@ try:
from google.generativeai import GenerativeModel from google.generativeai import GenerativeModel
from google.generativeai.types import content_types from google.generativeai.types import content_types
except ImportError: except ImportError:
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.") raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
@@ -44,8 +46,8 @@ class GeminiLLM(LLMBase):
if fn := part.function_call: if fn := part.function_call:
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {
"name": fn.name, "name": fn.name,
"arguments": {key:val for key, val in fn.args.items()}, "arguments": {key: val for key, val in fn.args.items()},
} }
) )
@@ -53,7 +55,7 @@ class GeminiLLM(LLMBase):
else: else:
return response.candidates[0].content.parts[0].text return response.candidates[0].content.parts[0].text
def _reformat_messages(self, messages : List[Dict[str, str]]): def _reformat_messages(self, messages: List[Dict[str, str]]):
""" """
Reformat messages for Gemini. Reformat messages for Gemini.
@@ -72,8 +74,7 @@ class GeminiLLM(LLMBase):
else: else:
content = message["content"] content = message["content"]
new_messages.append({"parts": content, new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})
"role": "model" if message["role"] == "model" else "user"})
return new_messages return new_messages
@@ -89,23 +90,23 @@ class GeminiLLM(LLMBase):
""" """
def remove_additional_properties(data): def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries.""" """Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict): if isinstance(data, dict):
filtered_dict = { filtered_dict = {
key: remove_additional_properties(value) key: remove_additional_properties(value)
for key, value in data.items() for key, value in data.items()
if not (key == "additionalProperties") if not (key == "additionalProperties")
} }
return filtered_dict return filtered_dict
else: else:
return data return data
new_tools = [] new_tools = []
if tools: if tools:
for tool in tools: for tool in tools:
func = tool['function'].copy() func = tool["function"].copy()
new_tools.append({"function_declarations":[remove_additional_properties(func)]}) new_tools.append({"function_declarations": [remove_additional_properties(func)]})
return new_tools return new_tools
else: else:
@@ -142,13 +143,21 @@ class GeminiLLM(LLMBase):
params["response_schema"] = list[response_format] params["response_schema"] = list[response_format]
if tool_choice: if tool_choice:
tool_config = content_types.to_tool_config( tool_config = content_types.to_tool_config(
{"function_calling_config": {
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None} "function_calling_config": {
}) "mode": tool_choice,
"allowed_function_names": [tool["function"]["name"] for tool in tools]
if tool_choice == "any"
else None,
}
}
)
response = self.client.generate_content(contents = self._reformat_messages(messages), response = self.client.generate_content(
tools = self._reformat_tools(tools), contents=self._reformat_messages(messages),
generation_config = genai.GenerationConfig(**params), tools=self._reformat_tools(tools),
tool_config = tool_config) generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
)
return self._parse_response(response, tools) return self._parse_response(response, tools)

View File

@@ -18,7 +18,9 @@ class OpenAILLM(LLMBase):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI( self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"), api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", base_url=self.config.openrouter_base_url
or os.getenv("OPENROUTER_API_BASE")
or "https://openrouter.ai/api/v1",
) )
else: else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")

View File

@@ -4,7 +4,6 @@ import json
import logging import logging
import uuid import uuid
import warnings import warnings
from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Any, Dict from typing import Any, Dict
@@ -186,7 +185,9 @@ class Memory(MemoryBase):
logging.info(resp) logging.info(resp)
try: try:
if resp["event"] == "ADD": if resp["event"] == "ADD":
memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) memory_id = self._create_memory(
data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata
)
returned_memories.append( returned_memories.append(
{ {
"id": memory_id, "id": memory_id,
@@ -195,7 +196,12 @@ class Memory(MemoryBase):
} }
) )
elif resp["event"] == "UPDATE": elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) self._update_memory(
memory_id=resp["id"],
data=resp["text"],
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append( returned_memories.append(
{ {
"id": resp["id"], "id": resp["id"],
@@ -304,10 +310,14 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = ( future_graph_entities = (
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None executor.submit(self.graph.get_all, filters, limit)
if self.version == "v1.1" and self.enable_graph
else None
) )
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories = future_memories.result() all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -399,7 +409,9 @@ class Memory(MemoryBase):
else None else None
) )
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
original_memories = future_memories.result() original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None

View File

@@ -183,7 +183,7 @@ class Completions:
# Check if self.mem0_client is an instance of Memory or MemoryClient # Check if self.mem0_client is an instance of Memory or MemoryClient
if isinstance(self.mem0_client, mem0.memory.main.Memory): if isinstance(self.mem0_client, mem0.memory.main.Memory):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories['results']) memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient): elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories) memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}" return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"

View File

@@ -76,11 +76,8 @@ class MilvusDB(VectorStoreBase):
schema = CollectionSchema(fields, enable_dynamic_field=True) schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params( index = self.client.prepare_index_params(
field_name="vectors", field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index"
metric_type=metric_type, )
index_type="AUTOINDEX",
index_name="vector_index"
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index) self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]): def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):