fix: Pass deployment name as param for azure api (#406)
This commit is contained in:
@@ -118,9 +118,13 @@ class CustomApp(EmbedChain):
|
|||||||
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
|
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||||
from langchain.chat_models import AzureChatOpenAI
|
from langchain.chat_models import AzureChatOpenAI
|
||||||
|
|
||||||
|
if not config.deployment_name:
|
||||||
|
raise ValueError("Deployment name must be provided for Azure OpenAI")
|
||||||
|
|
||||||
chat = AzureChatOpenAI(
|
chat = AzureChatOpenAI(
|
||||||
deployment_name="td2",
|
deployment_name=config.deployment_name,
|
||||||
model_name=config.model or "text-davinci-002",
|
openai_api_version="2023-05-15",
|
||||||
|
model_name=config.model or "gpt-3.5-turbo",
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
streaming=config.stream,
|
streaming=config.stream,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class ChatConfig(QueryConfig):
|
|||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
deployment_name=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the ChatConfig instance.
|
Initializes the ChatConfig instance.
|
||||||
@@ -68,6 +69,7 @@ class ChatConfig(QueryConfig):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
history=[0],
|
history=[0],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
deployment_name=deployment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_history(self, history):
|
def set_history(self, history):
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class QueryConfig(BaseConfig):
|
|||||||
top_p=None,
|
top_p=None,
|
||||||
history=None,
|
history=None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
deployment_name=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the QueryConfig instance.
|
Initializes the QueryConfig instance.
|
||||||
@@ -106,6 +107,7 @@ class QueryConfig(BaseConfig):
|
|||||||
self.max_tokens = max_tokens if max_tokens else 1000
|
self.max_tokens = max_tokens if max_tokens else 1000
|
||||||
self.model = model
|
self.model = model
|
||||||
self.top_p = top_p if top_p else 1
|
self.top_p = top_p if top_p else 1
|
||||||
|
self.deployment_name = deployment_name
|
||||||
|
|
||||||
if self.validate_template(template):
|
if self.validate_template(template):
|
||||||
self.template = template
|
self.template = template
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
provider: Providers = None,
|
provider: Providers = None,
|
||||||
model=None,
|
model=None,
|
||||||
open_source_app_config=None,
|
open_source_app_config=None,
|
||||||
|
deployment_name=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
@@ -49,7 +50,10 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
|
embedding_fn=CustomAppConfig.embedding_function(
|
||||||
|
embedding_function=embedding_fn, model=embedding_fn_model,
|
||||||
|
deployment_name=deployment_name
|
||||||
|
),
|
||||||
db=db,
|
db=db,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
@@ -68,7 +72,7 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
return embed_function
|
return embed_function
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None):
|
def embedding_function(embedding_function: EmbeddingFunctions, model: str = None, deployment_name: str = None):
|
||||||
if not isinstance(embedding_function, EmbeddingFunctions):
|
if not isinstance(embedding_function, EmbeddingFunctions):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
|
f"Invalid option: '{embedding_function}'. Expecting one of the following options: {list(map(lambda x: x.value, EmbeddingFunctions))}" # noqa: E501
|
||||||
@@ -79,6 +83,9 @@ class CustomAppConfig(BaseAppConfig):
|
|||||||
|
|
||||||
if model:
|
if model:
|
||||||
embeddings = OpenAIEmbeddings(model=model)
|
embeddings = OpenAIEmbeddings(model=model)
|
||||||
|
else:
|
||||||
|
if deployment_name:
|
||||||
|
embeddings = OpenAIEmbeddings(deployment=deployment_name)
|
||||||
else:
|
else:
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
return CustomAppConfig.langchain_default_concept(embeddings)
|
return CustomAppConfig.langchain_default_concept(embeddings)
|
||||||
|
|||||||
Reference in New Issue
Block a user