From 11f4ce8fb6fef335e2642ab95021e7e7cc820517 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 4 Mar 2024 20:17:20 -0600 Subject: [PATCH] #1155: Add support for OpenAI-compatible endpoint in LLM and Embed (#1197) --- embedchain/config/embedder/base.py | 2 ++ embedchain/config/llm/base.py | 2 ++ embedchain/embedder/openai.py | 2 ++ embedchain/llm/openai.py | 11 +++++++++-- embedchain/utils/misc.py | 2 ++ 5 files changed, 17 insertions(+), 2 deletions(-) diff --git a/embedchain/config/embedder/base.py b/embedchain/config/embedder/base.py index dc14cea2..f4229183 100644 --- a/embedchain/config/embedder/base.py +++ b/embedchain/config/embedder/base.py @@ -11,6 +11,7 @@ class BaseEmbedderConfig: deployment_name: Optional[str] = None, vector_dimension: Optional[int] = None, api_key: Optional[str] = None, + api_base: Optional[str] = None, ): """ Initialize a new instance of an embedder config class. @@ -24,3 +25,4 @@ class BaseEmbedderConfig: self.deployment_name = deployment_name self.vector_dimension = vector_dimension self.api_key = api_key + self.api_base = api_base diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index c6d3d640..a5a7364d 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -93,6 +93,7 @@ class BaseLlmConfig(BaseConfig): query_type: Optional[str] = None, callbacks: Optional[list] = None, api_key: Optional[str] = None, + base_url: Optional[str] = None, endpoint: Optional[str] = None, model_kwargs: Optional[dict[str, Any]] = None, local: Optional[bool] = False, @@ -167,6 +168,7 @@ class BaseLlmConfig(BaseConfig): self.query_type = query_type self.callbacks = callbacks self.api_key = api_key + self.base_url = base_url self.endpoint = endpoint self.model_kwargs = model_kwargs self.local = local diff --git a/embedchain/embedder/openai.py b/embedchain/embedder/openai.py index 29ca042e..ad2f068c 100644 --- a/embedchain/embedder/openai.py +++ b/embedchain/embedder/openai.py @@ -17,6 +17,7 @@ class OpenAIEmbedder(BaseEmbedder): self.config.model = "text-embedding-ada-002" api_key = self.config.api_key or os.environ["OPENAI_API_KEY"] + api_base = self.config.api_base or os.environ["OPENAI_API_BASE"] if self.config.deployment_name: embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name) @@ -28,6 +29,7 @@ class OpenAIEmbedder(BaseEmbedder): ) # noqa:E501 embedding_fn = OpenAIEmbeddingFunction( api_key=api_key, + api_base=api_base, organization_id=os.getenv("OPENAI_ORGANIZATION"), model_name=self.config.model, ) diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 709d2262..240c8070 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -39,13 +39,20 @@ class OpenAILlm(BaseLlm): "model_kwargs": {}, } api_key = config.api_key or os.environ["OPENAI_API_KEY"] + base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None) if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] - chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) + chat = ChatOpenAI( + **kwargs, + streaming=config.stream, + callbacks=callbacks, + api_key=api_key, + base_url=base_url, + ) else: - chat = ChatOpenAI(**kwargs, api_key=api_key) + chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url) if self.tools: return self._query_function_call(chat, self.tools, messages) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 53802d81..7e1e6a6d 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -424,6 +424,7 @@ def validate_config(config_data): Optional("where"): dict, Optional("query_type"): str, Optional("api_key"): str, + Optional("base_url"): str, Optional("endpoint"): str, Optional("model_kwargs"): dict, Optional("local"): bool, @@ -451,6 +452,7 @@ def validate_config(config_data): Optional("model"): Optional(str), Optional("deployment_name"): Optional(str), Optional("api_key"): str, + Optional("api_base"): str, Optional("title"): str, Optional("task_type"): str, Optional("vector_dimension"): int,