From 0b5b12575af7315c381aba975c0a9c5ab01cf579 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Fri, 19 Jan 2024 21:35:44 +0530 Subject: [PATCH] [Bugfix] fix google ai embedding function (#1195) Co-authored-by: Deven Patel --- embedchain/embedder/google.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/embedchain/embedder/google.py b/embedchain/embedder/google.py index 4d09f6be..c0be8350 100644 --- a/embedchain/embedder/google.py +++ b/embedchain/embedder/google.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import google.generativeai as genai from chromadb import EmbeddingFunction, Embeddings @@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction): super().__init__() self.config = config or GoogleAIEmbedderConfig() - def __call__(self, input_: str) -> Embeddings: + def __call__(self, input: Union[list[str], str]) -> Embeddings: model = self.config.model title = self.config.title task_type = self.config.task_type - embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title) - return embeddings["embedding"] + if isinstance(input, str): + input_ = [input] + else: + input_ = input + data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title) + embeddings = data["embedding"] + if isinstance(input_, str): + embeddings = [embeddings] + return embeddings class GoogleAIEmbedder(BaseEmbedder):