diff --git a/configs/google.yaml b/configs/google.yaml index e3d1a72a..4f6a4655 100644 --- a/configs/google.yaml +++ b/configs/google.yaml @@ -6,3 +6,8 @@ llm: temperature: 0.9 top_p: 1.0 stream: false + +embedder: + provider: google + config: + model: models/embedding-001 diff --git a/docs/components/embedding-models.mdx b/docs/components/embedding-models.mdx index d15851d9..7e460378 100644 --- a/docs/components/embedding-models.mdx +++ b/docs/components/embedding-models.mdx @@ -8,6 +8,7 @@ Embedchain supports several embedding models from the following providers: + @@ -44,6 +45,34 @@ embedder: +## Google AI + +To use Google AI embedding function, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey) + + +```python main.py +import os +from embedchain import Pipeline as App + +os.environ["GOOGLE_API_KEY"] = "xxx" + +app = App.from_config(config_path="config.yaml") +``` + +```yaml config.yaml +embedder: + provider: google + config: + model: 'models/embedding-001' + task_type: "retrieval_document" + title: "Embeddings for Embedchain" +``` + +
+ +For more details regarding the Google AI embedding model, please refer to the [Google AI documentation](https://ai.google.dev/tutorials/python_quickstart#use_embeddings). + + ## Azure OpenAI To use Azure OpenAI embedding model, you have to set some of the azure openai related environment variables as given in the code block below: diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index c042ca3f..802ba270 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -72,7 +72,6 @@ To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variabl import os from embedchain import Pipeline as App -os.environ["OPENAI_API_KEY"] = "sk-xxxx" os.environ["GOOGLE_API_KEY"] = "xxx" app = App.from_config(config_path="config.yaml") @@ -96,6 +95,13 @@ llm: temperature: 0.5 top_p: 1 stream: false + +embedder: + provider: google + config: + model: 'models/embedding-001' + task_type: "retrieval_document" + title: "Embeddings for Embedchain" ``` diff --git a/embedchain/config/embedder/google.py b/embedchain/config/embedder/google.py new file mode 100644 index 00000000..f42e5b53 --- /dev/null +++ b/embedchain/config/embedder/google.py @@ -0,0 +1,18 @@ +from typing import Optional + +from embedchain.config.embedder.base import BaseEmbedderConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class GoogleAIEmbedderConfig(BaseEmbedderConfig): + def __init__( + self, + model: Optional[str] = None, + deployment_name: Optional[str] = None, + task_type: Optional[str] = None, + title: Optional[str] = None, + ): + super().__init__(model, deployment_name) + self.task_type = task_type or "retrieval_document" + self.title = title or "Embeddings for Embedchain" diff --git a/embedchain/embedder/google.py b/embedchain/embedder/google.py new file mode 100644 index 00000000..28be7a84 --- /dev/null +++ b/embedchain/embedder/google.py @@ -0,0 +1,31 @@ +from typing import Optional + +import google.generativeai as genai +from chromadb import EmbeddingFunction, Embeddings + +from embedchain.config.embedder.google import GoogleAIEmbedderConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.models import VectorDimensions + + +class GoogleAIEmbeddingFunction(EmbeddingFunction): + def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None: + super().__init__() + self.config = config or GoogleAIEmbedderConfig() + + def __call__(self, input: str) -> Embeddings: + model = self.config.model + title = self.config.title + task_type = self.config.task_type + embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title) + return embeddings["embedding"] + + +class GoogleAIEmbedder(BaseEmbedder): + def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None): + super().__init__(config) + embedding_fn = GoogleAIEmbeddingFunction(config=config) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = VectorDimensions.GOOGLE_AI.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/factory.py b/embedchain/factory.py index 2731cf2f..6ec1479b 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -47,11 +47,13 @@ class EmbedderFactory: "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder", "openai": "embedchain.embedder.openai.OpenAIEmbedder", "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder", + "google": "embedchain.embedder.google.GoogleAIEmbedder", } provider_to_config_class = { "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", "openai": "embedchain.config.embedder.base.BaseEmbedderConfig", "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig", + "google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig", } @classmethod diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 523bbf82..d01c7b64 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -146,21 +146,7 @@ class BaseLlm(JSONSerializable): logging.info(f"Access search to get answers for {input_query}") return search.run(input_query) - def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]: - """Generator to be used as streaming response - - :param answer: Answer chunk from llm - :type answer: Any - :yield: Answer chunk from llm - :rtype: Generator[Any, Any, None] - """ - streamed_answer = "" - for chunk in answer: - streamed_answer = streamed_answer + chunk - yield chunk - logging.info(f"Answer: {streamed_answer}") - - def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]: + def _stream_response(self, answer: Any) -> Generator[Any, Any, None]: """Generator to be used as streaming response :param answer: Answer chunk from llm @@ -220,7 +206,7 @@ class BaseLlm(JSONSerializable): logging.info(f"Answer: {answer}") return answer else: - return self._stream_query_response(answer) + return self._stream_response(answer) finally: if config: # Restore previous config @@ -269,14 +255,12 @@ class BaseLlm(JSONSerializable): return prompt answer = self.get_answer_from_llm(prompt) - if isinstance(answer, str): logging.info(f"Answer: {answer}") - return answer else: # this is a streamed response and needs to be handled differently. - return self._stream_chat_response(answer) + return self._stream_response(answer) finally: if config: # Restore previous config diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py index f54826f2..d1bf3d8e 100644 --- a/embedchain/llm/google.py +++ b/embedchain/llm/google.py @@ -1,7 +1,7 @@ import importlib import logging import os -from typing import Optional +from typing import Any, Generator, Optional, Union import google.generativeai as genai @@ -30,22 +30,22 @@ class GoogleLlm(BaseLlm): def get_llm_model_answer(self, prompt): if self.config.system_prompt: raise ValueError("GoogleLlm does not support `system_prompt`") - return GoogleLlm._get_answer(prompt, self.config) + response = self._get_answer(prompt) + return response - @staticmethod - def _get_answer(prompt: str, config: BaseLlmConfig): - model_name = config.model or "gemini-pro" + def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]: + model_name = self.config.model or "gemini-pro" logging.info(f"Using Google LLM model: {model_name}") model = genai.GenerativeModel(model_name=model_name) generation_config_params = { "candidate_count": 1, - "max_output_tokens": config.max_tokens, - "temperature": config.temperature or 0.5, + "max_output_tokens": self.config.max_tokens, + "temperature": self.config.temperature or 0.5, } - if config.top_p >= 0.0 and config.top_p <= 1.0: - generation_config_params["top_p"] = config.top_p + if self.config.top_p >= 0.0 and self.config.top_p <= 1.0: + generation_config_params["top_p"] = self.config.top_p else: raise ValueError("`top_p` must be > 0.0 and < 1.0") @@ -54,11 +54,11 @@ class GoogleLlm(BaseLlm): response = model.generate_content( prompt, generation_config=generation_config, - stream=config.stream, + stream=self.config.stream, ) - - if config.stream: - for chunk in response: - yield chunk.text + if self.config.stream: + # TODO: Implement streaming + response.resolve() + return response.text else: return response.text diff --git a/embedchain/models/vector_dimensions.py b/embedchain/models/vector_dimensions.py index 9be1f304..2bdaa0fa 100644 --- a/embedchain/models/vector_dimensions.py +++ b/embedchain/models/vector_dimensions.py @@ -7,3 +7,4 @@ class VectorDimensions(Enum): OPENAI = 1536 VERTEX_AI = 768 HUGGING_FACE = 384 + GOOGLE_AI = 768 diff --git a/embedchain/utils.py b/embedchain/utils.py index 4db82d60..ddb7d2d4 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -411,14 +411,14 @@ def validate_config(config_data): Optional("config"): object, # TODO: add particular config schema for each provider }, Optional("embedder"): { - Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"), + Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"), Optional("config"): { Optional("model"): Optional(str), Optional("deployment_name"): Optional(str), }, }, Optional("embedding_model"): { - Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"), + Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"), Optional("config"): { Optional("model"): str, Optional("deployment_name"): str, diff --git a/pyproject.toml b/pyproject.toml index 9ffabcd0..be0f3bb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.33" +version = "0.1.34" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py index ddbc4747..e2e56bdc 100644 --- a/tests/llm/test_base_llm.py +++ b/tests/llm/test_base_llm.py @@ -38,15 +38,9 @@ def test_is_get_llm_model_answer_implemented(): assert llm.get_llm_model_answer() == "Implemented" -def test_stream_query_response(base_llm): +def test_stream_response(base_llm): answer = ["Chunk1", "Chunk2", "Chunk3"] - result = list(base_llm._stream_query_response(answer)) - assert result == answer - - -def test_stream_chat_response(base_llm): - answer = ["Chunk1", "Chunk2", "Chunk3"] - result = list(base_llm._stream_chat_response(answer)) + result = list(base_llm._stream_response(answer)) assert result == answer