Support model config in LLMs (#1495)

This commit is contained in:
Dev Khant
2024-07-18 21:51:40 +05:30
committed by GitHub
parent c411dc294e
commit 40c9abe484
15 changed files with 172 additions and 41 deletions

View File

34
mem0/configs/llms/base.py Normal file
View File

@@ -0,0 +1,34 @@
from abc import ABC
from typing import Optional
class BaseLlmConfig(ABC):
"""
Config for LLMs.
"""
def __init__(
self,
model: Optional[str] = None,
temperature: float = 0,
max_tokens: int = 3000,
top_p: float = 1
):
"""
Initializes a configuration class instance for the LLM.
:param model: Controls the OpenAI model used, defaults to None
:type model: Optional[str], optional
:param temperature: Controls the randomness of the model's output.
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
:type temperature: float, optional
:param max_tokens: Controls how many tokens are generated, defaults to 3000
:type max_tokens: int, optional
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
defaults to 1
:type top_p: float, optional
"""
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p

View File

@@ -5,12 +5,16 @@ from typing import Dict, List, Optional, Any
import boto3 import boto3
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AWSBedrockLLM(LLMBase): class AWSBedrockLLM(LLMBase):
def __init__(self, model="cohere.command-r-v1:0"): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")) self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
self.model = model self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p}
def _format_messages(self, messages: List[Dict[str, str]]) -> str: def _format_messages(self, messages: List[Dict[str, str]]) -> str:
""" """
@@ -171,19 +175,20 @@ class AWSBedrockLLM(LLMBase):
if tools: if tools:
# Use converse method when tools are provided # Use converse method when tools are provided
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}] messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
tools_config = {"tools": self._convert_tool_format(tools)} tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse( response = self.client.converse(
modelId=self.model, modelId=self.config.model,
messages=messages, messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config toolConfig=tools_config
) )
print("Tools response: ", response)
else: else:
# Use invoke_model method when no tools are provided # Use invoke_model method when no tools are provided
prompt = self._format_messages(messages) prompt = self._format_messages(messages)
provider = self.model.split(".")[0] provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.model, prompt) input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body) body = json.dumps(input_body)
response = self.client.invoke_model( response = self.client.invoke_model(

View File

@@ -1,7 +1,21 @@
from typing import Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from mem0.configs.llms.base import BaseLlmConfig
class LLMBase(ABC): class LLMBase(ABC):
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""Initialize a base LLM class
:param config: LLM configuration option class, defaults to None
:type config: Optional[BaseLlmConfig], optional
"""
if config is None:
self.config = BaseLlmConfig()
else:
self.config = config
@abstractmethod @abstractmethod
def generate_response(self, messages): def generate_response(self, messages):
""" """

View File

@@ -8,7 +8,7 @@ class LlmConfig(BaseModel):
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai" description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
) )
config: Optional[dict] = Field( config: Optional[dict] = Field(
description="Configuration for the specific LLM", default=None description="Configuration for the specific LLM", default={}
) )
@field_validator("config") @field_validator("config")

View File

@@ -4,12 +4,16 @@ from typing import Dict, List, Optional
from groq import Groq from groq import Groq
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class GroqLLM(LLMBase): class GroqLLM(LLMBase):
def __init__(self, model="llama3-70b-8192"): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="llama3-70b-8192"
self.client = Groq() self.client = Groq()
self.model = model
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
@@ -58,7 +62,13 @@ class GroqLLM(LLMBase):
Returns: Returns:
str: The generated response. str: The generated response.
""" """
params = {"model": self.model, "messages": messages} params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
}
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools:

View File

@@ -4,11 +4,15 @@ from typing import Dict, List, Optional
import litellm import litellm
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class LiteLLM(LLMBase): class LiteLLM(LLMBase):
def __init__(self, model="gpt-4o"): def __init__(self, config: Optional[BaseLlmConfig] = None):
self.model = model super().__init__(config)
if not self.config.model:
self.config.model="gpt-4o"
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
@@ -57,10 +61,16 @@ class LiteLLM(LLMBase):
Returns: Returns:
str: The generated response. str: The generated response.
""" """
if not litellm.supports_function_calling(self.model): if not litellm.supports_function_calling(self.config.model):
raise ValueError(f"Model '{self.model}' in litellm does not support function calling.") raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
params = {"model": self.model, "messages": messages} params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
}
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools:

View File

@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
from openai import OpenAI from openai import OpenAI
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OpenAILLM(LLMBase): class OpenAILLM(LLMBase):
def __init__(self, model="gpt-4o"): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="gpt-4o"
self.client = OpenAI() self.client = OpenAI()
self.model = model
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
@@ -58,7 +61,13 @@ class OpenAILLM(LLMBase):
Returns: Returns:
str: The generated response. str: The generated response.
""" """
params = {"model": self.model, "messages": messages} params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
}
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools:

View File

@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
from together import Together from together import Together
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class TogetherLLM(LLMBase): class TogetherLLM(LLMBase):
def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
self.client = Together() self.client = Together()
self.model = model
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
@@ -58,7 +61,13 @@ class TogetherLLM(LLMBase):
Returns: Returns:
str: The generated response. str: The generated response.
""" """
params = {"model": self.model, "messages": messages} params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
}
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: if tools:

View File

@@ -82,7 +82,7 @@ class Memory(MemoryBase):
f"Unsupported vector store type: {self.config.vector_store_type}" f"Unsupported vector store type: {self.config.vector_store_type}"
) )
self.llm = LlmFactory.create(self.config.llm.provider) self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path) self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.collection_name self.collection_name = self.config.collection_name
self.vector_store.create_col( self.vector_store.create_col(

View File

@@ -1,5 +1,7 @@
import importlib import importlib
from mem0.configs.llms.base import BaseLlmConfig
def load_class(class_type): def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1) module_path, class_name = class_type.rsplit(".", 1)
@@ -18,11 +20,12 @@ class LlmFactory:
} }
@classmethod @classmethod
def create(cls, provider_name): def create(cls, provider_name, config):
class_type = cls.provider_to_class.get(provider_name) class_type = cls.provider_to_class.get(provider_name)
if class_type: if class_type:
llm_instance = load_class(class_type)() llm_instance = load_class(class_type)
return llm_instance base_config = BaseLlmConfig(**config)
return llm_instance(base_config)
else: else:
raise ValueError(f"Unsupported Llm provider: {provider_name}") raise ValueError(f"Unsupported Llm provider: {provider_name}")

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.groq import GroqLLM from mem0.llms.groq import GroqLLM
from mem0.configs.llms.base import BaseLlmConfig
@pytest.fixture @pytest.fixture
def mock_groq_client(): def mock_groq_client():
@@ -11,7 +12,8 @@ def mock_groq_client():
def test_generate_response_without_tools(mock_groq_client): def test_generate_response_without_tools(mock_groq_client):
llm = GroqLLM() config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GroqLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"} {"role": "user", "content": "Hello, how are you?"}
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_groq_client):
mock_groq_client.chat.completions.create.assert_called_once_with( mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192", model="llama3-70b-8192",
messages=messages messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_groq_client): def test_generate_response_with_tools(mock_groq_client):
llm = GroqLLM() config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GroqLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."} {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_groq_client):
mock_groq_client.chat.completions.create.assert_called_once_with( mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192", model="llama3-70b-8192",
messages=messages, messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools, tools=tools,
tool_choice="auto" tool_choice="auto"
) )

View File

@@ -2,6 +2,7 @@ import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms import litellm from mem0.llms import litellm
from mem0.configs.llms.base import BaseLlmConfig
@pytest.fixture @pytest.fixture
def mock_litellm(): def mock_litellm():
@@ -9,7 +10,8 @@ def mock_litellm():
yield mock_litellm yield mock_litellm
def test_generate_response_with_unsupported_model(mock_litellm): def test_generate_response_with_unsupported_model(mock_litellm):
llm = litellm.LiteLLM(model="unsupported-model") config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
mock_litellm.supports_function_calling.return_value = False mock_litellm.supports_function_calling.return_value = False
@@ -19,7 +21,8 @@ def test_generate_response_with_unsupported_model(mock_litellm):
def test_generate_response_without_tools(mock_litellm): def test_generate_response_without_tools(mock_litellm):
llm = litellm.LiteLLM() config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"} {"role": "user", "content": "Hello, how are you?"}
@@ -34,13 +37,17 @@ def test_generate_response_without_tools(mock_litellm):
mock_litellm.completion.assert_called_once_with( mock_litellm.completion.assert_called_once_with(
model="gpt-4o", model="gpt-4o",
messages=messages messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_litellm): def test_generate_response_with_tools(mock_litellm):
llm = litellm.LiteLLM() config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."} {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -80,6 +87,9 @@ def test_generate_response_with_tools(mock_litellm):
mock_litellm.completion.assert_called_once_with( mock_litellm.completion.assert_called_once_with(
model="gpt-4o", model="gpt-4o",
messages=messages, messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1,
tools=tools, tools=tools,
tool_choice="auto" tool_choice="auto"
) )

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.openai import OpenAILLM from mem0.llms.openai import OpenAILLM
from mem0.configs.llms.base import BaseLlmConfig
@pytest.fixture @pytest.fixture
def mock_openai_client(): def mock_openai_client():
@@ -11,7 +12,8 @@ def mock_openai_client():
def test_generate_response_without_tools(mock_openai_client): def test_generate_response_without_tools(mock_openai_client):
llm = OpenAILLM() config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"} {"role": "user", "content": "Hello, how are you?"}
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_openai_client):
mock_openai_client.chat.completions.create.assert_called_once_with( mock_openai_client.chat.completions.create.assert_called_once_with(
model="gpt-4o", model="gpt-4o",
messages=messages messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_openai_client): def test_generate_response_with_tools(mock_openai_client):
llm = OpenAILLM() config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."} {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_openai_client):
mock_openai_client.chat.completions.create.assert_called_once_with( mock_openai_client.chat.completions.create.assert_called_once_with(
model="gpt-4o", model="gpt-4o",
messages=messages, messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools, tools=tools,
tool_choice="auto" tool_choice="auto"
) )

View File

@@ -1,6 +1,7 @@
import pytest import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.together import TogetherLLM from mem0.llms.together import TogetherLLM
from mem0.configs.llms.base import BaseLlmConfig
@pytest.fixture @pytest.fixture
def mock_together_client(): def mock_together_client():
@@ -11,7 +12,8 @@ def mock_together_client():
def test_generate_response_without_tools(mock_together_client): def test_generate_response_without_tools(mock_together_client):
llm = TogetherLLM() config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
llm = TogetherLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"} {"role": "user", "content": "Hello, how are you?"}
@@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_together_client):
mock_together_client.chat.completions.create.assert_called_once_with( mock_together_client.chat.completions.create.assert_called_once_with(
model="mistralai/Mixtral-8x7B-Instruct-v0.1", model="mistralai/Mixtral-8x7B-Instruct-v0.1",
messages=messages messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_together_client): def test_generate_response_with_tools(mock_together_client):
llm = TogetherLLM() config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
llm = TogetherLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."} {"role": "user", "content": "Add a new memory: Today is a sunny day."}
@@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_together_client):
mock_together_client.chat.completions.create.assert_called_once_with( mock_together_client.chat.completions.create.assert_called_once_with(
model="mistralai/Mixtral-8x7B-Instruct-v0.1", model="mistralai/Mixtral-8x7B-Instruct-v0.1",
messages=messages, messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools, tools=tools,
tool_choice="auto" tool_choice="auto"
) )