From 03a84daf9d2401af6d2a4d940f14fd2982fa8278 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Tue, 10 Oct 2023 00:24:24 +0530 Subject: [PATCH] Add Cohere LLM support (#751) --- docs/advanced/app_types.mdx | 1 + embedchain/llm/cohere.py | 43 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 ++ tests/llm/test_cohere.py | 33 ++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+) create mode 100644 embedchain/llm/cohere.py create mode 100644 tests/llm/test_cohere.py diff --git a/docs/advanced/app_types.mdx b/docs/advanced/app_types.mdx index d8a8f54f..68ddfc61 100644 --- a/docs/advanced/app_types.mdx +++ b/docs/advanced/app_types.mdx @@ -27,6 +27,7 @@ The following LLM providers are supported by Embedchain: - GPT4ALL - AZURE_OPENAI - LLAMA2 +- COHERE You can choose one by importing it from `embedchain.llm`. E.g.: diff --git a/embedchain/llm/cohere.py b/embedchain/llm/cohere.py new file mode 100644 index 00000000..0811c067 --- /dev/null +++ b/embedchain/llm/cohere.py @@ -0,0 +1,43 @@ +import importlib +import os +from typing import Optional + +from langchain.llms import Cohere + +from embedchain.config import BaseLlmConfig +from embedchain.helper.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class CohereLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + if "COHERE_API_KEY" not in os.environ: + raise ValueError("Please set the COHERE_API_KEY environment variable.") + + try: + importlib.import_module("cohere") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for Cohere are not installed." + 'Please install with `pip install --upgrade "embedchain[cohere]"`' + ) from None + + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + if self.config.system_prompt: + raise ValueError("CohereLlm does not support `system_prompt`") + return CohereLlm._get_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + llm = Cohere( + cohere_api_key=os.environ["COHERE_API_KEY"], + model=config.model, + max_tokens=config.max_tokens, + temperature=config.temperature, + p=config.top_p, + ) + + return llm(prompt) diff --git a/pyproject.toml b/pyproject.toml index 758d3705..7bfc09d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ twilio = { version = "^8.5.0", optional = true } fastapi-poe = { version = "0.0.16", optional = true } discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } +cohere = { version = "^4.27", optional= true } docx2txt = "^0.8" unstructured = {extras = ["local-inference"], version = "^0.10.18"} pillow = { version = "10.0.1", optional = true } @@ -134,6 +135,7 @@ discord = ["discord"] slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"] +cohere = ["cohere"] [tool.poetry.group.docs.dependencies] diff --git a/tests/llm/test_cohere.py b/tests/llm/test_cohere.py new file mode 100644 index 00000000..c7372447 --- /dev/null +++ b/tests/llm/test_cohere.py @@ -0,0 +1,33 @@ +import os +import unittest +from unittest.mock import patch + +from embedchain.config import BaseLlmConfig +from embedchain.llm.cohere import CohereLlm + + +class TestCohereLlm(unittest.TestCase): + def setUp(self): + os.environ["COHERE_API_KEY"] = "test_api_key" + self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8) + + def test_init_raises_value_error_without_api_key(self): + os.environ.pop("COHERE_API_KEY") + with self.assertRaises(ValueError): + CohereLlm() + + def test_get_llm_model_answer_raises_value_error_for_system_prompt(self): + llm = CohereLlm(self.config) + llm.config.system_prompt = "system_prompt" + with self.assertRaises(ValueError): + llm.get_llm_model_answer("prompt") + + @patch("embedchain.llm.cohere.CohereLlm._get_answer") + def test_get_llm_model_answer(self, mock_get_answer): + mock_get_answer.return_value = "Test answer" + + llm = CohereLlm(self.config) + answer = llm.get_llm_model_answer("Test query") + + self.assertEqual(answer, "Test answer") + mock_get_answer.assert_called_once()