diff --git a/docs/llms.mdx b/docs/llms.mdx index cab9475b..278c25b1 100644 --- a/docs/llms.mdx +++ b/docs/llms.mdx @@ -10,6 +10,7 @@ Mem0 includes built-in support for various popular large language models. Memory + ## OpenAI @@ -92,3 +93,33 @@ config = { m = Memory.from_config(config) m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) ``` + +## AWS Bedrock + +### Setup +- Before using the AWS Bedrock LLM, make sure you have the appropriate model access from [Bedrock Console](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess). +- You will also need to authenticate the `boto3` client by using a method in the [AWS documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials) +- You will have to export `AWS_REGION`, `AWS_ACCESS_KEY`, and `AWS_SECRET_ACCESS_KEY` to set environment variables. + +```python +import os +from mem0 import Memory + +os.environ['AWS_REGION'] = 'us-east-1' +os.environ["AWS_ACCESS_KEY"] = "xx" +os.environ["AWS_SECRET_ACCESS_KEY"] = "xx" + +config = { + "llm": { + "provider": "aws_bedrock", + "config": { + "model": "arn:aws:bedrock:us-east-1:123456789012:model/your-model-name", + "temperature": 0.2, + "max_tokens": 1500, + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py new file mode 100644 index 00000000..9e1912e2 --- /dev/null +++ b/mem0/llms/aws_bedrock.py @@ -0,0 +1,196 @@ +import os +import json +from typing import Dict, List, Optional, Any + +import boto3 + +from mem0.llms.base import LLMBase + + +class AWSBedrockLLM(LLMBase): + def __init__(self, model="cohere.command-r-v1:0"): + self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")) + self.model = model + + def _format_messages(self, messages: List[Dict[str, str]]) -> str: + """ + Formats a list of messages into the required prompt structure for the model. + + Args: + messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message. + Each dictionary contains 'role' and 'content' keys. + + Returns: + str: A formatted string combining all messages, structured with roles capitalized and separated by newlines. + """ + formatted_messages = [] + for message in messages: + role = message['role'].capitalize() + content = message['content'] + formatted_messages.append(f"\n\n{role}: {content}") + + return "".join(formatted_messages) + "\n\nAssistant:" + + def _parse_response(self, response, tools) -> str: + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "tool_calls": [] + } + + if response["output"]["message"]["content"]: + for item in response["output"]["message"]["content"]: + if "toolUse" in item: + processed_response["tool_calls"].append({ + "name": item["toolUse"]["name"], + "arguments": item["toolUse"]["input"] + }) + + return processed_response + + response_body = json.loads(response['body'].read().decode()) + return response_body.get('completion', '') + + def _prepare_input( + self, + provider: str, + model: str, + prompt: str, + model_kwargs: Optional[Dict[str, Any]] = {}, + ) -> Dict[str, Any]: + """ + Prepares the input dictionary for the specified provider's model by mapping and renaming + keys in the input based on the provider's requirements. + + Args: + provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon"). + model (str): The name or identifier of the model being used. + prompt (str): The text prompt to be processed by the model. + model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements. + + Returns: + Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider. + """ + + input_body = {"prompt": prompt, **model_kwargs} + + provider_mappings = { + "meta": {"max_tokens_to_sample": "max_gen_len"}, + "ai21": {"max_tokens_to_sample": "maxTokens", "top_p": "topP"}, + "mistral": {"max_tokens_to_sample": "max_tokens"}, + "cohere": {"max_tokens_to_sample": "max_tokens", "top_p": "p"}, + } + + if provider in provider_mappings: + for old_key, new_key in provider_mappings[provider].items(): + if old_key in input_body: + input_body[new_key] = input_body.pop(old_key) + + if provider == "cohere" and "cohere.command-r" in model: + input_body["message"] = input_body.pop("prompt") + + if provider == "amazon": + input_body = { + "inputText": prompt, + "textGenerationConfig": { + "maxTokenCount": model_kwargs.get("max_tokens_to_sample"), + "topP": model_kwargs.get("top_p"), + "temperature": model_kwargs.get("temperature") + } + } + input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None} + + return input_body + + def _convert_tool_format(self, original_tools): + """ + Converts a list of tools from their original format to a new standardized format. + + Args: + original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details. + + Returns: + list: A list of dictionaries representing the tools in the new standardized format. + """ + new_tools = [] + + for tool in original_tools: + if tool['type'] == 'function': + function = tool['function'] + new_tool = { + "toolSpec": { + "name": function['name'], + "description": function['description'], + "inputSchema": { + "json": { + "type": "object", + "properties": {}, + "required": function['parameters'].get('required', []) + } + } + } + } + + for prop, details in function['parameters'].get('properties', {}).items(): + new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = { + "type": details.get('type', 'string'), + "description": details.get('description', '') + } + + new_tools.append(new_tool) + + return new_tools + + def generate_response( + self, + messages: List[Dict[str, str]], + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using AWS Bedrock. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + + if tools: + # Use converse method when tools are provided + messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}] + tools_config = {"tools": self._convert_tool_format(tools)} + + response = self.client.converse( + modelId=self.model, + messages=messages, + toolConfig=tools_config + ) + print("Tools response: ", response) + else: + # Use invoke_model method when no tools are provided + prompt = self._format_messages(messages) + provider = self.model.split(".")[0] + input_body = self._prepare_input(provider, self.model, prompt) + body = json.dumps(input_body) + + response = self.client.invoke_model( + body=body, + modelId=self.model, + accept='application/json', + contentType='application/json' + ) + + return self._parse_response(response, tools) diff --git a/mem0/llms/groq.py b/mem0/llms/groq.py index 9d662899..948625fd 100644 --- a/mem0/llms/groq.py +++ b/mem0/llms/groq.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Optional from groq import Groq @@ -10,6 +11,34 @@ class GroqLLM(LLMBase): self.client = Groq() self.model = model + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [] + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append({ + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + }) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], @@ -37,4 +66,4 @@ class GroqLLM(LLMBase): params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) - return response + return self._parse_response(response, tools) diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index 2f614c34..f87f78d7 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Optional from openai import OpenAI @@ -9,6 +10,34 @@ class OpenAILLM(LLMBase): def __init__(self, model="gpt-4o"): self.client = OpenAI() self.model = model + + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [] + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append({ + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + }) + + return processed_response + else: + return response.choices[0].message.content def generate_response( self, @@ -37,4 +66,4 @@ class OpenAILLM(LLMBase): params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) - return response + return self._parse_response(response, tools) diff --git a/mem0/llms/together.py b/mem0/llms/together.py index 5868d9c1..e497a80f 100644 --- a/mem0/llms/together.py +++ b/mem0/llms/together.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Optional from together import Together @@ -9,6 +10,34 @@ class TogetherLLM(LLMBase): def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"): self.client = Together() self.model = model + + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [] + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append({ + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + }) + + return processed_response + else: + return response.choices[0].message.content def generate_response( self, @@ -37,4 +66,4 @@ class TogetherLLM(LLMBase): params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) - return response + return self._parse_response(response, tools) diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 89e1e0ae..c9327a4b 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -149,7 +149,6 @@ class Memory(MemoryBase): {"role": "user", "content": prompt}, ] ) - extracted_memories = extracted_memories.choices[0].message.content existing_memories = self.vector_store.search( name=self.collection_name, query=embeddings, @@ -176,8 +175,7 @@ class Memory(MemoryBase): # Add tools for noop, add, update, delete memory. tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL] response = self.llm.generate_response(messages=messages, tools=tools) - response_message = response.choices[0].message - tool_calls = response_message.tool_calls + tool_calls = response["tool_calls"] response = [] if tool_calls: @@ -188,9 +186,9 @@ class Memory(MemoryBase): "delete_memory": self._delete_memory_tool, } for tool_call in tool_calls: - function_name = tool_call.function.name + function_name = tool_call["name"] function_to_call = available_functions[function_name] - function_args = json.loads(tool_call.function.arguments) + function_args = tool_call["arguments"] logging.info( f"[openai_func] func: {function_name}, args: {function_args}" ) diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 45772361..1dc571f8 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -12,7 +12,8 @@ class LlmFactory: "ollama": "mem0.llms.ollama.py.OllamaLLM", "openai": "mem0.llms.openai.OpenAILLM", "groq": "mem0.llms.groq.GroqLLM", - "together": "mem0.llms.together.TogetherLLM" + "together": "mem0.llms.together.TogetherLLM", + "aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM" } @classmethod diff --git a/poetry.lock b/poetry.lock index ea224c2a..8a3a4db0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -227,6 +227,47 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "boto3" +version = "1.34.144" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.144-py3-none-any.whl", hash = "sha256:b8433d481d50b68a0162c0379c0dd4aabfc3d1ad901800beb5b87815997511c1"}, + {file = "boto3-1.34.144.tar.gz", hash = "sha256:2f3e88b10b8fcc5f6100a9d74cd28230edc9d4fa226d99dd40a3ab38ac213673"}, +] + +[package.dependencies] +botocore = ">=1.34.144,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.144" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.144-py3-none-any.whl", hash = "sha256:a2cf26e1bf10d5917a2285e50257bc44e94a1d16574f282f3274f7a5d8d1f08b"}, + {file = "botocore-1.34.144.tar.gz", hash = "sha256:4215db28d25309d59c99507f1f77df9089e5bebbad35f6e19c7c44ec5383a3e8"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.20.11)"] + [[package]] name = "certifi" version = "2024.7.4" @@ -1017,6 +1058,17 @@ docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alab qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "jupyter-client" version = "8.6.2" @@ -2105,6 +2157,23 @@ files = [ {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, ] +[[package]] +name = "s3transfer" +version = "0.10.2" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, + {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "setuptools" version = "70.3.0" @@ -2308,6 +2377,22 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "urllib3" +version = "1.26.19" +description = "HTTP library with thread-safe connection pooling, file post, and more." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, + {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, +] + +[package.extras] +brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "urllib3" version = "2.2.2" @@ -2457,4 +2542,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "29b68f540e0567d310cbf2f8f3137a0bbd7ecaefeaafc95273d1bdaddfeac1bd" +content-hash = "619f45c245c60ed6e534c177eeb6e1335d515e8b17ae760f867abc1d611258c5" diff --git a/pyproject.toml b/pyproject.toml index 32b95354..83bce0ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ openai = "^1.33.0" posthog = "^3.5.0" groq = "^0.9.0" together = "^1.2.1" +boto3 = "^1.34.144" [tool.poetry.group.test.dependencies] pytest = "^8.2.2" diff --git a/tests/llms/test_groq.py b/tests/llms/test_groq.py index 6c6ede3e..c5820708 100644 --- a/tests/llms/test_groq.py +++ b/tests/llms/test_groq.py @@ -27,7 +27,7 @@ def test_generate_response_without_tools(mock_groq_client): model="llama3-70b-8192", messages=messages ) - assert response.choices[0].message.content == "I'm doing well, thank you for asking!" + assert response == "I'm doing well, thank you for asking!" def test_generate_response_with_tools(mock_groq_client): @@ -54,7 +54,15 @@ def test_generate_response_with_tools(mock_groq_client): ] mock_response = Mock() - mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))] + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] mock_groq_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages, tools=tools) @@ -65,5 +73,9 @@ def test_generate_response_with_tools(mock_groq_client): tools=tools, tool_choice="auto" ) - assert response.choices[0].message.content == "Memory added successfully." + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} \ No newline at end of file diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index e63f7abb..535ba355 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -27,7 +27,7 @@ def test_generate_response_without_tools(mock_openai_client): model="gpt-4o", messages=messages ) - assert response.choices[0].message.content == "I'm doing well, thank you for asking!" + assert response == "I'm doing well, thank you for asking!" def test_generate_response_with_tools(mock_openai_client): @@ -54,7 +54,15 @@ def test_generate_response_with_tools(mock_openai_client): ] mock_response = Mock() - mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))] + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] mock_openai_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages, tools=tools) @@ -65,5 +73,9 @@ def test_generate_response_with_tools(mock_openai_client): tools=tools, tool_choice="auto" ) - assert response.choices[0].message.content == "Memory added successfully." + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} \ No newline at end of file diff --git a/tests/llms/test_together.py b/tests/llms/test_together.py index 27107e20..dad2bdca 100644 --- a/tests/llms/test_together.py +++ b/tests/llms/test_together.py @@ -27,7 +27,7 @@ def test_generate_response_without_tools(mock_together_client): model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages ) - assert response.choices[0].message.content == "I'm doing well, thank you for asking!" + assert response == "I'm doing well, thank you for asking!" def test_generate_response_with_tools(mock_together_client): @@ -54,7 +54,15 @@ def test_generate_response_with_tools(mock_together_client): ] mock_response = Mock() - mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))] + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] mock_together_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages, tools=tools) @@ -65,5 +73,9 @@ def test_generate_response_with_tools(mock_together_client): tools=tools, tool_choice="auto" ) - assert response.choices[0].message.content == "Memory added successfully." + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'} \ No newline at end of file