From fbf1d8c372ac889777fae0cbe52345430953ee48 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Tue, 22 Oct 2024 12:42:55 +0530 Subject: [PATCH] Support async client (#1980) --- docs/platform/quickstart.mdx | 17 +++++ mem0/__init__.py | 2 +- mem0/client/main.py | 133 ++++++++++++++++++++++++++++++--- mem0/configs/base.py | 4 +- mem0/embeddings/gemini.py | 5 +- mem0/embeddings/huggingface.py | 4 +- mem0/llms/gemini.py | 67 ++++++++++------- mem0/llms/openai.py | 4 +- mem0/memory/main.py | 24 ++++-- mem0/proxy/main.py | 4 +- mem0/vector_stores/milvus.py | 7 +- 11 files changed, 213 insertions(+), 58 deletions(-) diff --git a/docs/platform/quickstart.mdx b/docs/platform/quickstart.mdx index 606aff3d..5ad0c558 100644 --- a/docs/platform/quickstart.mdx +++ b/docs/platform/quickstart.mdx @@ -38,6 +38,23 @@ const client = new MemoryClient('your-api-key'); +### 3.1 Instantiate Async Client (Python only) + +For asynchronous operations in Python, you can use the AsyncMemoryClient: + +```python Python +from mem0 import AsyncMemoryClient + +client = AsyncMemoryClient(api_key="your-api-key") + + +async def main(): + response = await client.add("I'm travelling to SF", user_id="john") + print(response) + +await main() +``` + ## 4. Memory Operations Mem0 provides a simple and customizable interface for performing CRUD operations on memory. diff --git a/mem0/__init__.py b/mem0/__init__.py index 66632ea0..ad44287b 100644 --- a/mem0/__init__.py +++ b/mem0/__init__.py @@ -2,5 +2,5 @@ import importlib.metadata __version__ = importlib.metadata.version("mem0ai") -from mem0.client.main import MemoryClient # noqa +from mem0.client.main import MemoryClient, AsyncMemoryClient # noqa from mem0.memory.main import Memory # noqa diff --git a/mem0/client/main.py b/mem0/client/main.py index 99500eff..ef90183f 100644 --- a/mem0/client/main.py +++ b/mem0/client/main.py @@ -56,12 +56,12 @@ class MemoryClient: """ def __init__( - self, - api_key: Optional[str] = None, - host: Optional[str] = None, - organization: Optional[str] = None, - project: Optional[str] = None - ): + self, + api_key: Optional[str] = None, + host: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + ): """Initialize the MemoryClient. Args: @@ -275,9 +275,7 @@ class MemoryClient: params = {"org_name": self.organization, "project_name": self.project} entities = self.users() for entity in entities["results"]: - response = self.client.delete( - f"/v1/entities/{entity['type']}/{entity['id']}/", params=params - ) + response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params) response.raise_for_status() capture_client_event("client.delete_users", self) @@ -362,3 +360,120 @@ class MemoryClient: kwargs["run_id"] = kwargs.pop("session_id") return {k: v for k, v in kwargs.items() if v is not None} + + +class AsyncMemoryClient: + """Asynchronous client for interacting with the Mem0 API.""" + + def __init__( + self, + api_key: Optional[str] = None, + host: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + ): + self.sync_client = MemoryClient(api_key, host, organization, project) + self.async_client = httpx.AsyncClient( + base_url=self.sync_client.host, + headers=self.sync_client.client.headers, + timeout=60, + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.async_client.aclose() + + @api_error_handler + async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]: + payload = self.sync_client._prepare_payload(messages, kwargs) + response = await self.async_client.post("/v1/memories/", json=payload) + response.raise_for_status() + capture_client_event("async_client.add", self.sync_client) + return response.json() + + @api_error_handler + async def get(self, memory_id: str) -> Dict[str, Any]: + response = await self.async_client.get(f"/v1/memories/{memory_id}/") + response.raise_for_status() + capture_client_event("async_client.get", self.sync_client) + return response.json() + + @api_error_handler + async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: + params = self.sync_client._prepare_params(kwargs) + if version == "v1": + response = await self.async_client.get(f"/{version}/memories/", params=params) + elif version == "v2": + response = await self.async_client.post(f"/{version}/memories/", json=params) + response.raise_for_status() + capture_client_event( + "async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)} + ) + return response.json() + + @api_error_handler + async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]: + payload = {"query": query} + payload.update(self.sync_client._prepare_params(kwargs)) + response = await self.async_client.post(f"/{version}/memories/search/", json=payload) + response.raise_for_status() + capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)}) + return response.json() + + @api_error_handler + async def update(self, memory_id: str, data: str) -> Dict[str, Any]: + response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data}) + response.raise_for_status() + capture_client_event("async_client.update", self.sync_client) + return response.json() + + @api_error_handler + async def delete(self, memory_id: str) -> Dict[str, Any]: + response = await self.async_client.delete(f"/v1/memories/{memory_id}/") + response.raise_for_status() + capture_client_event("async_client.delete", self.sync_client) + return response.json() + + @api_error_handler + async def delete_all(self, **kwargs) -> Dict[str, str]: + params = self.sync_client._prepare_params(kwargs) + response = await self.async_client.delete("/v1/memories/", params=params) + response.raise_for_status() + capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)}) + return response.json() + + @api_error_handler + async def history(self, memory_id: str) -> List[Dict[str, Any]]: + response = await self.async_client.get(f"/v1/memories/{memory_id}/history/") + response.raise_for_status() + capture_client_event("async_client.history", self.sync_client) + return response.json() + + @api_error_handler + async def users(self) -> Dict[str, Any]: + params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project} + response = await self.async_client.get("/v1/entities/", params=params) + response.raise_for_status() + capture_client_event("async_client.users", self.sync_client) + return response.json() + + @api_error_handler + async def delete_users(self) -> Dict[str, str]: + params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project} + entities = await self.users() + for entity in entities["results"]: + response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params) + response.raise_for_status() + capture_client_event("async_client.delete_users", self.sync_client) + return {"message": "All users, agents, and sessions deleted."} + + @api_error_handler + async def reset(self) -> Dict[str, str]: + await self.delete_users() + capture_client_event("async_client.reset", self.sync_client) + return {"message": "Client reset successful. All users and memories deleted."} + + async def chat(self): + raise NotImplementedError("Chat is not implemented yet") diff --git a/mem0/configs/base.py b/mem0/configs/base.py index c9293c25..55d6b2e9 100644 --- a/mem0/configs/base.py +++ b/mem0/configs/base.py @@ -73,4 +73,6 @@ class AzureConfig(BaseModel): azure_deployment: str = Field(description="The name of the Azure deployment.", default=None) azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None) api_version: str = Field(description="The version of the Azure API being used.", default=None) - default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None) + default_headers: Optional[Dict[str, str]] = Field( + description="Headers to include in requests to the Azure API.", default=None + ) diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py index 7ef429a9..210848e3 100644 --- a/mem0/embeddings/gemini.py +++ b/mem0/embeddings/gemini.py @@ -1,5 +1,6 @@ import os from typing import Optional + import google.generativeai as genai from mem0.configs.embeddings.base import BaseEmbedderConfig @@ -9,7 +10,7 @@ from mem0.embeddings.base import EmbeddingBase class GoogleGenAIEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) - + self.config.model = self.config.model or "models/text-embedding-004" self.config.embedding_dims = self.config.embedding_dims or 768 @@ -27,4 +28,4 @@ class GoogleGenAIEmbedding(EmbeddingBase): """ text = text.replace("\n", " ") response = genai.embed_content(model=self.config.model, content=text) - return response['embedding'] \ No newline at end of file + return response["embedding"] diff --git a/mem0/embeddings/huggingface.py b/mem0/embeddings/huggingface.py index d2bf5b82..b8641cac 100644 --- a/mem0/embeddings/huggingface.py +++ b/mem0/embeddings/huggingface.py @@ -14,7 +14,7 @@ class HuggingFaceEmbedding(EmbeddingBase): self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs) - self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() + self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() def embed(self, text): """ @@ -26,4 +26,4 @@ class HuggingFaceEmbedding(EmbeddingBase): Returns: list: The embedding vector. """ - return self.model.encode(text, convert_to_numpy = True).tolist() + return self.model.encode(text, convert_to_numpy=True).tolist() diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index a475226c..7fdf5e4e 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -6,7 +6,9 @@ try: from google.generativeai import GenerativeModel from google.generativeai.types import content_types except ImportError: - raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.") + raise ImportError( + "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." + ) from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase @@ -44,8 +46,8 @@ class GeminiLLM(LLMBase): if fn := part.function_call: processed_response["tool_calls"].append( { - "name": fn.name, - "arguments": {key:val for key, val in fn.args.items()}, + "name": fn.name, + "arguments": {key: val for key, val in fn.args.items()}, } ) @@ -53,7 +55,7 @@ class GeminiLLM(LLMBase): else: return response.candidates[0].content.parts[0].text - def _reformat_messages(self, messages : List[Dict[str, str]]): + def _reformat_messages(self, messages: List[Dict[str, str]]): """ Reformat messages for Gemini. @@ -71,9 +73,8 @@ class GeminiLLM(LLMBase): else: content = message["content"] - - new_messages.append({"parts": content, - "role": "model" if message["role"] == "model" else "user"}) + + new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"}) return new_messages @@ -89,24 +90,24 @@ class GeminiLLM(LLMBase): """ def remove_additional_properties(data): - """Recursively removes 'additionalProperties' from nested dictionaries.""" - - if isinstance(data, dict): - filtered_dict = { - key: remove_additional_properties(value) - for key, value in data.items() - if not (key == "additionalProperties") - } - return filtered_dict - else: - return data - + """Recursively removes 'additionalProperties' from nested dictionaries.""" + + if isinstance(data, dict): + filtered_dict = { + key: remove_additional_properties(value) + for key, value in data.items() + if not (key == "additionalProperties") + } + return filtered_dict + else: + return data + new_tools = [] if tools: for tool in tools: - func = tool['function'].copy() - new_tools.append({"function_declarations":[remove_additional_properties(func)]}) - + func = tool["function"].copy() + new_tools.append({"function_declarations": [remove_additional_properties(func)]}) + return new_tools else: return None @@ -142,13 +143,21 @@ class GeminiLLM(LLMBase): params["response_schema"] = list[response_format] if tool_choice: tool_config = content_types.to_tool_config( - {"function_calling_config": - {"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None} - }) + { + "function_calling_config": { + "mode": tool_choice, + "allowed_function_names": [tool["function"]["name"] for tool in tools] + if tool_choice == "any" + else None, + } + } + ) - response = self.client.generate_content(contents = self._reformat_messages(messages), - tools = self._reformat_tools(tools), - generation_config = genai.GenerationConfig(**params), - tool_config = tool_config) + response = self.client.generate_content( + contents=self._reformat_messages(messages), + tools=self._reformat_tools(tools), + generation_config=genai.GenerationConfig(**params), + tool_config=tool_config, + ) return self._parse_response(response, tools) diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index c585162f..a9c302f8 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -18,7 +18,9 @@ class OpenAILLM(LLMBase): if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter self.client = OpenAI( api_key=os.environ.get("OPENROUTER_API_KEY"), - base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", + base_url=self.config.openrouter_base_url + or os.getenv("OPENROUTER_API_BASE") + or "https://openrouter.ai/api/v1", ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") diff --git a/mem0/memory/main.py b/mem0/memory/main.py index c8c3c229..c4e7b212 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -4,7 +4,6 @@ import json import logging import uuid import warnings -from collections import defaultdict from datetime import datetime from typing import Any, Dict @@ -186,7 +185,9 @@ class Memory(MemoryBase): logging.info(resp) try: if resp["event"] == "ADD": - memory_id = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) + memory_id = self._create_memory( + data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata + ) returned_memories.append( { "id": memory_id, @@ -195,7 +196,12 @@ class Memory(MemoryBase): } ) elif resp["event"] == "UPDATE": - self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) + self._update_memory( + memory_id=resp["id"], + data=resp["text"], + existing_embeddings=new_message_embeddings, + metadata=metadata, + ) returned_memories.append( { "id": resp["id"], @@ -304,10 +310,14 @@ class Memory(MemoryBase): with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_graph_entities = ( - executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None + executor.submit(self.graph.get_all, filters, limit) + if self.version == "v1.1" and self.enable_graph + else None ) - concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) + concurrent.futures.wait( + [future_memories, future_graph_entities] if future_graph_entities else [future_memories] + ) all_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None @@ -399,7 +409,9 @@ class Memory(MemoryBase): else None ) - concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) + concurrent.futures.wait( + [future_memories, future_graph_entities] if future_graph_entities else [future_memories] + ) original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None diff --git a/mem0/proxy/main.py b/mem0/proxy/main.py index d1db29b2..f52177d0 100644 --- a/mem0/proxy/main.py +++ b/mem0/proxy/main.py @@ -181,9 +181,9 @@ class Completions: def _format_query_with_memories(self, messages, relevant_memories): # Check if self.mem0_client is an instance of Memory or MemoryClient - + if isinstance(self.mem0_client, mem0.memory.main.Memory): - memories_text = "\n".join(memory["memory"] for memory in relevant_memories['results']) + memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"]) elif isinstance(self.mem0_client, mem0.client.main.MemoryClient): memories_text = "\n".join(memory["memory"] for memory in relevant_memories) return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}" diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py index a2fb8002..013fc0e3 100644 --- a/mem0/vector_stores/milvus.py +++ b/mem0/vector_stores/milvus.py @@ -76,11 +76,8 @@ class MilvusDB(VectorStoreBase): schema = CollectionSchema(fields, enable_dynamic_field=True) index = self.client.prepare_index_params( - field_name="vectors", - metric_type=metric_type, - index_type="AUTOINDEX", - index_name="vector_index" - ) + field_name="vectors", metric_type=metric_type, index_type="AUTOINDEX", index_name="vector_index" + ) self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index) def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):