From 40c9abe48435730b8e855f68772b668a89bed426 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Thu, 18 Jul 2024 21:51:40 +0530 Subject: [PATCH] Support model config in LLMs (#1495) --- mem0/configs/llms/__init__.py | 0 mem0/configs/llms/base.py | 34 ++++++++++++++++++++++++++++++++++ mem0/llms/aws_bedrock.py | 17 +++++++++++------ mem0/llms/base.py | 14 ++++++++++++++ mem0/llms/configs.py | 2 +- mem0/llms/groq.py | 16 +++++++++++++--- mem0/llms/litellm.py | 20 +++++++++++++++----- mem0/llms/openai.py | 17 +++++++++++++---- mem0/llms/together.py | 17 +++++++++++++---- mem0/memory/main.py | 2 +- mem0/utils/factory.py | 9 ++++++--- tests/llms/test_groq.py | 15 ++++++++++++--- tests/llms/test_litellm.py | 18 ++++++++++++++---- tests/llms/test_openai.py | 17 +++++++++++++---- tests/llms/test_together.py | 15 ++++++++++++--- 15 files changed, 172 insertions(+), 41 deletions(-) create mode 100644 mem0/configs/llms/__init__.py create mode 100644 mem0/configs/llms/base.py diff --git a/mem0/configs/llms/__init__.py b/mem0/configs/llms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py new file mode 100644 index 00000000..81aea342 --- /dev/null +++ b/mem0/configs/llms/base.py @@ -0,0 +1,34 @@ +from abc import ABC +from typing import Optional + +class BaseLlmConfig(ABC): + """ + Config for LLMs. + """ + + def __init__( + self, + model: Optional[str] = None, + temperature: float = 0, + max_tokens: int = 3000, + top_p: float = 1 + ): + """ + Initializes a configuration class instance for the LLM. + + :param model: Controls the OpenAI model used, defaults to None + :type model: Optional[str], optional + :param temperature: Controls the randomness of the model's output. + Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0 + :type temperature: float, optional + :param max_tokens: Controls how many tokens are generated, defaults to 3000 + :type max_tokens: int, optional + :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, + defaults to 1 + :type top_p: float, optional + """ + + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + self.top_p = top_p \ No newline at end of file diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index 9e1912e2..4afe31c1 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -5,12 +5,16 @@ from typing import Dict, List, Optional, Any import boto3 from mem0.llms.base import LLMBase +from mem0.configs.llms.base import BaseLlmConfig +class AWSBedrockLLM(LLMBase): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) -class AWSBedrockLLM(LLMBase): - def __init__(self, model="cohere.command-r-v1:0"): + if not self.config.model: + self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0" self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY")) - self.model = model + self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p} def _format_messages(self, messages: List[Dict[str, str]]) -> str: """ @@ -171,19 +175,20 @@ class AWSBedrockLLM(LLMBase): if tools: # Use converse method when tools are provided messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}] + inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]} tools_config = {"tools": self._convert_tool_format(tools)} response = self.client.converse( - modelId=self.model, + modelId=self.config.model, messages=messages, + inferenceConfig=inference_config, toolConfig=tools_config ) - print("Tools response: ", response) else: # Use invoke_model method when no tools are provided prompt = self._format_messages(messages) provider = self.model.split(".")[0] - input_body = self._prepare_input(provider, self.model, prompt) + input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs) body = json.dumps(input_body) response = self.client.invoke_model( diff --git a/mem0/llms/base.py b/mem0/llms/base.py index 421bde20..69e880b4 100644 --- a/mem0/llms/base.py +++ b/mem0/llms/base.py @@ -1,7 +1,21 @@ +from typing import Optional from abc import ABC, abstractmethod +from mem0.configs.llms.base import BaseLlmConfig + class LLMBase(ABC): + def __init__(self, config: Optional[BaseLlmConfig] = None): + """Initialize a base LLM class + + :param config: LLM configuration option class, defaults to None + :type config: Optional[BaseLlmConfig], optional + """ + if config is None: + self.config = BaseLlmConfig() + else: + self.config = config + @abstractmethod def generate_response(self, messages): """ diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index b28a8a5e..838e3877 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -8,7 +8,7 @@ class LlmConfig(BaseModel): description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai" ) config: Optional[dict] = Field( - description="Configuration for the specific LLM", default=None + description="Configuration for the specific LLM", default={} ) @field_validator("config") diff --git a/mem0/llms/groq.py b/mem0/llms/groq.py index 948625fd..1046c3d2 100644 --- a/mem0/llms/groq.py +++ b/mem0/llms/groq.py @@ -4,12 +4,16 @@ from typing import Dict, List, Optional from groq import Groq from mem0.llms.base import LLMBase +from mem0.configs.llms.base import BaseLlmConfig class GroqLLM(LLMBase): - def __init__(self, model="llama3-70b-8192"): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model="llama3-70b-8192" self.client = Groq() - self.model = model def _parse_response(self, response, tools): """ @@ -58,7 +62,13 @@ class GroqLLM(LLMBase): Returns: str: The generated response. """ - params = {"model": self.model, "messages": messages} + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p + } if response_format: params["response_format"] = response_format if tools: diff --git a/mem0/llms/litellm.py b/mem0/llms/litellm.py index 889b9e79..4cd2ddd3 100644 --- a/mem0/llms/litellm.py +++ b/mem0/llms/litellm.py @@ -4,11 +4,15 @@ from typing import Dict, List, Optional import litellm from mem0.llms.base import LLMBase +from mem0.configs.llms.base import BaseLlmConfig class LiteLLM(LLMBase): - def __init__(self, model="gpt-4o"): - self.model = model + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model="gpt-4o" def _parse_response(self, response, tools): """ @@ -57,10 +61,16 @@ class LiteLLM(LLMBase): Returns: str: The generated response. """ - if not litellm.supports_function_calling(self.model): - raise ValueError(f"Model '{self.model}' in litellm does not support function calling.") + if not litellm.supports_function_calling(self.config.model): + raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.") - params = {"model": self.model, "messages": messages} + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p + } if response_format: params["response_format"] = response_format if tools: diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index f87f78d7..bedf0d8d 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -4,12 +4,15 @@ from typing import Dict, List, Optional from openai import OpenAI from mem0.llms.base import LLMBase - +from mem0.configs.llms.base import BaseLlmConfig class OpenAILLM(LLMBase): - def __init__(self, model="gpt-4o"): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model="gpt-4o" self.client = OpenAI() - self.model = model def _parse_response(self, response, tools): """ @@ -58,7 +61,13 @@ class OpenAILLM(LLMBase): Returns: str: The generated response. """ - params = {"model": self.model, "messages": messages} + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p + } if response_format: params["response_format"] = response_format if tools: diff --git a/mem0/llms/together.py b/mem0/llms/together.py index e497a80f..6750a7a0 100644 --- a/mem0/llms/together.py +++ b/mem0/llms/together.py @@ -4,12 +4,15 @@ from typing import Dict, List, Optional from together import Together from mem0.llms.base import LLMBase - +from mem0.configs.llms.base import BaseLlmConfig class TogetherLLM(LLMBase): - def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1" self.client = Together() - self.model = model def _parse_response(self, response, tools): """ @@ -58,7 +61,13 @@ class TogetherLLM(LLMBase): Returns: str: The generated response. """ - params = {"model": self.model, "messages": messages} + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p + } if response_format: params["response_format"] = response_format if tools: diff --git a/mem0/memory/main.py b/mem0/memory/main.py index c9327a4b..4477fbea 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -82,7 +82,7 @@ class Memory(MemoryBase): f"Unsupported vector store type: {self.config.vector_store_type}" ) - self.llm = LlmFactory.create(self.config.llm.provider) + self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config) self.db = SQLiteManager(self.config.history_db_path) self.collection_name = self.config.collection_name self.vector_store.create_col( diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 9f110d2d..50856b04 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -1,5 +1,7 @@ import importlib +from mem0.configs.llms.base import BaseLlmConfig + def load_class(class_type): module_path, class_name = class_type.rsplit(".", 1) @@ -18,11 +20,12 @@ class LlmFactory: } @classmethod - def create(cls, provider_name): + def create(cls, provider_name, config): class_type = cls.provider_to_class.get(provider_name) if class_type: - llm_instance = load_class(class_type)() - return llm_instance + llm_instance = load_class(class_type) + base_config = BaseLlmConfig(**config) + return llm_instance(base_config) else: raise ValueError(f"Unsupported Llm provider: {provider_name}") diff --git a/tests/llms/test_groq.py b/tests/llms/test_groq.py index c5820708..52c3b222 100644 --- a/tests/llms/test_groq.py +++ b/tests/llms/test_groq.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import Mock, patch from mem0.llms.groq import GroqLLM +from mem0.configs.llms.base import BaseLlmConfig @pytest.fixture def mock_groq_client(): @@ -11,7 +12,8 @@ def mock_groq_client(): def test_generate_response_without_tools(mock_groq_client): - llm = GroqLLM() + config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0) + llm = GroqLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} @@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_groq_client): mock_groq_client.chat.completions.create.assert_called_once_with( model="llama3-70b-8192", - messages=messages + messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" def test_generate_response_with_tools(mock_groq_client): - llm = GroqLLM() + config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0) + llm = GroqLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Add a new memory: Today is a sunny day."} @@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_groq_client): mock_groq_client.chat.completions.create.assert_called_once_with( model="llama3-70b-8192", messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0, tools=tools, tool_choice="auto" ) diff --git a/tests/llms/test_litellm.py b/tests/llms/test_litellm.py index 99fa77c3..b611d5c2 100644 --- a/tests/llms/test_litellm.py +++ b/tests/llms/test_litellm.py @@ -2,6 +2,7 @@ import pytest from unittest.mock import Mock, patch from mem0.llms import litellm +from mem0.configs.llms.base import BaseLlmConfig @pytest.fixture def mock_litellm(): @@ -9,7 +10,8 @@ def mock_litellm(): yield mock_litellm def test_generate_response_with_unsupported_model(mock_litellm): - llm = litellm.LiteLLM(model="unsupported-model") + config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1) + llm = litellm.LiteLLM(config) messages = [{"role": "user", "content": "Hello"}] mock_litellm.supports_function_calling.return_value = False @@ -19,7 +21,8 @@ def test_generate_response_with_unsupported_model(mock_litellm): def test_generate_response_without_tools(mock_litellm): - llm = litellm.LiteLLM() + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1) + llm = litellm.LiteLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} @@ -34,13 +37,17 @@ def test_generate_response_without_tools(mock_litellm): mock_litellm.completion.assert_called_once_with( model="gpt-4o", - messages=messages + messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" def test_generate_response_with_tools(mock_litellm): - llm = litellm.LiteLLM() + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1) + llm = litellm.LiteLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Add a new memory: Today is a sunny day."} @@ -80,6 +87,9 @@ def test_generate_response_with_tools(mock_litellm): mock_litellm.completion.assert_called_once_with( model="gpt-4o", messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1, tools=tools, tool_choice="auto" ) diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 535ba355..81005079 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import Mock, patch from mem0.llms.openai import OpenAILLM +from mem0.configs.llms.base import BaseLlmConfig @pytest.fixture def mock_openai_client(): @@ -11,7 +12,8 @@ def mock_openai_client(): def test_generate_response_without_tools(mock_openai_client): - llm = OpenAILLM() + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) + llm = OpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} @@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_openai_client): mock_openai_client.chat.completions.create.assert_called_once_with( model="gpt-4o", - messages=messages + messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" - + def test_generate_response_with_tools(mock_openai_client): - llm = OpenAILLM() + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) + llm = OpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Add a new memory: Today is a sunny day."} @@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_openai_client): mock_openai_client.chat.completions.create.assert_called_once_with( model="gpt-4o", messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0, tools=tools, tool_choice="auto" ) diff --git a/tests/llms/test_together.py b/tests/llms/test_together.py index dad2bdca..1353416f 100644 --- a/tests/llms/test_together.py +++ b/tests/llms/test_together.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import Mock, patch from mem0.llms.together import TogetherLLM +from mem0.configs.llms.base import BaseLlmConfig @pytest.fixture def mock_together_client(): @@ -11,7 +12,8 @@ def mock_together_client(): def test_generate_response_without_tools(mock_together_client): - llm = TogetherLLM() + config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0) + llm = TogetherLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} @@ -25,13 +27,17 @@ def test_generate_response_without_tools(mock_together_client): mock_together_client.chat.completions.create.assert_called_once_with( model="mistralai/Mixtral-8x7B-Instruct-v0.1", - messages=messages + messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" def test_generate_response_with_tools(mock_together_client): - llm = TogetherLLM() + config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0) + llm = TogetherLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Add a new memory: Today is a sunny day."} @@ -70,6 +76,9 @@ def test_generate_response_with_tools(mock_together_client): mock_together_client.chat.completions.create.assert_called_once_with( model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0, tools=tools, tool_choice="auto" )