diff --git a/docs/components/llms/config.mdx b/docs/components/llms/config.mdx index 5c471cbb..626304dc 100644 --- a/docs/components/llms/config.mdx +++ b/docs/components/llms/config.mdx @@ -9,6 +9,16 @@ The config is defined as a Python dictionary with two main keys: - `provider`: The name of the llm (e.g., "openai", "groq") - `config`: A nested dictionary containing provider-specific settings +### Config Values Precedence + +Config values are applied in the following order of precedence (from highest to lowest): + +1. Values explicitly set in the `config` dictionary +2. Environment variables (e.g., `OPENAI_API_KEY`, `OPENAI_API_BASE`) +3. Default values defined in the LLM implementation + +This means that values specified in the `config` dictionary will override corresponding environment variables, which in turn override default values. + ## How to Use Config Here's a general example of how to use the config with mem0: diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index 94c8e278..f9d63485 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -22,9 +22,9 @@ class BaseLlmConfig(ABC): # Openrouter specific models: Optional[list[str]] = None, route: Optional[str] = "fallback", - openrouter_base_url: Optional[str] = "https://openrouter.ai/api/v1", + openrouter_base_url: Optional[str] = None, # Openai specific - openai_base_url: Optional[str] = "https://api.openai.com/v1", + openai_base_url: Optional[str] = None, site_url: Optional[str] = None, app_name: Optional[str] = None, # Ollama specific diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index 89bef986..c585162f 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -18,11 +18,11 @@ class OpenAILLM(LLMBase): if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter self.client = OpenAI( api_key=os.environ.get("OPENROUTER_API_KEY"), - base_url=self.config.openrouter_base_url, + base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url + base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py index 4060afb8..1ccba28a 100644 --- a/mem0/llms/openai_structured.py +++ b/mem0/llms/openai_structured.py @@ -16,7 +16,7 @@ class OpenAIStructuredLLM(LLMBase): self.config.model = "gpt-4o-2024-08-06" api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") + base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index be2f6f95..9e62c6f2 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -1,5 +1,5 @@ from unittest.mock import Mock, patch - +import os import pytest from mem0.configs.llms.base import BaseLlmConfig @@ -14,6 +14,30 @@ def mock_openai_client(): yield mock_client +def test_openai_llm_base_url(): + # case1: default config: with openai official base url + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") + llm = OpenAILLM(config) + # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash + assert str(llm.client.base_url) == "https://api.openai.com/v1/" + + # case2: with env variable OPENAI_API_BASE + provider_base_url = "https://api.provider.com/v1" + os.environ["OPENAI_API_BASE"] = provider_base_url + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") + llm = OpenAILLM(config) + # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash + assert str(llm.client.base_url) == provider_base_url + "/" + + # case3: with config.openai_base_url + config_base_url = "https://api.config.com/v1" + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", + openai_base_url=config_base_url) + llm = OpenAILLM(config) + # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash + assert str(llm.client.base_url) == config_base_url + "/" + + def test_generate_response_without_tools(mock_openai_client): config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) llm = OpenAILLM(config)