[Feature] Add support for Google Gemini (#1009)

This commit is contained in:
Deven Patel
2023-12-15 06:10:55 +05:30
committed by GitHub
parent c0ee680546
commit 151746beec
9 changed files with 211 additions and 14 deletions

8
configs/google.yaml Normal file
View File

@@ -0,0 +1,8 @@
llm:
provider: google
config:
model: gemini-pro
max_tokens: 1000
temperature: 0.9
top_p: 1.0
stream: false

View File

@@ -8,6 +8,7 @@ Embedchain comes with built-in support for various popular large language models
<CardGroup cols={4}>
<Card title="OpenAI" href="#openai"></Card>
<Card title="Google AI" href="#google-ai"></Card>
<Card title="Azure OpenAI" href="#azure-openai"></Card>
<Card title="Anthropic" href="#anthropic"></Card>
<Card title="Cohere" href="#cohere"></Card>
@@ -62,6 +63,41 @@ llm:
</CodeGroup>
## Google AI
To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
<CodeGroup>
```python main.py
import os
from embedchain import Pipeline as App
os.environ["OPENAI_API_KEY"] = "sk-xxxx"
os.environ["GOOGLE_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
app.add("https://www.forbes.com/profile/elon-musk")
response = app.query("What is the net worth of Elon Musk?")
if app.llm.config.stream: # if stream is enabled, response is a generator
for chunk in response:
print(chunk)
else:
print(response)
```
```yaml config.yaml
llm:
provider: google
config:
model: gemini-pro
max_tokens: 1000
temperature: 0.5
top_p: 1
stream: false
```
</CodeGroup>
## Azure OpenAI

View File

@@ -18,6 +18,7 @@ class LlmFactory:
"llama2": "embedchain.llm.llama2.Llama2Llm",
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
"google": "embedchain.llm.google.GoogleLlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",

View File

@@ -217,7 +217,6 @@ class BaseLlm(JSONSerializable):
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer

64
embedchain/llm/google.py Normal file
View File

@@ -0,0 +1,64 @@
import importlib
import logging
import os
from typing import Optional
import google.generativeai as genai
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class GoogleLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "GOOGLE_API_KEY" not in os.environ:
raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
try:
importlib.import_module("google.generativeai")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for GoogleLlm are not installed."
'Please install with `pip install --upgrade "embedchain[google]"`'
) from None
super().__init__(config)
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
def get_llm_model_answer(self, prompt):
if self.config.system_prompt:
raise ValueError("GoogleLlm does not support `system_prompt`")
return GoogleLlm._get_answer(prompt, self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig):
model_name = config.model or "gemini-pro"
logging.info(f"Using Google LLM model: {model_name}")
model = genai.GenerativeModel(model_name=model_name)
generation_config_params = {
"candidate_count": 1,
"max_output_tokens": config.max_tokens,
"temperature": config.temperature or 0.5,
}
if config.top_p >= 0.0 and config.top_p <= 1.0:
generation_config_params["top_p"] = config.top_p
else:
raise ValueError("`top_p` must be > 0.0 and < 1.0")
generation_config = genai.types.GenerationConfig(**generation_config_params)
response = model.generate_content(
prompt,
generation_config=generation_config,
stream=config.stream,
)
if config.stream:
for chunk in response:
yield chunk.text
else:
return response.text

View File

@@ -387,6 +387,7 @@ def validate_config(config_data):
"jina",
"llama2",
"vertexai",
"google",
),
Optional("config"): {
Optional("model"): str,

63
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "aiofiles"
@@ -1762,6 +1762,22 @@ gitdb = ">=4.0.1,<5"
[package.extras]
test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"]
[[package]]
name = "google-ai-generativelanguage"
version = "0.4.0"
description = "Google Ai Generativelanguage API client library"
optional = true
python-versions = ">=3.7"
files = [
{file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"},
{file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"},
]
[package.dependencies]
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
[[package]]
name = "google-api-core"
version = "2.12.0"
@@ -1777,12 +1793,12 @@ files = [
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
{version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
@@ -1870,8 +1886,8 @@ google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras =
google-cloud-core = ">=1.6.0,<3.0.0dev"
google-resumable-media = ">=0.6.0,<3.0dev"
grpcio = [
{version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
{version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
{version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
]
packaging = ">=20.0.0"
proto-plus = ">=1.15.0,<2.0.0dev"
@@ -1922,8 +1938,8 @@ files = [
google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
proto-plus = [
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
{version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
{version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
@@ -2029,6 +2045,26 @@ files = [
[package.extras]
testing = ["pytest"]
[[package]]
name = "google-generativeai"
version = "0.3.1"
description = "Google Generative AI High level API client library and tools."
optional = true
python-versions = ">=3.9"
files = [
{file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
]
[package.dependencies]
google-ai-generativelanguage = "0.4.0"
google-api-core = "*"
google-auth = "*"
protobuf = "*"
tqdm = "*"
[package.extras]
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
[[package]]
name = "google-resumable-media"
version = "2.6.0"
@@ -4004,11 +4040,13 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
]
[[package]]
@@ -4202,8 +4240,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@@ -6344,7 +6382,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
typing-extensions = ">=4.2.0"
[package.extras]
@@ -7820,6 +7858,7 @@ discord = ["discord"]
elasticsearch = ["elasticsearch"]
github = ["PyGithub", "gitpython"]
gmail = ["llama-hub", "requests"]
google = ["google-generativeai"]
huggingface-hub = ["huggingface_hub"]
images = ["ftfy", "pillow", "regex", "torch", "torchvision"]
json = ["llama-hub"]
@@ -7843,4 +7882,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
content-hash = "776ae7f49adab8a5dc98f6fe7c2887d2e700fd2d7c447383ea81ef05a463c8f3"
content-hash = "846cca158ccd7a2ecc3d0a08218d273846aad15e0ac5c19ddeb1b5de00fa9a3f"

View File

@@ -143,6 +143,7 @@ PyGithub = { version = "^1.59.1", optional = true }
feedparser = { version = "^6.0.10", optional = true }
newspaper3k = { version = "^0.2.8", optional = true }
listparser = { version = "^0.19", optional = true }
google-generativeai = { version = "^0.3.0", optional = true }
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
@@ -204,7 +205,12 @@ youtube = [
"yt_dlp",
"youtube-transcript-api",
]
rss_feed = ["feedparser", "listparser", "newspaper3k"]
rss_feed = [
"feedparser",
"listparser",
"newspaper3k"
]
google = ["google-generativeai"]
[tool.poetry.group.docs.dependencies]

43
tests/llm/test_google.py Normal file
View File

@@ -0,0 +1,43 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.google import GoogleLlm
@pytest.fixture
def google_llm_config():
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
def test_google_llm_init_missing_api_key(monkeypatch):
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
GoogleLlm()
def test_google_llm_init(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
with monkeypatch.context() as m:
m.setattr("importlib.import_module", lambda x: None)
google_llm = GoogleLlm()
assert google_llm is not None
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
monkeypatch.setattr("importlib.import_module", lambda x: None)
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
google_llm.get_llm_model_answer("test prompt")
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
def mock_get_answer(prompt, config):
return "Generated Text"
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
google_llm = GoogleLlm(config=google_llm_config)
result = google_llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"