#1155: Add support for OpenAI-compatible endpoint in LLM and Embed (#1197)

This commit is contained in:
Joe
2024-03-04 20:17:20 -06:00
committed by GitHub
parent 6078738d34
commit 11f4ce8fb6
5 changed files with 17 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ class BaseEmbedderConfig:
deployment_name: Optional[str] = None, deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None, vector_dimension: Optional[int] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
): ):
""" """
Initialize a new instance of an embedder config class. Initialize a new instance of an embedder config class.
@@ -24,3 +25,4 @@ class BaseEmbedderConfig:
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
self.api_key = api_key self.api_key = api_key
self.api_base = api_base

View File

@@ -93,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
query_type: Optional[str] = None, query_type: Optional[str] = None,
callbacks: Optional[list] = None, callbacks: Optional[list] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
base_url: Optional[str] = None,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None, model_kwargs: Optional[dict[str, Any]] = None,
local: Optional[bool] = False, local: Optional[bool] = False,
@@ -167,6 +168,7 @@ class BaseLlmConfig(BaseConfig):
self.query_type = query_type self.query_type = query_type
self.callbacks = callbacks self.callbacks = callbacks
self.api_key = api_key self.api_key = api_key
self.base_url = base_url
self.endpoint = endpoint self.endpoint = endpoint
self.model_kwargs = model_kwargs self.model_kwargs = model_kwargs
self.local = local self.local = local

View File

@@ -17,6 +17,7 @@ class OpenAIEmbedder(BaseEmbedder):
self.config.model = "text-embedding-ada-002" self.config.model = "text-embedding-ada-002"
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"] api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
api_base = self.config.api_base or os.environ["OPENAI_API_BASE"]
if self.config.deployment_name: if self.config.deployment_name:
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name) embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
@@ -28,6 +29,7 @@ class OpenAIEmbedder(BaseEmbedder):
) # noqa:E501 ) # noqa:E501
embedding_fn = OpenAIEmbeddingFunction( embedding_fn = OpenAIEmbeddingFunction(
api_key=api_key, api_key=api_key,
api_base=api_base,
organization_id=os.getenv("OPENAI_ORGANIZATION"), organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name=self.config.model, model_name=self.config.model,
) )

View File

@@ -39,13 +39,20 @@ class OpenAILlm(BaseLlm):
"model_kwargs": {}, "model_kwargs": {},
} }
api_key = config.api_key or os.environ["OPENAI_API_KEY"] api_key = config.api_key or os.environ["OPENAI_API_KEY"]
base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
if config.top_p: if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream: if config.stream:
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) chat = ChatOpenAI(
**kwargs,
streaming=config.stream,
callbacks=callbacks,
api_key=api_key,
base_url=base_url,
)
else: else:
chat = ChatOpenAI(**kwargs, api_key=api_key) chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
if self.tools: if self.tools:
return self._query_function_call(chat, self.tools, messages) return self._query_function_call(chat, self.tools, messages)

View File

@@ -424,6 +424,7 @@ def validate_config(config_data):
Optional("where"): dict, Optional("where"): dict,
Optional("query_type"): str, Optional("query_type"): str,
Optional("api_key"): str, Optional("api_key"): str,
Optional("base_url"): str,
Optional("endpoint"): str, Optional("endpoint"): str,
Optional("model_kwargs"): dict, Optional("model_kwargs"): dict,
Optional("local"): bool, Optional("local"): bool,
@@ -451,6 +452,7 @@ def validate_config(config_data):
Optional("model"): Optional(str), Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str), Optional("deployment_name"): Optional(str),
Optional("api_key"): str, Optional("api_key"): str,
Optional("api_base"): str,
Optional("title"): str, Optional("title"): str,
Optional("task_type"): str, Optional("task_type"): str,
Optional("vector_dimension"): int, Optional("vector_dimension"): int,