Add Ollama as a supported embedding provider (#1344)
This commit is contained in:
@@ -8,6 +8,7 @@ llm:
|
|||||||
base_url: http://localhost:11434
|
base_url: http://localhost:11434
|
||||||
|
|
||||||
embedder:
|
embedder:
|
||||||
provider: huggingface
|
provider: ollama
|
||||||
config:
|
config:
|
||||||
model: 'BAAI/bge-small-en-v1.5'
|
model: 'mxbai-embed-large:latest'
|
||||||
|
base_url: http://localhost:11434
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .base_config import BaseConfig
|
|||||||
from .cache_config import CacheConfig
|
from .cache_config import CacheConfig
|
||||||
from .embedder.base import BaseEmbedderConfig
|
from .embedder.base import BaseEmbedderConfig
|
||||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||||
|
from .embedder.ollama import OllamaEmbedderConfig
|
||||||
from .llm.base import BaseLlmConfig
|
from .llm.base import BaseLlmConfig
|
||||||
from .vectordb.chroma import ChromaDbConfig
|
from .vectordb.chroma import ChromaDbConfig
|
||||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||||
|
|||||||
15
embedchain/config/embedder/ollama.py
Normal file
15
embedchain/config/embedder/ollama.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class OllamaEmbedderConfig(BaseEmbedderConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(model)
|
||||||
|
self.base_url = base_url or "http://127.0.0.1:11434"
|
||||||
BIN
embedchain/embedder/.ollama.py.swp
Normal file
BIN
embedchain/embedder/.ollama.py.swp
Normal file
Binary file not shown.
19
embedchain/embedder/ollama.py
Normal file
19
embedchain/embedder/ollama.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
|
from embedchain.config import OllamaEmbedderConfig
|
||||||
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
|
from embedchain.models import VectorDimensions
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedder(BaseEmbedder):
|
||||||
|
def __init__(self, config: Optional[OllamaEmbedderConfig] = None):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
embeddings = OllamaEmbeddings(model=self.config.model, base_url=self.config.base_url)
|
||||||
|
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||||
|
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||||
|
|
||||||
|
vector_dimension = self.config.vector_dimension or VectorDimensions.OLLAMA.value
|
||||||
|
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||||
@@ -58,6 +58,7 @@ class EmbedderFactory:
|
|||||||
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
||||||
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
|
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
|
||||||
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
|
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
|
||||||
|
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
@@ -65,6 +66,7 @@ class EmbedderFactory:
|
|||||||
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
|
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum):
|
|||||||
HUGGING_FACE = "HUGGING_FACE"
|
HUGGING_FACE = "HUGGING_FACE"
|
||||||
VERTEX_AI = "VERTEX_AI"
|
VERTEX_AI = "VERTEX_AI"
|
||||||
GPT4ALL = "GPT4ALL"
|
GPT4ALL = "GPT4ALL"
|
||||||
|
OLLAMA = "OLLAMA"
|
||||||
|
|||||||
@@ -11,3 +11,4 @@ class VectorDimensions(Enum):
|
|||||||
MISTRAL_AI = 1024
|
MISTRAL_AI = 1024
|
||||||
NVIDIA_AI = 1024
|
NVIDIA_AI = 1024
|
||||||
COHERE = 384
|
COHERE = 384
|
||||||
|
OLLAMA = 384
|
||||||
|
|||||||
@@ -449,6 +449,7 @@ def validate_config(config_data):
|
|||||||
"google",
|
"google",
|
||||||
"mistralai",
|
"mistralai",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
|
"ollama",
|
||||||
),
|
),
|
||||||
Optional("config"): {
|
Optional("config"): {
|
||||||
Optional("model"): Optional(str),
|
Optional("model"): Optional(str),
|
||||||
@@ -458,6 +459,7 @@ def validate_config(config_data):
|
|||||||
Optional("title"): str,
|
Optional("title"): str,
|
||||||
Optional("task_type"): str,
|
Optional("task_type"): str,
|
||||||
Optional("vector_dimension"): int,
|
Optional("vector_dimension"): int,
|
||||||
|
Optional("base_url"): str,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Optional("embedding_model"): {
|
Optional("embedding_model"): {
|
||||||
@@ -470,6 +472,7 @@ def validate_config(config_data):
|
|||||||
"google",
|
"google",
|
||||||
"mistralai",
|
"mistralai",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
|
"ollama",
|
||||||
),
|
),
|
||||||
Optional("config"): {
|
Optional("config"): {
|
||||||
Optional("model"): str,
|
Optional("model"): str,
|
||||||
@@ -478,6 +481,7 @@ def validate_config(config_data):
|
|||||||
Optional("title"): str,
|
Optional("title"): str,
|
||||||
Optional("task_type"): str,
|
Optional("task_type"): str,
|
||||||
Optional("vector_dimension"): int,
|
Optional("vector_dimension"): int,
|
||||||
|
Optional("base_url"): str,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Optional("chunker"): {
|
Optional("chunker"): {
|
||||||
|
|||||||
Reference in New Issue
Block a user