Files
t6_mem0/embedchain/embedder/google.py
2024-01-19 10:31:41 +05:30

32 lines
1.2 KiB
Python

from typing import Optional
import google.generativeai as genai
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config.embedder.google import GoogleAIEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class GoogleAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
super().__init__()
self.config = config or GoogleAIEmbedderConfig()
def __call__(self, input_: str) -> Embeddings:
model = self.config.model
title = self.config.title
task_type = self.config.task_type
embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
return embeddings["embedding"]
class GoogleAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
super().__init__(config)
embedding_fn = GoogleAIEmbeddingFunction(config=config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.GOOGLE_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)