Update embedding_fn signature to newest chroma db's (#969)
This commit is contained in:
@@ -3,12 +3,20 @@ from typing import Any, Callable, Optional
|
|||||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from chromadb.api.types import Documents, Embeddings
|
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
from embedchain.utils import use_pysqlite3
|
from embedchain.utils import use_pysqlite3
|
||||||
|
|
||||||
use_pysqlite3()
|
use_pysqlite3()
|
||||||
from chromadb.api.types import Documents, Embeddings
|
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunc(EmbeddingFunction):
|
||||||
|
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||||
|
self.embedding_fn = embedding_fn
|
||||||
|
|
||||||
|
def __call__(self, input: Embeddable) -> Embeddings:
|
||||||
|
return self.embedding_fn(input)
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbedder:
|
class BaseEmbedder:
|
||||||
@@ -66,7 +74,4 @@ class BaseEmbedder:
|
|||||||
:rtype: Callable
|
:rtype: Callable
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def embed_function(texts: Documents) -> Embeddings:
|
return EmbeddingFunc(embeddings.embed_documents)
|
||||||
return embeddings.embed_documents(texts)
|
|
||||||
|
|
||||||
return embed_function
|
|
||||||
|
|||||||
Reference in New Issue
Block a user