refactor: classes and configs (#528)
This commit is contained in:
0
embedchain/embedder/__init__.py
Normal file
0
embedchain/embedder/__init__.py
Normal file
45
embedchain/embedder/base_embedder.py
Normal file
45
embedchain/embedder/base_embedder.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig
|
||||
|
||||
try:
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
|
||||
class BaseEmbedder:
|
||||
"""
|
||||
Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
|
||||
|
||||
Embedding functions and vector dimensions are set based on the child class you choose.
|
||||
To manually overwrite you can use this classes `set_...` methods.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = FileNotFoundError):
|
||||
if config is None:
|
||||
self.config = BaseEmbedderConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||
if not hasattr(embedding_fn, "__call__"):
|
||||
raise ValueError("Embedding function is not a function")
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
def set_vector_dimension(self, vector_dimension: int):
|
||||
self.vector_dimension = vector_dimension
|
||||
|
||||
@staticmethod
|
||||
def _langchain_default_concept(embeddings: Any):
|
||||
"""
|
||||
Langchains default function layout for embeddings.
|
||||
"""
|
||||
|
||||
def embed_function(texts: Documents) -> Embeddings:
|
||||
return embeddings.embed_documents(texts)
|
||||
|
||||
return embed_function
|
||||
21
embedchain/embedder/gpt4all_embedder.py
Normal file
21
embedchain/embedder/gpt4all_embedder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
|
||||
class GPT4AllEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
# Note: We could use langchains GPT4ALL embedding, but it's not available in all versions.
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "all-MiniLM-L6-v2"
|
||||
|
||||
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = EmbeddingFunctions.GPT4ALL.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
19
embedchain/embedder/huggingface_embedder.py
Normal file
19
embedchain/embedder/huggingface_embedder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
|
||||
class HuggingFaceEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = EmbeddingFunctions.HUGGING_FACE.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
40
embedchain/embedder/openai_embedder.py
Normal file
40
embedchain/embedder/openai_embedder.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
|
||||
class OpenAiEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
if self.config.deployment_name:
|
||||
embeddings = OpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
else:
|
||||
if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
|
||||
) # noqa:E501
|
||||
embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name=self.config.model,
|
||||
)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value)
|
||||
19
embedchain/embedder/vertexai_embedder.py
Normal file
19
embedchain/embedder/vertexai_embedder.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import VertexAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.models import EmbeddingFunctions
|
||||
|
||||
|
||||
class VertexAiEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
embeddings = VertexAIEmbeddings(model_name=config.model)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = EmbeddingFunctions.VERTEX_AI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
Reference in New Issue
Block a user