Support model config in LLMs (#1495)
This commit is contained in:
0
mem0/configs/llms/__init__.py
Normal file
0
mem0/configs/llms/__init__.py
Normal file
34
mem0/configs/llms/base.py
Normal file
34
mem0/configs/llms/base.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class BaseLlmConfig(ABC):
|
||||||
|
"""
|
||||||
|
Config for LLMs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
max_tokens: int = 3000,
|
||||||
|
top_p: float = 1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes a configuration class instance for the LLM.
|
||||||
|
|
||||||
|
:param model: Controls the OpenAI model used, defaults to None
|
||||||
|
:type model: Optional[str], optional
|
||||||
|
:param temperature: Controls the randomness of the model's output.
|
||||||
|
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
|
||||||
|
:type temperature: float, optional
|
||||||
|
:param max_tokens: Controls how many tokens are generated, defaults to 3000
|
||||||
|
:type max_tokens: int, optional
|
||||||
|
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
|
||||||
|
defaults to 1
|
||||||
|
:type top_p: float, optional
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
@@ -5,12 +5,16 @@ from typing import Dict, List, Optional, Any
|
|||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
class AWSBedrockLLM(LLMBase):
|
||||||
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
class AWSBedrockLLM(LLMBase):
|
if not self.config.model:
|
||||||
def __init__(self, model="cohere.command-r-v1:0"):
|
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
|
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
|
||||||
self.model = model
|
self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p}
|
||||||
|
|
||||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -171,19 +175,20 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
if tools:
|
if tools:
|
||||||
# Use converse method when tools are provided
|
# Use converse method when tools are provided
|
||||||
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
|
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
|
||||||
|
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
|
||||||
tools_config = {"tools": self._convert_tool_format(tools)}
|
tools_config = {"tools": self._convert_tool_format(tools)}
|
||||||
|
|
||||||
response = self.client.converse(
|
response = self.client.converse(
|
||||||
modelId=self.model,
|
modelId=self.config.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
inferenceConfig=inference_config,
|
||||||
toolConfig=tools_config
|
toolConfig=tools_config
|
||||||
)
|
)
|
||||||
print("Tools response: ", response)
|
|
||||||
else:
|
else:
|
||||||
# Use invoke_model method when no tools are provided
|
# Use invoke_model method when no tools are provided
|
||||||
prompt = self._format_messages(messages)
|
prompt = self._format_messages(messages)
|
||||||
provider = self.model.split(".")[0]
|
provider = self.model.split(".")[0]
|
||||||
input_body = self._prepare_input(provider, self.model, prompt)
|
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(
|
||||||
|
|||||||
@@ -1,7 +1,21 @@
|
|||||||
|
from typing import Optional
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|
||||||
class LLMBase(ABC):
|
class LLMBase(ABC):
|
||||||
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
"""Initialize a base LLM class
|
||||||
|
|
||||||
|
:param config: LLM configuration option class, defaults to None
|
||||||
|
:type config: Optional[BaseLlmConfig], optional
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
self.config = BaseLlmConfig()
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_response(self, messages):
|
def generate_response(self, messages):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class LlmConfig(BaseModel):
|
|||||||
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
||||||
)
|
)
|
||||||
config: Optional[dict] = Field(
|
config: Optional[dict] = Field(
|
||||||
description="Configuration for the specific LLM", default=None
|
description="Configuration for the specific LLM", default={}
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("config")
|
@field_validator("config")
|
||||||
|
|||||||
@@ -4,12 +4,16 @@ from typing import Dict, List, Optional
|
|||||||
from groq import Groq
|
from groq import Groq
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|
||||||
class GroqLLM(LLMBase):
|
class GroqLLM(LLMBase):
|
||||||
def __init__(self, model="llama3-70b-8192"):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not self.config.model:
|
||||||
|
self.config.model="llama3-70b-8192"
|
||||||
self.client = Groq()
|
self.client = Groq()
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
def _parse_response(self, response, tools):
|
||||||
"""
|
"""
|
||||||
@@ -58,7 +62,13 @@ class GroqLLM(LLMBase):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The generated response.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {"model": self.model, "messages": messages}
|
params = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.config.temperature,
|
||||||
|
"max_tokens": self.config.max_tokens,
|
||||||
|
"top_p": self.config.top_p
|
||||||
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -4,11 +4,15 @@ from typing import Dict, List, Optional
|
|||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM(LLMBase):
|
class LiteLLM(LLMBase):
|
||||||
def __init__(self, model="gpt-4o"):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
self.model = model
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not self.config.model:
|
||||||
|
self.config.model="gpt-4o"
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
def _parse_response(self, response, tools):
|
||||||
"""
|
"""
|
||||||
@@ -57,10 +61,16 @@ class LiteLLM(LLMBase):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The generated response.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
if not litellm.supports_function_calling(self.model):
|
if not litellm.supports_function_calling(self.config.model):
|
||||||
raise ValueError(f"Model '{self.model}' in litellm does not support function calling.")
|
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||||
|
|
||||||
params = {"model": self.model, "messages": messages}
|
params = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.config.temperature,
|
||||||
|
"max_tokens": self.config.max_tokens,
|
||||||
|
"top_p": self.config.top_p
|
||||||
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
class OpenAILLM(LLMBase):
|
class OpenAILLM(LLMBase):
|
||||||
def __init__(self, model="gpt-4o"):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not self.config.model:
|
||||||
|
self.config.model="gpt-4o"
|
||||||
self.client = OpenAI()
|
self.client = OpenAI()
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
def _parse_response(self, response, tools):
|
||||||
"""
|
"""
|
||||||
@@ -58,7 +61,13 @@ class OpenAILLM(LLMBase):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The generated response.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {"model": self.model, "messages": messages}
|
params = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.config.temperature,
|
||||||
|
"max_tokens": self.config.max_tokens,
|
||||||
|
"top_p": self.config.top_p
|
||||||
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
|
|||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
class TogetherLLM(LLMBase):
|
class TogetherLLM(LLMBase):
|
||||||
def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if not self.config.model:
|
||||||
|
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
self.client = Together()
|
self.client = Together()
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
def _parse_response(self, response, tools):
|
||||||
"""
|
"""
|
||||||
@@ -58,7 +61,13 @@ class TogetherLLM(LLMBase):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The generated response.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {"model": self.model, "messages": messages}
|
params = {
|
||||||
|
"model": self.config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.config.temperature,
|
||||||
|
"max_tokens": self.config.max_tokens,
|
||||||
|
"top_p": self.config.top_p
|
||||||
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class Memory(MemoryBase):
|
|||||||
f"Unsupported vector store type: {self.config.vector_store_type}"
|
f"Unsupported vector store type: {self.config.vector_store_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm = LlmFactory.create(self.config.llm.provider)
|
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
|
||||||
self.db = SQLiteManager(self.config.history_db_path)
|
self.db = SQLiteManager(self.config.history_db_path)
|
||||||
self.collection_name = self.config.collection_name
|
self.collection_name = self.config.collection_name
|
||||||
self.vector_store.create_col(
|
self.vector_store.create_col(
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|
||||||
def load_class(class_type):
|
def load_class(class_type):
|
||||||
module_path, class_name = class_type.rsplit(".", 1)
|
module_path, class_name = class_type.rsplit(".", 1)
|
||||||
@@ -18,11 +20,12 @@ class LlmFactory:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, provider_name):
|
def create(cls, provider_name, config):
|
||||||
class_type = cls.provider_to_class.get(provider_name)
|
class_type = cls.provider_to_class.get(provider_name)
|
||||||
if class_type:
|
if class_type:
|
||||||
llm_instance = load_class(class_type)()
|
llm_instance = load_class(class_type)
|
||||||
return llm_instance
|
base_config = BaseLlmConfig(**config)
|
||||||
|
return llm_instance(base_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
raise ValueError(f"Unsupported Llm provider: {provider_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.groq import GroqLLM
|
from mem0.llms.groq import GroqLLM
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_groq_client():
|
def mock_groq_client():
|
||||||
@@ -11,7 +12,8 @@ def mock_groq_client():
|
|||||||
|
|
||||||
|
|
||||||
def test_generate_response_without_tools(mock_groq_client):
|
def test_generate_response_without_tools(mock_groq_client):
|
||||||
llm = GroqLLM()
|
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = GroqLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_groq_client):
|
|||||||
|
|
||||||
mock_groq_client.chat.completions.create.assert_called_once_with(
|
mock_groq_client.chat.completions.create.assert_called_once_with(
|
||||||
model="llama3-70b-8192",
|
model="llama3-70b-8192",
|
||||||
messages=messages
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_groq_client):
|
def test_generate_response_with_tools(mock_groq_client):
|
||||||
llm = GroqLLM()
|
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = GroqLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
||||||
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_groq_client):
|
|||||||
mock_groq_client.chat.completions.create.assert_called_once_with(
|
mock_groq_client.chat.completions.create.assert_called_once_with(
|
||||||
model="llama3-70b-8192",
|
model="llama3-70b-8192",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import pytest
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from mem0.llms import litellm
|
from mem0.llms import litellm
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_litellm():
|
def mock_litellm():
|
||||||
@@ -9,7 +10,8 @@ def mock_litellm():
|
|||||||
yield mock_litellm
|
yield mock_litellm
|
||||||
|
|
||||||
def test_generate_response_with_unsupported_model(mock_litellm):
|
def test_generate_response_with_unsupported_model(mock_litellm):
|
||||||
llm = litellm.LiteLLM(model="unsupported-model")
|
config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
|
||||||
|
llm = litellm.LiteLLM(config)
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
mock_litellm.supports_function_calling.return_value = False
|
mock_litellm.supports_function_calling.return_value = False
|
||||||
@@ -19,7 +21,8 @@ def test_generate_response_with_unsupported_model(mock_litellm):
|
|||||||
|
|
||||||
|
|
||||||
def test_generate_response_without_tools(mock_litellm):
|
def test_generate_response_without_tools(mock_litellm):
|
||||||
llm = litellm.LiteLLM()
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
||||||
|
llm = litellm.LiteLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
@@ -34,13 +37,17 @@ def test_generate_response_without_tools(mock_litellm):
|
|||||||
|
|
||||||
mock_litellm.completion.assert_called_once_with(
|
mock_litellm.completion.assert_called_once_with(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
messages=messages
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_litellm):
|
def test_generate_response_with_tools(mock_litellm):
|
||||||
llm = litellm.LiteLLM()
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
||||||
|
llm = litellm.LiteLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
||||||
@@ -80,6 +87,9 @@ def test_generate_response_with_tools(mock_litellm):
|
|||||||
mock_litellm.completion.assert_called_once_with(
|
mock_litellm.completion.assert_called_once_with(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.openai import OpenAILLM
|
from mem0.llms.openai import OpenAILLM
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_openai_client():
|
def mock_openai_client():
|
||||||
@@ -11,7 +12,8 @@ def mock_openai_client():
|
|||||||
|
|
||||||
|
|
||||||
def test_generate_response_without_tools(mock_openai_client):
|
def test_generate_response_without_tools(mock_openai_client):
|
||||||
llm = OpenAILLM()
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = OpenAILLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_openai_client):
|
|||||||
|
|
||||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
messages=messages
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_openai_client):
|
def test_generate_response_with_tools(mock_openai_client):
|
||||||
llm = OpenAILLM()
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = OpenAILLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
||||||
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_openai_client):
|
|||||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.together import TogetherLLM
|
from mem0.llms.together import TogetherLLM
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_together_client():
|
def mock_together_client():
|
||||||
@@ -11,7 +12,8 @@ def mock_together_client():
|
|||||||
|
|
||||||
|
|
||||||
def test_generate_response_without_tools(mock_together_client):
|
def test_generate_response_without_tools(mock_together_client):
|
||||||
llm = TogetherLLM()
|
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = TogetherLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_together_client):
|
|||||||
|
|
||||||
mock_together_client.chat.completions.create.assert_called_once_with(
|
mock_together_client.chat.completions.create.assert_called_once_with(
|
||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
messages=messages
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_together_client):
|
def test_generate_response_with_tools(mock_together_client):
|
||||||
llm = TogetherLLM()
|
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = TogetherLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
|
||||||
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_together_client):
|
|||||||
mock_together_client.chat.completions.create.assert_called_once_with(
|
mock_together_client.chat.completions.create.assert_called_once_with(
|
||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto"
|
tool_choice="auto"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user