Add Ollama as a supported embedding provider (#1344)

This commit is contained in:
Colin O'Brien
2024-05-02 01:08:47 -04:00
committed by GitHub
parent 1a66f961f4
commit a795798156
9 changed files with 46 additions and 2 deletions

View File

@@ -6,6 +6,7 @@ from .base_config import BaseConfig
from .cache_config import CacheConfig
from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig
from .vectordb.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig

View File

@@ -0,0 +1,15 @@
from typing import Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class OllamaEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
base_url: Optional[str] = None,
):
super().__init__(model)
self.base_url = base_url or "http://127.0.0.1:11434"

Binary file not shown.

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langchain_community.embeddings import OllamaEmbeddings
from embedchain.config import OllamaEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class OllamaEmbedder(BaseEmbedder):
def __init__(self, config: Optional[OllamaEmbedderConfig] = None):
super().__init__(config=config)
embeddings = OllamaEmbeddings(model=self.config.model, base_url=self.config.base_url)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.OLLAMA.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -58,6 +58,7 @@ class EmbedderFactory:
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
@@ -65,6 +66,7 @@ class EmbedderFactory:
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
}
@classmethod

View File

@@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum):
HUGGING_FACE = "HUGGING_FACE"
VERTEX_AI = "VERTEX_AI"
GPT4ALL = "GPT4ALL"
OLLAMA = "OLLAMA"

View File

@@ -11,3 +11,4 @@ class VectorDimensions(Enum):
MISTRAL_AI = 1024
NVIDIA_AI = 1024
COHERE = 384
OLLAMA = 384

View File

@@ -449,6 +449,7 @@ def validate_config(config_data):
"google",
"mistralai",
"nvidia",
"ollama",
),
Optional("config"): {
Optional("model"): Optional(str),
@@ -458,6 +459,7 @@ def validate_config(config_data):
Optional("title"): str,
Optional("task_type"): str,
Optional("vector_dimension"): int,
Optional("base_url"): str,
},
},
Optional("embedding_model"): {
@@ -470,6 +472,7 @@ def validate_config(config_data):
"google",
"mistralai",
"nvidia",
"ollama",
),
Optional("config"): {
Optional("model"): str,
@@ -478,6 +481,7 @@ def validate_config(config_data):
Optional("title"): str,
Optional("task_type"): str,
Optional("vector_dimension"): int,
Optional("base_url"): str,
},
},
Optional("chunker"): {