Provide openai-key support from config (#1052)
This commit is contained in:
@@ -5,7 +5,9 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@register_deserializable
|
||||
class BaseEmbedderConfig:
|
||||
def __init__(self, model: Optional[str] = None, deployment_name: Optional[str] = None):
|
||||
def __init__(
|
||||
self, model: Optional[str] = None, deployment_name: Optional[str] = None, api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of an embedder config class.
|
||||
|
||||
@@ -16,3 +18,4 @@ class BaseEmbedderConfig:
|
||||
"""
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
self.api_key = api_key
|
||||
|
||||
@@ -69,6 +69,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
where: Dict[str, Any] = None,
|
||||
query_type: Optional[str] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -117,6 +118,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.system_prompt = system_prompt
|
||||
self.query_type = query_type
|
||||
self.callbacks = callbacks
|
||||
self.api_key = api_key
|
||||
|
||||
if type(template) is str:
|
||||
template = Template(template)
|
||||
|
||||
@@ -16,16 +16,18 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
|
||||
|
||||
if self.config.deployment_name:
|
||||
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
else:
|
||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
if api_key is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
|
||||
) # noqa:E501
|
||||
embedding_fn = OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
api_key=api_key,
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name=self.config.model,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
@@ -30,6 +31,7 @@ class OpenAILlm(BaseLlm):
|
||||
"max_tokens": config.max_tokens,
|
||||
"model_kwargs": {},
|
||||
}
|
||||
api_key = config.api_key or os.environ["OPENAI_API_KEY"]
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
@@ -37,9 +39,9 @@ class OpenAILlm(BaseLlm):
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs)
|
||||
chat = ChatOpenAI(**kwargs, api_key=api_key)
|
||||
if self.functions is not None:
|
||||
from langchain.chains.openai_functions import \
|
||||
create_openai_fn_runnable
|
||||
|
||||
@@ -403,6 +403,7 @@ def validate_config(config_data):
|
||||
Optional("deployment_name"): str,
|
||||
Optional("where"): dict,
|
||||
Optional("query_type"): str,
|
||||
Optional("api_key"): str,
|
||||
},
|
||||
},
|
||||
Optional("vectordb"): {
|
||||
@@ -416,6 +417,7 @@ def validate_config(config_data):
|
||||
Optional("config"): {
|
||||
Optional("model"): Optional(str),
|
||||
Optional("deployment_name"): Optional(str),
|
||||
Optional("api_key"): str,
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
@@ -423,6 +425,7 @@ def validate_config(config_data):
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
Optional("deployment_name"): str,
|
||||
Optional("api_key"): str,
|
||||
},
|
||||
},
|
||||
Optional("chunker"): {
|
||||
|
||||
Reference in New Issue
Block a user