[Feature] Add support for Google Gemini (#1009)
This commit is contained in:
8
configs/google.yaml
Normal file
8
configs/google.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
llm:
|
||||
provider: google
|
||||
config:
|
||||
model: gemini-pro
|
||||
max_tokens: 1000
|
||||
temperature: 0.9
|
||||
top_p: 1.0
|
||||
stream: false
|
||||
@@ -8,6 +8,7 @@ Embedchain comes with built-in support for various popular large language models
|
||||
|
||||
<CardGroup cols={4}>
|
||||
<Card title="OpenAI" href="#openai"></Card>
|
||||
<Card title="Google AI" href="#google-ai"></Card>
|
||||
<Card title="Azure OpenAI" href="#azure-openai"></Card>
|
||||
<Card title="Anthropic" href="#anthropic"></Card>
|
||||
<Card title="Cohere" href="#cohere"></Card>
|
||||
@@ -62,6 +63,41 @@ llm:
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Google AI
|
||||
|
||||
To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
|
||||
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxxx"
|
||||
os.environ["GOOGLE_API_KEY"] = "xxx"
|
||||
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
|
||||
app.add("https://www.forbes.com/profile/elon-musk")
|
||||
|
||||
response = app.query("What is the net worth of Elon Musk?")
|
||||
if app.llm.config.stream: # if stream is enabled, response is a generator
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
else:
|
||||
print(response)
|
||||
```
|
||||
|
||||
```yaml config.yaml
|
||||
llm:
|
||||
provider: google
|
||||
config:
|
||||
model: gemini-pro
|
||||
max_tokens: 1000
|
||||
temperature: 0.5
|
||||
top_p: 1
|
||||
stream: false
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
## Azure OpenAI
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ class LlmFactory:
|
||||
"llama2": "embedchain.llm.llama2.Llama2Llm",
|
||||
"openai": "embedchain.llm.openai.OpenAILlm",
|
||||
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
|
||||
"google": "embedchain.llm.google.GoogleLlm",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
||||
|
||||
@@ -217,7 +217,6 @@ class BaseLlm(JSONSerializable):
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
|
||||
64
embedchain/llm/google.py
Normal file
64
embedchain/llm/google.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "GOOGLE_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
|
||||
|
||||
try:
|
||||
importlib.import_module("google.generativeai")
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for GoogleLlm are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[google]"`'
|
||||
) from None
|
||||
|
||||
super().__init__(config)
|
||||
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
if self.config.system_prompt:
|
||||
raise ValueError("GoogleLlm does not support `system_prompt`")
|
||||
return GoogleLlm._get_answer(prompt, self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig):
|
||||
model_name = config.model or "gemini-pro"
|
||||
logging.info(f"Using Google LLM model: {model_name}")
|
||||
model = genai.GenerativeModel(model_name=model_name)
|
||||
|
||||
generation_config_params = {
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": config.max_tokens,
|
||||
"temperature": config.temperature or 0.5,
|
||||
}
|
||||
|
||||
if config.top_p >= 0.0 and config.top_p <= 1.0:
|
||||
generation_config_params["top_p"] = config.top_p
|
||||
else:
|
||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||
|
||||
generation_config = genai.types.GenerationConfig(**generation_config_params)
|
||||
|
||||
response = model.generate_content(
|
||||
prompt,
|
||||
generation_config=generation_config,
|
||||
stream=config.stream,
|
||||
)
|
||||
|
||||
if config.stream:
|
||||
for chunk in response:
|
||||
yield chunk.text
|
||||
else:
|
||||
return response.text
|
||||
@@ -387,6 +387,7 @@ def validate_config(config_data):
|
||||
"jina",
|
||||
"llama2",
|
||||
"vertexai",
|
||||
"google",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
|
||||
63
poetry.lock
generated
63
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@@ -1762,6 +1762,22 @@ gitdb = ">=4.0.1,<5"
|
||||
[package.extras]
|
||||
test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
|
||||
|
||||
[[package]]
|
||||
name = "google-ai-generativelanguage"
|
||||
version = "0.4.0"
|
||||
description = "Google Ai Generativelanguage API client library"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"},
|
||||
{file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||
proto-plus = ">=1.22.3,<2.0.0dev"
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
||||
|
||||
[[package]]
|
||||
name = "google-api-core"
|
||||
version = "2.12.0"
|
||||
@@ -1777,12 +1793,12 @@ files = [
|
||||
google-auth = ">=2.14.1,<3.0.dev0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||
grpcio = [
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
||||
requests = ">=2.18.0,<3.0.0.dev0"
|
||||
@@ -1870,8 +1886,8 @@ google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras =
|
||||
google-cloud-core = ">=1.6.0,<3.0.0dev"
|
||||
google-resumable-media = ">=0.6.0,<3.0dev"
|
||||
grpcio = [
|
||||
{version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
packaging = ">=20.0.0"
|
||||
proto-plus = ">=1.15.0,<2.0.0dev"
|
||||
@@ -1922,8 +1938,8 @@ files = [
|
||||
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
||||
|
||||
@@ -2029,6 +2045,26 @@ files = [
|
||||
[package.extras]
|
||||
testing = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "google-generativeai"
|
||||
version = "0.3.1"
|
||||
description = "Google Generative AI High level API client library and tools."
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
google-ai-generativelanguage = "0.4.0"
|
||||
google-api-core = "*"
|
||||
google-auth = "*"
|
||||
protobuf = "*"
|
||||
tqdm = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "google-resumable-media"
|
||||
version = "2.6.0"
|
||||
@@ -4004,11 +4040,13 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
|
||||
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
|
||||
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
|
||||
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
|
||||
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4202,8 +4240,8 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@@ -6344,7 +6382,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
|
||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
@@ -7820,6 +7858,7 @@ discord = ["discord"]
|
||||
elasticsearch = ["elasticsearch"]
|
||||
github = ["PyGithub", "gitpython"]
|
||||
gmail = ["llama-hub", "requests"]
|
||||
google = ["google-generativeai"]
|
||||
huggingface-hub = ["huggingface_hub"]
|
||||
images = ["ftfy", "pillow", "regex", "torch", "torchvision"]
|
||||
json = ["llama-hub"]
|
||||
@@ -7843,4 +7882,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.12"
|
||||
content-hash = "776ae7f49adab8a5dc98f6fe7c2887d2e700fd2d7c447383ea81ef05a463c8f3"
|
||||
content-hash = "846cca158ccd7a2ecc3d0a08218d273846aad15e0ac5c19ddeb1b5de00fa9a3f"
|
||||
|
||||
@@ -143,6 +143,7 @@ PyGithub = { version = "^1.59.1", optional = true }
|
||||
feedparser = { version = "^6.0.10", optional = true }
|
||||
newspaper3k = { version = "^0.2.8", optional = true }
|
||||
listparser = { version = "^0.19", optional = true }
|
||||
google-generativeai = { version = "^0.3.0", optional = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.3.0"
|
||||
@@ -204,7 +205,12 @@ youtube = [
|
||||
"yt_dlp",
|
||||
"youtube-transcript-api",
|
||||
]
|
||||
rss_feed = ["feedparser", "listparser", "newspaper3k"]
|
||||
rss_feed = [
|
||||
"feedparser",
|
||||
"listparser",
|
||||
"newspaper3k"
|
||||
]
|
||||
google = ["google-generativeai"]
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
|
||||
|
||||
43
tests/llm/test_google.py
Normal file
43
tests/llm/test_google.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.google import GoogleLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_llm_config():
|
||||
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||
|
||||
|
||||
def test_google_llm_init_missing_api_key(monkeypatch):
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
|
||||
GoogleLlm()
|
||||
|
||||
|
||||
def test_google_llm_init(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr("importlib.import_module", lambda x: None)
|
||||
google_llm = GoogleLlm()
|
||||
assert google_llm is not None
|
||||
|
||||
|
||||
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
monkeypatch.setattr("importlib.import_module", lambda x: None)
|
||||
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
|
||||
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
|
||||
google_llm.get_llm_model_answer("test prompt")
|
||||
|
||||
|
||||
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
|
||||
def mock_get_answer(prompt, config):
|
||||
return "Generated Text"
|
||||
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
|
||||
google_llm = GoogleLlm(config=google_llm_config)
|
||||
result = google_llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
Reference in New Issue
Block a user