Fix skipped tests (#1385)
This commit is contained in:
@@ -2,12 +2,17 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain.schema import HumanMessage, SystemMessage
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
from langchain_google_vertexai import ChatVertexAI
|
|
||||||
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
|
from embedchain.core.db.database import database_manager
|
||||||
from embedchain.llm.vertex_ai import VertexAILlm
|
from embedchain.llm.vertex_ai import VertexAILlm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_database():
|
||||||
|
database_manager.setup_engine()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vertexai_llm():
|
def vertexai_llm():
|
||||||
config = BaseLlmConfig(temperature=0.6, model="chat-bison")
|
config = BaseLlmConfig(temperature=0.6, model="chat-bison")
|
||||||
@@ -22,19 +27,18 @@ def test_get_llm_model_answer(vertexai_llm):
|
|||||||
mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
|
mock_method.assert_called_once_with(prompt=prompt, config=vertexai_llm.config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@patch("embedchain.llm.vertex_ai.ChatVertexAI")
|
||||||
reason="Requires mocking of Google Console Auth. Revisit later since don't want to block users right now."
|
def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog):
|
||||||
)
|
mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response")
|
||||||
def test_get_answer(vertexai_llm, caplog):
|
|
||||||
with patch.object(ChatVertexAI, "invoke", return_value=MagicMock(content="Test Response")) as mock_method:
|
|
||||||
config = vertexai_llm.config
|
|
||||||
prompt = "Test Prompt"
|
|
||||||
messages = vertexai_llm._get_messages(prompt)
|
|
||||||
response = vertexai_llm._get_answer(prompt, config)
|
|
||||||
mock_method.assert_called_once_with(messages)
|
|
||||||
|
|
||||||
assert response == "Test Response" # Assertion corrected
|
config = vertexai_llm.config
|
||||||
assert "Config option `top_p` is not supported by this model." not in caplog.text
|
prompt = "Test Prompt"
|
||||||
|
messages = vertexai_llm._get_messages(prompt)
|
||||||
|
response = vertexai_llm._get_answer(prompt, config)
|
||||||
|
mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages)
|
||||||
|
|
||||||
|
assert response == "Test Response" # Assertion corrected
|
||||||
|
assert "Config option `top_p` is not supported by this model." not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def test_get_messages(vertexai_llm):
|
def test_get_messages(vertexai_llm):
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||||
|
|
||||||
|
|
||||||
@@ -54,7 +52,6 @@ class TestAnonymousTelemetry:
|
|||||||
properties,
|
properties,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work")
|
|
||||||
def test_capture_with_exception(self, mocker, caplog):
|
def test_capture_with_exception(self, mocker, caplog):
|
||||||
os.environ["EC_TELEMETRY"] = "true"
|
os.environ["EC_TELEMETRY"] = "true"
|
||||||
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
|
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
|
||||||
@@ -65,3 +62,4 @@ class TestAnonymousTelemetry:
|
|||||||
with caplog.at_level(logging.ERROR):
|
with caplog.at_level(logging.ERROR):
|
||||||
telemetry.capture(event_name, properties)
|
telemetry.capture(event_name, properties)
|
||||||
assert "Failed to send telemetry event" in caplog.text
|
assert "Failed to send telemetry event" in caplog.text
|
||||||
|
caplog.clear()
|
||||||
|
|||||||
@@ -34,15 +34,16 @@ def cleanup_db():
|
|||||||
print("Error: %s - %s." % (e.filename, e.strerror))
|
print("Error: %s - %s." % (e.filename, e.strerror))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="ChromaDB client needs to be mocked")
|
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||||
def test_chroma_db_init_with_host_and_port(chroma_db):
|
def test_chroma_db_init_with_host_and_port(mock_client):
|
||||||
settings = chroma_db.client.get_settings()
|
chroma_db = ChromaDB(config=ChromaDbConfig(host="test-host", port="1234")) # noqa
|
||||||
assert settings.chroma_server_host == "test-host"
|
called_settings: Settings = mock_client.call_args[0][0]
|
||||||
assert settings.chroma_server_http_port == "1234"
|
assert called_settings.chroma_server_host == "test-host"
|
||||||
|
assert called_settings.chroma_server_http_port == "1234"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="ChromaDB client needs to be mocked")
|
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||||
def test_chroma_db_init_with_basic_auth():
|
def test_chroma_db_init_with_basic_auth(mock_client):
|
||||||
chroma_config = {
|
chroma_config = {
|
||||||
"host": "test-host",
|
"host": "test-host",
|
||||||
"port": "1234",
|
"port": "1234",
|
||||||
@@ -52,12 +53,17 @@ def test_chroma_db_init_with_basic_auth():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
db = ChromaDB(config=ChromaDbConfig(**chroma_config))
|
ChromaDB(config=ChromaDbConfig(**chroma_config))
|
||||||
settings = db.client.get_settings()
|
called_settings: Settings = mock_client.call_args[0][0]
|
||||||
assert settings.chroma_server_host == "test-host"
|
assert called_settings.chroma_server_host == "test-host"
|
||||||
assert settings.chroma_server_http_port == "1234"
|
assert called_settings.chroma_server_http_port == "1234"
|
||||||
assert settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
|
assert (
|
||||||
assert settings.chroma_client_auth_credentials == chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
|
called_settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
called_settings.chroma_client_auth_credentials
|
||||||
|
== chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||||
@@ -84,7 +90,6 @@ def test_app_init_with_host_and_port_none(mock_client):
|
|||||||
assert called_settings.chroma_server_http_port is None
|
assert called_settings.chroma_server_http_port is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Logging setup needs to be fixed to make this test to work")
|
|
||||||
def test_chroma_db_duplicates_throw_warning(caplog):
|
def test_chroma_db_duplicates_throw_warning(caplog):
|
||||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
|
||||||
from mock import patch
|
from mock import patch
|
||||||
from qdrant_client.http import models
|
from qdrant_client.http import models
|
||||||
from qdrant_client.http.models import Batch
|
from qdrant_client.http.models import Batch
|
||||||
@@ -61,7 +60,6 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
|
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
|
||||||
self.assertEqual(resp2, {"ids": [], "metadatas": []})
|
self.assertEqual(resp2, {"ids": [], "metadatas": []})
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Investigate the issue with the test case.")
|
|
||||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||||
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
|
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
|
||||||
def test_add(self, uuid_mock, qdrant_client_mock):
|
def test_add(self, uuid_mock, qdrant_client_mock):
|
||||||
@@ -84,7 +82,7 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||||
collection_name="embedchain-store-1536",
|
collection_name="embedchain-store-1536",
|
||||||
points=Batch(
|
points=Batch(
|
||||||
ids=["abc", "def"],
|
ids=["123", "456"],
|
||||||
payloads=[
|
payloads=[
|
||||||
{
|
{
|
||||||
"identifier": "123",
|
"identifier": "123",
|
||||||
|
|||||||
Reference in New Issue
Block a user