[Feature] add google ai embedder (#1019)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-18 13:58:01 +05:30
committed by GitHub
parent 6983ebba49
commit c0b5e93967
12 changed files with 115 additions and 45 deletions

View File

@@ -6,3 +6,8 @@ llm:
temperature: 0.9 temperature: 0.9
top_p: 1.0 top_p: 1.0
stream: false stream: false
embedder:
provider: google
config:
model: models/embedding-001

View File

@@ -8,6 +8,7 @@ Embedchain supports several embedding models from the following providers:
<CardGroup cols={4}> <CardGroup cols={4}>
<Card title="OpenAI" href="#openai"></Card> <Card title="OpenAI" href="#openai"></Card>
<Card title="GoogleAI" href="#google-ai"></Card>
<Card title="Azure OpenAI" href="#azure-openai"></Card> <Card title="Azure OpenAI" href="#azure-openai"></Card>
<Card title="GPT4All" href="#gpt4all"></Card> <Card title="GPT4All" href="#gpt4all"></Card>
<Card title="Hugging Face" href="#hugging-face"></Card> <Card title="Hugging Face" href="#hugging-face"></Card>
@@ -44,6 +45,34 @@ embedder:
</CodeGroup> </CodeGroup>
## Google AI
To use Google AI embedding function, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
<CodeGroup>
```python main.py
import os
from embedchain import Pipeline as App
os.environ["GOOGLE_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
```
```yaml config.yaml
embedder:
provider: google
config:
model: 'models/embedding-001'
task_type: "retrieval_document"
title: "Embeddings for Embedchain"
```
</CodeGroup>
<br/>
<Note>
For more details regarding the Google AI embedding model, please refer to the [Google AI documentation](https://ai.google.dev/tutorials/python_quickstart#use_embeddings).
</Note>
## Azure OpenAI ## Azure OpenAI
To use Azure OpenAI embedding model, you have to set some of the azure openai related environment variables as given in the code block below: To use Azure OpenAI embedding model, you have to set some of the azure openai related environment variables as given in the code block below:

View File

@@ -72,7 +72,6 @@ To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variabl
import os import os
from embedchain import Pipeline as App from embedchain import Pipeline as App
os.environ["OPENAI_API_KEY"] = "sk-xxxx"
os.environ["GOOGLE_API_KEY"] = "xxx" os.environ["GOOGLE_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml") app = App.from_config(config_path="config.yaml")
@@ -96,6 +95,13 @@ llm:
temperature: 0.5 temperature: 0.5
top_p: 1 top_p: 1
stream: false stream: false
embedder:
provider: google
config:
model: 'models/embedding-001'
task_type: "retrieval_document"
title: "Embeddings for Embedchain"
``` ```
</CodeGroup> </CodeGroup>

View File

@@ -0,0 +1,18 @@
from typing import Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class GoogleAIEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
deployment_name: Optional[str] = None,
task_type: Optional[str] = None,
title: Optional[str] = None,
):
super().__init__(model, deployment_name)
self.task_type = task_type or "retrieval_document"
self.title = title or "Embeddings for Embedchain"

View File

@@ -0,0 +1,31 @@
from typing import Optional
import google.generativeai as genai
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config.embedder.google import GoogleAIEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class GoogleAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
super().__init__()
self.config = config or GoogleAIEmbedderConfig()
def __call__(self, input: str) -> Embeddings:
model = self.config.model
title = self.config.title
task_type = self.config.task_type
embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title)
return embeddings["embedding"]
class GoogleAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
super().__init__(config)
embedding_fn = GoogleAIEmbeddingFunction(config=config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = VectorDimensions.GOOGLE_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -47,11 +47,13 @@ class EmbedderFactory:
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder", "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
"openai": "embedchain.embedder.openai.OpenAIEmbedder", "openai": "embedchain.embedder.openai.OpenAIEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder", "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
} }
provider_to_config_class = { provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig", "openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig", "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
} }
@classmethod @classmethod

View File

@@ -146,21 +146,7 @@ class BaseLlm(JSONSerializable):
logging.info(f"Access search to get answers for {input_query}") logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query) return search.run(input_query)
def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]: def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
:type answer: Any
:yield: Answer chunk from llm
:rtype: Generator[Any, Any, None]
"""
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response """Generator to be used as streaming response
:param answer: Answer chunk from llm :param answer: Answer chunk from llm
@@ -220,7 +206,7 @@ class BaseLlm(JSONSerializable):
logging.info(f"Answer: {answer}") logging.info(f"Answer: {answer}")
return answer return answer
else: else:
return self._stream_query_response(answer) return self._stream_response(answer)
finally: finally:
if config: if config:
# Restore previous config # Restore previous config
@@ -269,14 +255,12 @@ class BaseLlm(JSONSerializable):
return prompt return prompt
answer = self.get_answer_from_llm(prompt) answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str): if isinstance(answer, str):
logging.info(f"Answer: {answer}") logging.info(f"Answer: {answer}")
return answer return answer
else: else:
# this is a streamed response and needs to be handled differently. # this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer) return self._stream_response(answer)
finally: finally:
if config: if config:
# Restore previous config # Restore previous config

View File

@@ -1,7 +1,7 @@
import importlib import importlib
import logging import logging
import os import os
from typing import Optional from typing import Any, Generator, Optional, Union
import google.generativeai as genai import google.generativeai as genai
@@ -30,22 +30,22 @@ class GoogleLlm(BaseLlm):
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt):
if self.config.system_prompt: if self.config.system_prompt:
raise ValueError("GoogleLlm does not support `system_prompt`") raise ValueError("GoogleLlm does not support `system_prompt`")
return GoogleLlm._get_answer(prompt, self.config) response = self._get_answer(prompt)
return response
@staticmethod def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
def _get_answer(prompt: str, config: BaseLlmConfig): model_name = self.config.model or "gemini-pro"
model_name = config.model or "gemini-pro"
logging.info(f"Using Google LLM model: {model_name}") logging.info(f"Using Google LLM model: {model_name}")
model = genai.GenerativeModel(model_name=model_name) model = genai.GenerativeModel(model_name=model_name)
generation_config_params = { generation_config_params = {
"candidate_count": 1, "candidate_count": 1,
"max_output_tokens": config.max_tokens, "max_output_tokens": self.config.max_tokens,
"temperature": config.temperature or 0.5, "temperature": self.config.temperature or 0.5,
} }
if config.top_p >= 0.0 and config.top_p <= 1.0: if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
generation_config_params["top_p"] = config.top_p generation_config_params["top_p"] = self.config.top_p
else: else:
raise ValueError("`top_p` must be > 0.0 and < 1.0") raise ValueError("`top_p` must be > 0.0 and < 1.0")
@@ -54,11 +54,11 @@ class GoogleLlm(BaseLlm):
response = model.generate_content( response = model.generate_content(
prompt, prompt,
generation_config=generation_config, generation_config=generation_config,
stream=config.stream, stream=self.config.stream,
) )
if self.config.stream:
if config.stream: # TODO: Implement streaming
for chunk in response: response.resolve()
yield chunk.text return response.text
else: else:
return response.text return response.text

View File

@@ -7,3 +7,4 @@ class VectorDimensions(Enum):
OPENAI = 1536 OPENAI = 1536
VERTEX_AI = 768 VERTEX_AI = 768
HUGGING_FACE = 384 HUGGING_FACE = 384
GOOGLE_AI = 768

View File

@@ -411,14 +411,14 @@ def validate_config(config_data):
Optional("config"): object, # TODO: add particular config schema for each provider Optional("config"): object, # TODO: add particular config schema for each provider
}, },
Optional("embedder"): { Optional("embedder"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"), Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("config"): { Optional("config"): {
Optional("model"): Optional(str), Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str), Optional("deployment_name"): Optional(str),
}, },
}, },
Optional("embedding_model"): { Optional("embedding_model"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"), Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("config"): { Optional("config"): {
Optional("model"): str, Optional("model"): str,
Optional("deployment_name"): str, Optional("deployment_name"): str,

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.33" version = "0.1.34"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -38,15 +38,9 @@ def test_is_get_llm_model_answer_implemented():
assert llm.get_llm_model_answer() == "Implemented" assert llm.get_llm_model_answer() == "Implemented"
def test_stream_query_response(base_llm): def test_stream_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"] answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_query_response(answer)) result = list(base_llm._stream_response(answer))
assert result == answer
def test_stream_chat_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_chat_response(answer))
assert result == answer assert result == answer