diff --git a/embedchain/embedder/google.py b/embedchain/embedder/google.py index 4d09f6be..c0be8350 100644 --- a/embedchain/embedder/google.py +++ b/embedchain/embedder/google.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import google.generativeai as genai from chromadb import EmbeddingFunction, Embeddings @@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction): super().__init__() self.config = config or GoogleAIEmbedderConfig() - def __call__(self, input_: str) -> Embeddings: + def __call__(self, input: Union[list[str], str]) -> Embeddings: model = self.config.model title = self.config.title task_type = self.config.task_type - embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title) - return embeddings["embedding"] + if isinstance(input, str): + input_ = [input] + else: + input_ = input + data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title) + embeddings = data["embedding"] + if isinstance(input_, str): + embeddings = [embeddings] + return embeddings class GoogleAIEmbedder(BaseEmbedder):