From dc0d8e09324c52b6c9e8af98b23edcf64ce0bf9a Mon Sep 17 00:00:00 2001 From: Aditya Veer Parmar Date: Mon, 17 Jun 2024 21:14:52 +0530 Subject: [PATCH] Allow ollama llm to take custom callback for handling streaming (#1376) --- embedchain/llm/ollama.py | 7 +++++-- tests/llm/test_ollama.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/embedchain/llm/ollama.py b/embedchain/llm/ollama.py index 7de675cd..c21c522d 100644 --- a/embedchain/llm/ollama.py +++ b/embedchain/llm/ollama.py @@ -33,14 +33,17 @@ class OllamaLlm(BaseLlm): @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: - callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()] + if config.stream: + callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] + else: + callbacks = [StdOutCallbackHandler()] llm = Ollama( model=config.model, system=config.system_prompt, temperature=config.temperature, top_p=config.top_p, - callback_manager=CallbackManager(callback_manager), + callback_manager=CallbackManager(callbacks), base_url=config.base_url, ) diff --git a/tests/llm/test_ollama.py b/tests/llm/test_ollama.py index ff16d126..36a62b63 100644 --- a/tests/llm/test_ollama.py +++ b/tests/llm/test_ollama.py @@ -2,6 +2,7 @@ import pytest from embedchain.config import BaseLlmConfig from embedchain.llm.ollama import OllamaLlm +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler @pytest.fixture @@ -31,3 +32,20 @@ def test_get_answer_mocked_ollama(ollama_llm_config, mocker): answer = llm.get_llm_model_answer(prompt) assert answer == "Mocked answer" + + +def test_get_llm_model_answer_with_streaming(ollama_llm_config, mocker): + ollama_llm_config.stream = True + ollama_llm_config.callbacks = [StreamingStdOutCallbackHandler()] + mocked_ollama_chat = mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer") + + llm = OllamaLlm(ollama_llm_config) + llm.get_llm_model_answer("Test query") + + mocked_ollama_chat.assert_called_once() + call_args = mocked_ollama_chat.call_args + config_arg = call_args[1]["config"] + callbacks = config_arg.callbacks + + assert len(callbacks) == 1 + assert isinstance(callbacks[0], StreamingStdOutCallbackHandler)