diff --git a/configs/google.yaml b/configs/google.yaml new file mode 100644 index 00000000..e3d1a72a --- /dev/null +++ b/configs/google.yaml @@ -0,0 +1,8 @@ +llm: + provider: google + config: + model: gemini-pro + max_tokens: 1000 + temperature: 0.9 + top_p: 1.0 + stream: false diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index 84a5dada..c042ca3f 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -8,6 +8,7 @@ Embedchain comes with built-in support for various popular large language models + @@ -62,6 +63,41 @@ llm: +## Google AI + +To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey) + + +```python main.py +import os +from embedchain import Pipeline as App + +os.environ["OPENAI_API_KEY"] = "sk-xxxx" +os.environ["GOOGLE_API_KEY"] = "xxx" + +app = App.from_config(config_path="config.yaml") + +app.add("https://www.forbes.com/profile/elon-musk") + +response = app.query("What is the net worth of Elon Musk?") +if app.llm.config.stream: # if stream is enabled, response is a generator + for chunk in response: + print(chunk) +else: + print(response) +``` + +```yaml config.yaml +llm: + provider: google + config: + model: gemini-pro + max_tokens: 1000 + temperature: 0.5 + top_p: 1 + stream: false +``` + ## Azure OpenAI diff --git a/embedchain/factory.py b/embedchain/factory.py index 8e8d7b50..2731cf2f 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -18,6 +18,7 @@ class LlmFactory: "llama2": "embedchain.llm.llama2.Llama2Llm", "openai": "embedchain.llm.openai.OpenAILlm", "vertexai": "embedchain.llm.vertex_ai.VertexAILlm", + "google": "embedchain.llm.google.GoogleLlm", } provider_to_config_class = { "embedchain": "embedchain.config.llm.base.BaseLlmConfig", diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index dc2471e4..b5daec30 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -217,7 +217,6 @@ class BaseLlm(JSONSerializable): return prompt answer = self.get_answer_from_llm(prompt) - if isinstance(answer, str): logging.info(f"Answer: {answer}") return answer diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py new file mode 100644 index 00000000..f54826f2 --- /dev/null +++ b/embedchain/llm/google.py @@ -0,0 +1,64 @@ +import importlib +import logging +import os +from typing import Optional + +import google.generativeai as genai + +from embedchain.config import BaseLlmConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class GoogleLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + if "GOOGLE_API_KEY" not in os.environ: + raise ValueError("Please set the GOOGLE_API_KEY environment variable.") + + try: + importlib.import_module("google.generativeai") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for GoogleLlm are not installed." + 'Please install with `pip install --upgrade "embedchain[google]"`' + ) from None + + super().__init__(config) + genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) + + def get_llm_model_answer(self, prompt): + if self.config.system_prompt: + raise ValueError("GoogleLlm does not support `system_prompt`") + return GoogleLlm._get_answer(prompt, self.config) + + @staticmethod + def _get_answer(prompt: str, config: BaseLlmConfig): + model_name = config.model or "gemini-pro" + logging.info(f"Using Google LLM model: {model_name}") + model = genai.GenerativeModel(model_name=model_name) + + generation_config_params = { + "candidate_count": 1, + "max_output_tokens": config.max_tokens, + "temperature": config.temperature or 0.5, + } + + if config.top_p >= 0.0 and config.top_p <= 1.0: + generation_config_params["top_p"] = config.top_p + else: + raise ValueError("`top_p` must be > 0.0 and < 1.0") + + generation_config = genai.types.GenerationConfig(**generation_config_params) + + response = model.generate_content( + prompt, + generation_config=generation_config, + stream=config.stream, + ) + + if config.stream: + for chunk in response: + yield chunk.text + else: + return response.text diff --git a/embedchain/utils.py b/embedchain/utils.py index af8a57b8..7396fb5c 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -387,6 +387,7 @@ def validate_config(config_data): "jina", "llama2", "vertexai", + "google", ), Optional("config"): { Optional("model"): str, diff --git a/poetry.lock b/poetry.lock index 781c2ed4..d0bcd986 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiofiles" @@ -1762,6 +1762,22 @@ gitdb = ">=4.0.1,<5" [package.extras] test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"] +[[package]] +name = "google-ai-generativelanguage" +version = "0.4.0" +description = "Google Ai Generativelanguage API client library" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"}, + {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +proto-plus = ">=1.22.3,<2.0.0dev" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" + [[package]] name = "google-api-core" version = "2.12.0" @@ -1777,12 +1793,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1870,8 +1886,8 @@ google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras = google-cloud-core = ">=1.6.0,<3.0.0dev" google-resumable-media = ">=0.6.0,<3.0dev" grpcio = [ - {version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""}, ] packaging = ">=20.0.0" proto-plus = ">=1.15.0,<2.0.0dev" @@ -1922,8 +1938,8 @@ files = [ google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" proto-plus = [ - {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""}, + {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -2029,6 +2045,26 @@ files = [ [package.extras] testing = ["pytest"] +[[package]] +name = "google-generativeai" +version = "0.3.1" +description = "Google Generative AI High level API client library and tools." +optional = true +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.4.0" +google-api-core = "*" +google-auth = "*" +protobuf = "*" +tqdm = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + [[package]] name = "google-resumable-media" version = "2.6.0" @@ -4004,11 +4040,13 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""}, + {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""}, + {version = ">=1.17.0", markers = "python_version >= \"3.7\""}, + {version = ">=1.17.3", markers = "python_version >= \"3.8\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, - {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, ] [[package]] @@ -4202,8 +4240,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -6344,7 +6382,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""} typing-extensions = ">=4.2.0" [package.extras] @@ -7820,6 +7858,7 @@ discord = ["discord"] elasticsearch = ["elasticsearch"] github = ["PyGithub", "gitpython"] gmail = ["llama-hub", "requests"] +google = ["google-generativeai"] huggingface-hub = ["huggingface_hub"] images = ["ftfy", "pillow", "regex", "torch", "torchvision"] json = ["llama-hub"] @@ -7843,4 +7882,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "776ae7f49adab8a5dc98f6fe7c2887d2e700fd2d7c447383ea81ef05a463c8f3" +content-hash = "846cca158ccd7a2ecc3d0a08218d273846aad15e0ac5c19ddeb1b5de00fa9a3f" diff --git a/pyproject.toml b/pyproject.toml index b3f850af..9ffabcd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,7 @@ PyGithub = { version = "^1.59.1", optional = true } feedparser = { version = "^6.0.10", optional = true } newspaper3k = { version = "^0.2.8", optional = true } listparser = { version = "^0.19", optional = true } +google-generativeai = { version = "^0.3.0", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -204,7 +205,12 @@ youtube = [ "yt_dlp", "youtube-transcript-api", ] -rss_feed = ["feedparser", "listparser", "newspaper3k"] +rss_feed = [ + "feedparser", + "listparser", + "newspaper3k" +] +google = ["google-generativeai"] [tool.poetry.group.docs.dependencies] diff --git a/tests/llm/test_google.py b/tests/llm/test_google.py new file mode 100644 index 00000000..d2ba301e --- /dev/null +++ b/tests/llm/test_google.py @@ -0,0 +1,43 @@ +import pytest + +from embedchain.config import BaseLlmConfig +from embedchain.llm.google import GoogleLlm + + +@pytest.fixture +def google_llm_config(): + return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False) + + +def test_google_llm_init_missing_api_key(monkeypatch): + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."): + GoogleLlm() + + +def test_google_llm_init(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key") + with monkeypatch.context() as m: + m.setattr("importlib.import_module", lambda x: None) + google_llm = GoogleLlm() + assert google_llm is not None + + +def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key") + monkeypatch.setattr("importlib.import_module", lambda x: None) + google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt")) + with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"): + google_llm.get_llm_model_answer("test prompt") + + +def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config): + def mock_get_answer(prompt, config): + return "Generated Text" + + monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key") + monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer) + google_llm = GoogleLlm(config=google_llm_config) + result = google_llm.get_llm_model_answer("test prompt") + + assert result == "Generated Text"