diff --git a/tests/llm/test_vertex_ai.py b/tests/llm/test_vertex_ai.py index 143b4d07..7b799a96 100644 --- a/tests/llm/test_vertex_ai.py +++ b/tests/llm/test_vertex_ai.py @@ -2,12 +2,17 @@ from unittest.mock import MagicMock, patch import pytest from langchain.schema import HumanMessage, SystemMessage -from langchain_google_vertexai import ChatVertexAI from embedchain.config import BaseLlmConfig +from embedchain.core.db.database import database_manager from embedchain.llm.vertex_ai import VertexAILlm +@pytest.fixture(autouse=True) +def setup_database(): + database_manager.setup_engine() + + @pytest.fixture def vertexai_llm(): config = BaseLlmConfig(temperature=0.6, model="chat-bison") @@ -22,19 +27,18 @@ def test_get_llm_model_answer(vertexai_llm): mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config) -@pytest.mark.skip( - reason="Requires mocking of Google Console Auth. Revisit later since don't want to block users right now." -) -def test_get_answer(vertexai_llm, caplog): - with patch.object(ChatVertexAI, "invoke", return_value=MagicMock(content="Test Response")) as mock_method: - config = vertexai_llm.config - prompt = "Test Prompt" - messages = vertexai_llm._get_messages(prompt) - response = vertexai_llm._get_answer(prompt, config) - mock_method.assert_called_once_with(messages) +@patch("embedchain.llm.vertex_ai.ChatVertexAI") +def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog): + mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response") - assert response == "Test Response" # Assertion corrected - assert "Config option `top_p` is not supported by this model." not in caplog.text + config = vertexai_llm.config + prompt = "Test Prompt" + messages = vertexai_llm._get_messages(prompt) + response = vertexai_llm._get_answer(prompt, config) + mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages) + + assert response == "Test Response" # Assertion corrected + assert "Config option `top_p` is not supported by this model." not in caplog.text def test_get_messages(vertexai_llm): diff --git a/tests/telemetry/test_posthog.py b/tests/telemetry/test_posthog.py index 5ef127b2..8efd150e 100644 --- a/tests/telemetry/test_posthog.py +++ b/tests/telemetry/test_posthog.py @@ -1,8 +1,6 @@ import logging import os -import pytest - from embedchain.telemetry.posthog import AnonymousTelemetry @@ -54,7 +52,6 @@ class TestAnonymousTelemetry: properties, ) - @pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work") def test_capture_with_exception(self, mocker, caplog): os.environ["EC_TELEMETRY"] = "true" mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog") @@ -65,3 +62,4 @@ class TestAnonymousTelemetry: with caplog.at_level(logging.ERROR): telemetry.capture(event_name, properties) assert "Failed to send telemetry event" in caplog.text + caplog.clear() diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index e827a22d..1e2659e3 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -34,15 +34,16 @@ def cleanup_db(): print("Error: %s - %s." % (e.filename, e.strerror)) -@pytest.mark.skip(reason="ChromaDB client needs to be mocked") -def test_chroma_db_init_with_host_and_port(chroma_db): - settings = chroma_db.client.get_settings() - assert settings.chroma_server_host == "test-host" - assert settings.chroma_server_http_port == "1234" +@patch("embedchain.vectordb.chroma.chromadb.Client") +def test_chroma_db_init_with_host_and_port(mock_client): + chroma_db = ChromaDB(config=ChromaDbConfig(host="test-host", port="1234")) # noqa + called_settings: Settings = mock_client.call_args[0][0] + assert called_settings.chroma_server_host == "test-host" + assert called_settings.chroma_server_http_port == "1234" -@pytest.mark.skip(reason="ChromaDB client needs to be mocked") -def test_chroma_db_init_with_basic_auth(): +@patch("embedchain.vectordb.chroma.chromadb.Client") +def test_chroma_db_init_with_basic_auth(mock_client): chroma_config = { "host": "test-host", "port": "1234", @@ -52,12 +53,17 @@ def test_chroma_db_init_with_basic_auth(): }, } - db = ChromaDB(config=ChromaDbConfig(**chroma_config)) - settings = db.client.get_settings() - assert settings.chroma_server_host == "test-host" - assert settings.chroma_server_http_port == "1234" - assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"] - assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"] + ChromaDB(config=ChromaDbConfig(**chroma_config)) + called_settings: Settings = mock_client.call_args[0][0] + assert called_settings.chroma_server_host == "test-host" + assert called_settings.chroma_server_http_port == "1234" + assert ( + called_settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"] + ) + assert ( + called_settings.chroma_client_auth_credentials + == chroma_config["chroma_settings"]["chroma_client_auth_credentials"] + ) @patch("embedchain.vectordb.chroma.chromadb.Client") @@ -84,7 +90,6 @@ def test_app_init_with_host_and_port_none(mock_client): assert called_settings.chroma_server_http_port is None -@pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work") def test_chroma_db_duplicates_throw_warning(caplog): db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db")) app = App(config=AppConfig(collect_metrics=False), db=db) diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py index 288874e1..47326952 100644 --- a/tests/vectordb/test_qdrant.py +++ b/tests/vectordb/test_qdrant.py @@ -1,7 +1,6 @@ import unittest import uuid -import pytest from mock import patch from qdrant_client.http import models from qdrant_client.http.models import Batch @@ -61,7 +60,6 @@ class TestQdrantDB(unittest.TestCase): resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"}) self.assertEqual(resp2, {"ids": [], "metadatas": []}) - @pytest.mark.skip(reason="Investigate the issue with the test case.") @patch("embedchain.vectordb.qdrant.QdrantClient") @patch.object(uuid, "uuid4", side_effect=TEST_UUIDS) def test_add(self, uuid_mock, qdrant_client_mock): @@ -84,7 +82,7 @@ class TestQdrantDB(unittest.TestCase): qdrant_client_mock.return_value.upsert.assert_called_once_with( collection_name="embedchain-store-1536", points=Batch( - ids=["abc", "def"], + ids=["123", "456"], payloads=[ { "identifier": "123",