fix(llm): consume llm base url config with a better way (#1861)

This commit is contained in:
Mathew Shen
2024-09-24 12:35:09 +08:00
committed by GitHub
parent 56ceecb4e3
commit 8511eca03b
5 changed files with 40 additions and 6 deletions

View File

@@ -9,6 +9,16 @@ The config is defined as a Python dictionary with two main keys:
- `provider`: The name of the llm (e.g., "openai", "groq")
- `config`: A nested dictionary containing provider-specific settings
### Config Values Precedence
Config values are applied in the following order of precedence (from highest to lowest):
1. Values explicitly set in the `config` dictionary
2. Environment variables (e.g., `OPENAI_API_KEY`, `OPENAI_API_BASE`)
3. Default values defined in the LLM implementation
This means that values specified in the `config` dictionary will override corresponding environment variables, which in turn override default values.
## How to Use Config
Here's a general example of how to use the config with mem0:

View File

@@ -22,9 +22,9 @@ class BaseLlmConfig(ABC):
# Openrouter specific
models: Optional[list[str]] = None,
route: Optional[str] = "fallback",
openrouter_base_url: Optional[str] = "https://openrouter.ai/api/v1",
openrouter_base_url: Optional[str] = None,
# Openai specific
openai_base_url: Optional[str] = "https://api.openai.com/v1",
openai_base_url: Optional[str] = None,
site_url: Optional[str] = None,
app_name: Optional[str] = None,
# Ollama specific

View File

@@ -18,11 +18,11 @@ class OpenAILLM(LLMBase):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url,
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
)
else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):

View File

@@ -16,7 +16,7 @@ class OpenAIStructuredLLM(LLMBase):
self.config.model = "gpt-4o-2024-08-06"
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):

View File

@@ -1,5 +1,5 @@
from unittest.mock import Mock, patch
import os
import pytest
from mem0.configs.llms.base import BaseLlmConfig
@@ -14,6 +14,30 @@ def mock_openai_client():
yield mock_client
def test_openai_llm_base_url():
# case1: default config: with openai official base url
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == "https://api.openai.com/v1/"
# case2: with env variable OPENAI_API_BASE
provider_base_url = "https://api.provider.com/v1"
os.environ["OPENAI_API_BASE"] = provider_base_url
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == provider_base_url + "/"
# case3: with config.openai_base_url
config_base_url = "https://api.config.com/v1"
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key",
openai_base_url=config_base_url)
llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == config_base_url + "/"
def test_generate_response_without_tools(mock_openai_client):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config)