From 73e53aaff1922da11e91b6b0a3c7341bdaee7c5e Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sun, 9 Jun 2024 12:13:03 +0530 Subject: [PATCH] Download Ollama model if not present (#1397) --- Makefile | 2 +- embedchain/llm/ollama.py | 10 ++++++++++ tests/llm/test_ollama.py | 2 ++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 571d4eb8..8ddb58e7 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ install: install_all: poetry install --all-extras - poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" + poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama install_es: poetry install --extras elasticsearch diff --git a/embedchain/llm/ollama.py b/embedchain/llm/ollama.py index 221753d3..7de675cd 100644 --- a/embedchain/llm/ollama.py +++ b/embedchain/llm/ollama.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Iterable from typing import Optional, Union @@ -5,11 +6,14 @@ from langchain.callbacks.manager import CallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_community.llms.ollama import Ollama +from ollama import Client from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +logger = logging.getLogger(__name__) + @register_deserializable class OllamaLlm(BaseLlm): @@ -18,6 +22,12 @@ class OllamaLlm(BaseLlm): if self.config.model is None: self.config.model = "llama2" + client = Client(host=config.base_url) + local_models = client.list()["models"] + if not any(model.get("name") == self.config.model for model in local_models): + logger.info(f"Pulling {self.config.model} from Ollama!") + client.pull(self.config.model) + def get_llm_model_answer(self, prompt): return self._get_answer(prompt=prompt, config=self.config) diff --git a/tests/llm/test_ollama.py b/tests/llm/test_ollama.py index 62252783..ff16d126 100644 --- a/tests/llm/test_ollama.py +++ b/tests/llm/test_ollama.py @@ -11,6 +11,7 @@ def ollama_llm_config(): def test_get_llm_model_answer(ollama_llm_config, mocker): + mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]}) mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer") llm = OllamaLlm(ollama_llm_config) @@ -20,6 +21,7 @@ def test_get_llm_model_answer(ollama_llm_config, mocker): def test_get_answer_mocked_ollama(ollama_llm_config, mocker): + mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]}) mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama") mock_instance = mocked_ollama.return_value mock_instance.invoke.return_value = "Mocked answer"