[Bugfix] fix google ai embedding function (#1195)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-19 21:35:44 +05:30
committed by GitHub
parent d79d30bf0c
commit 0b5b12575a

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union
import google.generativeai as genai
from chromadb import EmbeddingFunction, Embeddings
@@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
super().__init__()
self.config = config or GoogleAIEmbedderConfig()
def __call__(self, input_: str) -> Embeddings:
def __call__(self, input: Union[list[str], str]) -> Embeddings:
model = self.config.model
title = self.config.title
task_type = self.config.task_type
embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
return embeddings["embedding"]
if isinstance(input, str):
input_ = [input]
else:
input_ = input
data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
embeddings = data["embedding"]
if isinstance(input_, str):
embeddings = [embeddings]
return embeddings
class GoogleAIEmbedder(BaseEmbedder):