Add support for loading api_key from config or env variable (#1421)
This commit is contained in:
@@ -17,18 +17,17 @@ logger = logging.getLogger(__name__)
|
||||
@register_deserializable
|
||||
class AnthropicLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "ANTHROPIC_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
chat = ChatAnthropic(
|
||||
anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model
|
||||
)
|
||||
api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
chat = ChatAnthropic(anthropic_api_key=api_key, temperature=config.temperature, model_name=config.model)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logger.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
|
||||
@register_deserializable
|
||||
class CohereLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "COHERE_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the COHERE_API_KEY environment variable.")
|
||||
|
||||
try:
|
||||
importlib.import_module("cohere")
|
||||
except ModuleNotFoundError:
|
||||
@@ -24,6 +21,8 @@ class CohereLlm(BaseLlm):
|
||||
) from None
|
||||
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
if self.config.system_prompt:
|
||||
@@ -32,8 +31,9 @@ class CohereLlm(BaseLlm):
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
api_key = config.api_key or os.getenv("COHERE_API_KEY")
|
||||
llm = Cohere(
|
||||
cohere_api_key=os.environ["COHERE_API_KEY"],
|
||||
cohere_api_key=api_key,
|
||||
model=config.model,
|
||||
max_tokens=config.max_tokens,
|
||||
temperature=config.temperature,
|
||||
|
||||
@@ -16,9 +16,6 @@ logger = logging.getLogger(__name__)
|
||||
@register_deserializable
|
||||
class GoogleLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "GOOGLE_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
|
||||
|
||||
try:
|
||||
importlib.import_module("google.generativeai")
|
||||
except ModuleNotFoundError:
|
||||
@@ -28,7 +25,11 @@ class GoogleLlm(BaseLlm):
|
||||
) from None
|
||||
|
||||
super().__init__(config)
|
||||
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
||||
if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
||||
genai.configure(api_key=api_key)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
if self.config.system_prompt:
|
||||
|
||||
@@ -19,6 +19,8 @@ from embedchain.llm.base import BaseLlm
|
||||
class GroqLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt) -> str:
|
||||
response = self._get_answer(prompt, self.config)
|
||||
|
||||
@@ -17,9 +17,6 @@ logger = logging.getLogger(__name__)
|
||||
@register_deserializable
|
||||
class HuggingFaceLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable.")
|
||||
|
||||
try:
|
||||
importlib.import_module("huggingface_hub")
|
||||
except ModuleNotFoundError:
|
||||
@@ -29,6 +26,8 @@ class HuggingFaceLlm(BaseLlm):
|
||||
) from None
|
||||
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
if self.config.system_prompt:
|
||||
@@ -60,9 +59,10 @@ class HuggingFaceLlm(BaseLlm):
|
||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||
|
||||
model = config.model
|
||||
api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
|
||||
logger.info(f"Using HuggingFaceHub with model {model}")
|
||||
llm = HuggingFaceHub(
|
||||
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
|
||||
huggingfacehub_api_token=api_key,
|
||||
repo_id=model,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
@@ -70,8 +70,9 @@ class HuggingFaceLlm(BaseLlm):
|
||||
|
||||
@staticmethod
|
||||
def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
|
||||
api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
|
||||
llm = HuggingFaceEndpoint(
|
||||
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
|
||||
huggingfacehub_api_token=api_key,
|
||||
endpoint_url=config.endpoint,
|
||||
task="text-generation",
|
||||
model_kwargs=config.model_kwargs,
|
||||
|
||||
@@ -12,9 +12,9 @@ from embedchain.llm.base import BaseLlm
|
||||
@register_deserializable
|
||||
class JinaLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "JINACHAT_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the JINACHAT_API_KEY environment variable.")
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "JINACHAT_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the JINACHAT_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
response = JinaLlm._get_answer(prompt, self.config)
|
||||
@@ -29,13 +29,13 @@ class JinaLlm(BaseLlm):
|
||||
kwargs = {
|
||||
"temperature": config.temperature,
|
||||
"max_tokens": config.max_tokens,
|
||||
"jinachat_api_key": config.api_key or os.environ["JINACHAT_API_KEY"],
|
||||
"model_kwargs": {},
|
||||
}
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
|
||||
@@ -19,8 +19,6 @@ class Llama2Llm(BaseLlm):
|
||||
"The required dependencies for Llama2 are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[llama2]"`'
|
||||
) from None
|
||||
if "REPLICATE_API_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
|
||||
|
||||
# Set default config values specific to this llm
|
||||
if not config:
|
||||
@@ -35,13 +33,17 @@ class Llama2Llm(BaseLlm):
|
||||
)
|
||||
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
# TODO: Move the model and other inputs into config
|
||||
if self.config.system_prompt:
|
||||
raise ValueError("Llama2 does not support `system_prompt`")
|
||||
api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN")
|
||||
llm = Replicate(
|
||||
model=self.config.model,
|
||||
replicate_api_token=api_key,
|
||||
input={
|
||||
"temperature": self.config.temperature,
|
||||
"max_length": self.config.max_tokens,
|
||||
|
||||
@@ -21,10 +21,9 @@ from embedchain.llm.base import BaseLlm
|
||||
@register_deserializable
|
||||
class NvidiaLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "NVIDIA_API_KEY" not in os.environ:
|
||||
raise ValueError("NVIDIA_API_KEY environment variable must be set")
|
||||
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return self._get_answer(prompt=prompt, config=self.config)
|
||||
@@ -34,7 +33,7 @@ class NvidiaLlm(BaseLlm):
|
||||
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
|
||||
model_kwargs = config.model_kwargs or {}
|
||||
labels = model_kwargs.get("labels", None)
|
||||
params = {"model": config.model}
|
||||
params = {"model": config.model, "nvidia_api_key": config.api_key or os.getenv("NVIDIA_API_KEY")}
|
||||
if config.system_prompt:
|
||||
params["system_prompt"] = config.system_prompt
|
||||
if config.temperature:
|
||||
|
||||
@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
|
||||
@register_deserializable
|
||||
class TogetherLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "TOGETHER_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the TOGETHER_API_KEY environment variable.")
|
||||
|
||||
try:
|
||||
importlib.import_module("together")
|
||||
except ModuleNotFoundError:
|
||||
@@ -24,6 +21,8 @@ class TogetherLlm(BaseLlm):
|
||||
) from None
|
||||
|
||||
super().__init__(config=config)
|
||||
if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
if self.config.system_prompt:
|
||||
@@ -32,8 +31,9 @@ class TogetherLlm(BaseLlm):
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
api_key = config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||
llm = Together(
|
||||
together_api_key=os.environ["TOGETHER_API_KEY"],
|
||||
together_api_key=api_key,
|
||||
model=config.model,
|
||||
max_tokens=config.max_tokens,
|
||||
temperature=config.temperature,
|
||||
|
||||
@@ -74,5 +74,6 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
mocked_jinachat.assert_called_once_with(
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user