Provide openai-key support from config (#1052)

This commit is contained in:
Sidharth Mohanty
2023-12-23 14:42:18 +05:30
committed by GitHub
parent e90673ae5b
commit 11f0d719f5
9 changed files with 31 additions and 43 deletions

View File

@@ -5,7 +5,9 @@ from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class BaseEmbedderConfig:
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
def __init__(
self, model: Optional[str] = None, deployment_name: Optional[str] = None, api_key: Optional[str] = None
):
"""
Initialize a new instance of an embedder config class.
@@ -16,3 +18,4 @@ class BaseEmbedderConfig:
"""
self.model = model
self.deployment_name = deployment_name
self.api_key = api_key

View File

@@ -69,6 +69,7 @@ class BaseLlmConfig(BaseConfig):
where: Dict[str, Any] = None,
query_type: Optional[str] = None,
callbacks: Optional[List] = None,
api_key: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.
@@ -117,6 +118,7 @@ class BaseLlmConfig(BaseConfig):
self.system_prompt = system_prompt
self.query_type = query_type
self.callbacks = callbacks
self.api_key = api_key
if type(template) is str:
template = Template(template)

View File

@@ -16,16 +16,18 @@ class OpenAIEmbedder(BaseEmbedder):
if self.config.model is None:
self.config.model = "text-embedding-ada-002"
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
if self.config.deployment_name:
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
else:
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
if api_key is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError(
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
) # noqa:E501
embedding_fn = OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
api_key=api_key,
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name=self.config.model,
)

View File

@@ -1,4 +1,5 @@
import json
import os
from typing import Any, Dict, Optional
from langchain.chat_models import ChatOpenAI
@@ -30,6 +31,7 @@ class OpenAILlm(BaseLlm):
"max_tokens": config.max_tokens,
"model_kwargs": {},
}
api_key = config.api_key or os.environ["OPENAI_API_KEY"]
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
@@ -37,9 +39,9 @@ class OpenAILlm(BaseLlm):
StreamingStdOutCallbackHandler
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
else:
chat = ChatOpenAI(**kwargs)
chat = ChatOpenAI(**kwargs, api_key=api_key)
if self.functions is not None:
from langchain.chains.openai_functions import \
create_openai_fn_runnable

View File

@@ -403,6 +403,7 @@ def validate_config(config_data):
Optional("deployment_name"): str,
Optional("where"): dict,
Optional("query_type"): str,
Optional("api_key"): str,
},
},
Optional("vectordb"): {
@@ -416,6 +417,7 @@ def validate_config(config_data):
Optional("config"): {
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
Optional("api_key"): str,
},
},
Optional("embedding_model"): {
@@ -423,6 +425,7 @@ def validate_config(config_data):
Optional("config"): {
Optional("model"): str,
Optional("deployment_name"): str,
Optional("api_key"): str,
},
},
Optional("chunker"): {