Rename embedchain to mem0 and open sourcing code for long term memory (#1474)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Taranjeet Singh
2024-07-12 07:51:33 -07:00
committed by GitHub
parent 83e8c97295
commit f842a92e25
665 changed files with 9427 additions and 6592 deletions

View File

@@ -0,0 +1,54 @@
import os
from unittest.mock import patch
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.anthropic import AnthropicLlm
@pytest.fixture
def anthropic_llm():
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
return AnthropicLlm(config)
def test_get_llm_model_answer(anthropic_llm):
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt, anthropic_llm.config)
def test_get_messages(anthropic_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = anthropic_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_get_llm_model_answer_with_token_usage(anthropic_llm):
test_config = BaseLlmConfig(
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
)
anthropic_llm.config = test_config
with patch.object(
AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
) as mock_method:
prompt = "Test Prompt"
response, token_info = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 1.265e-05,
"cost_currency": "USD",
}
mock_method.assert_called_once_with(prompt, anthropic_llm.config)

View File

@@ -0,0 +1,56 @@
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.aws_bedrock import AWSBedrockLlm
@pytest.fixture
def config(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
config = BaseLlmConfig(
model="amazon.titan-text-express-v1",
model_kwargs={
"temperature": 0.5,
"topP": 1,
"maxTokenCount": 1000,
},
)
yield config
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
monkeypatch.delenv("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
llm = AWSBedrockLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
llm = AWSBedrockLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
llm = AWSBedrockLlm(config)
llm.get_llm_model_answer("Test query")
mocked_bedrock_chat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)

View File

@@ -0,0 +1,87 @@
from unittest.mock import MagicMock, patch
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.azure_openai import AzureOpenAILlm
@pytest.fixture
def azure_openai_llm():
config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
model="gpt-3.5-turbo",
max_tokens=50,
system_prompt="System Prompt",
)
return AzureOpenAILlm(config)
def test_get_llm_model_answer(azure_openai_llm):
with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = azure_openai_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
def test_get_answer(azure_openai_llm):
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.invoke.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
deployment_name=azure_openai_llm.config.deployment_name,
openai_api_version="2024-02-01",
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
temperature=azure_openai_llm.config.temperature,
max_tokens=azure_openai_llm.config.max_tokens,
streaming=azure_openai_llm.config.stream,
)
def test_get_messages(azure_openai_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = azure_openai_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_when_no_deployment_name_provided():
config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
with pytest.raises(ValueError):
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")
def test_with_api_version():
config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
model="gpt-3.5-turbo",
max_tokens=50,
system_prompt="System Prompt",
api_version="2024-02-01",
)
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")
mock_chat.assert_called_once_with(
deployment_name="azure_deployment",
openai_api_version="2024-02-01",
model_name="gpt-3.5-turbo",
temperature=0.7,
max_tokens=50,
streaming=False,
)

View File

@@ -0,0 +1,61 @@
from string import Template
import pytest
from embedchain.llm.base import BaseLlm, BaseLlmConfig
@pytest.fixture
def base_llm():
config = BaseLlmConfig()
return BaseLlm(config=config)
def test_is_get_llm_model_answer_not_implemented(base_llm):
with pytest.raises(NotImplementedError):
base_llm.get_llm_model_answer()
def test_is_stream_bool():
with pytest.raises(ValueError):
config = BaseLlmConfig(stream="test value")
BaseLlm(config=config)
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.prompt, Template)
def test_is_get_llm_model_answer_implemented():
class TestLlm(BaseLlm):
def get_llm_model_answer(self):
return "Implemented"
config = BaseLlmConfig()
llm = TestLlm(config=config)
assert llm.get_llm_model_answer() == "Implemented"
def test_stream_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_response(answer))
assert result == answer
def test_append_search_and_context(base_llm):
context = "Context"
web_search_result = "Web Search Result"
result = base_llm._append_search_and_context(context, web_search_result)
expected_result = "Context\nWeb Search Result: Web Search Result"
assert result == expected_result
def test_access_search_and_get_results(base_llm, mocker):
base_llm.access_search_and_get_results = mocker.patch.object(
base_llm, "access_search_and_get_results", return_value="Search Results"
)
input_query = "Test query"
result = base_llm.access_search_and_get_results(input_query)
assert result == "Search Results"

View File

@@ -0,0 +1,120 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.base import BaseLlm
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
class TestApp(unittest.TestCase):
def setUp(self):
os.environ["OPENAI_API_KEY"] = "test_key"
self.app = App(config=AppConfig(collect_metrics=False))
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
memory.
The 'chat' method is called twice. The first call initializes the chat history memory.
The second call is expected to use the chat history from the first call.
Key assumptions tested:
called with correct arguments, adding the correct chat history.
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
- During the second call, the 'chat' method uses the chat history from the first call.
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
with patch.object(BaseLlm, "add_history") as mock_history:
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer", session_id="default")
second_answer = app.chat("Test query 2", session_id="test_session")
self.assertEqual(second_answer, "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer", session_id="test_session")
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_template_replacement(self, mock_get_answer, mock_retrieve):
"""
Tests that if a default template is used and it doesn't contain history,
the default template is swapped in.
Also tests that a dry run does not change the history
"""
with patch.object(ChatHistory, "get") as mock_memory:
mock_message = ChatMessage()
mock_message.add_user_message("Test query 1")
mock_message.add_ai_message("Test answer")
mock_memory.return_value = [mock_message]
config = AppConfig(collect_metrics=False)
app = App(config=config)
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.llm.history), 1)
history = app.llm.history
dry_run = app.chat("Test query 2", dry_run=True)
self.assertIn("Conversation history:", dry_run)
self.assertEqual(history, app.llm.history)
self.assertEqual(len(app.llm.history), 1)
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_params(self):
"""
Test where filter
"""
with patch.object(self.app, "_retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.chat("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_chat_config(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the '_retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.chat("Test query", llm_config)
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
where = kwargs.get("where")
assert "app_id" in where
assert "attribute" in where
mock_answer.assert_called_once()

View File

@@ -0,0 +1,23 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.clarifai import ClarifaiLlm
@pytest.fixture
def clarifai_llm_config(monkeypatch):
monkeypatch.setenv("CLARIFAI_PAT","test_api_key")
config = BaseLlmConfig(
model="https://clarifai.com/openai/chat-completion/models/GPT-4",
model_kwargs={"temperature": 0.7, "max_tokens": 100},
)
yield config
monkeypatch.delenv("CLARIFAI_PAT")
def test_clarifai__llm_get_llm_model_answer(clarifai_llm_config, mocker):
mocker.patch("embedchain.llm.clarifai.ClarifaiLlm._get_answer", return_value="Test answer")
llm = ClarifaiLlm(clarifai_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"

View File

@@ -0,0 +1,73 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.cohere import CohereLlm
@pytest.fixture
def cohere_llm_config():
os.environ["COHERE_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
yield config
os.environ.pop("COHERE_API_KEY")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
CohereLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
llm = CohereLlm(cohere_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(cohere_llm_config, mocker):
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
llm = CohereLlm(cohere_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
test_config = BaseLlmConfig(
temperature=cohere_llm_config.temperature,
max_tokens=cohere_llm_config.max_tokens,
top_p=cohere_llm_config.top_p,
model=cohere_llm_config.model,
token_usage=True,
)
mocker.patch(
"embedchain.llm.cohere.CohereLlm._get_answer",
return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
)
llm = CohereLlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3.5e-06,
"cost_currency": "USD",
}
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
llm = CohereLlm(cohere_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"

View File

@@ -0,0 +1,70 @@
import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
class TestGeneratePrompt(unittest.TestCase):
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
def test_generate_prompt_with_template(self):
"""
Tests that the generate_prompt method correctly formats the prompt using
a custom template provided in the BaseLlmConfig instance.
This test sets up a scenario with an input query and a list of contexts,
and a custom template, and then calls generate_prompt. It checks that the
returned prompt correctly incorporates all the contexts and the query into
the format specified by the template.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = BaseLlmConfig(template=Template(template))
self.app.llm.config = config
# Execute
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = (
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_contexts_list(self):
"""
Tests that the generate_prompt method correctly handles a list of contexts.
This test sets up a scenario with an input query and a list of contexts,
and then calls generate_prompt. It checks that the returned prompt
correctly includes all the contexts and the query.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = BaseLlmConfig()
# Execute
self.app.llm.config = config
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
"""
config = BaseLlmConfig()
config.prompt = Template("Context: $context | Query: $query | History: $history")
self.app.llm.config = config
self.app.llm.set_history(["Past context 1", "Past context 2"])
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
expected_prompt = "Context: Test context | Query: Test query | History: Past context 1\nPast context 2"
self.assertEqual(prompt, expected_prompt)

View File

@@ -0,0 +1,43 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.google import GoogleLlm
@pytest.fixture
def google_llm_config():
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
def test_google_llm_init_missing_api_key(monkeypatch):
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
GoogleLlm()
def test_google_llm_init(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
with monkeypatch.context() as m:
m.setattr("importlib.import_module", lambda x: None)
google_llm = GoogleLlm()
assert google_llm is not None
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
monkeypatch.setattr("importlib.import_module", lambda x: None)
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
google_llm.get_llm_model_answer("test prompt")
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
def mock_get_answer(prompt, config):
return "Generated Text"
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
google_llm = GoogleLlm(config=google_llm_config)
result = google_llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"

View File

@@ -0,0 +1,60 @@
import pytest
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
from embedchain.config import BaseLlmConfig
from embedchain.llm.gpt4all import GPT4ALLLlm
@pytest.fixture
def config():
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="orca-mini-3b-gguf2-q4_0.gguf",
)
yield config
@pytest.fixture
def gpt4all_with_config(config):
return GPT4ALLLlm(config=config)
@pytest.fixture
def gpt4all_without_config():
return GPT4ALLLlm()
def test_gpt4all_init_with_config(config, gpt4all_with_config):
assert gpt4all_with_config.config.temperature == config.temperature
assert gpt4all_with_config.config.max_tokens == config.max_tokens
assert gpt4all_with_config.config.top_p == config.top_p
assert gpt4all_with_config.config.stream == config.stream
assert gpt4all_with_config.config.system_prompt == config.system_prompt
assert gpt4all_with_config.config.model == config.model
assert isinstance(gpt4all_with_config.instance, LangchainGPT4All)
def test_gpt4all_init_without_config(gpt4all_without_config):
assert gpt4all_without_config.config.model == "orca-mini-3b-gguf2-q4_0.gguf"
assert isinstance(gpt4all_without_config.instance, LangchainGPT4All)
def test_get_llm_model_answer(mocker, gpt4all_with_config):
test_query = "Test query"
test_answer = "Test answer"
mocked_get_answer = mocker.patch("embedchain.llm.gpt4all.GPT4ALLLlm._get_answer", return_value=test_answer)
answer = gpt4all_with_config.get_llm_model_answer(test_query)
assert answer == test_answer
mocked_get_answer.assert_called_once_with(prompt=test_query, config=gpt4all_with_config.config)
def test_gpt4all_model_switching(gpt4all_with_config):
with pytest.raises(RuntimeError, match="GPT4ALLLlm does not support switching models at runtime."):
gpt4all_with_config._get_answer("Test prompt", BaseLlmConfig(model="new_model"))

View File

@@ -0,0 +1,83 @@
import importlib
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm
@pytest.fixture
def huggingface_llm_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
@pytest.fixture
def huggingface_endpoint_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
config = BaseLlmConfig(endpoint="https://api-inference.huggingface.co/models/gpt2", model_kwargs={"device": "cpu"})
yield config
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
HuggingFaceLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
llm = HuggingFaceLlm(huggingface_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
def test_top_p_value_within_range():
config = BaseLlmConfig(top_p=1.0)
with pytest.raises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config)
def test_dependency_is_imported():
importlib_installed = True
try:
importlib.import_module("huggingface_hub")
except ImportError:
importlib_installed = False
assert importlib_installed
def test_get_llm_model_answer(huggingface_llm_config, mocker):
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_hugging_face_mock(huggingface_llm_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mock_hf_hub = mocker.patch("embedchain.llm.huggingface.HuggingFaceHub")
mock_hf_hub.return_value.invoke = mock_llm_instance
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mock_llm_instance.assert_called_once_with("Test query")
def test_custom_endpoint(huggingface_endpoint_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mock_hf_endpoint = mocker.patch("embedchain.llm.huggingface.HuggingFaceEndpoint")
mock_hf_endpoint.return_value.invoke = mock_llm_instance
llm = HuggingFaceLlm(huggingface_endpoint_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mock_llm_instance.assert_called_once_with("Test query")

View File

@@ -0,0 +1,79 @@
import os
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm
@pytest.fixture
def config():
os.environ["JINACHAT_API_KEY"] = "test_api_key"
config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
yield config
os.environ.pop("JINACHAT_API_KEY")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
JinaLlm()
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
temperature=config.temperature,
max_tokens=config.max_tokens,
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
model_kwargs={"top_p": config.top_p},
)

View File

@@ -0,0 +1,40 @@
import os
import pytest
from embedchain.llm.llama2 import Llama2Llm
@pytest.fixture
def llama2_llm():
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
llm = Llama2Llm()
return llm
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
Llama2Llm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
llama2_llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llama2_llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(llama2_llm, mocker):
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
mocked_replicate_instance = mocker.MagicMock()
mocked_replicate.return_value = mocked_replicate_instance
mocked_replicate_instance.invoke.return_value = "Test answer"
llama2_llm.config.model = "test_model"
llama2_llm.config.max_tokens = 50
llama2_llm.config.temperature = 0.7
llama2_llm.config.top_p = 0.8
answer = llama2_llm.get_llm_model_answer("Test query")
assert answer == "Test answer"

View File

@@ -0,0 +1,87 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.mistralai import MistralAILlm
@pytest.fixture
def mistralai_llm_config(monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
def test_mistralai_llm_init_missing_api_key(monkeypatch):
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
MistralAILlm()
def test_mistralai_llm_init(monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
llm = MistralAILlm()
assert llm is not None
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
def mock_get_answer(self, prompt, config):
return "Generated Text"
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = "Test system prompt"
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("")
assert result == "Generated Text"
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = None
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
test_config = BaseLlmConfig(
temperature=mistralai_llm_config.temperature,
max_tokens=mistralai_llm_config.max_tokens,
top_p=mistralai_llm_config.top_p,
model=mistralai_llm_config.model,
token_usage=True,
)
monkeypatch.setattr(
MistralAILlm,
"_get_answer",
lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = MistralAILlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Generated Text"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 7.5e-07,
"cost_currency": "USD",
}

View File

@@ -0,0 +1,52 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.ollama import OllamaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
def ollama_llm_config():
config = BaseLlmConfig(model="llama2", temperature=0.7, top_p=0.8, stream=True, system_prompt=None)
yield config
def test_get_llm_model_answer(ollama_llm_config, mocker):
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
llm = OllamaLlm(ollama_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_answer_mocked_ollama(ollama_llm_config, mocker):
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama")
mock_instance = mocked_ollama.return_value
mock_instance.invoke.return_value = "Mocked answer"
llm = OllamaLlm(ollama_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"
def test_get_llm_model_answer_with_streaming(ollama_llm_config, mocker):
ollama_llm_config.stream = True
ollama_llm_config.callbacks = [StreamingStdOutCallbackHandler()]
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
mocked_ollama_chat = mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
llm = OllamaLlm(ollama_llm_config)
llm.get_llm_model_answer("Test query")
mocked_ollama_chat.assert_called_once()
call_args = mocked_ollama_chat.call_args
config_arg = call_args[1]["config"]
callbacks = config_arg.callbacks
assert len(callbacks) == 1
assert isinstance(callbacks[0], StreamingStdOutCallbackHandler)

View File

@@ -0,0 +1,261 @@
import os
import httpx
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
@pytest.fixture()
def env_config():
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
yield
os.environ.pop("OPENAI_API_KEY")
@pytest.fixture
def config(env_config):
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_client_proxies=None,
http_async_client_proxies=None,
)
yield config
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_token_usage(config, mocker):
test_config = BaseLlmConfig(
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stream=config.stream,
system_prompt=config.system_prompt,
model=config.model,
token_usage=True,
)
mocked_get_answer = mocker.patch(
"embedchain.llm.openai.OpenAILlm._get_answer",
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = OpenAILlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 5.5e-06,
"cost_currency": "USD",
}
mocked_get_answer.assert_called_once_with("Test query", test_config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=None,
)
def test_get_llm_model_answer_with_special_headers(config, mocker):
config.default_headers = {"test": "test"}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
default_headers={"test": "test"},
http_client=None,
http_async_client=None,
)
def test_get_llm_model_answer_with_model_kwargs(config, mocker):
config.model_kwargs = {"response_format": {"type": "json_object"}}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=None,
)
@pytest.mark.parametrize(
"mock_return, expected",
[
([{"test": "test"}], '{"test": "test"}'),
([], "Input could not be mapped to the function!"),
],
)
def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
llm = OpenAILlm(config, tools={"test": "test"})
answer = llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=None,
)
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
mocked_json_output_tools_parser.assert_called_once()
assert answer == expected
def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
mock_http_client = mocker.Mock(spec=httpx.Client)
mock_http_client_instance = mocker.Mock(spec=httpx.Client)
mock_http_client.return_value = mock_http_client_instance
mocker.patch("httpx.Client", new=mock_http_client)
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_client_proxies="http://testproxy.mem0.net:8000",
)
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=mock_http_client_instance,
http_async_client=None,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
mock_http_async_client.return_value = mock_http_async_client_instance
mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
)
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=mock_http_async_client_instance,
)
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})

View File

@@ -0,0 +1,79 @@
import os
from unittest.mock import MagicMock, patch
import pytest
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
app = App(config=AppConfig(collect_metrics=False))
return app
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(app):
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = app.query(input_query="Test query")
assert answer == "Test answer"
mock_retrieve.assert_called_once()
_, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
assert input_query_arg == "Test query"
mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
llm = OpenAILlm(config=chat_config)
app = App(config=config, llm=llm)
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(app):
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = app.query("Test query", where={"attribute": "value"})
assert answer == "Test answer"
_, kwargs = mock_retrieve.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(app):
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = app.query("Test query", llm_config)
assert answer == "Test answer"
_, kwargs = mock_database_query.call_args
assert kwargs.get("input_query") == "Test query"
where = kwargs.get("where")
assert "app_id" in where
assert "attribute" in where
mock_answer.assert_called_once()

View File

@@ -0,0 +1,74 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.together import TogetherLlm
@pytest.fixture
def together_llm_config():
os.environ["TOGETHER_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("TOGETHER_API_KEY")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
TogetherLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(together_llm_config):
llm = TogetherLlm(together_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(together_llm_config, mocker):
mocker.patch("embedchain.llm.together.TogetherLlm._get_answer", return_value="Test answer")
llm = TogetherLlm(together_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
test_config = BaseLlmConfig(
temperature=together_llm_config.temperature,
max_tokens=together_llm_config.max_tokens,
top_p=together_llm_config.top_p,
model=together_llm_config.model,
token_usage=True,
)
mocker.patch(
"embedchain.llm.together.TogetherLlm._get_answer",
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = TogetherLlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3e-07,
"cost_currency": "USD",
}
def test_get_answer_mocked_together(together_llm_config, mocker):
mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
mock_instance = mocked_together.return_value
mock_instance.invoke.return_value.content = "Mocked answer"
llm = TogetherLlm(together_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"

View File

@@ -0,0 +1,76 @@
from unittest.mock import MagicMock, patch
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.core.db.database import database_manager
from embedchain.llm.vertex_ai import VertexAILlm
@pytest.fixture(autouse=True)
def setup_database():
database_manager.setup_engine()
@pytest.fixture
def vertexai_llm():
config = BaseLlmConfig(temperature=0.6, model="chat-bison")
return VertexAILlm(config)
def test_get_llm_model_answer(vertexai_llm):
with patch.object(VertexAILlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = vertexai_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt, vertexai_llm.config)
def test_get_llm_model_answer_with_token_usage(vertexai_llm):
test_config = BaseLlmConfig(
temperature=vertexai_llm.config.temperature,
max_tokens=vertexai_llm.config.max_tokens,
top_p=vertexai_llm.config.top_p,
model=vertexai_llm.config.model,
token_usage=True,
)
vertexai_llm.config = test_config
with patch.object(
VertexAILlm,
"_get_answer",
return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
):
response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
assert response == "Test Response"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3.75e-07,
"cost_currency": "USD",
}
@patch("embedchain.llm.vertex_ai.ChatVertexAI")
def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog):
mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response")
config = vertexai_llm.config
prompt = "Test Prompt"
messages = vertexai_llm._get_messages(prompt)
response = vertexai_llm._get_answer(prompt, config)
mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages)
assert response == "Test Response" # Assertion corrected
assert "Config option `top_p` is not supported by this model." not in caplog.text
def test_get_messages(vertexai_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = vertexai_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]