#1155: Add support for OpenAI-compatible endpoint in LLM and Embed (#1197)

This commit is contained in:
Joe
2024-03-04 20:17:20 -06:00
committed by GitHub
parent 6078738d34
commit 11f4ce8fb6
5 changed files with 17 additions and 2 deletions

View File

@@ -11,6 +11,7 @@ class BaseEmbedderConfig:
deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
):
"""
Initialize a new instance of an embedder config class.
@@ -24,3 +25,4 @@ class BaseEmbedderConfig:
self.deployment_name = deployment_name
self.vector_dimension = vector_dimension
self.api_key = api_key
self.api_base = api_base

View File

@@ -93,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
query_type: Optional[str] = None,
callbacks: Optional[list] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
endpoint: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
local: Optional[bool] = False,
@@ -167,6 +168,7 @@ class BaseLlmConfig(BaseConfig):
self.query_type = query_type
self.callbacks = callbacks
self.api_key = api_key
self.base_url = base_url
self.endpoint = endpoint
self.model_kwargs = model_kwargs
self.local = local

View File

@@ -17,6 +17,7 @@ class OpenAIEmbedder(BaseEmbedder):
self.config.model = "text-embedding-ada-002"
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
api_base = self.config.api_base or os.environ["OPENAI_API_BASE"]
if self.config.deployment_name:
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
@@ -28,6 +29,7 @@ class OpenAIEmbedder(BaseEmbedder):
) # noqa:E501
embedding_fn = OpenAIEmbeddingFunction(
api_key=api_key,
api_base=api_base,
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name=self.config.model,
)

View File

@@ -39,13 +39,20 @@ class OpenAILlm(BaseLlm):
"model_kwargs": {},
}
api_key = config.api_key or os.environ["OPENAI_API_KEY"]
base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
chat = ChatOpenAI(
**kwargs,
streaming=config.stream,
callbacks=callbacks,
api_key=api_key,
base_url=base_url,
)
else:
chat = ChatOpenAI(**kwargs, api_key=api_key)
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
if self.tools:
return self._query_function_call(chat, self.tools, messages)

View File

@@ -424,6 +424,7 @@ def validate_config(config_data):
Optional("where"): dict,
Optional("query_type"): str,
Optional("api_key"): str,
Optional("base_url"): str,
Optional("endpoint"): str,
Optional("model_kwargs"): dict,
Optional("local"): bool,
@@ -451,6 +452,7 @@ def validate_config(config_data):
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
Optional("api_key"): str,
Optional("api_base"): str,
Optional("title"): str,
Optional("task_type"): str,
Optional("vector_dimension"): int,