fix: Pass deployment name as param for azure api (#406)

This commit is contained in:
Taranjeet Singh
2023-08-08 23:55:26 -07:00
committed by GitHub
parent 030e3521a9
commit 1f0f0c93b7
4 changed files with 20 additions and 5 deletions

View File

@@ -33,6 +33,7 @@ class ChatConfig(QueryConfig):
max_tokens=None,
top_p=None,
stream: bool = False,
deployment_name=None,
):
"""
Initializes the ChatConfig instance.
@@ -68,6 +69,7 @@ class ChatConfig(QueryConfig):
top_p=top_p,
history=[0],
stream=stream,
deployment_name=deployment_name,
)
def set_history(self, history):

View File

@@ -62,6 +62,7 @@ class QueryConfig(BaseConfig):
top_p=None,
history=None,
stream: bool = False,
deployment_name=None,
):
"""
Initializes the QueryConfig instance.
@@ -106,6 +107,7 @@ class QueryConfig(BaseConfig):
self.max_tokens = max_tokens if max_tokens else 1000
self.model = model
self.top_p = top_p if top_p else 1
self.deployment_name = deployment_name
if self.validate_template(template):
self.template = template

View File

@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
provider: Providers = None,
model=None,
open_source_app_config=None,
deployment_name=None,
):
"""
:param log_level: Optional. (String) Debug level
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
super().__init__(
log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
embedding_fn=CustomAppConfig.embedding_function(
embedding_function=embedding_fn, model=embedding_fn_model,
deployment_name=deployment_name
),
db=db,
host=host,
port=port,
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
return embed_function
@staticmethod
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
if not isinstance(embedding_function, EmbeddingFunctions):
raise ValueError(
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
@@ -80,7 +84,10 @@ class CustomAppConfig(BaseAppConfig):
if model:
embeddings = OpenAIEmbeddings(model=model)
else:
embeddings = OpenAIEmbeddings()
if deployment_name:
embeddings = OpenAIEmbeddings(deployment=deployment_name)
else:
embeddings = OpenAIEmbeddings()
return CustomAppConfig.langchain_default_concept(embeddings)
elif embedding_function == EmbeddingFunctions.HUGGING_FACE: