Update embedding_fn signature to newest chroma db's (#969)

This commit is contained in:
Sidharth Mohanty
2023-11-21 23:12:11 +05:30
committed by GitHub
parent 9fcf2130b5
commit 85f3ac428b

View File

@@ -3,12 +3,20 @@ from typing import Any, Callable, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
try:
from chromadb.api.types import Documents, Embeddings
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.api.types import Documents, Embeddings
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
class EmbeddingFunc(EmbeddingFunction):
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
self.embedding_fn = embedding_fn
def __call__(self, input: Embeddable) -> Embeddings:
return self.embedding_fn(input)
class BaseEmbedder:
@@ -66,7 +74,4 @@ class BaseEmbedder:
:rtype: Callable
"""
def embed_function(texts: Documents) -> Embeddings:
return embeddings.embed_documents(texts)
return embed_function
return EmbeddingFunc(embeddings.embed_documents)