diff --git a/embedchain/embedder/base.py b/embedchain/embedder/base.py index dc024a3f..50ed475b 100644 --- a/embedchain/embedder/base.py +++ b/embedchain/embedder/base.py @@ -3,12 +3,20 @@ from typing import Any, Callable, Optional from embedchain.config.embedder.base import BaseEmbedderConfig try: - from chromadb.api.types import Documents, Embeddings + from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction except RuntimeError: from embedchain.utils import use_pysqlite3 use_pysqlite3() - from chromadb.api.types import Documents, Embeddings + from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction + + +class EmbeddingFunc(EmbeddingFunction): + def __init__(self, embedding_fn: Callable[[list[str]], list[str]]): + self.embedding_fn = embedding_fn + + def __call__(self, input: Embeddable) -> Embeddings: + return self.embedding_fn(input) class BaseEmbedder: @@ -66,7 +74,4 @@ class BaseEmbedder: :rtype: Callable """ - def embed_function(texts: Documents) -> Embeddings: - return embeddings.embed_documents(texts) - - return embed_function + return EmbeddingFunc(embeddings.embed_documents)