[docs]: Revamp embedchain docs (#799)
This commit is contained in:
@@ -85,11 +85,11 @@ class TestApp(unittest.TestCase):
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
|
||||
in the LlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
LLmConfig.
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestGeneratePrompt(unittest.TestCase):
|
||||
def test_generate_prompt_with_template(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly formats the prompt using
|
||||
a custom template provided in the LlmConfig instance.
|
||||
a custom template provided in the BaseLlmConfig instance.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and a custom template, and then calls generate_prompt. It checks that the
|
||||
@@ -58,7 +58,7 @@ class TestGeneratePrompt(unittest.TestCase):
|
||||
|
||||
def test_generate_prompt_with_history(self):
|
||||
"""
|
||||
Test the 'generate_prompt' method with LlmConfig containing a history attribute.
|
||||
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
|
||||
"""
|
||||
config = BaseLlmConfig()
|
||||
config.template = Template("Context: $context | Query: $query | History: $history")
|
||||
|
||||
@@ -4,21 +4,21 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm
|
||||
from embedchain.llm.huggingface import HuggingFaceLlm
|
||||
|
||||
|
||||
class TestHuggingFaceHubLlm(unittest.TestCase):
|
||||
class TestHuggingFaceLlm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["HUGGINGFACEHUB_ACCESS_TOKEN"] = "test_access_token"
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
|
||||
def test_init_raises_value_error_without_api_key(self):
|
||||
os.environ.pop("HUGGINGFACEHUB_ACCESS_TOKEN")
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
with self.assertRaises(ValueError):
|
||||
HuggingFaceHubLlm()
|
||||
HuggingFaceLlm()
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
|
||||
llm = HuggingFaceHubLlm(self.config)
|
||||
llm = HuggingFaceLlm(self.config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with self.assertRaises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
@@ -26,7 +26,7 @@ class TestHuggingFaceHubLlm(unittest.TestCase):
|
||||
def test_top_p_value_within_range(self):
|
||||
config = BaseLlmConfig(top_p=1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
HuggingFaceHubLlm._get_answer("test_prompt", config)
|
||||
HuggingFaceLlm._get_answer("test_prompt", config)
|
||||
|
||||
def test_dependency_is_imported(self):
|
||||
importlib_installed = True
|
||||
@@ -36,27 +36,27 @@ class TestHuggingFaceHubLlm(unittest.TestCase):
|
||||
importlib_installed = False
|
||||
self.assertTrue(importlib_installed)
|
||||
|
||||
@patch("embedchain.llm.hugging_face_hub.HuggingFaceHubLlm._get_answer")
|
||||
@patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer")
|
||||
def test_get_llm_model_answer(self, mock_get_answer):
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
|
||||
llm = HuggingFaceHubLlm(self.config)
|
||||
llm = HuggingFaceLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_get_answer.assert_called_once()
|
||||
|
||||
@patch("embedchain.llm.hugging_face_hub.HuggingFaceHub")
|
||||
def test_hugging_face_mock(self, mock_hugging_face_hub):
|
||||
@patch("embedchain.llm.huggingface.HuggingFaceHub")
|
||||
def test_hugging_face_mock(self, mock_huggingface):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_llm_instance.return_value = "Test answer"
|
||||
mock_hugging_face_hub.return_value = mock_llm_instance
|
||||
mock_huggingface.return_value = mock_llm_instance
|
||||
|
||||
llm = HuggingFaceHubLlm(self.config)
|
||||
llm = HuggingFaceLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_hugging_face_hub.assert_called_once_with(
|
||||
mock_huggingface.assert_called_once_with(
|
||||
huggingfacehub_api_token="test_access_token",
|
||||
repo_id="google/flan-t5-xxl",
|
||||
model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8},
|
||||
@@ -24,7 +24,7 @@ class TestApp(unittest.TestCase):
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
LlmConfig.
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestApp(unittest.TestCase):
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
LlmConfig.
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestApp(unittest.TestCase):
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
LlmConfig.
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
|
||||
@@ -4,17 +4,17 @@ import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.vertex_ai import VertexAiLlm
|
||||
from embedchain.llm.vertex_ai import VertexAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_llm():
|
||||
config = BaseLlmConfig(temperature=0.6, model="vertexai_model", system_prompt="System Prompt")
|
||||
return VertexAiLlm(config)
|
||||
return VertexAILlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(vertexai_llm):
|
||||
with patch.object(VertexAiLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
with patch.object(VertexAILlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = vertexai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
|
||||
Reference in New Issue
Block a user