[Feature] Add support for Google Gemini (#1009)
This commit is contained in:
8
configs/google.yaml
Normal file
8
configs/google.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
llm:
|
||||||
|
provider: google
|
||||||
|
config:
|
||||||
|
model: gemini-pro
|
||||||
|
max_tokens: 1000
|
||||||
|
temperature: 0.9
|
||||||
|
top_p: 1.0
|
||||||
|
stream: false
|
||||||
@@ -8,6 +8,7 @@ Embedchain comes with built-in support for various popular large language models
|
|||||||
|
|
||||||
<CardGroup cols={4}>
|
<CardGroup cols={4}>
|
||||||
<Card title="OpenAI" href="#openai"></Card>
|
<Card title="OpenAI" href="#openai"></Card>
|
||||||
|
<Card title="Google AI" href="#google-ai"></Card>
|
||||||
<Card title="Azure OpenAI" href="#azure-openai"></Card>
|
<Card title="Azure OpenAI" href="#azure-openai"></Card>
|
||||||
<Card title="Anthropic" href="#anthropic"></Card>
|
<Card title="Anthropic" href="#anthropic"></Card>
|
||||||
<Card title="Cohere" href="#cohere"></Card>
|
<Card title="Cohere" href="#cohere"></Card>
|
||||||
@@ -62,6 +63,41 @@ llm:
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Google AI
|
||||||
|
|
||||||
|
To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
```python main.py
|
||||||
|
import os
|
||||||
|
from embedchain import Pipeline as App
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xxxx"
|
||||||
|
os.environ["GOOGLE_API_KEY"] = "xxx"
|
||||||
|
|
||||||
|
app = App.from_config(config_path="config.yaml")
|
||||||
|
|
||||||
|
app.add("https://www.forbes.com/profile/elon-musk")
|
||||||
|
|
||||||
|
response = app.query("What is the net worth of Elon Musk?")
|
||||||
|
if app.llm.config.stream: # if stream is enabled, response is a generator
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
else:
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml config.yaml
|
||||||
|
llm:
|
||||||
|
provider: google
|
||||||
|
config:
|
||||||
|
model: gemini-pro
|
||||||
|
max_tokens: 1000
|
||||||
|
temperature: 0.5
|
||||||
|
top_p: 1
|
||||||
|
stream: false
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
## Azure OpenAI
|
## Azure OpenAI
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class LlmFactory:
|
|||||||
"llama2": "embedchain.llm.llama2.Llama2Llm",
|
"llama2": "embedchain.llm.llama2.Llama2Llm",
|
||||||
"openai": "embedchain.llm.openai.OpenAILlm",
|
"openai": "embedchain.llm.openai.OpenAILlm",
|
||||||
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
||||||
|
"google": "embedchain.llm.google.GoogleLlm",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
||||||
|
|||||||
@@ -217,7 +217,6 @@ class BaseLlm(JSONSerializable):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
answer = self.get_answer_from_llm(prompt)
|
answer = self.get_answer_from_llm(prompt)
|
||||||
|
|
||||||
if isinstance(answer, str):
|
if isinstance(answer, str):
|
||||||
logging.info(f"Answer: {answer}")
|
logging.info(f"Answer: {answer}")
|
||||||
return answer
|
return answer
|
||||||
|
|||||||
64
embedchain/llm/google.py
Normal file
64
embedchain/llm/google.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
from embedchain.config import BaseLlmConfig
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
from embedchain.llm.base import BaseLlm
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class GoogleLlm(BaseLlm):
|
||||||
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
if "GOOGLE_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
importlib.import_module("google.generativeai")
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"The required dependencies for GoogleLlm are not installed."
|
||||||
|
'Please install with `pip install --upgrade "embedchain[google]"`'
|
||||||
|
) from None
|
||||||
|
|
||||||
|
super().__init__(config)
|
||||||
|
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
||||||
|
|
||||||
|
def get_llm_model_answer(self, prompt):
|
||||||
|
if self.config.system_prompt:
|
||||||
|
raise ValueError("GoogleLlm does not support `system_prompt`")
|
||||||
|
return GoogleLlm._get_answer(prompt, self.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_answer(prompt: str, config: BaseLlmConfig):
|
||||||
|
model_name = config.model or "gemini-pro"
|
||||||
|
logging.info(f"Using Google LLM model: {model_name}")
|
||||||
|
model = genai.GenerativeModel(model_name=model_name)
|
||||||
|
|
||||||
|
generation_config_params = {
|
||||||
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": config.max_tokens,
|
||||||
|
"temperature": config.temperature or 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.top_p >= 0.0 and config.top_p <= 1.0:
|
||||||
|
generation_config_params["top_p"] = config.top_p
|
||||||
|
else:
|
||||||
|
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||||
|
|
||||||
|
generation_config = genai.types.GenerationConfig(**generation_config_params)
|
||||||
|
|
||||||
|
response = model.generate_content(
|
||||||
|
prompt,
|
||||||
|
generation_config=generation_config,
|
||||||
|
stream=config.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.stream:
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk.text
|
||||||
|
else:
|
||||||
|
return response.text
|
||||||
@@ -387,6 +387,7 @@ def validate_config(config_data):
|
|||||||
"jina",
|
"jina",
|
||||||
"llama2",
|
"llama2",
|
||||||
"vertexai",
|
"vertexai",
|
||||||
|
"google",
|
||||||
),
|
),
|
||||||
Optional("config"): {
|
Optional("config"): {
|
||||||
Optional("model"): str,
|
Optional("model"): str,
|
||||||
|
|||||||
63
poetry.lock
generated
63
poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiofiles"
|
name = "aiofiles"
|
||||||
@@ -1762,6 +1762,22 @@ gitdb = ">=4.0.1,<5"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
|
test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "google-ai-generativelanguage"
|
||||||
|
version = "0.4.0"
|
||||||
|
description = "Google Ai Generativelanguage API client library"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"},
|
||||||
|
{file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||||
|
proto-plus = ">=1.22.3,<2.0.0dev"
|
||||||
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "google-api-core"
|
name = "google-api-core"
|
||||||
version = "2.12.0"
|
version = "2.12.0"
|
||||||
@@ -1777,12 +1793,12 @@ files = [
|
|||||||
google-auth = ">=2.14.1,<3.0.dev0"
|
google-auth = ">=2.14.1,<3.0.dev0"
|
||||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||||
grpcio = [
|
grpcio = [
|
||||||
|
{version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
|
||||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
|
||||||
]
|
]
|
||||||
grpcio-status = [
|
grpcio-status = [
|
||||||
|
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
|
||||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
|
||||||
]
|
]
|
||||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
||||||
requests = ">=2.18.0,<3.0.0.dev0"
|
requests = ">=2.18.0,<3.0.0.dev0"
|
||||||
@@ -1870,8 +1886,8 @@ google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras =
|
|||||||
google-cloud-core = ">=1.6.0,<3.0.0dev"
|
google-cloud-core = ">=1.6.0,<3.0.0dev"
|
||||||
google-resumable-media = ">=0.6.0,<3.0dev"
|
google-resumable-media = ">=0.6.0,<3.0dev"
|
||||||
grpcio = [
|
grpcio = [
|
||||||
{version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
|
|
||||||
{version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
|
{version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
|
||||||
|
{version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
|
||||||
]
|
]
|
||||||
packaging = ">=20.0.0"
|
packaging = ">=20.0.0"
|
||||||
proto-plus = ">=1.15.0,<2.0.0dev"
|
proto-plus = ">=1.15.0,<2.0.0dev"
|
||||||
@@ -1922,8 +1938,8 @@ files = [
|
|||||||
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||||
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
|
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
|
||||||
proto-plus = [
|
proto-plus = [
|
||||||
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
|
|
||||||
{version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
|
{version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
|
||||||
|
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
|
||||||
]
|
]
|
||||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
||||||
|
|
||||||
@@ -2029,6 +2045,26 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["pytest"]
|
testing = ["pytest"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "google-generativeai"
|
||||||
|
version = "0.3.1"
|
||||||
|
description = "Google Generative AI High level API client library and tools."
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
google-ai-generativelanguage = "0.4.0"
|
||||||
|
google-api-core = "*"
|
||||||
|
google-auth = "*"
|
||||||
|
protobuf = "*"
|
||||||
|
tqdm = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "google-resumable-media"
|
name = "google-resumable-media"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
@@ -4004,11 +4040,13 @@ files = [
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = [
|
numpy = [
|
||||||
|
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
|
||||||
|
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
|
||||||
|
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
|
||||||
|
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
|
||||||
|
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
|
||||||
|
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
|
||||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
|
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
|
||||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
|
||||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
|
||||||
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
|
|
||||||
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4202,8 +4240,8 @@ files = [
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = [
|
numpy = [
|
||||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
|
||||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||||
|
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||||
]
|
]
|
||||||
python-dateutil = ">=2.8.2"
|
python-dateutil = ">=2.8.2"
|
||||||
pytz = ">=2020.1"
|
pytz = ">=2020.1"
|
||||||
@@ -6344,7 +6382,7 @@ files = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
|
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
|
||||||
typing-extensions = ">=4.2.0"
|
typing-extensions = ">=4.2.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
@@ -7820,6 +7858,7 @@ discord = ["discord"]
|
|||||||
elasticsearch = ["elasticsearch"]
|
elasticsearch = ["elasticsearch"]
|
||||||
github = ["PyGithub", "gitpython"]
|
github = ["PyGithub", "gitpython"]
|
||||||
gmail = ["llama-hub", "requests"]
|
gmail = ["llama-hub", "requests"]
|
||||||
|
google = ["google-generativeai"]
|
||||||
huggingface-hub = ["huggingface_hub"]
|
huggingface-hub = ["huggingface_hub"]
|
||||||
images = ["ftfy", "pillow", "regex", "torch", "torchvision"]
|
images = ["ftfy", "pillow", "regex", "torch", "torchvision"]
|
||||||
json = ["llama-hub"]
|
json = ["llama-hub"]
|
||||||
@@ -7843,4 +7882,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.12"
|
python-versions = ">=3.9,<3.12"
|
||||||
content-hash = "776ae7f49adab8a5dc98f6fe7c2887d2e700fd2d7c447383ea81ef05a463c8f3"
|
content-hash = "846cca158ccd7a2ecc3d0a08218d273846aad15e0ac5c19ddeb1b5de00fa9a3f"
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ PyGithub = { version = "^1.59.1", optional = true }
|
|||||||
feedparser = { version = "^6.0.10", optional = true }
|
feedparser = { version = "^6.0.10", optional = true }
|
||||||
newspaper3k = { version = "^0.2.8", optional = true }
|
newspaper3k = { version = "^0.2.8", optional = true }
|
||||||
listparser = { version = "^0.19", optional = true }
|
listparser = { version = "^0.19", optional = true }
|
||||||
|
google-generativeai = { version = "^0.3.0", optional = true }
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
black = "^23.3.0"
|
black = "^23.3.0"
|
||||||
@@ -204,7 +205,12 @@ youtube = [
|
|||||||
"yt_dlp",
|
"yt_dlp",
|
||||||
"youtube-transcript-api",
|
"youtube-transcript-api",
|
||||||
]
|
]
|
||||||
rss_feed = ["feedparser", "listparser", "newspaper3k"]
|
rss_feed = [
|
||||||
|
"feedparser",
|
||||||
|
"listparser",
|
||||||
|
"newspaper3k"
|
||||||
|
]
|
||||||
|
google = ["google-generativeai"]
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
|
|
||||||
|
|||||||
43
tests/llm/test_google.py
Normal file
43
tests/llm/test_google.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from embedchain.config import BaseLlmConfig
|
||||||
|
from embedchain.llm.google import GoogleLlm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def google_llm_config():
|
||||||
|
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_llm_init_missing_api_key(monkeypatch):
|
||||||
|
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||||
|
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
|
||||||
|
GoogleLlm()
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_llm_init(monkeypatch):
|
||||||
|
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr("importlib.import_module", lambda x: None)
|
||||||
|
google_llm = GoogleLlm()
|
||||||
|
assert google_llm is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
|
||||||
|
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||||
|
monkeypatch.setattr("importlib.import_module", lambda x: None)
|
||||||
|
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
|
||||||
|
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
|
||||||
|
google_llm.get_llm_model_answer("test prompt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
|
||||||
|
def mock_get_answer(prompt, config):
|
||||||
|
return "Generated Text"
|
||||||
|
|
||||||
|
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||||
|
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
|
||||||
|
google_llm = GoogleLlm(config=google_llm_config)
|
||||||
|
result = google_llm.get_llm_model_answer("test prompt")
|
||||||
|
|
||||||
|
assert result == "Generated Text"
|
||||||
Reference in New Issue
Block a user