diff --git a/configs/ollama.yaml b/configs/ollama.yaml index da4be9c2..7ec5def5 100644 --- a/configs/ollama.yaml +++ b/configs/ollama.yaml @@ -8,6 +8,7 @@ llm: base_url: http://localhost:11434 embedder: - provider: huggingface + provider: ollama config: - model: 'BAAI/bge-small-en-v1.5' + model: 'mxbai-embed-large:latest' + base_url: http://localhost:11434 diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index 0e980c81..c6bd4f30 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -6,6 +6,7 @@ from .base_config import BaseConfig from .cache_config import CacheConfig from .embedder.base import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig +from .embedder.ollama import OllamaEmbedderConfig from .llm.base import BaseLlmConfig from .vectordb.chroma import ChromaDbConfig from .vectordb.elasticsearch import ElasticsearchDBConfig diff --git a/embedchain/config/embedder/ollama.py b/embedchain/config/embedder/ollama.py new file mode 100644 index 00000000..b08b39d7 --- /dev/null +++ b/embedchain/config/embedder/ollama.py @@ -0,0 +1,15 @@ +from typing import Optional + +from embedchain.config.embedder.base import BaseEmbedderConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class OllamaEmbedderConfig(BaseEmbedderConfig): + def __init__( + self, + model: Optional[str] = None, + base_url: Optional[str] = None, + ): + super().__init__(model) + self.base_url = base_url or "http://127.0.0.1:11434" diff --git a/embedchain/embedder/.ollama.py.swp b/embedchain/embedder/.ollama.py.swp new file mode 100644 index 00000000..85ec23a1 Binary files /dev/null and b/embedchain/embedder/.ollama.py.swp differ diff --git a/embedchain/embedder/ollama.py b/embedchain/embedder/ollama.py new file mode 100644 index 00000000..41001114 --- /dev/null +++ b/embedchain/embedder/ollama.py @@ -0,0 +1,19 @@ +from typing import Optional + +from langchain_community.embeddings import OllamaEmbeddings + +from embedchain.config import OllamaEmbedderConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.models import VectorDimensions + + +class OllamaEmbedder(BaseEmbedder): + def __init__(self, config: Optional[OllamaEmbedderConfig] = None): + super().__init__(config=config) + + embeddings = OllamaEmbeddings(model=self.config.model, base_url=self.config.base_url) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + self.set_embedding_fn(embedding_fn=embedding_fn) + + vector_dimension = self.config.vector_dimension or VectorDimensions.OLLAMA.value + self.set_vector_dimension(vector_dimension=vector_dimension) diff --git a/embedchain/factory.py b/embedchain/factory.py index fa46f3e9..81300ecc 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -58,6 +58,7 @@ class EmbedderFactory: "mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder", "nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder", "cohere": "embedchain.embedder.cohere.CohereEmbedder", + "ollama": "embedchain.embedder.ollama.OllamaEmbedder", } provider_to_config_class = { "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", @@ -65,6 +66,7 @@ class EmbedderFactory: "gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig", "huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig", "openai": "embedchain.config.embedder.base.BaseEmbedderConfig", + "ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig", } @classmethod diff --git a/embedchain/models/embedding_functions.py b/embedchain/models/embedding_functions.py index 7967c45a..557e2ea5 100644 --- a/embedchain/models/embedding_functions.py +++ b/embedchain/models/embedding_functions.py @@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum): HUGGING_FACE = "HUGGING_FACE" VERTEX_AI = "VERTEX_AI" GPT4ALL = "GPT4ALL" + OLLAMA = "OLLAMA" diff --git a/embedchain/models/vector_dimensions.py b/embedchain/models/vector_dimensions.py index 1e0c740a..3b5e33a3 100644 --- a/embedchain/models/vector_dimensions.py +++ b/embedchain/models/vector_dimensions.py @@ -11,3 +11,4 @@ class VectorDimensions(Enum): MISTRAL_AI = 1024 NVIDIA_AI = 1024 COHERE = 384 + OLLAMA = 384 diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 61aaea78..4b991116 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -449,6 +449,7 @@ def validate_config(config_data): "google", "mistralai", "nvidia", + "ollama", ), Optional("config"): { Optional("model"): Optional(str), @@ -458,6 +459,7 @@ def validate_config(config_data): Optional("title"): str, Optional("task_type"): str, Optional("vector_dimension"): int, + Optional("base_url"): str, }, }, Optional("embedding_model"): { @@ -470,6 +472,7 @@ def validate_config(config_data): "google", "mistralai", "nvidia", + "ollama", ), Optional("config"): { Optional("model"): str, @@ -478,6 +481,7 @@ def validate_config(config_data): Optional("title"): str, Optional("task_type"): str, Optional("vector_dimension"): int, + Optional("base_url"): str, }, }, Optional("chunker"): {