Add AWS Bedrock support (#1482)
This commit is contained in:
@@ -10,6 +10,7 @@ Mem0 includes built-in support for various popular large language models. Memory
|
|||||||
<Card title="OpenAI" href="#openai"></Card>
|
<Card title="OpenAI" href="#openai"></Card>
|
||||||
<Card title="Groq" href="#groq"></Card>
|
<Card title="Groq" href="#groq"></Card>
|
||||||
<Card title="Together" href="#together"></Card>
|
<Card title="Together" href="#together"></Card>
|
||||||
|
<Card title="AWS Bedrock" href="#aws_bedrock"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## OpenAI
|
## OpenAI
|
||||||
@@ -92,3 +93,33 @@ config = {
|
|||||||
m = Memory.from_config(config)
|
m = Memory.from_config(config)
|
||||||
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## AWS Bedrock
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
- Before using the AWS Bedrock LLM, make sure you have the appropriate model access from [Bedrock Console](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess).
|
||||||
|
- You will also need to authenticate the `boto3` client by using a method in the [AWS documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials)
|
||||||
|
- You will have to export `AWS_REGION`, `AWS_ACCESS_KEY`, and `AWS_SECRET_ACCESS_KEY` to set environment variables.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ['AWS_REGION'] = 'us-east-1'
|
||||||
|
os.environ["AWS_ACCESS_KEY"] = "xx"
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = "xx"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"llm": {
|
||||||
|
"provider": "aws_bedrock",
|
||||||
|
"config": {
|
||||||
|
"model": "arn:aws:bedrock:us-east-1:123456789012:model/your-model-name",
|
||||||
|
"temperature": 0.2,
|
||||||
|
"max_tokens": 1500,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
|
```
|
||||||
|
|||||||
196
mem0/llms/aws_bedrock.py
Normal file
196
mem0/llms/aws_bedrock.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
|
class AWSBedrockLLM(LLMBase):
|
||||||
|
def __init__(self, model="cohere.command-r-v1:0"):
|
||||||
|
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||||
|
"""
|
||||||
|
Formats a list of messages into the required prompt structure for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
|
||||||
|
Each dictionary contains 'role' and 'content' keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
||||||
|
"""
|
||||||
|
formatted_messages = []
|
||||||
|
for message in messages:
|
||||||
|
role = message['role'].capitalize()
|
||||||
|
content = message['content']
|
||||||
|
formatted_messages.append(f"\n\n{role}: {content}")
|
||||||
|
|
||||||
|
return "".join(formatted_messages) + "\n\nAssistant:"
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools) -> str:
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"tool_calls": []
|
||||||
|
}
|
||||||
|
|
||||||
|
if response["output"]["message"]["content"]:
|
||||||
|
for item in response["output"]["message"]["content"]:
|
||||||
|
if "toolUse" in item:
|
||||||
|
processed_response["tool_calls"].append({
|
||||||
|
"name": item["toolUse"]["name"],
|
||||||
|
"arguments": item["toolUse"]["input"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
|
||||||
|
response_body = json.loads(response['body'].read().decode())
|
||||||
|
return response_body.get('completion', '')
|
||||||
|
|
||||||
|
def _prepare_input(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Prepares the input dictionary for the specified provider's model by mapping and renaming
|
||||||
|
keys in the input based on the provider's requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
||||||
|
model (str): The name or identifier of the model being used.
|
||||||
|
prompt (str): The text prompt to be processed by the model.
|
||||||
|
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_body = {"prompt": prompt, **model_kwargs}
|
||||||
|
|
||||||
|
provider_mappings = {
|
||||||
|
"meta": {"max_tokens_to_sample": "max_gen_len"},
|
||||||
|
"ai21": {"max_tokens_to_sample": "maxTokens", "top_p": "topP"},
|
||||||
|
"mistral": {"max_tokens_to_sample": "max_tokens"},
|
||||||
|
"cohere": {"max_tokens_to_sample": "max_tokens", "top_p": "p"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider in provider_mappings:
|
||||||
|
for old_key, new_key in provider_mappings[provider].items():
|
||||||
|
if old_key in input_body:
|
||||||
|
input_body[new_key] = input_body.pop(old_key)
|
||||||
|
|
||||||
|
if provider == "cohere" and "cohere.command-r" in model:
|
||||||
|
input_body["message"] = input_body.pop("prompt")
|
||||||
|
|
||||||
|
if provider == "amazon":
|
||||||
|
input_body = {
|
||||||
|
"inputText": prompt,
|
||||||
|
"textGenerationConfig": {
|
||||||
|
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
|
||||||
|
"topP": model_kwargs.get("top_p"),
|
||||||
|
"temperature": model_kwargs.get("temperature")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None}
|
||||||
|
|
||||||
|
return input_body
|
||||||
|
|
||||||
|
def _convert_tool_format(self, original_tools):
|
||||||
|
"""
|
||||||
|
Converts a list of tools from their original format to a new standardized format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of dictionaries representing the tools in the new standardized format.
|
||||||
|
"""
|
||||||
|
new_tools = []
|
||||||
|
|
||||||
|
for tool in original_tools:
|
||||||
|
if tool['type'] == 'function':
|
||||||
|
function = tool['function']
|
||||||
|
new_tool = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": function['name'],
|
||||||
|
"description": function['description'],
|
||||||
|
"inputSchema": {
|
||||||
|
"json": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": function['parameters'].get('required', [])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for prop, details in function['parameters'].get('properties', {}).items():
|
||||||
|
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
||||||
|
"type": details.get('type', 'string'),
|
||||||
|
"description": details.get('description', '')
|
||||||
|
}
|
||||||
|
|
||||||
|
new_tools.append(new_tool)
|
||||||
|
|
||||||
|
return new_tools
|
||||||
|
|
||||||
|
def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a response based on the given messages using AWS Bedrock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Use converse method when tools are provided
|
||||||
|
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
|
||||||
|
tools_config = {"tools": self._convert_tool_format(tools)}
|
||||||
|
|
||||||
|
response = self.client.converse(
|
||||||
|
modelId=self.model,
|
||||||
|
messages=messages,
|
||||||
|
toolConfig=tools_config
|
||||||
|
)
|
||||||
|
print("Tools response: ", response)
|
||||||
|
else:
|
||||||
|
# Use invoke_model method when no tools are provided
|
||||||
|
prompt = self._format_messages(messages)
|
||||||
|
provider = self.model.split(".")[0]
|
||||||
|
input_body = self._prepare_input(provider, self.model, prompt)
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=self.model,
|
||||||
|
accept='application/json',
|
||||||
|
contentType='application/json'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_response(response, tools)
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
@@ -10,6 +11,34 @@ class GroqLLM(LLMBase):
|
|||||||
self.client = Groq()
|
self.client = Groq()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": []
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append({
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments)
|
||||||
|
})
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
@@ -37,4 +66,4 @@ class GroqLLM(LLMBase):
|
|||||||
params["tool_choice"] = tool_choice
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -9,6 +10,34 @@ class OpenAILLM(LLMBase):
|
|||||||
def __init__(self, model="gpt-4o"):
|
def __init__(self, model="gpt-4o"):
|
||||||
self.client = OpenAI()
|
self.client = OpenAI()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": []
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append({
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments)
|
||||||
|
})
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
@@ -37,4 +66,4 @@ class OpenAILLM(LLMBase):
|
|||||||
params["tool_choice"] = tool_choice
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
@@ -9,6 +10,34 @@ class TogetherLLM(LLMBase):
|
|||||||
def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
|
def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
|
||||||
self.client = Together()
|
self.client = Together()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": []
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append({
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments)
|
||||||
|
})
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
@@ -37,4 +66,4 @@ class TogetherLLM(LLMBase):
|
|||||||
params["tool_choice"] = tool_choice
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -149,7 +149,6 @@ class Memory(MemoryBase):
|
|||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
extracted_memories = extracted_memories.choices[0].message.content
|
|
||||||
existing_memories = self.vector_store.search(
|
existing_memories = self.vector_store.search(
|
||||||
name=self.collection_name,
|
name=self.collection_name,
|
||||||
query=embeddings,
|
query=embeddings,
|
||||||
@@ -176,8 +175,7 @@ class Memory(MemoryBase):
|
|||||||
# Add tools for noop, add, update, delete memory.
|
# Add tools for noop, add, update, delete memory.
|
||||||
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
|
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
|
||||||
response = self.llm.generate_response(messages=messages, tools=tools)
|
response = self.llm.generate_response(messages=messages, tools=tools)
|
||||||
response_message = response.choices[0].message
|
tool_calls = response["tool_calls"]
|
||||||
tool_calls = response_message.tool_calls
|
|
||||||
|
|
||||||
response = []
|
response = []
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
@@ -188,9 +186,9 @@ class Memory(MemoryBase):
|
|||||||
"delete_memory": self._delete_memory_tool,
|
"delete_memory": self._delete_memory_tool,
|
||||||
}
|
}
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
function_name = tool_call.function.name
|
function_name = tool_call["name"]
|
||||||
function_to_call = available_functions[function_name]
|
function_to_call = available_functions[function_name]
|
||||||
function_args = json.loads(tool_call.function.arguments)
|
function_args = tool_call["arguments"]
|
||||||
logging.info(
|
logging.info(
|
||||||
f"[openai_func] func: {function_name}, args: {function_args}"
|
f"[openai_func] func: {function_name}, args: {function_args}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ class LlmFactory:
|
|||||||
"ollama": "mem0.llms.ollama.py.OllamaLLM",
|
"ollama": "mem0.llms.ollama.py.OllamaLLM",
|
||||||
"openai": "mem0.llms.openai.OpenAILLM",
|
"openai": "mem0.llms.openai.OpenAILLM",
|
||||||
"groq": "mem0.llms.groq.GroqLLM",
|
"groq": "mem0.llms.groq.GroqLLM",
|
||||||
"together": "mem0.llms.together.TogetherLLM"
|
"together": "mem0.llms.together.TogetherLLM",
|
||||||
|
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM"
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
87
poetry.lock
generated
87
poetry.lock
generated
@@ -227,6 +227,47 @@ files = [
|
|||||||
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
|
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "boto3"
|
||||||
|
version = "1.34.144"
|
||||||
|
description = "The AWS SDK for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "boto3-1.34.144-py3-none-any.whl", hash = "sha256:b8433d481d50b68a0162c0379c0dd4aabfc3d1ad901800beb5b87815997511c1"},
|
||||||
|
{file = "boto3-1.34.144.tar.gz", hash = "sha256:2f3e88b10b8fcc5f6100a9d74cd28230edc9d4fa226d99dd40a3ab38ac213673"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
botocore = ">=1.34.144,<1.35.0"
|
||||||
|
jmespath = ">=0.7.1,<2.0.0"
|
||||||
|
s3transfer = ">=0.10.0,<0.11.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "botocore"
|
||||||
|
version = "1.34.144"
|
||||||
|
description = "Low-level, data-driven core of boto 3."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "botocore-1.34.144-py3-none-any.whl", hash = "sha256:a2cf26e1bf10d5917a2285e50257bc44e94a1d16574f282f3274f7a5d8d1f08b"},
|
||||||
|
{file = "botocore-1.34.144.tar.gz", hash = "sha256:4215db28d25309d59c99507f1f77df9089e5bebbad35f6e19c7c44ec5383a3e8"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
jmespath = ">=0.7.1,<2.0.0"
|
||||||
|
python-dateutil = ">=2.1,<3.0.0"
|
||||||
|
urllib3 = [
|
||||||
|
{version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
|
||||||
|
{version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
crt = ["awscrt (==0.20.11)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2024.7.4"
|
version = "2024.7.4"
|
||||||
@@ -1017,6 +1058,17 @@ docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alab
|
|||||||
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
||||||
testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
|
testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jmespath"
|
||||||
|
version = "1.0.1"
|
||||||
|
description = "JSON Matching Expressions"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
|
||||||
|
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jupyter-client"
|
name = "jupyter-client"
|
||||||
version = "8.6.2"
|
version = "8.6.2"
|
||||||
@@ -2105,6 +2157,23 @@ files = [
|
|||||||
{file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"},
|
{file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "s3transfer"
|
||||||
|
version = "0.10.2"
|
||||||
|
description = "An Amazon S3 Transfer Manager"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"},
|
||||||
|
{file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
botocore = ">=1.33.2,<2.0a.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools"
|
name = "setuptools"
|
||||||
version = "70.3.0"
|
version = "70.3.0"
|
||||||
@@ -2308,6 +2377,22 @@ files = [
|
|||||||
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "urllib3"
|
||||||
|
version = "1.26.19"
|
||||||
|
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||||
|
optional = false
|
||||||
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
||||||
|
files = [
|
||||||
|
{file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"},
|
||||||
|
{file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
|
||||||
|
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
|
||||||
|
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "urllib3"
|
name = "urllib3"
|
||||||
version = "2.2.2"
|
version = "2.2.2"
|
||||||
@@ -2457,4 +2542,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8"
|
python-versions = "^3.8"
|
||||||
content-hash = "29b68f540e0567d310cbf2f8f3137a0bbd7ecaefeaafc95273d1bdaddfeac1bd"
|
content-hash = "619f45c245c60ed6e534c177eeb6e1335d515e8b17ae760f867abc1d611258c5"
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ openai = "^1.33.0"
|
|||||||
posthog = "^3.5.0"
|
posthog = "^3.5.0"
|
||||||
groq = "^0.9.0"
|
groq = "^0.9.0"
|
||||||
together = "^1.2.1"
|
together = "^1.2.1"
|
||||||
|
boto3 = "^1.34.144"
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
pytest = "^8.2.2"
|
pytest = "^8.2.2"
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def test_generate_response_without_tools(mock_groq_client):
|
|||||||
model="llama3-70b-8192",
|
model="llama3-70b-8192",
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.content == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_groq_client):
|
def test_generate_response_with_tools(mock_groq_client):
|
||||||
@@ -54,7 +54,15 @@ def test_generate_response_with_tools(mock_groq_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))]
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
mock_groq_client.chat.completions.create.return_value = mock_response
|
mock_groq_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages, tools=tools)
|
response = llm.generate_response(messages, tools=tools)
|
||||||
@@ -65,5 +73,9 @@ def test_generate_response_with_tools(mock_groq_client):
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.content == "Memory added successfully."
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ def test_generate_response_without_tools(mock_openai_client):
|
|||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.content == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_openai_client):
|
def test_generate_response_with_tools(mock_openai_client):
|
||||||
@@ -54,7 +54,15 @@ def test_generate_response_with_tools(mock_openai_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))]
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages, tools=tools)
|
response = llm.generate_response(messages, tools=tools)
|
||||||
@@ -65,5 +73,9 @@ def test_generate_response_with_tools(mock_openai_client):
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.content == "Memory added successfully."
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ def test_generate_response_without_tools(mock_together_client):
|
|||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.content == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_together_client):
|
def test_generate_response_with_tools(mock_together_client):
|
||||||
@@ -54,7 +54,15 @@ def test_generate_response_with_tools(mock_together_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))]
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
mock_together_client.chat.completions.create.return_value = mock_response
|
mock_together_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages, tools=tools)
|
response = llm.generate_response(messages, tools=tools)
|
||||||
@@ -65,5 +73,9 @@ def test_generate_response_with_tools(mock_together_client):
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
assert response.choices[0].message.content == "Memory added successfully."
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {'data': 'Today is a sunny day.'}
|
||||||
|
|
||||||
Reference in New Issue
Block a user