From e226a89637683791fb51968a742850753ca8c033 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Tue, 10 Oct 2023 00:36:36 +0530 Subject: [PATCH] Add Jina LLM support (#760) --- docs/advanced/app_types.mdx | 1 + embedchain/llm/jina.py | 42 +++++++++++++++++++++++++++++++++++++ tests/llm/test_jina.py | 40 +++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 embedchain/llm/jina.py create mode 100644 tests/llm/test_jina.py diff --git a/docs/advanced/app_types.mdx b/docs/advanced/app_types.mdx index 68ddfc61..d58a9dca 100644 --- a/docs/advanced/app_types.mdx +++ b/docs/advanced/app_types.mdx @@ -27,6 +27,7 @@ The following LLM providers are supported by Embedchain: - GPT4ALL - AZURE_OPENAI - LLAMA2 +- JINA - COHERE You can choose one by importing it from `embedchain.llm`. E.g.: diff --git a/embedchain/llm/jina.py b/embedchain/llm/jina.py new file mode 100644 index 00000000..2c906c7e --- /dev/null +++ b/embedchain/llm/jina.py @@ -0,0 +1,42 @@ +import os +from typing import Optional + +from langchain.chat_models import JinaChat +from langchain.schema import HumanMessage, SystemMessage + +from embedchain.config import BaseLlmConfig +from embedchain.helper.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class JinaLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + if "JINACHAT_API_KEY" not in os.environ: + raise ValueError("Please set the JINACHAT_API_KEY environment variable.") + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + response = JinaLlm._get_answer(prompt, self.config) + return response + + @staticmethod + def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + messages = [] + if config.system_prompt: + messages.append(SystemMessage(content=config.system_prompt)) + messages.append(HumanMessage(content=prompt)) + kwargs = { + "temperature": config.temperature, + "max_tokens": config.max_tokens, + "model_kwargs": {}, + } + if config.top_p: + kwargs["model_kwargs"]["top_p"] = config.top_p + if config.stream: + from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + + chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) + else: + chat = JinaChat(**kwargs) + return chat(messages).content diff --git a/tests/llm/test_jina.py b/tests/llm/test_jina.py new file mode 100644 index 00000000..49793d57 --- /dev/null +++ b/tests/llm/test_jina.py @@ -0,0 +1,40 @@ +import os +import unittest +from unittest.mock import patch + +from embedchain.config import BaseLlmConfig +from embedchain.llm.jina import JinaLlm + + +class TestJinaLlm(unittest.TestCase): + def setUp(self): + os.environ["JINACHAT_API_KEY"] = "test_api_key" + self.config = BaseLlmConfig( + temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt" + ) + + def test_init_raises_value_error_without_api_key(self): + os.environ.pop("JINACHAT_API_KEY") + with self.assertRaises(ValueError): + JinaLlm() + + @patch("embedchain.llm.jina.JinaLlm._get_answer") + def test_get_llm_model_answer(self, mock_get_answer): + mock_get_answer.return_value = "Test answer" + + llm = JinaLlm(self.config) + answer = llm.get_llm_model_answer("Test query") + + self.assertEqual(answer, "Test answer") + mock_get_answer.assert_called_once() + + @patch("embedchain.llm.jina.JinaLlm._get_answer") + def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer): + self.config.system_prompt = "Custom system prompt" + mock_get_answer.return_value = "Test answer" + + llm = JinaLlm(self.config) + answer = llm.get_llm_model_answer("Test query") + + self.assertEqual(answer, "Test answer") + mock_get_answer.assert_called_once()