[Bugfix] fix google ai embedding function (#1195)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
@@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
|
||||
super().__init__()
|
||||
self.config = config or GoogleAIEmbedderConfig()
|
||||
|
||||
def __call__(self, input_: str) -> Embeddings:
|
||||
def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
||||
model = self.config.model
|
||||
title = self.config.title
|
||||
task_type = self.config.task_type
|
||||
embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
||||
return embeddings["embedding"]
|
||||
if isinstance(input, str):
|
||||
input_ = [input]
|
||||
else:
|
||||
input_ = input
|
||||
data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
||||
embeddings = data["embedding"]
|
||||
if isinstance(input_, str):
|
||||
embeddings = [embeddings]
|
||||
return embeddings
|
||||
|
||||
|
||||
class GoogleAIEmbedder(BaseEmbedder):
|
||||
|
||||
Reference in New Issue
Block a user