fix: Pass deployment name as param for azure api (#406)
This commit is contained in:
@@ -33,6 +33,7 @@ class ChatConfig(QueryConfig):
|
||||
max_tokens=None,
|
||||
top_p=None,
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
):
|
||||
"""
|
||||
Initializes the ChatConfig instance.
|
||||
@@ -68,6 +69,7 @@ class ChatConfig(QueryConfig):
|
||||
top_p=top_p,
|
||||
history=[0],
|
||||
stream=stream,
|
||||
deployment_name=deployment_name,
|
||||
)
|
||||
|
||||
def set_history(self, history):
|
||||
|
||||
@@ -62,6 +62,7 @@ class QueryConfig(BaseConfig):
|
||||
top_p=None,
|
||||
history=None,
|
||||
stream: bool = False,
|
||||
deployment_name=None,
|
||||
):
|
||||
"""
|
||||
Initializes the QueryConfig instance.
|
||||
@@ -106,6 +107,7 @@ class QueryConfig(BaseConfig):
|
||||
self.max_tokens = max_tokens if max_tokens else 1000
|
||||
self.model = model
|
||||
self.top_p = top_p if top_p else 1
|
||||
self.deployment_name = deployment_name
|
||||
|
||||
if self.validate_template(template):
|
||||
self.template = template
|
||||
|
||||
@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
provider: Providers = None,
|
||||
model=None,
|
||||
open_source_app_config=None,
|
||||
deployment_name=None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
|
||||
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
|
||||
embedding_fn=CustomAppConfig.embedding_function(
|
||||
embedding_function=embedding_fn, model=embedding_fn_model,
|
||||
deployment_name=deployment_name
|
||||
),
|
||||
db=db,
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
return embed_function
|
||||
|
||||
@staticmethod
|
||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
|
||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(
|
||||
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
|
||||
@@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig):
|
||||
if model:
|
||||
embeddings = OpenAIEmbeddings(model=model)
|
||||
else:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
if deployment_name:
|
||||
embeddings = OpenAIEmbeddings(deployment=deployment_name)
|
||||
else:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
|
||||
Reference in New Issue
Block a user