Code formatting and doc update (#2130)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -11,12 +10,12 @@ from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
self.model_kwargs = {
|
||||
"temperature": self.config.temperature,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -39,11 +38,7 @@ class GeminiLLM(LLMBase):
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": (
|
||||
content
|
||||
if (content := response.candidates[0].content.parts[0].text)
|
||||
else None
|
||||
),
|
||||
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
@@ -51,13 +46,9 @@ class GeminiLLM(LLMBase):
|
||||
if fn := part.function_call:
|
||||
if isinstance(fn, protos.FunctionCall):
|
||||
fn_call = type(fn).to_dict(fn)
|
||||
processed_response["tool_calls"].append(
|
||||
{"name": fn_call["name"], "arguments": fn_call["args"]}
|
||||
)
|
||||
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
|
||||
continue
|
||||
processed_response["tool_calls"].append(
|
||||
{"name": fn.name, "arguments": fn.args}
|
||||
)
|
||||
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
@@ -77,9 +68,7 @@ class GeminiLLM(LLMBase):
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
content = (
|
||||
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
)
|
||||
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
|
||||
else:
|
||||
content = message["content"]
|
||||
@@ -121,9 +110,7 @@ class GeminiLLM(LLMBase):
|
||||
if tools:
|
||||
for tool in tools:
|
||||
func = tool["function"].copy()
|
||||
new_tools.append(
|
||||
{"function_declarations": [remove_additional_properties(func)]}
|
||||
)
|
||||
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||
|
||||
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
||||
# return content_types.to_function_library(new_tools)
|
||||
@@ -168,9 +155,7 @@ class GeminiLLM(LLMBase):
|
||||
"function_calling_config": {
|
||||
"mode": tool_choice,
|
||||
"allowed_function_names": (
|
||||
[tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None
|
||||
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user