refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

View File

@@ -0,0 +1,45 @@
from typing import Any, Callable, Optional
from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig
try:
from chromadb.api.types import Documents, Embeddings
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.api.types import Documents, Embeddings
class BaseEmbedder:
"""
Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
Embedding functions and vector dimensions are set based on the child class you choose.
To manually overwrite you can use this classes `set_...` methods.
"""
def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
if config is None:
self.config = BaseEmbedderConfig()
else:
self.config = config
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")
self.embedding_fn = embedding_fn
def set_vector_dimension(self, vector_dimension: int):
self.vector_dimension = vector_dimension
@staticmethod
def _langchain_default_concept(embeddings: Any):
"""
Langchains default function layout for embeddings.
"""
def embed_function(texts: Documents) -> Embeddings:
return embeddings.embed_documents(texts)
return embed_function

View File

@@ -0,0 +1,21 @@
from typing import Optional
from chromadb.utils import embedding_functions
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
super().__init__(config=config)
if self.config.model is None:
self.config.model = "all-MiniLM-L6-v2"
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.GPT4ALL.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langchain.embeddings import HuggingFaceEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
class HuggingFaceEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.HUGGING_FACE.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,40 @@
import os
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
try:
from chromadb.utils import embedding_functions
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.utils import embedding_functions
class OpenAiEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "text-embedding-ada-002"
if self.config.deployment_name:
embeddings = OpenAIEmbeddings(deployment=self.config.deployment_name)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
else:
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
raise ValueError(
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
) # noqa:E501
embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name=self.config.model,
)
self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value)

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langchain.embeddings import VertexAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.models import EmbeddingFunctions
class VertexAiEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
embeddings = VertexAIEmbeddings(model_name=config.model)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.VERTEX_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)