fix: Pass deployment name as param for azure api (#406)
This commit is contained in:
@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
provider: Providers = None,
|
||||
model=None,
|
||||
open_source_app_config=None,
|
||||
deployment_name=None,
|
||||
):
|
||||
"""
|
||||
:param log_level: Optional. (String) Debug level
|
||||
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
|
||||
|
||||
super().__init__(
|
||||
log_level=log_level,
|
||||
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
|
||||
embedding_fn=CustomAppConfig.embedding_function(
|
||||
embedding_function=embedding_fn, model=embedding_fn_model,
|
||||
deployment_name=deployment_name
|
||||
),
|
||||
db=db,
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
|
||||
return embed_function
|
||||
|
||||
@staticmethod
|
||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
|
||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
|
||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||
raise ValueError(
|
||||
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
|
||||
@@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig):
|
||||
if model:
|
||||
embeddings = OpenAIEmbeddings(model=model)
|
||||
else:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
if deployment_name:
|
||||
embeddings = OpenAIEmbeddings(deployment=deployment_name)
|
||||
else:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||
|
||||
elif embedding_function == EmbeddingFunctions.HUGGING_FACE:
|
||||
|
||||
Reference in New Issue
Block a user