Update embedding_fn signature to newest chroma db's (#969)
This commit is contained in:
@@ -3,12 +3,20 @@ from typing import Any, Callable, Optional
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
|
||||
try:
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
||||
|
||||
|
||||
class EmbeddingFunc(EmbeddingFunction):
|
||||
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
return self.embedding_fn(input)
|
||||
|
||||
|
||||
class BaseEmbedder:
|
||||
@@ -66,7 +74,4 @@ class BaseEmbedder:
|
||||
:rtype: Callable
|
||||
"""
|
||||
|
||||
def embed_function(texts: Documents) -> Embeddings:
|
||||
return embeddings.embed_documents(texts)
|
||||
|
||||
return embed_function
|
||||
return EmbeddingFunc(embeddings.embed_documents)
|
||||
|
||||
Reference in New Issue
Block a user