From a79579815663304a66dc6782ccbf4228ead4eec2 Mon Sep 17 00:00:00 2001 From: Colin O'Brien Date: Thu, 2 May 2024 01:08:47 -0400 Subject: [PATCH] Add Ollama as a supported embedding provider (#1344) --- configs/ollama.yaml | 5 +++-- embedchain/config/__init__.py | 1 + embedchain/config/embedder/ollama.py | 15 +++++++++++++++ embedchain/embedder/.ollama.py.swp | Bin 0 -> 12288 bytes embedchain/embedder/ollama.py | 19 +++++++++++++++++++ embedchain/factory.py | 2 ++ embedchain/models/embedding_functions.py | 1 + embedchain/models/vector_dimensions.py | 1 + embedchain/utils/misc.py | 4 ++++ 9 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 embedchain/config/embedder/ollama.py create mode 100644 embedchain/embedder/.ollama.py.swp create mode 100644 embedchain/embedder/ollama.py diff --git a/configs/ollama.yaml b/configs/ollama.yaml index da4be9c2..7ec5def5 100644 --- a/configs/ollama.yaml +++ b/configs/ollama.yaml @@ -8,6 +8,7 @@ llm: base_url: http://localhost:11434 embedder: - provider: huggingface + provider: ollama config: - model: 'BAAI/bge-small-en-v1.5' + model: 'mxbai-embed-large:latest' + base_url: http://localhost:11434 diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index 0e980c81..c6bd4f30 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -6,6 +6,7 @@ from .base_config import BaseConfig from .cache_config import CacheConfig from .embedder.base import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig +from .embedder.ollama import OllamaEmbedderConfig from .llm.base import BaseLlmConfig from .vectordb.chroma import ChromaDbConfig from .vectordb.elasticsearch import ElasticsearchDBConfig diff --git a/embedchain/config/embedder/ollama.py b/embedchain/config/embedder/ollama.py new file mode 100644 index 00000000..b08b39d7 --- /dev/null +++ b/embedchain/config/embedder/ollama.py @@ -0,0 +1,15 @@ +from typing import Optional + +from embedchain.config.embedder.base import BaseEmbedderConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class OllamaEmbedderConfig(BaseEmbedderConfig): + def __init__( + self, + model: Optional[str] = None, + base_url: Optional[str] = None, + ): + super().__init__(model) + self.base_url = base_url or "http://127.0.0.1:11434" diff --git a/embedchain/embedder/.ollama.py.swp b/embedchain/embedder/.ollama.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..85ec23a1ed16e460ffa0c7265397b1fa19d90b0d GIT binary patch literal 12288 zcmeI2&2G~`5XYySDIX~hu%%L!NX0GXQdP>Y5g-CYfCvzQzmS0I1H8Wt@EcY9|6hIof4d9t5tXBk zQ4dfXs21w_4!{T073vbjP+!&o-l5*2j!;FuU#Op`AE@uB&!{)33lt3^Km>>Y5g-CY zfCvx)B0vQGF9NI_TFAiBBID=6%T&t!FcPs2RotF8bmt8nsFIj@X{L#1=`~rirC~D_ zEh%8WG}g6wJrxJx$Z2e=l%1J=pH}T^XK;Gj>-U^x(Iqx)sYgaK4gO= z!`H=?FYA?w^w^+(g%v+nvFIFvCta;)%EwddSkMO`NL9oPJHy8@y=WFjiAuA^LuWH}Lvx=%8(ulBaRIZD9up!e&=BX&kk<{3+rxj0OQG8*yvDrAm#l)us27dsFac(UD literal 0 HcmV?d00001 diff --git a/embedchain/embedder/ollama.py b/embedchain/embedder/ollama.py new file mode 100644 index 00000000..41001114 --- /dev/null +++ b/embedchain/embedder/ollama.py @@ -0,0 +1,19 @@ +from typing import Optional + +from langchain_community.embeddings import OllamaEmbeddings + +from embedchain.config import OllamaEmbedderConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.models import VectorDimensions + + +class OllamaEmbedder(BaseEmbedder): + def __init__(self, config: Optional[OllamaEmbedderConfig] = None): + super().__init__(config=config) + + embeddings = OllamaEmbeddings(model=self.config.model, base_url=self.config.base_url) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = self.config.vector_dimension or VectorDimensions.OLLAMA.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/factory.py b/embedchain/factory.py index fa46f3e9..81300ecc 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -58,6 +58,7 @@ class EmbedderFactory: "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder", "nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder", "cohere": "embedchain.embedder.cohere.CohereEmbedder", + "ollama": "embedchain.embedder.ollama.OllamaEmbedder", } provider_to_config_class = { "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", @@ -65,6 +66,7 @@ class EmbedderFactory: "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig", "huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig", "openai": "embedchain.config.embedder.base.BaseEmbedderConfig", + "ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig", } @classmethod diff --git a/embedchain/models/embedding_functions.py b/embedchain/models/embedding_functions.py index 7967c45a..557e2ea5 100644 --- a/embedchain/models/embedding_functions.py +++ b/embedchain/models/embedding_functions.py @@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum): HUGGING_FACE = "HUGGING_FACE" VERTEX_AI = "VERTEX_AI" GPT4ALL = "GPT4ALL" + OLLAMA = "OLLAMA" diff --git a/embedchain/models/vector_dimensions.py b/embedchain/models/vector_dimensions.py index 1e0c740a..3b5e33a3 100644 --- a/embedchain/models/vector_dimensions.py +++ b/embedchain/models/vector_dimensions.py @@ -11,3 +11,4 @@ class VectorDimensions(Enum): MISTRAL_AI = 1024 NVIDIA_AI = 1024 COHERE = 384 + OLLAMA = 384 diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 61aaea78..4b991116 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -449,6 +449,7 @@ def validate_config(config_data): "google", "mistralai", "nvidia", + "ollama", ), Optional("config"): { Optional("model"): Optional(str), @@ -458,6 +459,7 @@ def validate_config(config_data): Optional("title"): str, Optional("task_type"): str, Optional("vector_dimension"): int, + Optional("base_url"): str, }, }, Optional("embedding_model"): { @@ -470,6 +472,7 @@ def validate_config(config_data): "google", "mistralai", "nvidia", + "ollama", ), Optional("config"): { Optional("model"): str, @@ -478,6 +481,7 @@ def validate_config(config_data): Optional("title"): str, Optional("task_type"): str, Optional("vector_dimension"): int, + Optional("base_url"): str, }, }, Optional("chunker"): {