This commit is contained in:
@@ -11,6 +11,7 @@ class BaseEmbedderConfig:
|
||||
deployment_name: Optional[str] = None,
|
||||
vector_dimension: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of an embedder config class.
|
||||
@@ -24,3 +25,4 @@ class BaseEmbedderConfig:
|
||||
self.deployment_name = deployment_name
|
||||
self.vector_dimension = vector_dimension
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
@@ -93,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
query_type: Optional[str] = None,
|
||||
callbacks: Optional[list] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
local: Optional[bool] = False,
|
||||
@@ -167,6 +168,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.query_type = query_type
|
||||
self.callbacks = callbacks
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.endpoint = endpoint
|
||||
self.model_kwargs = model_kwargs
|
||||
self.local = local
|
||||
|
||||
@@ -17,6 +17,7 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
|
||||
api_base = self.config.api_base or os.environ["OPENAI_API_BASE"]
|
||||
|
||||
if self.config.deployment_name:
|
||||
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
@@ -28,6 +29,7 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
) # noqa:E501
|
||||
embedding_fn = OpenAIEmbeddingFunction(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name=self.config.model,
|
||||
)
|
||||
|
||||
@@ -39,13 +39,20 @@ class OpenAILlm(BaseLlm):
|
||||
"model_kwargs": {},
|
||||
}
|
||||
api_key = config.api_key or os.environ["OPENAI_API_KEY"]
|
||||
base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||
chat = ChatOpenAI(
|
||||
**kwargs,
|
||||
streaming=config.stream,
|
||||
callbacks=callbacks,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key)
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
|
||||
if self.tools:
|
||||
return self._query_function_call(chat, self.tools, messages)
|
||||
|
||||
|
||||
@@ -424,6 +424,7 @@ def validate_config(config_data):
|
||||
Optional("where"): dict,
|
||||
Optional("query_type"): str,
|
||||
Optional("api_key"): str,
|
||||
Optional("base_url"): str,
|
||||
Optional("endpoint"): str,
|
||||
Optional("model_kwargs"): dict,
|
||||
Optional("local"): bool,
|
||||
@@ -451,6 +452,7 @@ def validate_config(config_data):
|
||||
Optional("model"): Optional(str),
|
||||
Optional("deployment_name"): Optional(str),
|
||||
Optional("api_key"): str,
|
||||
Optional("api_base"): str,
|
||||
Optional("title"): str,
|
||||
Optional("task_type"): str,
|
||||
Optional("vector_dimension"): int,
|
||||
|
||||
Reference in New Issue
Block a user