This commit is contained in:
@@ -11,6 +11,7 @@ class BaseEmbedderConfig:
|
|||||||
deployment_name: Optional[str] = None,
|
deployment_name: Optional[str] = None,
|
||||||
vector_dimension: Optional[int] = None,
|
vector_dimension: Optional[int] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new instance of an embedder config class.
|
Initialize a new instance of an embedder config class.
|
||||||
@@ -24,3 +25,4 @@ class BaseEmbedderConfig:
|
|||||||
self.deployment_name = deployment_name
|
self.deployment_name = deployment_name
|
||||||
self.vector_dimension = vector_dimension
|
self.vector_dimension = vector_dimension
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.api_base = api_base
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ class BaseLlmConfig(BaseConfig):
|
|||||||
query_type: Optional[str] = None,
|
query_type: Optional[str] = None,
|
||||||
callbacks: Optional[list] = None,
|
callbacks: Optional[list] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
model_kwargs: Optional[dict[str, Any]] = None,
|
model_kwargs: Optional[dict[str, Any]] = None,
|
||||||
local: Optional[bool] = False,
|
local: Optional[bool] = False,
|
||||||
@@ -167,6 +168,7 @@ class BaseLlmConfig(BaseConfig):
|
|||||||
self.query_type = query_type
|
self.query_type = query_type
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
self.local = local
|
self.local = local
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class OpenAIEmbedder(BaseEmbedder):
|
|||||||
self.config.model = "text-embedding-ada-002"
|
self.config.model = "text-embedding-ada-002"
|
||||||
|
|
||||||
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
|
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
|
||||||
|
api_base = self.config.api_base or os.environ["OPENAI_API_BASE"]
|
||||||
|
|
||||||
if self.config.deployment_name:
|
if self.config.deployment_name:
|
||||||
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||||
@@ -28,6 +29,7 @@ class OpenAIEmbedder(BaseEmbedder):
|
|||||||
) # noqa:E501
|
) # noqa:E501
|
||||||
embedding_fn = OpenAIEmbeddingFunction(
|
embedding_fn = OpenAIEmbeddingFunction(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||||
model_name=self.config.model,
|
model_name=self.config.model,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,13 +39,20 @@ class OpenAILlm(BaseLlm):
|
|||||||
"model_kwargs": {},
|
"model_kwargs": {},
|
||||||
}
|
}
|
||||||
api_key = config.api_key or os.environ["OPENAI_API_KEY"]
|
api_key = config.api_key or os.environ["OPENAI_API_KEY"]
|
||||||
|
base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
|
||||||
if config.top_p:
|
if config.top_p:
|
||||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||||
if config.stream:
|
if config.stream:
|
||||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
chat = ChatOpenAI(
|
||||||
|
**kwargs,
|
||||||
|
streaming=config.stream,
|
||||||
|
callbacks=callbacks,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chat = ChatOpenAI(**kwargs, api_key=api_key)
|
chat = ChatOpenAI(**kwargs, api_key=api_key, base_url=base_url)
|
||||||
if self.tools:
|
if self.tools:
|
||||||
return self._query_function_call(chat, self.tools, messages)
|
return self._query_function_call(chat, self.tools, messages)
|
||||||
|
|
||||||
|
|||||||
@@ -424,6 +424,7 @@ def validate_config(config_data):
|
|||||||
Optional("where"): dict,
|
Optional("where"): dict,
|
||||||
Optional("query_type"): str,
|
Optional("query_type"): str,
|
||||||
Optional("api_key"): str,
|
Optional("api_key"): str,
|
||||||
|
Optional("base_url"): str,
|
||||||
Optional("endpoint"): str,
|
Optional("endpoint"): str,
|
||||||
Optional("model_kwargs"): dict,
|
Optional("model_kwargs"): dict,
|
||||||
Optional("local"): bool,
|
Optional("local"): bool,
|
||||||
@@ -451,6 +452,7 @@ def validate_config(config_data):
|
|||||||
Optional("model"): Optional(str),
|
Optional("model"): Optional(str),
|
||||||
Optional("deployment_name"): Optional(str),
|
Optional("deployment_name"): Optional(str),
|
||||||
Optional("api_key"): str,
|
Optional("api_key"): str,
|
||||||
|
Optional("api_base"): str,
|
||||||
Optional("title"): str,
|
Optional("title"): str,
|
||||||
Optional("task_type"): str,
|
Optional("task_type"): str,
|
||||||
Optional("vector_dimension"): int,
|
Optional("vector_dimension"): int,
|
||||||
|
|||||||
Reference in New Issue
Block a user