fix: bedrock llm, embeddings, tools, temporary creds (#3023)

This commit is contained in:
Laith Al-Saadoon
2025-06-24 10:16:06 -05:00
committed by GitHub
parent b4b27f099e
commit 8139b5887f
2 changed files with 10 additions and 14 deletions

View File

@@ -11,7 +11,6 @@ import numpy as np
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
from mem0.memory.utils import extract_json
class AWSBedrockEmbedding(EmbeddingBase): class AWSBedrockEmbedding(EmbeddingBase):
@@ -28,6 +27,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
# Get AWS config from environment variables or use defaults # Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2") aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config # Check if AWS config is provided in the config
@@ -43,6 +43,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
region_name=aws_region, region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None, aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None, aws_secret_access_key=aws_secret_key if aws_secret_key else None,
aws_session_token=aws_session_token if aws_session_token else None,
) )
def _normalize_vector(self, embeddings): def _normalize_vector(self, embeddings):
@@ -75,7 +76,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
contentType="application/json", contentType="application/json",
) )
response_body = json.loads(extract_json(response.get("body").read())) response_body = json.loads(response.get("body").read())
if provider == "cohere": if provider == "cohere":
embeddings = response_body.get("embeddings")[0] embeddings = response_body.get("embeddings")[0]

View File

@@ -10,7 +10,7 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"] PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
@@ -92,17 +92,15 @@ class AWSBedrockLLM(LLMBase):
if response["output"]["message"]["content"]: if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]: for item in response["output"]["message"]["content"]:
if "toolUse" in item: if "toolUse" in item:
processed_response["tool_calls"].append( processed_response["tool_calls"].append({
{ "name": item["toolUse"]["name"],
"name": item["toolUse"]["name"], "arguments": item["toolUse"]["input"],
"arguments": item["toolUse"]["input"], })
}
)
return processed_response return processed_response
response_body = response.get("body").read().decode() response_body = response.get("body").read().decode()
response_json = json.loads(extract_json(response_body)) response_json = json.loads(response_body)
return response_json.get("content", [{"text": ""}])[0].get("text", "") return response_json.get("content", [{"text": ""}])[0].get("text", "")
def _prepare_input( def _prepare_input(
@@ -190,10 +188,7 @@ class AWSBedrockLLM(LLMBase):
} }
for prop, details in function["parameters"].get("properties", {}).items(): for prop, details in function["parameters"].get("properties", {}).items():
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = { new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details
"type": details.get("type", "string"),
"description": details.get("description", ""),
}
new_tools.append(new_tool) new_tools.append(new_tool)