Improve tests (#795)

This commit is contained in:
Sidharth Mohanty
2023-10-13 01:45:22 +05:30
committed by GitHub
parent b5de605e2b
commit 4820ea15d6
7 changed files with 373 additions and 19 deletions

View File

@@ -1,11 +1,57 @@
import unittest
import pytest
from unittest.mock import MagicMock
from embedchain.embedder.base import BaseEmbedder
from embedchain.config.embedder.base import BaseEmbedderConfig
from chromadb.api.types import Documents, Embeddings
class TestEmbedder(unittest.TestCase):
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
with self.assertRaises(TypeError):
embedder.set_vector_dimension(None)
@pytest.fixture
def base_embedder():
return BaseEmbedder()
def test_initialization(base_embedder):
assert isinstance(base_embedder.config, BaseEmbedderConfig)
# not initialized
assert not hasattr(base_embedder, "embedding_fn")
assert not hasattr(base_embedder, "vector_dimension")
def test_set_embedding_fn(base_embedder):
def embedding_function(texts: Documents) -> Embeddings:
return [f"Embedding for {text}" for text in texts]
base_embedder.set_embedding_fn(embedding_function)
assert hasattr(base_embedder, "embedding_fn")
assert callable(base_embedder.embedding_fn)
embeddings = base_embedder.embedding_fn(["text1", "text2"])
assert embeddings == ["Embedding for text1", "Embedding for text2"]
def test_set_embedding_fn_when_not_a_function(base_embedder):
with pytest.raises(ValueError):
base_embedder.set_embedding_fn(None)
def test_set_vector_dimension(base_embedder):
base_embedder.set_vector_dimension(256)
assert hasattr(base_embedder, "vector_dimension")
assert base_embedder.vector_dimension == 256
def test_set_vector_dimension_type_error(base_embedder):
with pytest.raises(TypeError):
base_embedder.set_vector_dimension(None)
def test_langchain_default_concept():
embeddings = MagicMock()
embeddings.embed_documents.return_value = ["Embedding1", "Embedding2"]
embed_function = BaseEmbedder._langchain_default_concept(embeddings)
result = embed_function(["text1", "text2"])
assert result == ["Embedding1", "Embedding2"]
def test_embedder_with_config():
embedder = BaseEmbedder(BaseEmbedderConfig())
assert isinstance(embedder.config, BaseEmbedderConfig)