[Bugfix] fix google ai embedding function (#1195)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from chromadb import EmbeddingFunction, Embeddings
|
from chromadb import EmbeddingFunction, Embeddings
|
||||||
@@ -13,12 +13,19 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config or GoogleAIEmbedderConfig()
|
self.config = config or GoogleAIEmbedderConfig()
|
||||||
|
|
||||||
def __call__(self, input_: str) -> Embeddings:
|
def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
||||||
model = self.config.model
|
model = self.config.model
|
||||||
title = self.config.title
|
title = self.config.title
|
||||||
task_type = self.config.task_type
|
task_type = self.config.task_type
|
||||||
embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
if isinstance(input, str):
|
||||||
return embeddings["embedding"]
|
input_ = [input]
|
||||||
|
else:
|
||||||
|
input_ = input
|
||||||
|
data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
||||||
|
embeddings = data["embedding"]
|
||||||
|
if isinstance(input_, str):
|
||||||
|
embeddings = [embeddings]
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class GoogleAIEmbedder(BaseEmbedder):
|
class GoogleAIEmbedder(BaseEmbedder):
|
||||||
|
|||||||
Reference in New Issue
Block a user