Add support for loading api_key from config or env variable (#1421)
This commit is contained in:
@@ -17,18 +17,17 @@ logger = logging.getLogger(__name__)
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class AnthropicLlm(BaseLlm):
|
class AnthropicLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "ANTHROPIC_API_KEY" not in os.environ:
|
|
||||||
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
|
return AnthropicLlm._get_answer(prompt=prompt, config=self.config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||||
chat = ChatAnthropic(
|
api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||||
anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model
|
chat = ChatAnthropic(anthropic_api_key=api_key, temperature=config.temperature, model_name=config.model)
|
||||||
)
|
|
||||||
|
|
||||||
if config.max_tokens and config.max_tokens != 1000:
|
if config.max_tokens and config.max_tokens != 1000:
|
||||||
logger.warning("Config option `max_tokens` is not supported by this model.")
|
logger.warning("Config option `max_tokens` is not supported by this model.")
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class CohereLlm(BaseLlm):
|
class CohereLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "COHERE_API_KEY" not in os.environ:
|
|
||||||
raise ValueError("Please set the COHERE_API_KEY environment variable.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
importlib.import_module("cohere")
|
importlib.import_module("cohere")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@@ -24,6 +21,8 @@ class CohereLlm(BaseLlm):
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "COHERE_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
@@ -32,8 +31,9 @@ class CohereLlm(BaseLlm):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||||
|
api_key = config.api_key or os.getenv("COHERE_API_KEY")
|
||||||
llm = Cohere(
|
llm = Cohere(
|
||||||
cohere_api_key=os.environ["COHERE_API_KEY"],
|
cohere_api_key=api_key,
|
||||||
model=config.model,
|
model=config.model,
|
||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
|
|||||||
@@ -16,9 +16,6 @@ logger = logging.getLogger(__name__)
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class GoogleLlm(BaseLlm):
|
class GoogleLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "GOOGLE_API_KEY" not in os.environ:
|
|
||||||
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
importlib.import_module("google.generativeai")
|
importlib.import_module("google.generativeai")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@@ -28,7 +25,11 @@ class GoogleLlm(BaseLlm):
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
|
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ from embedchain.llm.base import BaseLlm
|
|||||||
class GroqLlm(BaseLlm):
|
class GroqLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "GROQ_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt) -> str:
|
def get_llm_model_answer(self, prompt) -> str:
|
||||||
response = self._get_answer(prompt, self.config)
|
response = self._get_answer(prompt, self.config)
|
||||||
|
|||||||
@@ -17,9 +17,6 @@ logger = logging.getLogger(__name__)
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class HuggingFaceLlm(BaseLlm):
|
class HuggingFaceLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
|
|
||||||
raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
importlib.import_module("huggingface_hub")
|
importlib.import_module("huggingface_hub")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@@ -29,6 +26,8 @@ class HuggingFaceLlm(BaseLlm):
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
|
||||||
|
raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
@@ -60,9 +59,10 @@ class HuggingFaceLlm(BaseLlm):
|
|||||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
|
api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
|
||||||
logger.info(f"Using HuggingFaceHub with model {model}")
|
logger.info(f"Using HuggingFaceHub with model {model}")
|
||||||
llm = HuggingFaceHub(
|
llm = HuggingFaceHub(
|
||||||
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
|
huggingfacehub_api_token=api_key,
|
||||||
repo_id=model,
|
repo_id=model,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -70,8 +70,9 @@ class HuggingFaceLlm(BaseLlm):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
|
def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
|
||||||
|
api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN")
|
||||||
llm = HuggingFaceEndpoint(
|
llm = HuggingFaceEndpoint(
|
||||||
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
|
huggingfacehub_api_token=api_key,
|
||||||
endpoint_url=config.endpoint,
|
endpoint_url=config.endpoint,
|
||||||
task="text-generation",
|
task="text-generation",
|
||||||
model_kwargs=config.model_kwargs,
|
model_kwargs=config.model_kwargs,
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ from embedchain.llm.base import BaseLlm
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class JinaLlm(BaseLlm):
|
class JinaLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "JINACHAT_API_KEY" not in os.environ:
|
|
||||||
raise ValueError("Please set the JINACHAT_API_KEY environment variable.")
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "JINACHAT_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the JINACHAT_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
response = JinaLlm._get_answer(prompt, self.config)
|
response = JinaLlm._get_answer(prompt, self.config)
|
||||||
@@ -29,13 +29,13 @@ class JinaLlm(BaseLlm):
|
|||||||
kwargs = {
|
kwargs = {
|
||||||
"temperature": config.temperature,
|
"temperature": config.temperature,
|
||||||
"max_tokens": config.max_tokens,
|
"max_tokens": config.max_tokens,
|
||||||
|
"jinachat_api_key": config.api_key or os.environ["JINACHAT_API_KEY"],
|
||||||
"model_kwargs": {},
|
"model_kwargs": {},
|
||||||
}
|
}
|
||||||
if config.top_p:
|
if config.top_p:
|
||||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||||
if config.stream:
|
if config.stream:
|
||||||
from langchain.callbacks.streaming_stdout import \
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
StreamingStdOutCallbackHandler
|
|
||||||
|
|
||||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ class Llama2Llm(BaseLlm):
|
|||||||
"The required dependencies for Llama2 are not installed."
|
"The required dependencies for Llama2 are not installed."
|
||||||
'Please install with `pip install --upgrade "embedchain[llama2]"`'
|
'Please install with `pip install --upgrade "embedchain[llama2]"`'
|
||||||
) from None
|
) from None
|
||||||
if "REPLICATE_API_TOKEN" not in os.environ:
|
|
||||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
|
|
||||||
|
|
||||||
# Set default config values specific to this llm
|
# Set default config values specific to this llm
|
||||||
if not config:
|
if not config:
|
||||||
@@ -35,13 +33,17 @@ class Llama2Llm(BaseLlm):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ:
|
||||||
|
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
# TODO: Move the model and other inputs into config
|
# TODO: Move the model and other inputs into config
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
raise ValueError("Llama2 does not support `system_prompt`")
|
raise ValueError("Llama2 does not support `system_prompt`")
|
||||||
|
api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN")
|
||||||
llm = Replicate(
|
llm = Replicate(
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
|
replicate_api_token=api_key,
|
||||||
input={
|
input={
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_length": self.config.max_tokens,
|
"max_length": self.config.max_tokens,
|
||||||
|
|||||||
@@ -21,10 +21,9 @@ from embedchain.llm.base import BaseLlm
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class NvidiaLlm(BaseLlm):
|
class NvidiaLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "NVIDIA_API_KEY" not in os.environ:
|
|
||||||
raise ValueError("NVIDIA_API_KEY environment variable must be set")
|
|
||||||
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
return self._get_answer(prompt=prompt, config=self.config)
|
return self._get_answer(prompt=prompt, config=self.config)
|
||||||
@@ -34,7 +33,7 @@ class NvidiaLlm(BaseLlm):
|
|||||||
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
|
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
|
||||||
model_kwargs = config.model_kwargs or {}
|
model_kwargs = config.model_kwargs or {}
|
||||||
labels = model_kwargs.get("labels", None)
|
labels = model_kwargs.get("labels", None)
|
||||||
params = {"model": config.model}
|
params = {"model": config.model, "nvidia_api_key": config.api_key or os.getenv("NVIDIA_API_KEY")}
|
||||||
if config.system_prompt:
|
if config.system_prompt:
|
||||||
params["system_prompt"] = config.system_prompt
|
params["system_prompt"] = config.system_prompt
|
||||||
if config.temperature:
|
if config.temperature:
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class TogetherLlm(BaseLlm):
|
class TogetherLlm(BaseLlm):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
if "TOGETHER_API_KEY" not in os.environ:
|
|
||||||
raise ValueError("Please set the TOGETHER_API_KEY environment variable.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
importlib.import_module("together")
|
importlib.import_module("together")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@@ -24,6 +21,8 @@ class TogetherLlm(BaseLlm):
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.")
|
||||||
|
|
||||||
def get_llm_model_answer(self, prompt):
|
def get_llm_model_answer(self, prompt):
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
@@ -32,8 +31,9 @@ class TogetherLlm(BaseLlm):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||||
|
api_key = config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||||
llm = Together(
|
llm = Together(
|
||||||
together_api_key=os.environ["TOGETHER_API_KEY"],
|
together_api_key=api_key,
|
||||||
model=config.model,
|
model=config.model,
|
||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
|
|||||||
@@ -74,5 +74,6 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
|||||||
mocked_jinachat.assert_called_once_with(
|
mocked_jinachat.assert_called_once_with(
|
||||||
temperature=config.temperature,
|
temperature=config.temperature,
|
||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
|
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
|
||||||
model_kwargs={"top_p": config.top_p},
|
model_kwargs={"top_p": config.top_p},
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user