[Feature] add google ai embedder (#1019)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-18 13:58:01 +05:30
committed by GitHub
parent 6983ebba49
commit c0b5e93967
12 changed files with 115 additions and 45 deletions

View File

@@ -6,3 +6,8 @@ llm:
temperature: 0.9
top_p: 1.0
stream: false
embedder:
provider: google
config:
model: models/embedding-001

View File

@@ -8,6 +8,7 @@ Embedchain supports several embedding models from the following providers:
<CardGroup cols={4}>
<Card title="OpenAI" href="#openai"></Card>
<Card title="GoogleAI" href="#google-ai"></Card>
<Card title="Azure OpenAI" href="#azure-openai"></Card>
<Card title="GPT4All" href="#gpt4all"></Card>
<Card title="Hugging Face" href="#hugging-face"></Card>
@@ -44,6 +45,34 @@ embedder:
</CodeGroup>
## Google AI
To use Google AI embedding function, you have to set the `GOOGLE_API_KEY` environment variable. You can obtain the Google API key from the [Google Maker Suite](https://makersuite.google.com/app/apikey)
<CodeGroup>
```python main.py
import os
from embedchain import Pipeline as App
os.environ["GOOGLE_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
```
```yaml config.yaml
embedder:
provider: google
config:
model: 'models/embedding-001'
task_type: "retrieval_document"
title: "Embeddings for Embedchain"
```
</CodeGroup>
<br/>
<Note>
For more details regarding the Google AI embedding model, please refer to the [Google AI documentation](https://ai.google.dev/tutorials/python_quickstart#use_embeddings).
</Note>
## Azure OpenAI
To use Azure OpenAI embedding model, you have to set some of the azure openai related environment variables as given in the code block below:

View File

@@ -72,7 +72,6 @@ To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variabl
import os
from embedchain import Pipeline as App
os.environ["OPENAI_API_KEY"] = "sk-xxxx"
os.environ["GOOGLE_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
@@ -96,6 +95,13 @@ llm:
temperature: 0.5
top_p: 1
stream: false
embedder:
provider: google
config:
model: 'models/embedding-001'
task_type: "retrieval_document"
title: "Embeddings for Embedchain"
```
</CodeGroup>

View File

@@ -0,0 +1,18 @@
from typing import Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class GoogleAIEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
deployment_name: Optional[str] = None,
task_type: Optional[str] = None,
title: Optional[str] = None,
):
super().__init__(model, deployment_name)
self.task_type = task_type or "retrieval_document"
self.title = title or "Embeddings for Embedchain"

View File

@@ -0,0 +1,31 @@
from typing import Optional
import google.generativeai as genai
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config.embedder.google import GoogleAIEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class GoogleAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
super().__init__()
self.config = config or GoogleAIEmbedderConfig()
def __call__(self, input: str) -> Embeddings:
model = self.config.model
title = self.config.title
task_type = self.config.task_type
embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title)
return embeddings["embedding"]
class GoogleAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
super().__init__(config)
embedding_fn = GoogleAIEmbeddingFunction(config=config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = VectorDimensions.GOOGLE_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -47,11 +47,13 @@ class EmbedderFactory:
"huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder",
"openai": "embedchain.embedder.openai.OpenAIEmbedder",
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
}
@classmethod

View File

@@ -146,21 +146,7 @@ class BaseLlm(JSONSerializable):
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
:type answer: Any
:yield: Answer chunk from llm
:rtype: Generator[Any, Any, None]
"""
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response
:param answer: Answer chunk from llm
@@ -220,7 +206,7 @@ class BaseLlm(JSONSerializable):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
return self._stream_response(answer)
finally:
if config:
# Restore previous config
@@ -269,14 +255,12 @@ class BaseLlm(JSONSerializable):
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
return self._stream_response(answer)
finally:
if config:
# Restore previous config

View File

@@ -1,7 +1,7 @@
import importlib
import logging
import os
from typing import Optional
from typing import Any, Generator, Optional, Union
import google.generativeai as genai
@@ -30,22 +30,22 @@ class GoogleLlm(BaseLlm):
def get_llm_model_answer(self, prompt):
if self.config.system_prompt:
raise ValueError("GoogleLlm does not support `system_prompt`")
return GoogleLlm._get_answer(prompt, self.config)
response = self._get_answer(prompt)
return response
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig):
model_name = config.model or "gemini-pro"
def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
model_name = self.config.model or "gemini-pro"
logging.info(f"Using Google LLM model: {model_name}")
model = genai.GenerativeModel(model_name=model_name)
generation_config_params = {
"candidate_count": 1,
"max_output_tokens": config.max_tokens,
"temperature": config.temperature or 0.5,
"max_output_tokens": self.config.max_tokens,
"temperature": self.config.temperature or 0.5,
}
if config.top_p >= 0.0 and config.top_p <= 1.0:
generation_config_params["top_p"] = config.top_p
if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
generation_config_params["top_p"] = self.config.top_p
else:
raise ValueError("`top_p` must be > 0.0 and < 1.0")
@@ -54,11 +54,11 @@ class GoogleLlm(BaseLlm):
response = model.generate_content(
prompt,
generation_config=generation_config,
stream=config.stream,
stream=self.config.stream,
)
if config.stream:
for chunk in response:
yield chunk.text
if self.config.stream:
# TODO: Implement streaming
response.resolve()
return response.text
else:
return response.text

View File

@@ -7,3 +7,4 @@ class VectorDimensions(Enum):
OPENAI = 1536
VERTEX_AI = 768
HUGGING_FACE = 384
GOOGLE_AI = 768

View File

@@ -411,14 +411,14 @@ def validate_config(config_data):
Optional("config"): object, # TODO: add particular config schema for each provider
},
Optional("embedder"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"),
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("config"): {
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
},
},
Optional("embedding_model"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai"),
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai", "azure_openai", "google"),
Optional("config"): {
Optional("model"): str,
Optional("deployment_name"): str,

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.33"
version = "0.1.34"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -38,15 +38,9 @@ def test_is_get_llm_model_answer_implemented():
assert llm.get_llm_model_answer() == "Implemented"
def test_stream_query_response(base_llm):
def test_stream_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_query_response(answer))
assert result == answer
def test_stream_chat_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_chat_response(answer))
result = list(base_llm._stream_response(answer))
assert result == answer