Add support for loading api_key from config or env variable (#1421)

This commit is contained in:
Dev Khant
2024-06-13 23:49:54 +05:30
committed by GitHub
parent 08b67b4a78
commit 2855f1635b
10 changed files with 37 additions and 32 deletions

View File

@@ -17,18 +17,17 @@ logger = logging.getLogger(__name__)
@register_deserializable
class AnthropicLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "ANTHROPIC_API_KEY" not in os.environ:
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
super().__init__(config=config)
if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
chat = ChatAnthropic(
anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model
)
api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY")
chat = ChatAnthropic(anthropic_api_key=api_key, temperature=config.temperature, model_name=config.model)
if config.max_tokens and config.max_tokens != 1000:
logger.warning("Config option `max_tokens` is not supported by this model.")

View File

@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class CohereLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "COHERE_API_KEY" not in os.environ:
raise ValueError("Please set the COHERE_API_KEY environment variable.")
try:
importlib.import_module("cohere")
except ModuleNotFoundError:
@@ -24,6 +21,8 @@ class CohereLlm(BaseLlm):
) from None
super().__init__(config=config)
if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
if self.config.system_prompt:
@@ -32,8 +31,9 @@ class CohereLlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
api_key = config.api_key or os.getenv("COHERE_API_KEY")
llm = Cohere(
cohere_api_key=os.environ["COHERE_API_KEY"],
cohere_api_key=api_key,
model=config.model,
max_tokens=config.max_tokens,
temperature=config.temperature,

View File

@@ -16,9 +16,6 @@ logger = logging.getLogger(__name__)
@register_deserializable
class GoogleLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "GOOGLE_API_KEY" not in os.environ:
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
try:
importlib.import_module("google.generativeai")
except ModuleNotFoundError:
@@ -28,7 +25,11 @@ class GoogleLlm(BaseLlm):
) from None
super().__init__(config)
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key)
def get_llm_model_answer(self, prompt):
if self.config.system_prompt:

View File

@@ -19,6 +19,8 @@ from embedchain.llm.base import BaseLlm
class GroqLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt) -> str:
response = self._get_answer(prompt, self.config)

View File

@@ -17,9 +17,6 @@ logger = logging.getLogger(__name__)
@register_deserializable
class HuggingFaceLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable.")
try:
importlib.import_module("huggingface_hub")
except ModuleNotFoundError:
@@ -29,6 +26,8 @@ class HuggingFaceLlm(BaseLlm):
) from None
super().__init__(config=config)
if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
if self.config.system_prompt:
@@ -60,9 +59,10 @@ class HuggingFaceLlm(BaseLlm):
raise ValueError("`top_p` must be > 0.0 and < 1.0")
model = config.model
api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
logger.info(f"Using HuggingFaceHub with model {model}")
llm = HuggingFaceHub(
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
huggingfacehub_api_token=api_key,
repo_id=model,
model_kwargs=model_kwargs,
)
@@ -70,8 +70,9 @@ class HuggingFaceLlm(BaseLlm):
@staticmethod
def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
llm = HuggingFaceEndpoint(
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
huggingfacehub_api_token=api_key,
endpoint_url=config.endpoint,
task="text-generation",
model_kwargs=config.model_kwargs,

View File

@@ -12,9 +12,9 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class JinaLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "JINACHAT_API_KEY" not in os.environ:
raise ValueError("Please set the JINACHAT_API_KEY environment variable.")
super().__init__(config=config)
if not self.config.api_key and "JINACHAT_API_KEY" not in os.environ:
raise ValueError("Please set the JINACHAT_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
response = JinaLlm._get_answer(prompt, self.config)
@@ -29,13 +29,13 @@ class JinaLlm(BaseLlm):
kwargs = {
"temperature": config.temperature,
"max_tokens": config.max_tokens,
"jinachat_api_key": config.api_key or os.environ["JINACHAT_API_KEY"],
"model_kwargs": {},
}
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:

View File

@@ -19,8 +19,6 @@ class Llama2Llm(BaseLlm):
"The required dependencies for Llama2 are not installed."
'Please install with `pip install --upgrade "embedchain[llama2]"`'
) from None
if "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
# Set default config values specific to this llm
if not config:
@@ -35,13 +33,17 @@ class Llama2Llm(BaseLlm):
)
super().__init__(config=config)
if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
# TODO: Move the model and other inputs into config
if self.config.system_prompt:
raise ValueError("Llama2 does not support `system_prompt`")
api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN")
llm = Replicate(
model=self.config.model,
replicate_api_token=api_key,
input={
"temperature": self.config.temperature,
"max_length": self.config.max_tokens,

View File

@@ -21,10 +21,9 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class NvidiaLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "NVIDIA_API_KEY" not in os.environ:
raise ValueError("NVIDIA_API_KEY environment variable must be set")
super().__init__(config=config)
if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
@@ -34,7 +33,7 @@ class NvidiaLlm(BaseLlm):
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
model_kwargs = config.model_kwargs or {}
labels = model_kwargs.get("labels", None)
params = {"model": config.model}
params = {"model": config.model, "nvidia_api_key": config.api_key or os.getenv("NVIDIA_API_KEY")}
if config.system_prompt:
params["system_prompt"] = config.system_prompt
if config.temperature:

View File

@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
@register_deserializable
class TogetherLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "TOGETHER_API_KEY" not in os.environ:
raise ValueError("Please set the TOGETHER_API_KEY environment variable.")
try:
importlib.import_module("together")
except ModuleNotFoundError:
@@ -24,6 +21,8 @@ class TogetherLlm(BaseLlm):
) from None
super().__init__(config=config)
if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
def get_llm_model_answer(self, prompt):
if self.config.system_prompt:
@@ -32,8 +31,9 @@ class TogetherLlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
api_key = config.api_key or os.getenv("TOGETHER_API_KEY")
llm = Together(
together_api_key=os.environ["TOGETHER_API_KEY"],
together_api_key=api_key,
model=config.model,
max_tokens=config.max_tokens,
temperature=config.temperature,

View File

@@ -74,5 +74,6 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
mocked_jinachat.assert_called_once_with(
temperature=config.temperature,
max_tokens=config.max_tokens,
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
model_kwargs={"top_p": config.top_p},
)