fix: bedrock llm, embeddings, tools, temporary creds (#3023)
This commit is contained in:
@@ -11,7 +11,6 @@ import numpy as np
|
|||||||
|
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
from mem0.embeddings.base import EmbeddingBase
|
from mem0.embeddings.base import EmbeddingBase
|
||||||
from mem0.memory.utils import extract_json
|
|
||||||
|
|
||||||
|
|
||||||
class AWSBedrockEmbedding(EmbeddingBase):
|
class AWSBedrockEmbedding(EmbeddingBase):
|
||||||
@@ -28,6 +27,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
# Get AWS config from environment variables or use defaults
|
# Get AWS config from environment variables or use defaults
|
||||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||||
|
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
|
||||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||||
|
|
||||||
# Check if AWS config is provided in the config
|
# Check if AWS config is provided in the config
|
||||||
@@ -43,6 +43,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
region_name=aws_region,
|
region_name=aws_region,
|
||||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||||
|
aws_session_token=aws_session_token if aws_session_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _normalize_vector(self, embeddings):
|
def _normalize_vector(self, embeddings):
|
||||||
@@ -75,7 +76,7 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
contentType="application/json",
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
response_body = json.loads(extract_json(response.get("body").read()))
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
embeddings = response_body.get("embeddings")[0]
|
embeddings = response_body.get("embeddings")[0]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
from mem0.memory.utils import extract_json
|
|
||||||
|
|
||||||
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||||
|
|
||||||
@@ -92,17 +92,15 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
if response["output"]["message"]["content"]:
|
if response["output"]["message"]["content"]:
|
||||||
for item in response["output"]["message"]["content"]:
|
for item in response["output"]["message"]["content"]:
|
||||||
if "toolUse" in item:
|
if "toolUse" in item:
|
||||||
processed_response["tool_calls"].append(
|
processed_response["tool_calls"].append({
|
||||||
{
|
|
||||||
"name": item["toolUse"]["name"],
|
"name": item["toolUse"]["name"],
|
||||||
"arguments": item["toolUse"]["input"],
|
"arguments": item["toolUse"]["input"],
|
||||||
}
|
})
|
||||||
)
|
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
response_body = response.get("body").read().decode()
|
response_body = response.get("body").read().decode()
|
||||||
response_json = json.loads(extract_json(response_body))
|
response_json = json.loads(response_body)
|
||||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||||
|
|
||||||
def _prepare_input(
|
def _prepare_input(
|
||||||
@@ -190,10 +188,7 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
for prop, details in function["parameters"].get("properties", {}).items():
|
for prop, details in function["parameters"].get("properties", {}).items():
|
||||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = details
|
||||||
"type": details.get("type", "string"),
|
|
||||||
"description": details.get("description", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
new_tools.append(new_tool)
|
new_tools.append(new_tool)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user