From 85f3ac428b71ceebec653c66a01fffd2fd753b61 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Tue, 21 Nov 2023 23:12:11 +0530 Subject: [PATCH] Update embedding_fn signature to newest chroma db's (#969) --- embedchain/embedder/base.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/embedchain/embedder/base.py b/embedchain/embedder/base.py index dc024a3f..50ed475b 100644 --- a/embedchain/embedder/base.py +++ b/embedchain/embedder/base.py @@ -3,12 +3,20 @@ from typing import Any, Callable, Optional from embedchain.config.embedder.base import BaseEmbedderConfig try: - from chromadb.api.types import Documents, Embeddings + from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction except RuntimeError: from embedchain.utils import use_pysqlite3 use_pysqlite3() - from chromadb.api.types import Documents, Embeddings + from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction + + +class EmbeddingFunc(EmbeddingFunction): + def __init__(self, embedding_fn: Callable[[list[str]], list[str]]): + self.embedding_fn = embedding_fn + + def __call__(self, input: Embeddable) -> Embeddings: + return self.embedding_fn(input) class BaseEmbedder: @@ -66,7 +74,4 @@ class BaseEmbedder: :rtype: Callable """ - def embed_function(texts: Documents) -> Embeddings: - return embeddings.embed_documents(texts) - - return embed_function + return EmbeddingFunc(embeddings.embed_documents)