Add Ollama as a supported embedding provider (#1344)
This commit is contained in:
@@ -8,6 +8,7 @@ llm:
|
||||
base_url: http://localhost:11434
|
||||
|
||||
embedder:
|
||||
provider: huggingface
|
||||
provider: ollama
|
||||
config:
|
||||
model: 'BAAI/bge-small-en-v1.5'
|
||||
model: 'mxbai-embed-large:latest'
|
||||
base_url: http://localhost:11434
|
||||
|
||||
@@ -6,6 +6,7 @@ from .base_config import BaseConfig
|
||||
from .cache_config import CacheConfig
|
||||
from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .embedder.ollama import OllamaEmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .vectordb.chroma import ChromaDbConfig
|
||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||
|
||||
15
embedchain/config/embedder/ollama.py
Normal file
15
embedchain/config/embedder/ollama.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OllamaEmbedderConfig(BaseEmbedderConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model)
|
||||
self.base_url = base_url or "http://127.0.0.1:11434"
|
||||
BIN
embedchain/embedder/.ollama.py.swp
Normal file
BIN
embedchain/embedder/.ollama.py.swp
Normal file
Binary file not shown.
19
embedchain/embedder/ollama.py
Normal file
19
embedchain/embedder/ollama.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from embedchain.config import OllamaEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class OllamaEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[OllamaEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
embeddings = OllamaEmbeddings(model=self.config.model, base_url=self.config.base_url)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.OLLAMA.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
@@ -58,6 +58,7 @@ class EmbedderFactory:
|
||||
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
||||
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
|
||||
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
|
||||
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
@@ -65,6 +66,7 @@ class EmbedderFactory:
|
||||
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum):
|
||||
HUGGING_FACE = "HUGGING_FACE"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
GPT4ALL = "GPT4ALL"
|
||||
OLLAMA = "OLLAMA"
|
||||
|
||||
@@ -11,3 +11,4 @@ class VectorDimensions(Enum):
|
||||
MISTRAL_AI = 1024
|
||||
NVIDIA_AI = 1024
|
||||
COHERE = 384
|
||||
OLLAMA = 384
|
||||
|
||||
@@ -449,6 +449,7 @@ def validate_config(config_data):
|
||||
"google",
|
||||
"mistralai",
|
||||
"nvidia",
|
||||
"ollama",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): Optional(str),
|
||||
@@ -458,6 +459,7 @@ def validate_config(config_data):
|
||||
Optional("title"): str,
|
||||
Optional("task_type"): str,
|
||||
Optional("vector_dimension"): int,
|
||||
Optional("base_url"): str,
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
@@ -470,6 +472,7 @@ def validate_config(config_data):
|
||||
"google",
|
||||
"mistralai",
|
||||
"nvidia",
|
||||
"ollama",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
@@ -478,6 +481,7 @@ def validate_config(config_data):
|
||||
Optional("title"): str,
|
||||
Optional("task_type"): str,
|
||||
Optional("vector_dimension"): int,
|
||||
Optional("base_url"): str,
|
||||
},
|
||||
},
|
||||
Optional("chunker"): {
|
||||
|
||||
Reference in New Issue
Block a user