fix: Pass deployment name as param for azure api (#406)

This commit is contained in:
Taranjeet Singh
2023-08-08 23:55:26 -07:00
committed by GitHub
parent 030e3521a9
commit 1f0f0c93b7
4 changed files with 20 additions and 5 deletions

View File

@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
provider: Providers = None,
model=None,
open_source_app_config=None,
deployment_name=None,
):
"""
:param log_level: Optional. (String) Debug level
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
super().__init__(
log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
embedding_fn=CustomAppConfig.embedding_function(
embedding_function=embedding_fn, model=embedding_fn_model,
deployment_name=deployment_name
),
db=db,
host=host,
port=port,
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
return embed_function
@staticmethod
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
@@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig):
if model:
embeddings = OpenAIEmbeddings(model=model)
else:
embeddings = OpenAIEmbeddings()
if deployment_name:
embeddings = OpenAIEmbeddings(deployment=deployment_name)
else:
embeddings = OpenAIEmbeddings()
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.HUGGING_FACE: