From b80925e8579560a068aaba39356d0923db7e818e Mon Sep 17 00:00:00 2001 From: ParseDark Date: Sun, 25 Aug 2024 18:55:14 +0800 Subject: [PATCH] [openai_api_base support] - ft/Added openai OPENAI_API_BASE llm config support (#1737) --- mem0/configs/embeddings/base.py | 6 +++++- mem0/configs/llms/base.py | 5 +++++ mem0/embeddings/openai.py | 3 ++- mem0/llms/openai.py | 3 ++- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py index 05ad9952..20083e1e 100644 --- a/mem0/configs/embeddings/base.py +++ b/mem0/configs/embeddings/base.py @@ -14,6 +14,8 @@ class BaseEmbedderConfig(ABC): embedding_dims: Optional[int] = None, # Ollama specific ollama_base_url: Optional[str] = None, + # Openai specific + openai_base_url: Optional[str] = None, # Huggingface specific model_kwargs: Optional[dict] = None, ): @@ -30,11 +32,13 @@ class BaseEmbedderConfig(ABC): :type ollama_base_url: Optional[str], optional :param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init :type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init - + :param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1" + :type openai_base_url: Optional[str], optional """ self.model = model self.api_key = api_key + self.openai_base_url = openai_base_url self.embedding_dims = embedding_dims # Ollama specific diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index c7b83411..3bc67b2e 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -19,6 +19,8 @@ class BaseLlmConfig(ABC): models: Optional[list[str]] = None, route: Optional[str] = "fallback", openrouter_base_url: Optional[str] = "https://openrouter.ai/api/v1", + # Openai specific + openai_base_url: Optional[str] = "https://api.openai.com/v1", site_url: Optional[str] = None, app_name: Optional[str] = None, # Ollama specific @@ -53,6 +55,8 @@ class BaseLlmConfig(ABC): :type app_name: Optional[str], optional :param ollama_base_url: The base URL of the LLM, defaults to None :type ollama_base_url: Optional[str], optional + :param openai_base_url: Openai base URL to be use, defaults to "https://api.openai.com/v1" + :type openai_base_url: Optional[str], optional """ self.model = model @@ -66,6 +70,7 @@ class BaseLlmConfig(ABC): self.models = models self.route = route self.openrouter_base_url = openrouter_base_url + self.openai_base_url = openai_base_url self.site_url = site_url self.app_name = app_name diff --git a/mem0/embeddings/openai.py b/mem0/embeddings/openai.py index a229fc8c..30a7c1df 100644 --- a/mem0/embeddings/openai.py +++ b/mem0/embeddings/openai.py @@ -15,7 +15,8 @@ class OpenAIEmbedding(EmbeddingBase): self.config.embedding_dims = self.config.embedding_dims or 1536 api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key - self.client = OpenAI(api_key=api_key) + base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url + self.client = OpenAI(api_key=api_key, base_url=base_url) def embed(self, text): """ diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index e2e33f18..1ed2dbaf 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -22,7 +22,8 @@ class OpenAILLM(LLMBase): ) else: api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key - self.client = OpenAI(api_key=api_key) + base_url = os.getenv("OPENAI_API_BASE") or self.config.openai_base_url + self.client = OpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): """