Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
54
embedchain/tests/llm/test_anthrophic.py
Normal file
54
embedchain/tests/llm/test_anthrophic.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.anthropic import AnthropicLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anthropic_llm():
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
|
||||
return AnthropicLlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(anthropic_llm):
|
||||
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = anthropic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt, anthropic_llm.config)
|
||||
|
||||
|
||||
def test_get_messages(anthropic_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = anthropic_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(anthropic_llm):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
|
||||
)
|
||||
anthropic_llm.config = test_config
|
||||
with patch.object(
|
||||
AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
|
||||
) as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response, token_info = anthropic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 1.265e-05,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
mock_method.assert_called_once_with(prompt, anthropic_llm.config)
|
||||
56
embedchain/tests/llm/test_aws_bedrock.py
Normal file
56
embedchain/tests/llm/test_aws_bedrock.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.aws_bedrock import AWSBedrockLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(monkeypatch):
|
||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="amazon.titan-text-express-v1",
|
||||
model_kwargs={
|
||||
"temperature": 0.5,
|
||||
"topP": 1,
|
||||
"maxTokenCount": 1000,
|
||||
},
|
||||
)
|
||||
yield config
|
||||
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
|
||||
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = AWSBedrockLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = AWSBedrockLlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
|
||||
|
||||
llm = AWSBedrockLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_bedrock_chat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
87
embedchain/tests/llm/test_azure_openai.py
Normal file
87
embedchain/tests/llm/test_azure_openai.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_llm():
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
model="gpt-3.5-turbo",
|
||||
max_tokens=50,
|
||||
system_prompt="System Prompt",
|
||||
)
|
||||
return AzureOpenAILlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(azure_openai_llm):
|
||||
with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = azure_openai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
|
||||
|
||||
|
||||
def test_get_answer(azure_openai_llm):
|
||||
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.invoke.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name=azure_openai_llm.config.deployment_name,
|
||||
openai_api_version="2024-02-01",
|
||||
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
|
||||
temperature=azure_openai_llm.config.temperature,
|
||||
max_tokens=azure_openai_llm.config.max_tokens,
|
||||
streaming=azure_openai_llm.config.stream,
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages(azure_openai_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = azure_openai_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_when_no_deployment_name_provided():
|
||||
config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
|
||||
with pytest.raises(ValueError):
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test Prompt")
|
||||
|
||||
|
||||
def test_with_api_version():
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
model="gpt-3.5-turbo",
|
||||
max_tokens=50,
|
||||
system_prompt="System Prompt",
|
||||
api_version="2024-02-01",
|
||||
)
|
||||
|
||||
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test Prompt")
|
||||
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name="azure_deployment",
|
||||
openai_api_version="2024-02-01",
|
||||
model_name="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
streaming=False,
|
||||
)
|
||||
61
embedchain/tests/llm/test_base_llm.py
Normal file
61
embedchain/tests/llm/test_base_llm.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from string import Template
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.base import BaseLlm, BaseLlmConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_llm():
|
||||
config = BaseLlmConfig()
|
||||
return BaseLlm(config=config)
|
||||
|
||||
|
||||
def test_is_get_llm_model_answer_not_implemented(base_llm):
|
||||
with pytest.raises(NotImplementedError):
|
||||
base_llm.get_llm_model_answer()
|
||||
|
||||
|
||||
def test_is_stream_bool():
|
||||
with pytest.raises(ValueError):
|
||||
config = BaseLlmConfig(stream="test value")
|
||||
BaseLlm(config=config)
|
||||
|
||||
|
||||
def test_template_string_gets_converted_to_Template_instance():
|
||||
config = BaseLlmConfig(template="test value $query $context")
|
||||
llm = BaseLlm(config=config)
|
||||
assert isinstance(llm.config.prompt, Template)
|
||||
|
||||
|
||||
def test_is_get_llm_model_answer_implemented():
|
||||
class TestLlm(BaseLlm):
|
||||
def get_llm_model_answer(self):
|
||||
return "Implemented"
|
||||
|
||||
config = BaseLlmConfig()
|
||||
llm = TestLlm(config=config)
|
||||
assert llm.get_llm_model_answer() == "Implemented"
|
||||
|
||||
|
||||
def test_stream_response(base_llm):
|
||||
answer = ["Chunk1", "Chunk2", "Chunk3"]
|
||||
result = list(base_llm._stream_response(answer))
|
||||
assert result == answer
|
||||
|
||||
|
||||
def test_append_search_and_context(base_llm):
|
||||
context = "Context"
|
||||
web_search_result = "Web Search Result"
|
||||
result = base_llm._append_search_and_context(context, web_search_result)
|
||||
expected_result = "Context\nWeb Search Result: Web Search Result"
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_access_search_and_get_results(base_llm, mocker):
|
||||
base_llm.access_search_and_get_results = mocker.patch.object(
|
||||
base_llm, "access_search_and_get_results", return_value="Search Results"
|
||||
)
|
||||
input_query = "Test query"
|
||||
result = base_llm.access_search_and_get_results(input_query)
|
||||
assert result == "Search Results"
|
||||
120
embedchain/tests/llm/test_chat.py
Normal file
120
embedchain/tests/llm/test_chat.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
|
||||
memory.
|
||||
The 'chat' method is called twice. The first call initializes the chat history memory.
|
||||
The second call is expected to use the chat history from the first call.
|
||||
|
||||
Key assumptions tested:
|
||||
called with correct arguments, adding the correct chat history.
|
||||
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
|
||||
- During the second call, the 'chat' method uses the chat history from the first call.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and
|
||||
'memory' methods.
|
||||
"""
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
with patch.object(BaseLlm, "add_history") as mock_history:
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer", session_id="default")
|
||||
|
||||
second_answer = app.chat("Test query 2", session_id="test_session")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer", session_id="test_session")
|
||||
|
||||
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_template_replacement(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
Tests that if a default template is used and it doesn't contain history,
|
||||
the default template is swapped in.
|
||||
|
||||
Also tests that a dry run does not change the history
|
||||
"""
|
||||
with patch.object(ChatHistory, "get") as mock_memory:
|
||||
mock_message = ChatMessage()
|
||||
mock_message.add_user_message("Test query 1")
|
||||
mock_message.add_ai_message("Test answer")
|
||||
mock_memory.return_value = [mock_message]
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
history = app.llm.history
|
||||
dry_run = app.chat("Test query 2", dry_run=True)
|
||||
self.assertIn("Conversation history:", dry_run)
|
||||
self.assertEqual(history, app.llm.history)
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
"""
|
||||
Test where filter
|
||||
"""
|
||||
with patch.object(self.app, "_retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.chat("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_retrieve.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_chat_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the '_retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified
|
||||
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(self.app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
llm_config = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test query", llm_config)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
23
embedchain/tests/llm/test_clarifai.py
Normal file
23
embedchain/tests/llm/test_clarifai.py
Normal file
@@ -0,0 +1,23 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.clarifai import ClarifaiLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clarifai_llm_config(monkeypatch):
|
||||
monkeypatch.setenv("CLARIFAI_PAT","test_api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="https://clarifai.com/openai/chat-completion/models/GPT-4",
|
||||
model_kwargs={"temperature": 0.7, "max_tokens": 100},
|
||||
)
|
||||
yield config
|
||||
monkeypatch.delenv("CLARIFAI_PAT")
|
||||
|
||||
def test_clarifai__llm_get_llm_model_answer(clarifai_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.clarifai.ClarifaiLlm._get_answer", return_value="Test answer")
|
||||
llm = ClarifaiLlm(clarifai_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
73
embedchain/tests/llm/test_cohere.py
Normal file
73
embedchain/tests/llm/test_cohere.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.cohere import CohereLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_llm_config():
|
||||
os.environ["COHERE_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
|
||||
yield config
|
||||
os.environ.pop("COHERE_API_KEY")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
CohereLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(cohere_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=cohere_llm_config.temperature,
|
||||
max_tokens=cohere_llm_config.max_tokens,
|
||||
top_p=cohere_llm_config.top_p,
|
||||
model=cohere_llm_config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
mocker.patch(
|
||||
"embedchain.llm.cohere.CohereLlm._get_answer",
|
||||
return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = CohereLlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 3.5e-06,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
|
||||
mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
|
||||
mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
|
||||
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
70
embedchain/tests/llm/test_generate_prompt.py
Normal file
70
embedchain/tests/llm/test_generate_prompt.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import unittest
|
||||
from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
|
||||
|
||||
class TestGeneratePrompt(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
def test_generate_prompt_with_template(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly formats the prompt using
|
||||
a custom template provided in the BaseLlmConfig instance.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and a custom template, and then calls generate_prompt. It checks that the
|
||||
returned prompt correctly incorporates all the contexts and the query into
|
||||
the format specified by the template.
|
||||
"""
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
|
||||
config = BaseLlmConfig(template=Template(template))
|
||||
self.app.llm.config = config
|
||||
|
||||
# Execute
|
||||
result = self.app.llm.generate_prompt(input_query, contexts)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_prompt_with_contexts_list(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly handles a list of contexts.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and then calls generate_prompt. It checks that the returned prompt
|
||||
correctly includes all the contexts and the query.
|
||||
"""
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
config = BaseLlmConfig()
|
||||
|
||||
# Execute
|
||||
self.app.llm.config = config
|
||||
result = self.app.llm.generate_prompt(input_query, contexts)
|
||||
|
||||
# Assert
|
||||
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_prompt_with_history(self):
|
||||
"""
|
||||
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
|
||||
"""
|
||||
config = BaseLlmConfig()
|
||||
config.prompt = Template("Context: $context | Query: $query | History: $history")
|
||||
self.app.llm.config = config
|
||||
self.app.llm.set_history(["Past context 1", "Past context 2"])
|
||||
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
|
||||
|
||||
expected_prompt = "Context: Test context | Query: Test query | History: Past context 1\nPast context 2"
|
||||
self.assertEqual(prompt, expected_prompt)
|
||||
43
embedchain/tests/llm/test_google.py
Normal file
43
embedchain/tests/llm/test_google.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.google import GoogleLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_llm_config():
|
||||
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||
|
||||
|
||||
def test_google_llm_init_missing_api_key(monkeypatch):
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
|
||||
GoogleLlm()
|
||||
|
||||
|
||||
def test_google_llm_init(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr("importlib.import_module", lambda x: None)
|
||||
google_llm = GoogleLlm()
|
||||
assert google_llm is not None
|
||||
|
||||
|
||||
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
monkeypatch.setattr("importlib.import_module", lambda x: None)
|
||||
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
|
||||
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
|
||||
google_llm.get_llm_model_answer("test prompt")
|
||||
|
||||
|
||||
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
|
||||
def mock_get_answer(prompt, config):
|
||||
return "Generated Text"
|
||||
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
|
||||
google_llm = GoogleLlm(config=google_llm_config)
|
||||
result = google_llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
60
embedchain/tests/llm/test_gpt4all.py
Normal file
60
embedchain/tests/llm/test_gpt4all.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.gpt4all import GPT4ALLLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="orca-mini-3b-gguf2-q4_0.gguf",
|
||||
)
|
||||
yield config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt4all_with_config(config):
|
||||
return GPT4ALLLlm(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt4all_without_config():
|
||||
return GPT4ALLLlm()
|
||||
|
||||
|
||||
def test_gpt4all_init_with_config(config, gpt4all_with_config):
|
||||
assert gpt4all_with_config.config.temperature == config.temperature
|
||||
assert gpt4all_with_config.config.max_tokens == config.max_tokens
|
||||
assert gpt4all_with_config.config.top_p == config.top_p
|
||||
assert gpt4all_with_config.config.stream == config.stream
|
||||
assert gpt4all_with_config.config.system_prompt == config.system_prompt
|
||||
assert gpt4all_with_config.config.model == config.model
|
||||
|
||||
assert isinstance(gpt4all_with_config.instance, LangchainGPT4All)
|
||||
|
||||
|
||||
def test_gpt4all_init_without_config(gpt4all_without_config):
|
||||
assert gpt4all_without_config.config.model == "orca-mini-3b-gguf2-q4_0.gguf"
|
||||
assert isinstance(gpt4all_without_config.instance, LangchainGPT4All)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(mocker, gpt4all_with_config):
|
||||
test_query = "Test query"
|
||||
test_answer = "Test answer"
|
||||
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.gpt4all.GPT4ALLLlm._get_answer", return_value=test_answer)
|
||||
answer = gpt4all_with_config.get_llm_model_answer(test_query)
|
||||
|
||||
assert answer == test_answer
|
||||
mocked_get_answer.assert_called_once_with(prompt=test_query, config=gpt4all_with_config.config)
|
||||
|
||||
|
||||
def test_gpt4all_model_switching(gpt4all_with_config):
|
||||
with pytest.raises(RuntimeError, match="GPT4ALLLlm does not support switching models at runtime."):
|
||||
gpt4all_with_config._get_answer("Test prompt", BaseLlmConfig(model="new_model"))
|
||||
83
embedchain/tests/llm/test_huggingface.py
Normal file
83
embedchain/tests/llm/test_huggingface.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.huggingface import HuggingFaceLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huggingface_llm_config():
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
yield config
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huggingface_endpoint_config():
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
config = BaseLlmConfig(endpoint="https://api-inference.huggingface.co/models/gpt2", model_kwargs={"device": "cpu"})
|
||||
yield config
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_top_p_value_within_range():
|
||||
config = BaseLlmConfig(top_p=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceLlm._get_answer("test_prompt", config)
|
||||
|
||||
|
||||
def test_dependency_is_imported():
|
||||
importlib_installed = True
|
||||
try:
|
||||
importlib.import_module("huggingface_hub")
|
||||
except ImportError:
|
||||
importlib_installed = False
|
||||
assert importlib_installed
|
||||
|
||||
|
||||
def test_get_llm_model_answer(huggingface_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_hugging_face_mock(huggingface_llm_config, mocker):
|
||||
mock_llm_instance = mocker.Mock(return_value="Test answer")
|
||||
mock_hf_hub = mocker.patch("embedchain.llm.huggingface.HuggingFaceHub")
|
||||
mock_hf_hub.return_value.invoke = mock_llm_instance
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
assert answer == "Test answer"
|
||||
mock_llm_instance.assert_called_once_with("Test query")
|
||||
|
||||
|
||||
def test_custom_endpoint(huggingface_endpoint_config, mocker):
|
||||
mock_llm_instance = mocker.Mock(return_value="Test answer")
|
||||
mock_hf_endpoint = mocker.patch("embedchain.llm.huggingface.HuggingFaceEndpoint")
|
||||
mock_hf_endpoint.return_value.invoke = mock_llm_instance
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_endpoint_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mock_llm_instance.assert_called_once_with("Test query")
|
||||
79
embedchain/tests/llm/test_jina.py
Normal file
79
embedchain/tests/llm/test_jina.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.jina import JinaLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
os.environ["JINACHAT_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
|
||||
yield config
|
||||
os.environ.pop("JINACHAT_API_KEY")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
JinaLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(config, mocker):
|
||||
config.system_prompt = "Custom system prompt"
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once_with(
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
)
|
||||
40
embedchain/tests/llm/test_llama2.py
Normal file
40
embedchain/tests/llm/test_llama2.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.llama2 import Llama2Llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama2_llm():
|
||||
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
|
||||
llm = Llama2Llm()
|
||||
return llm
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
Llama2Llm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
|
||||
llama2_llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llama2_llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(llama2_llm, mocker):
|
||||
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
|
||||
mocked_replicate_instance = mocker.MagicMock()
|
||||
mocked_replicate.return_value = mocked_replicate_instance
|
||||
mocked_replicate_instance.invoke.return_value = "Test answer"
|
||||
|
||||
llama2_llm.config.model = "test_model"
|
||||
llama2_llm.config.max_tokens = 50
|
||||
llama2_llm.config.temperature = 0.7
|
||||
llama2_llm.config.top_p = 0.8
|
||||
|
||||
answer = llama2_llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
87
embedchain/tests/llm/test_mistralai.py
Normal file
87
embedchain/tests/llm/test_mistralai.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.mistralai import MistralAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistralai_llm_config(monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
|
||||
yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
|
||||
|
||||
def test_mistralai_llm_init_missing_api_key(monkeypatch):
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
|
||||
MistralAILlm()
|
||||
|
||||
|
||||
def test_mistralai_llm_init(monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
|
||||
llm = MistralAILlm()
|
||||
assert llm is not None
|
||||
|
||||
|
||||
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
|
||||
def mock_get_answer(self, prompt, config):
|
||||
return "Generated Text"
|
||||
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
|
||||
mistralai_llm_config.system_prompt = "Test system prompt"
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
|
||||
mistralai_llm_config.system_prompt = None
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=mistralai_llm_config.temperature,
|
||||
max_tokens=mistralai_llm_config.max_tokens,
|
||||
top_p=mistralai_llm_config.top_p,
|
||||
model=mistralai_llm_config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
MistralAILlm,
|
||||
"_get_answer",
|
||||
lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = MistralAILlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Generated Text"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 7.5e-07,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
52
embedchain/tests/llm/test_ollama.py
Normal file
52
embedchain/tests/llm/test_ollama.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.ollama import OllamaLlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_llm_config():
|
||||
config = BaseLlmConfig(model="llama2", temperature=0.7, top_p=0.8, stream=True, system_prompt=None)
|
||||
yield config
|
||||
|
||||
|
||||
def test_get_llm_model_answer(ollama_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
|
||||
mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OllamaLlm(ollama_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_answer_mocked_ollama(ollama_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
|
||||
mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama")
|
||||
mock_instance = mocked_ollama.return_value
|
||||
mock_instance.invoke.return_value = "Mocked answer"
|
||||
|
||||
llm = OllamaLlm(ollama_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(ollama_llm_config, mocker):
|
||||
ollama_llm_config.stream = True
|
||||
ollama_llm_config.callbacks = [StreamingStdOutCallbackHandler()]
|
||||
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
|
||||
mocked_ollama_chat = mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OllamaLlm(ollama_llm_config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_ollama_chat.assert_called_once()
|
||||
call_args = mocked_ollama_chat.call_args
|
||||
config_arg = call_args[1]["config"]
|
||||
callbacks = config_arg.callbacks
|
||||
|
||||
assert len(callbacks) == 1
|
||||
assert isinstance(callbacks[0], StreamingStdOutCallbackHandler)
|
||||
261
embedchain/tests/llm/test_openai.py
Normal file
261
embedchain/tests/llm/test_openai.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def env_config():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
|
||||
yield
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(env_config):
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies=None,
|
||||
http_async_client_proxies=None,
|
||||
)
|
||||
yield config
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(config, mocker):
|
||||
config.system_prompt = "Custom system prompt"
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(config, mocker):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
stream=config.stream,
|
||||
system_prompt=config.system_prompt,
|
||||
model=config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
mocked_get_answer = mocker.patch(
|
||||
"embedchain.llm.openai.OpenAILlm._get_answer",
|
||||
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = OpenAILlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 5.5e-06,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
mocked_get_answer.assert_called_once_with("Test query", test_config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_special_headers(config, mocker):
|
||||
config.default_headers = {"test": "test"}
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
default_headers={"test": "test"},
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_model_kwargs(config, mocker):
|
||||
config.model_kwargs = {"response_format": {"type": "json_object"}}
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mock_return, expected",
|
||||
[
|
||||
([{"test": "test"}], '{"test": "test"}'),
|
||||
([], "Input could not be mapped to the function!"),
|
||||
],
|
||||
)
|
||||
def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
|
||||
mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
|
||||
mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
|
||||
|
||||
llm = OpenAILlm(config, tools={"test": "test"})
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
|
||||
mocked_json_output_tools_parser.assert_called_once()
|
||||
|
||||
assert answer == expected
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mock_http_client = mocker.Mock(spec=httpx.Client)
|
||||
mock_http_client_instance = mocker.Mock(spec=httpx.Client)
|
||||
mock_http_client.return_value = mock_http_client_instance
|
||||
|
||||
mocker.patch("httpx.Client", new=mock_http_client)
|
||||
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies="http://testproxy.mem0.net:8000",
|
||||
)
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=mock_http_client_instance,
|
||||
http_async_client=None,
|
||||
)
|
||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client.return_value = mock_http_async_client_instance
|
||||
|
||||
mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
|
||||
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
|
||||
)
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=mock_http_async_client_instance,
|
||||
)
|
||||
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
|
||||
79
embedchain/tests/llm/test_query.py
Normal file
79
embedchain/tests/llm/test_query.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
return app
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query(app):
|
||||
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = app.query(input_query="Test query")
|
||||
assert answer == "Test answer"
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
_, kwargs = mock_retrieve.call_args
|
||||
input_query_arg = kwargs.get("input_query")
|
||||
assert input_query_arg == "Test query"
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_query_config_app_passing(mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
llm = OpenAILlm(config=chat_config)
|
||||
app = App(config=config, llm=llm)
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(app):
|
||||
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_retrieve.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
assert kwargs.get("where") == {"attribute": "value"}
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(app):
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
llm_config = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = app.query("Test query", llm_config)
|
||||
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_database_query.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
74
embedchain/tests/llm/test_together.py
Normal file
74
embedchain/tests/llm/test_together.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.together import TogetherLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def together_llm_config():
|
||||
os.environ["TOGETHER_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
yield config
|
||||
os.environ.pop("TOGETHER_API_KEY")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
TogetherLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(together_llm_config):
|
||||
llm = TogetherLlm(together_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(together_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.together.TogetherLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = TogetherLlm(together_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=together_llm_config.temperature,
|
||||
max_tokens=together_llm_config.max_tokens,
|
||||
top_p=together_llm_config.top_p,
|
||||
model=together_llm_config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
mocker.patch(
|
||||
"embedchain.llm.together.TogetherLlm._get_answer",
|
||||
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = TogetherLlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 3e-07,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
def test_get_answer_mocked_together(together_llm_config, mocker):
|
||||
mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
|
||||
mock_instance = mocked_together.return_value
|
||||
mock_instance.invoke.return_value.content = "Mocked answer"
|
||||
|
||||
llm = TogetherLlm(together_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
76
embedchain/tests/llm/test_vertex_ai.py
Normal file
76
embedchain/tests/llm/test_vertex_ai.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.core.db.database import database_manager
|
||||
from embedchain.llm.vertex_ai import VertexAILlm
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_database():
|
||||
database_manager.setup_engine()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_llm():
|
||||
config = BaseLlmConfig(temperature=0.6, model="chat-bison")
|
||||
return VertexAILlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(vertexai_llm):
|
||||
with patch.object(VertexAILlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = vertexai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt, vertexai_llm.config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(vertexai_llm):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=vertexai_llm.config.temperature,
|
||||
max_tokens=vertexai_llm.config.max_tokens,
|
||||
top_p=vertexai_llm.config.top_p,
|
||||
model=vertexai_llm.config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
vertexai_llm.config = test_config
|
||||
with patch.object(
|
||||
VertexAILlm,
|
||||
"_get_answer",
|
||||
return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
|
||||
):
|
||||
response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
|
||||
assert response == "Test Response"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 3.75e-07,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
@patch("embedchain.llm.vertex_ai.ChatVertexAI")
|
||||
def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog):
|
||||
mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response")
|
||||
|
||||
config = vertexai_llm.config
|
||||
prompt = "Test Prompt"
|
||||
messages = vertexai_llm._get_messages(prompt)
|
||||
response = vertexai_llm._get_answer(prompt, config)
|
||||
mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages)
|
||||
|
||||
assert response == "Test Response" # Assertion corrected
|
||||
assert "Config option `top_p` is not supported by this model." not in caplog.text
|
||||
|
||||
|
||||
def test_get_messages(vertexai_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = vertexai_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
Reference in New Issue
Block a user