From 09cdaff9a2e740f71c403c457499bdc2e8976f7f Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Tue, 27 Feb 2024 15:10:41 -0800 Subject: [PATCH] [Improvement] Fix deprecation warnings (#1288) --- embedchain/helpers/callbacks.py | 2 +- embedchain/llm/aws_bedrock.py | 5 ++--- embedchain/llm/cohere.py | 2 +- embedchain/llm/llama2.py | 2 +- embedchain/llm/ollama.py | 2 +- embedchain/llm/together.py | 2 +- embedchain/llm/vllm.py | 2 +- pyproject.toml | 2 +- tests/llm/test_cohere.py | 10 +--------- tests/llm/test_llama2.py | 11 +---------- tests/llm/test_ollama.py | 10 +--------- tests/llm/test_together.py | 10 +--------- 12 files changed, 13 insertions(+), 47 deletions(-) diff --git a/embedchain/helpers/callbacks.py b/embedchain/helpers/callbacks.py index 994e0fdc..4847e0fe 100644 --- a/embedchain/helpers/callbacks.py +++ b/embedchain/helpers/callbacks.py @@ -56,7 +56,7 @@ def generate(rq: queue.Queue): ``` def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield): llm = OpenAI(streaming=True, callbacks=[callback_fn]) - return llm(prompt="Write a poem about a tree.") + return llm.invoke(prompt="Write a poem about a tree.") @app.route("/", methods=["GET"]) def generate_output(): diff --git a/embedchain/llm/aws_bedrock.py b/embedchain/llm/aws_bedrock.py index de3da5d0..34170981 100644 --- a/embedchain/llm/aws_bedrock.py +++ b/embedchain/llm/aws_bedrock.py @@ -38,12 +38,11 @@ class AWSBedrockLlm(BaseLlm): } if config.stream: - from langchain.callbacks.streaming_stdout import \ - StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler callbacks = [StreamingStdOutCallbackHandler()] llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks) else: llm = Bedrock(**kwargs) - return llm(prompt) + return llm.invoke(prompt) diff --git a/embedchain/llm/cohere.py b/embedchain/llm/cohere.py index be4407fd..d755db0c 100644 --- a/embedchain/llm/cohere.py +++ b/embedchain/llm/cohere.py @@ -40,4 +40,4 @@ class CohereLlm(BaseLlm): p=config.top_p, ) - return llm(prompt) + return llm.invoke(prompt) diff --git a/embedchain/llm/llama2.py b/embedchain/llm/llama2.py index 486dbc71..426239a8 100644 --- a/embedchain/llm/llama2.py +++ b/embedchain/llm/llama2.py @@ -48,4 +48,4 @@ class Llama2Llm(BaseLlm): "top_p": self.config.top_p, }, ) - return llm(prompt) + return llm.invoke(prompt) diff --git a/embedchain/llm/ollama.py b/embedchain/llm/ollama.py index e1065f4e..6d7f802b 100644 --- a/embedchain/llm/ollama.py +++ b/embedchain/llm/ollama.py @@ -33,4 +33,4 @@ class OllamaLlm(BaseLlm): callback_manager=CallbackManager(callback_manager), ) - return llm(prompt) + return llm.invoke(prompt) diff --git a/embedchain/llm/together.py b/embedchain/llm/together.py index 23462b1e..17995ca5 100644 --- a/embedchain/llm/together.py +++ b/embedchain/llm/together.py @@ -40,4 +40,4 @@ class TogetherLlm(BaseLlm): top_p=config.top_p, ) - return llm(prompt) + return llm.invoke(prompt) diff --git a/embedchain/llm/vllm.py b/embedchain/llm/vllm.py index faac1f39..88a8e2ad 100644 --- a/embedchain/llm/vllm.py +++ b/embedchain/llm/vllm.py @@ -37,4 +37,4 @@ class VLLM(BaseLlm): llm_args.update(config.model_kwargs) llm = BaseVLLM(**llm_args) - return llm(prompt) + return llm.invoke(prompt) diff --git a/pyproject.toml b/pyproject.toml index f1362687..70c93a85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.87" +version = "0.1.88" description = "Simplest open source retrieval(RAG) framework" authors = [ "Taranjeet Singh ", diff --git a/tests/llm/test_cohere.py b/tests/llm/test_cohere.py index 1bee4cff..2671d473 100644 --- a/tests/llm/test_cohere.py +++ b/tests/llm/test_cohere.py @@ -39,18 +39,10 @@ def test_get_llm_model_answer(cohere_llm_config, mocker): def test_get_answer_mocked_cohere(cohere_llm_config, mocker): mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere") mock_instance = mocked_cohere.return_value - mock_instance.return_value = "Mocked answer" + mock_instance.invoke.return_value = "Mocked answer" llm = CohereLlm(cohere_llm_config) prompt = "Test query" answer = llm.get_llm_model_answer(prompt) assert answer == "Mocked answer" - mocked_cohere.assert_called_once_with( - cohere_api_key="test_api_key", - model="gptd-instruct-tft", - max_tokens=50, - temperature=0.7, - p=0.8, - ) - mock_instance.assert_called_once_with(prompt) diff --git a/tests/llm/test_llama2.py b/tests/llm/test_llama2.py index 40885fd2..a9dd4049 100644 --- a/tests/llm/test_llama2.py +++ b/tests/llm/test_llama2.py @@ -28,7 +28,7 @@ def test_get_llm_model_answer(llama2_llm, mocker): mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate") mocked_replicate_instance = mocker.MagicMock() mocked_replicate.return_value = mocked_replicate_instance - mocked_replicate_instance.return_value = "Test answer" + mocked_replicate_instance.invoke.return_value = "Test answer" llama2_llm.config.model = "test_model" llama2_llm.config.max_tokens = 50 @@ -38,12 +38,3 @@ def test_get_llm_model_answer(llama2_llm, mocker): answer = llama2_llm.get_llm_model_answer("Test query") assert answer == "Test answer" - mocked_replicate.assert_called_once_with( - model="test_model", - input={ - "temperature": 0.7, - "max_length": 50, - "top_p": 0.8, - }, - ) - mocked_replicate_instance.assert_called_once_with("Test query") diff --git a/tests/llm/test_ollama.py b/tests/llm/test_ollama.py index 34ab8238..62252783 100644 --- a/tests/llm/test_ollama.py +++ b/tests/llm/test_ollama.py @@ -22,18 +22,10 @@ def test_get_llm_model_answer(ollama_llm_config, mocker): def test_get_answer_mocked_ollama(ollama_llm_config, mocker): mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama") mock_instance = mocked_ollama.return_value - mock_instance.return_value = "Mocked answer" + mock_instance.invoke.return_value = "Mocked answer" llm = OllamaLlm(ollama_llm_config) prompt = "Test query" answer = llm.get_llm_model_answer(prompt) assert answer == "Mocked answer" - mocked_ollama.assert_called_once_with( - model="llama2", - system=None, - temperature=0.7, - top_p=0.8, - callback_manager=mocker.ANY, # Use mocker.ANY to ignore the exact instance - ) - mock_instance.assert_called_once_with(prompt) diff --git a/tests/llm/test_together.py b/tests/llm/test_together.py index e5b668d5..a72cfe13 100644 --- a/tests/llm/test_together.py +++ b/tests/llm/test_together.py @@ -39,18 +39,10 @@ def test_get_llm_model_answer(together_llm_config, mocker): def test_get_answer_mocked_together(together_llm_config, mocker): mocked_together = mocker.patch("embedchain.llm.together.Together") mock_instance = mocked_together.return_value - mock_instance.return_value = "Mocked answer" + mock_instance.invoke.return_value = "Mocked answer" llm = TogetherLlm(together_llm_config) prompt = "Test query" answer = llm.get_llm_model_answer(prompt) assert answer == "Mocked answer" - mocked_together.assert_called_once_with( - together_api_key="test_api_key", - model="togethercomputer/RedPajama-INCITE-7B-Base", - max_tokens=50, - temperature=0.7, - top_p=0.8, - ) - mock_instance.assert_called_once_with(prompt)