Rename embedchain to mem0 and open sourcing code for long term memory (#1474)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Taranjeet Singh
2024-07-12 07:51:33 -07:00
committed by GitHub
parent 83e8c97295
commit f842a92e25
665 changed files with 9427 additions and 6592 deletions

View File

View File

@@ -0,0 +1,99 @@
import hashlib
from unittest.mock import MagicMock
import pytest
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.models.data_type import DataType
@pytest.fixture
def text_splitter_mock():
return MagicMock()
@pytest.fixture
def loader_mock():
return MagicMock()
@pytest.fixture
def app_id():
return "test_app"
@pytest.fixture
def data_type():
return DataType.TEXT
@pytest.fixture
def chunker(text_splitter_mock, data_type):
text_splitter = text_splitter_mock
chunker = BaseChunker(text_splitter)
chunker.set_data_type(data_type)
return chunker
def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
loader_mock.load_data.return_value = {
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
"doc_id": "DocID",
}
config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
assert result["documents"] == ["long chunk"]
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
loader_mock.load_data.return_value = {
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
"doc_id": "DocID",
}
result = chunker.create_chunks(loader_mock, "test_src", app_id)
expected_ids = [
f"{app_id}--" + hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
f"{app_id}--" + hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
]
assert result["documents"] == ["Chunk 1", "Chunk 2"]
assert result["ids"] == expected_ids
assert result["metadatas"] == [
{
"url": "URL 1",
"data_type": data_type.value,
"doc_id": f"{app_id}--DocID",
},
{
"url": "URL 1",
"data_type": data_type.value,
"doc_id": f"{app_id}--DocID",
},
]
assert result["doc_id"] == f"{app_id}--DocID"
def test_get_chunks(chunker, text_splitter_mock):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
content = "This is a test content."
result = chunker.get_chunks(content)
assert len(result) == 2
assert result == ["Chunk 1", "Chunk 2"]
def test_set_data_type(chunker):
chunker.set_data_type(DataType.MDX)
assert chunker.data_type == DataType.MDX
def test_get_word_count(chunker):
documents = ["This is a test.", "Another test."]
result = chunker.get_word_count(documents)
assert result == 6

View File

@@ -0,0 +1,66 @@
from embedchain.chunkers.audio import AudioChunker
from embedchain.chunkers.common_chunker import CommonChunker
from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.excel_file import ExcelFileChunker
from embedchain.chunkers.gmail import GmailChunker
from embedchain.chunkers.google_drive import GoogleDriveChunker
from embedchain.chunkers.json import JSONChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.openapi import OpenAPIChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.postgres import PostgresChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.sitemap import SitemapChunker
from embedchain.chunkers.slack import SlackChunker
from embedchain.chunkers.table import TableChunker
from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config.add_config import ChunkerConfig
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
chunker_common_config = {
DocsSiteChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
DocxFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PdfFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
TextChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
MdxChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
NotionChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
QnaPairChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
TableChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
SitemapChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
WebPageChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
AudioChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
}
def test_default_config_values():
for chunker_class, config in chunker_common_config.items():
chunker = chunker_class()
assert chunker.text_splitter._chunk_size == config["chunk_size"]
assert chunker.text_splitter._chunk_overlap == config["chunk_overlap"]
assert chunker.text_splitter._length_function == config["length_function"]
def test_custom_config_values():
for chunker_class, _ in chunker_common_config.items():
chunker = chunker_class(config=chunker_config)
assert chunker.text_splitter._chunk_size == 500
assert chunker.text_splitter._chunk_overlap == 0
assert chunker.text_splitter._length_function == len

View File

@@ -0,0 +1,86 @@
# ruff: noqa: E501
from embedchain.chunkers.text import TextChunker
from embedchain.config import ChunkerConfig
from embedchain.models.data_type import DataType
class TestTextChunker:
def test_chunks_without_app_id(self):
"""
Test the chunks generated by TextChunker.
"""
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"]
assert len(documents) > 5
def test_chunks_with_app_id(self):
"""
Test the chunks generated by TextChunker with app_id
"""
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"]
assert len(documents) > 5
def test_big_chunksize(self):
"""
Test that if an infinitely high chunk size is used, only one chunk is returned.
"""
chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"]
assert len(documents) == 1
def test_small_chunksize(self):
"""
Test that if a chunk size of one is used, every character is a chunk.
"""
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config)
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
# Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"]
assert len(documents) == len(text)
def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config)
chunker.set_data_type(DataType.TEXT)
document = ["ab cd", "ef gh"]
result = chunker.get_word_count(document)
assert result == 4
class MockLoader:
@staticmethod
def load_data(src) -> dict:
"""
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return {
"doc_id": "123",
"data": [
{
"content": src,
"meta_data": {"url": "none"},
}
],
}

View File

@@ -0,0 +1,35 @@
import os
import pytest
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import sessionmaker
@pytest.fixture(autouse=True)
def clean_db():
db_path = os.path.expanduser("~/.embedchain/embedchain.db")
db_url = f"sqlite:///{db_path}"
engine = create_engine(db_url)
metadata = MetaData()
metadata.reflect(bind=engine) # Reflect schema from the engine
Session = sessionmaker(bind=engine)
session = Session()
try:
# Iterate over all tables in reversed order to respect foreign keys
for table in reversed(metadata.sorted_tables):
if table.name != "alembic_version": # Skip the Alembic version table
session.execute(table.delete())
session.commit()
except Exception as e:
session.rollback()
print(f"Error cleaning database: {e}")
finally:
session.close()
@pytest.fixture(autouse=True)
def disable_telemetry():
os.environ["EC_TELEMETRY"] = "false"
yield
del os.environ["EC_TELEMETRY"]

View File

@@ -0,0 +1,52 @@
import os
import pytest
from embedchain import App
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType
os.environ["OPENAI_API_KEY"] = "test_key"
@pytest.fixture
def app(mocker):
mocker.patch("chromadb.api.models.Collection.Collection.add")
return App(config=AppConfig(collect_metrics=False))
def test_add(app):
app.add("https://example.com", metadata={"foo": "bar"})
assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
# def test_add_sitemap(app):
# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
def test_add_forced_type(app):
data_type = "text"
app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
def test_dry_run(app):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
chunks = result["chunks"]
metadata = result["metadata"]
count = result["count"]
data_type = result["type"]
assert len(chunks) == len(text)
assert count == len(text)
assert data_type == DataType.TEXT
for item in metadata:
assert isinstance(item, dict)
assert "local" in item["url"]
assert "text" in item["data_type"]

View File

@@ -0,0 +1,75 @@
import os
import pytest
from chromadb.api.models.Collection import Collection
from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.llm.base import BaseLlm
from embedchain.memory.base import ChatHistory
from embedchain.vectordb.chroma import ChromaDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
@pytest.fixture
def app_instance():
config = AppConfig(log_level="DEBUG", collect_metrics=False)
return App(config=config)
def test_whole_app(app_instance, mocker):
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
mocker.patch.object(EmbedChain, "add")
mocker.patch.object(EmbedChain, "_retrieve_from_database")
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
mocker.patch.object(BaseLlm, "generate_prompt")
mocker.patch.object(BaseLlm, "add_history")
mocker.patch.object(ChatHistory, "delete", autospec=True)
app_instance.add(knowledge, data_type="text")
app_instance.query("What text did I give you?")
app_instance.chat("What text did I give you?")
assert BaseLlm.generate_prompt.call_count == 2
app_instance.reset()
def test_add_after_reset(app_instance, mocker):
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
config = AppConfig(log_level="DEBUG", collect_metrics=False)
chroma_config = ChromaDbConfig(allow_reset=True)
db = ChromaDB(config=chroma_config)
app_instance = App(config=config, db=db)
# mock delete chat history
mocker.patch.object(ChatHistory, "delete", autospec=True)
app_instance.reset()
app_instance.db.client.heartbeat()
mocker.patch.object(Collection, "add")
app_instance.db.collection.add(
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
metadatas=[
{"chapter": "3", "verse": "16"},
{"chapter": "3", "verse": "5"},
{"chapter": "29", "verse": "11"},
],
ids=["id1", "id2", "id3"],
)
app_instance.reset()
def test_add_with_incorrect_content(app_instance, mocker):
content = [{"foo": "bar"}]
with pytest.raises(TypeError):
app_instance.add(content, data_type="json")

View File

@@ -0,0 +1,133 @@
import tempfile
import unittest
from unittest.mock import patch
from embedchain.models.data_type import DataType
from embedchain.utils.misc import detect_datatype
class TestApp(unittest.TestCase):
"""Test that the datatype detection is working, based on the input."""
def test_detect_datatype_youtube(self):
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(detect_datatype("https://m.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(
detect_datatype("https://www.youtube-nocookie.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO
)
self.assertEqual(detect_datatype("https://vid.plus/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(detect_datatype("https://youtu.be/dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
def test_detect_datatype_local_file(self):
self.assertEqual(detect_datatype("file:///home/user/file.txt"), DataType.WEB_PAGE)
def test_detect_datatype_pdf(self):
self.assertEqual(detect_datatype("https://www.example.com/document.pdf"), DataType.PDF_FILE)
def test_detect_datatype_local_pdf(self):
self.assertEqual(detect_datatype("file:///home/user/document.pdf"), DataType.PDF_FILE)
def test_detect_datatype_xml(self):
self.assertEqual(detect_datatype("https://www.example.com/sitemap.xml"), DataType.SITEMAP)
def test_detect_datatype_local_xml(self):
self.assertEqual(detect_datatype("file:///home/user/sitemap.xml"), DataType.SITEMAP)
def test_detect_datatype_docx(self):
self.assertEqual(detect_datatype("https://www.example.com/document.docx"), DataType.DOCX)
def test_detect_datatype_local_docx(self):
self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
def test_detect_data_type_json(self):
self.assertEqual(detect_datatype("https://www.example.com/data.json"), DataType.JSON)
def test_detect_data_type_local_json(self):
self.assertEqual(detect_datatype("file:///home/user/data.json"), DataType.JSON)
@patch("os.path.isfile")
def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
mock_isfile.return_value = True
self.assertEqual(detect_datatype(tmp.name), DataType.DOCX)
def test_detect_datatype_docs_site(self):
self.assertEqual(detect_datatype("https://docs.example.com"), DataType.DOCS_SITE)
def test_detect_datatype_docs_sitein_path(self):
self.assertEqual(detect_datatype("https://www.example.com/docs/index.html"), DataType.DOCS_SITE)
self.assertNotEqual(detect_datatype("file:///var/www/docs/index.html"), DataType.DOCS_SITE) # NOT equal
def test_detect_datatype_web_page(self):
self.assertEqual(detect_datatype("https://nav.al/agi"), DataType.WEB_PAGE)
def test_detect_datatype_invalid_url(self):
self.assertEqual(detect_datatype("not a url"), DataType.TEXT)
def test_detect_datatype_qna_pair(self):
self.assertEqual(
detect_datatype(("Question?", "Answer. Content of the string is irrelevant.")), DataType.QNA_PAIR
) #
def test_detect_datatype_qna_pair_types(self):
"""Test that a QnA pair needs to be a tuple of length two, and both items have to be strings."""
with self.assertRaises(TypeError):
self.assertNotEqual(
detect_datatype(("How many planets are in our solar system?", 8)), DataType.QNA_PAIR
) # NOT equal
def test_detect_datatype_text(self):
self.assertEqual(detect_datatype("Just some text."), DataType.TEXT)
def test_detect_datatype_non_string_error(self):
"""Test type error if the value passed is not a string, and not a valid non-string data_type"""
with self.assertRaises(TypeError):
detect_datatype(["foo", "bar"])
@patch("os.path.isfile")
def test_detect_datatype_regular_filesystem_file_txt(self, mock_isfile):
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
mock_isfile.return_value = True
self.assertEqual(detect_datatype(tmp.name), DataType.TEXT_FILE)
def test_detect_datatype_regular_filesystem_no_file(self):
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""
self.assertEqual(detect_datatype("/var/not-an-existing-file.txt"), DataType.TEXT)
def test_doc_examples_quickstart(self):
"""Test examples used in the documentation."""
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Elon_Musk"), DataType.WEB_PAGE)
self.assertEqual(detect_datatype("https://www.tesla.com/elon-musk"), DataType.WEB_PAGE)
def test_doc_examples_introduction(self):
"""Test examples used in the documentation."""
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=3qHkcs3kG44"), DataType.YOUTUBE_VIDEO)
self.assertEqual(
detect_datatype(
"https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf"
),
DataType.PDF_FILE,
)
self.assertEqual(detect_datatype("https://nav.al/feedback"), DataType.WEB_PAGE)
def test_doc_examples_app_types(self):
"""Test examples used in the documentation."""
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=Ff4fRgnuFgQ"), DataType.YOUTUBE_VIDEO)
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Mark_Zuckerberg"), DataType.WEB_PAGE)
def test_doc_examples_configuration(self):
"""Test examples used in the documentation."""
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "wikipedia"])
import wikipedia
page = wikipedia.page("Albert Einstein")
# TODO: Add a wikipedia type, so wikipedia is a dependency and we don't need this slow test.
# (timings: import: 1.4s, fetch wiki: 0.7s)
self.assertEqual(detect_datatype(page.content), DataType.TEXT)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,49 @@
import pytest
from chromadb.api.types import Documents, Embeddings
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
@pytest.fixture
def base_embedder():
return BaseEmbedder()
def test_initialization(base_embedder):
assert isinstance(base_embedder.config, BaseEmbedderConfig)
# not initialized
assert not hasattr(base_embedder, "embedding_fn")
assert not hasattr(base_embedder, "vector_dimension")
def test_set_embedding_fn(base_embedder):
def embedding_function(texts: Documents) -> Embeddings:
return [f"Embedding for {text}" for text in texts]
base_embedder.set_embedding_fn(embedding_function)
assert hasattr(base_embedder, "embedding_fn")
assert callable(base_embedder.embedding_fn)
embeddings = base_embedder.embedding_fn(["text1", "text2"])
assert embeddings == ["Embedding for text1", "Embedding for text2"]
def test_set_embedding_fn_when_not_a_function(base_embedder):
with pytest.raises(ValueError):
base_embedder.set_embedding_fn(None)
def test_set_vector_dimension(base_embedder):
base_embedder.set_vector_dimension(256)
assert hasattr(base_embedder, "vector_dimension")
assert base_embedder.vector_dimension == 256
def test_set_vector_dimension_type_error(base_embedder):
with pytest.raises(TypeError):
base_embedder.set_vector_dimension(None)
def test_embedder_with_config():
embedder = BaseEmbedder(BaseEmbedderConfig())
assert isinstance(embedder.config, BaseEmbedderConfig)

View File

@@ -0,0 +1,18 @@
from unittest.mock import patch
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.huggingface import HuggingFaceEmbedder
def test_huggingface_embedder_with_model(monkeypatch):
config = BaseEmbedderConfig(model="test-model", model_kwargs={"param": "value"})
with patch('embedchain.embedder.huggingface.HuggingFaceEmbeddings') as mock_embeddings:
embedder = HuggingFaceEmbedder(config=config)
assert embedder.config.model == "test-model"
assert embedder.config.model_kwargs == {"param": "value"}
mock_embeddings.assert_called_once_with(
model_name="test-model",
model_kwargs={"param": "value"}
)

View File

@@ -0,0 +1,224 @@
import numpy as np
import pytest
from embedchain.config.evaluation.base import AnswerRelevanceConfig
from embedchain.evaluation.metrics import AnswerRelevance
from embedchain.utils.evaluation import EvalData, EvalMetric
@pytest.fixture
def mock_data():
return [
EvalData(
contexts=[
"This is a test context 1.",
],
question="This is a test question 1.",
answer="This is a test answer 1.",
),
EvalData(
contexts=[
"This is a test context 2-1.",
"This is a test context 2-2.",
],
question="This is a test question 2.",
answer="This is a test answer 2.",
),
]
@pytest.fixture
def mock_answer_relevance_metric(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
monkeypatch.setenv("OPENAI_API_BASE", "test_api_base")
metric = AnswerRelevance()
return metric
def test_answer_relevance_init(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
metric = AnswerRelevance()
assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
assert metric.config.model == "gpt-4"
assert metric.config.embedder == "text-embedding-ada-002"
assert metric.config.api_key is None
assert metric.config.num_gen_questions == 1
monkeypatch.delenv("OPENAI_API_KEY")
def test_answer_relevance_init_with_config():
metric = AnswerRelevance(config=AnswerRelevanceConfig(api_key="test_api_key"))
assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
assert metric.config.model == "gpt-4"
assert metric.config.embedder == "text-embedding-ada-002"
assert metric.config.api_key == "test_api_key"
assert metric.config.num_gen_questions == 1
def test_answer_relevance_init_without_api_key(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError):
AnswerRelevance()
def test_generate_prompt(mock_answer_relevance_metric, mock_data):
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
assert "This is a test answer 1." in prompt
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
assert "This is a test answer 2." in prompt
def test_generate_questions(mock_answer_relevance_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_answer_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type(
"obj",
(object,),
{"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
)
]
},
)(),
)
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
questions = mock_answer_relevance_metric._generate_questions(prompt)
assert len(questions) == 1
monkeypatch.setattr(
mock_answer_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
]
},
)(),
)
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
questions = mock_answer_relevance_metric._generate_questions(prompt)
assert len(questions) == 2
def test_generate_embedding(mock_answer_relevance_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_answer_relevance_metric.client.embeddings,
"create",
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
)
embedding = mock_answer_relevance_metric._generate_embedding("This is a test question.")
assert len(embedding) == 3
def test_compute_similarity(mock_answer_relevance_metric, mock_data):
original = np.array([1, 2, 3])
generated = np.array([[1, 2, 3], [1, 2, 3]])
similarity = mock_answer_relevance_metric._compute_similarity(original, generated)
assert len(similarity) == 2
assert similarity[0] == 1.0
assert similarity[1] == 1.0
def test_compute_score(mock_answer_relevance_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_answer_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type(
"obj",
(object,),
{"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
)
]
},
)(),
)
monkeypatch.setattr(
mock_answer_relevance_metric.client.embeddings,
"create",
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
)
score = mock_answer_relevance_metric._compute_score(mock_data[0])
assert score == 1.0
monkeypatch.setattr(
mock_answer_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
]
},
)(),
)
monkeypatch.setattr(
mock_answer_relevance_metric.client.embeddings,
"create",
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
)
score = mock_answer_relevance_metric._compute_score(mock_data[1])
assert score == 1.0
def test_evaluate(mock_answer_relevance_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_answer_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type(
"obj",
(object,),
{"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
)
]
},
)(),
)
monkeypatch.setattr(
mock_answer_relevance_metric.client.embeddings,
"create",
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
)
score = mock_answer_relevance_metric.evaluate(mock_data)
assert score == 1.0
monkeypatch.setattr(
mock_answer_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
]
},
)(),
)
monkeypatch.setattr(
mock_answer_relevance_metric.client.embeddings,
"create",
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
)
score = mock_answer_relevance_metric.evaluate(mock_data)
assert score == 1.0

View File

@@ -0,0 +1,100 @@
import pytest
from embedchain.config.evaluation.base import ContextRelevanceConfig
from embedchain.evaluation.metrics import ContextRelevance
from embedchain.utils.evaluation import EvalData, EvalMetric
@pytest.fixture
def mock_data():
return [
EvalData(
contexts=[
"This is a test context 1.",
],
question="This is a test question 1.",
answer="This is a test answer 1.",
),
EvalData(
contexts=[
"This is a test context 2-1.",
"This is a test context 2-2.",
],
question="This is a test question 2.",
answer="This is a test answer 2.",
),
]
@pytest.fixture
def mock_context_relevance_metric(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
metric = ContextRelevance()
return metric
def test_context_relevance_init(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
metric = ContextRelevance()
assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
assert metric.config.model == "gpt-4"
assert metric.config.api_key is None
assert metric.config.language == "en"
monkeypatch.delenv("OPENAI_API_KEY")
def test_context_relevance_init_with_config():
metric = ContextRelevance(config=ContextRelevanceConfig(api_key="test_api_key"))
assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
assert metric.config.model == "gpt-4"
assert metric.config.api_key == "test_api_key"
assert metric.config.language == "en"
def test_context_relevance_init_without_api_key(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError):
ContextRelevance()
def test_sentence_segmenter(mock_context_relevance_metric):
text = "This is a test sentence. This is another sentence."
assert mock_context_relevance_metric._sentence_segmenter(text) == [
"This is a test sentence. ",
"This is another sentence.",
]
def test_compute_score(mock_context_relevance_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_context_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
]
},
)(),
)
assert mock_context_relevance_metric._compute_score(mock_data[0]) == 1.0
assert mock_context_relevance_metric._compute_score(mock_data[1]) == 0.5
def test_evaluate(mock_context_relevance_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_context_relevance_metric.client.chat.completions,
"create",
lambda model, messages: type(
"obj",
(object,),
{
"choices": [
type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
]
},
)(),
)
assert mock_context_relevance_metric.evaluate(mock_data) == 0.75

View File

@@ -0,0 +1,152 @@
import numpy as np
import pytest
from embedchain.config.evaluation.base import GroundednessConfig
from embedchain.evaluation.metrics import Groundedness
from embedchain.utils.evaluation import EvalData, EvalMetric
@pytest.fixture
def mock_data():
return [
EvalData(
contexts=[
"This is a test context 1.",
],
question="This is a test question 1.",
answer="This is a test answer 1.",
),
EvalData(
contexts=[
"This is a test context 2-1.",
"This is a test context 2-2.",
],
question="This is a test question 2.",
answer="This is a test answer 2.",
),
]
@pytest.fixture
def mock_groundedness_metric(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
metric = Groundedness()
return metric
def test_groundedness_init(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
metric = Groundedness()
assert metric.name == EvalMetric.GROUNDEDNESS.value
assert metric.config.model == "gpt-4"
assert metric.config.api_key is None
monkeypatch.delenv("OPENAI_API_KEY")
def test_groundedness_init_with_config():
metric = Groundedness(config=GroundednessConfig(api_key="test_api_key"))
assert metric.name == EvalMetric.GROUNDEDNESS.value
assert metric.config.model == "gpt-4"
assert metric.config.api_key == "test_api_key"
def test_groundedness_init_without_api_key(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError):
Groundedness()
def test_generate_answer_claim_prompt(mock_groundedness_metric, mock_data):
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
assert "This is a test question 1." in prompt
assert "This is a test answer 1." in prompt
def test_get_claim_statements(mock_groundedness_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_groundedness_metric.client.chat.completions,
"create",
lambda *args, **kwargs: type(
"obj",
(object,),
{
"choices": [
type(
"obj",
(object,),
{
"message": type(
"obj",
(object,),
{
"content": """This is a test answer 1.
This is a test answer 2.
This is a test answer 3."""
},
)
},
)
]
},
)(),
)
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
assert len(claim_statements) == 3
assert "This is a test answer 1." in claim_statements
def test_generate_claim_inference_prompt(mock_groundedness_metric, mock_data):
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
claim_statements = [
"This is a test claim 1.",
"This is a test claim 2.",
]
prompt = mock_groundedness_metric._generate_claim_inference_prompt(
data=mock_data[0], claim_statements=claim_statements
)
assert "This is a test context 1." in prompt
assert "This is a test claim 1." in prompt
def test_get_claim_verdict_scores(mock_groundedness_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_groundedness_metric.client.chat.completions,
"create",
lambda *args, **kwargs: type(
"obj",
(object,),
{"choices": [type("obj", (object,), {"message": type("obj", (object,), {"content": "1\n0\n-1"})})]},
)(),
)
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
prompt = mock_groundedness_metric._generate_claim_inference_prompt(
data=mock_data[0], claim_statements=claim_statements
)
claim_verdict_scores = mock_groundedness_metric._get_claim_verdict_scores(prompt=prompt)
assert len(claim_verdict_scores) == 3
assert claim_verdict_scores[0] == 1
assert claim_verdict_scores[1] == 0
def test_compute_score(mock_groundedness_metric, mock_data, monkeypatch):
monkeypatch.setattr(
mock_groundedness_metric,
"_get_claim_statements",
lambda *args, **kwargs: np.array(
[
"This is a test claim 1.",
"This is a test claim 2.",
]
),
)
monkeypatch.setattr(mock_groundedness_metric, "_get_claim_verdict_scores", lambda *args, **kwargs: np.array([1, 0]))
score = mock_groundedness_metric._compute_score(data=mock_data[0])
assert score == 0.5
def test_evaluate(mock_groundedness_metric, mock_data, monkeypatch):
monkeypatch.setattr(mock_groundedness_metric, "_compute_score", lambda *args, **kwargs: 0.5)
score = mock_groundedness_metric.evaluate(dataset=mock_data)
assert score == 0.5

View File

@@ -0,0 +1,79 @@
import random
import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable)
class TestJsonSerializable(unittest.TestCase):
"""Test that the datatype detection is working, based on the input."""
def test_base_function(self):
"""Test that the base premise of serialization and deserealization is working"""
@register_deserializable
class TestClass(JSONSerializable):
def __init__(self):
self.rng = random.random()
original_class = TestClass()
serial = original_class.serialize()
# Negative test to show that a new class does not have the same random number.
negative_test_class = TestClass()
self.assertNotEqual(original_class.rng, negative_test_class.rng)
# Test to show that a deserialized class has the same random number.
positive_test_class: TestClass = TestClass().deserialize(serial)
self.assertEqual(original_class.rng, positive_test_class.rng)
self.assertTrue(isinstance(positive_test_class, TestClass))
# Test that it works as a static method too.
positive_test_class: TestClass = TestClass.deserialize(serial)
self.assertEqual(original_class.rng, positive_test_class.rng)
# TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
def test_registration_required(self):
"""Test that registration is required, and that without registration the default class is returned."""
class SecondTestClass(JSONSerializable):
def __init__(self):
self.default = True
app = SecondTestClass()
# Make not default
app.default = False
# Serialize
serial = app.serialize()
# Deserialize. Due to the way errors are handled, it will not fail but return a default class.
app: SecondTestClass = SecondTestClass().deserialize(serial)
self.assertTrue(app.default)
# If we register and try again with the same serial, it should work
SecondTestClass._register_class_as_deserializable(SecondTestClass)
app: SecondTestClass = SecondTestClass().deserialize(serial)
self.assertFalse(app.default)
def test_recursive(self):
"""Test recursiveness with the real app"""
random_id = str(random.random())
config = AppConfig(id=random_id, collect_metrics=False)
# config class is set under app.config.
app = App(config=config)
s = app.serialize()
new_app: App = App.deserialize(s)
# The id of the new app is the same as the first one.
self.assertEqual(random_id, new_app.config.id)
# We have proven that a nested class (app.config) can be serialized and deserialized just the same.
# TODO: test deeper recursion
def test_special_subclasses(self):
"""Test special subclasses that are not serializable by default."""
# Template
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
s = config.serialize()
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
self.assertEqual(config.prompt.template, new_config.prompt.template)

View File

@@ -0,0 +1,54 @@
import os
from unittest.mock import patch
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.anthropic import AnthropicLlm
@pytest.fixture
def anthropic_llm():
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
return AnthropicLlm(config)
def test_get_llm_model_answer(anthropic_llm):
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt, anthropic_llm.config)
def test_get_messages(anthropic_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = anthropic_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_get_llm_model_answer_with_token_usage(anthropic_llm):
test_config = BaseLlmConfig(
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
)
anthropic_llm.config = test_config
with patch.object(
AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
) as mock_method:
prompt = "Test Prompt"
response, token_info = anthropic_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 1.265e-05,
"cost_currency": "USD",
}
mock_method.assert_called_once_with(prompt, anthropic_llm.config)

View File

@@ -0,0 +1,56 @@
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.aws_bedrock import AWSBedrockLlm
@pytest.fixture
def config(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
config = BaseLlmConfig(
model="amazon.titan-text-express-v1",
model_kwargs={
"temperature": 0.5,
"topP": 1,
"maxTokenCount": 1000,
},
)
yield config
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
monkeypatch.delenv("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
llm = AWSBedrockLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
llm = AWSBedrockLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
llm = AWSBedrockLlm(config)
llm.get_llm_model_answer("Test query")
mocked_bedrock_chat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)

View File

@@ -0,0 +1,87 @@
from unittest.mock import MagicMock, patch
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.llm.azure_openai import AzureOpenAILlm
@pytest.fixture
def azure_openai_llm():
config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
model="gpt-3.5-turbo",
max_tokens=50,
system_prompt="System Prompt",
)
return AzureOpenAILlm(config)
def test_get_llm_model_answer(azure_openai_llm):
with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = azure_openai_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
def test_get_answer(azure_openai_llm):
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.invoke.return_value = MagicMock(content="Test Response")
prompt = "Test Prompt"
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
assert response == "Test Response"
mock_chat.assert_called_once_with(
deployment_name=azure_openai_llm.config.deployment_name,
openai_api_version="2024-02-01",
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
temperature=azure_openai_llm.config.temperature,
max_tokens=azure_openai_llm.config.max_tokens,
streaming=azure_openai_llm.config.stream,
)
def test_get_messages(azure_openai_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = azure_openai_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]
def test_when_no_deployment_name_provided():
config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
with pytest.raises(ValueError):
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")
def test_with_api_version():
config = BaseLlmConfig(
deployment_name="azure_deployment",
temperature=0.7,
model="gpt-3.5-turbo",
max_tokens=50,
system_prompt="System Prompt",
api_version="2024-02-01",
)
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")
mock_chat.assert_called_once_with(
deployment_name="azure_deployment",
openai_api_version="2024-02-01",
model_name="gpt-3.5-turbo",
temperature=0.7,
max_tokens=50,
streaming=False,
)

View File

@@ -0,0 +1,61 @@
from string import Template
import pytest
from embedchain.llm.base import BaseLlm, BaseLlmConfig
@pytest.fixture
def base_llm():
config = BaseLlmConfig()
return BaseLlm(config=config)
def test_is_get_llm_model_answer_not_implemented(base_llm):
with pytest.raises(NotImplementedError):
base_llm.get_llm_model_answer()
def test_is_stream_bool():
with pytest.raises(ValueError):
config = BaseLlmConfig(stream="test value")
BaseLlm(config=config)
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.prompt, Template)
def test_is_get_llm_model_answer_implemented():
class TestLlm(BaseLlm):
def get_llm_model_answer(self):
return "Implemented"
config = BaseLlmConfig()
llm = TestLlm(config=config)
assert llm.get_llm_model_answer() == "Implemented"
def test_stream_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_response(answer))
assert result == answer
def test_append_search_and_context(base_llm):
context = "Context"
web_search_result = "Web Search Result"
result = base_llm._append_search_and_context(context, web_search_result)
expected_result = "Context\nWeb Search Result: Web Search Result"
assert result == expected_result
def test_access_search_and_get_results(base_llm, mocker):
base_llm.access_search_and_get_results = mocker.patch.object(
base_llm, "access_search_and_get_results", return_value="Search Results"
)
input_query = "Test query"
result = base_llm.access_search_and_get_results(input_query)
assert result == "Search Results"

View File

@@ -0,0 +1,120 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.base import BaseLlm
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
class TestApp(unittest.TestCase):
def setUp(self):
os.environ["OPENAI_API_KEY"] = "test_key"
self.app = App(config=AppConfig(collect_metrics=False))
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
memory.
The 'chat' method is called twice. The first call initializes the chat history memory.
The second call is expected to use the chat history from the first call.
Key assumptions tested:
called with correct arguments, adding the correct chat history.
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
- During the second call, the 'chat' method uses the chat history from the first call.
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
with patch.object(BaseLlm, "add_history") as mock_history:
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer", session_id="default")
second_answer = app.chat("Test query 2", session_id="test_session")
self.assertEqual(second_answer, "Test answer")
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer", session_id="test_session")
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_template_replacement(self, mock_get_answer, mock_retrieve):
"""
Tests that if a default template is used and it doesn't contain history,
the default template is swapped in.
Also tests that a dry run does not change the history
"""
with patch.object(ChatHistory, "get") as mock_memory:
mock_message = ChatMessage()
mock_message.add_user_message("Test query 1")
mock_message.add_ai_message("Test answer")
mock_memory.return_value = [mock_message]
config = AppConfig(collect_metrics=False)
app = App(config=config)
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.llm.history), 1)
history = app.llm.history
dry_run = app.chat("Test query 2", dry_run=True)
self.assertIn("Conversation history:", dry_run)
self.assertEqual(history, app.llm.history)
self.assertEqual(len(app.llm.history), 1)
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_params(self):
"""
Test where filter
"""
with patch.object(self.app, "_retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.chat("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_chat_config(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the '_retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
BaseLlmConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.chat("Test query", llm_config)
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
where = kwargs.get("where")
assert "app_id" in where
assert "attribute" in where
mock_answer.assert_called_once()

View File

@@ -0,0 +1,23 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.clarifai import ClarifaiLlm
@pytest.fixture
def clarifai_llm_config(monkeypatch):
monkeypatch.setenv("CLARIFAI_PAT","test_api_key")
config = BaseLlmConfig(
model="https://clarifai.com/openai/chat-completion/models/GPT-4",
model_kwargs={"temperature": 0.7, "max_tokens": 100},
)
yield config
monkeypatch.delenv("CLARIFAI_PAT")
def test_clarifai__llm_get_llm_model_answer(clarifai_llm_config, mocker):
mocker.patch("embedchain.llm.clarifai.ClarifaiLlm._get_answer", return_value="Test answer")
llm = ClarifaiLlm(clarifai_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"

View File

@@ -0,0 +1,73 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.cohere import CohereLlm
@pytest.fixture
def cohere_llm_config():
os.environ["COHERE_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
yield config
os.environ.pop("COHERE_API_KEY")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
CohereLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
llm = CohereLlm(cohere_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(cohere_llm_config, mocker):
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
llm = CohereLlm(cohere_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
test_config = BaseLlmConfig(
temperature=cohere_llm_config.temperature,
max_tokens=cohere_llm_config.max_tokens,
top_p=cohere_llm_config.top_p,
model=cohere_llm_config.model,
token_usage=True,
)
mocker.patch(
"embedchain.llm.cohere.CohereLlm._get_answer",
return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
)
llm = CohereLlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3.5e-06,
"cost_currency": "USD",
}
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
llm = CohereLlm(cohere_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"

View File

@@ -0,0 +1,70 @@
import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
class TestGeneratePrompt(unittest.TestCase):
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
def test_generate_prompt_with_template(self):
"""
Tests that the generate_prompt method correctly formats the prompt using
a custom template provided in the BaseLlmConfig instance.
This test sets up a scenario with an input query and a list of contexts,
and a custom template, and then calls generate_prompt. It checks that the
returned prompt correctly incorporates all the contexts and the query into
the format specified by the template.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = BaseLlmConfig(template=Template(template))
self.app.llm.config = config
# Execute
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = (
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_contexts_list(self):
"""
Tests that the generate_prompt method correctly handles a list of contexts.
This test sets up a scenario with an input query and a list of contexts,
and then calls generate_prompt. It checks that the returned prompt
correctly includes all the contexts and the query.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = BaseLlmConfig()
# Execute
self.app.llm.config = config
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
"""
config = BaseLlmConfig()
config.prompt = Template("Context: $context | Query: $query | History: $history")
self.app.llm.config = config
self.app.llm.set_history(["Past context 1", "Past context 2"])
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
expected_prompt = "Context: Test context | Query: Test query | History: Past context 1\nPast context 2"
self.assertEqual(prompt, expected_prompt)

View File

@@ -0,0 +1,43 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.google import GoogleLlm
@pytest.fixture
def google_llm_config():
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
def test_google_llm_init_missing_api_key(monkeypatch):
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
GoogleLlm()
def test_google_llm_init(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
with monkeypatch.context() as m:
m.setattr("importlib.import_module", lambda x: None)
google_llm = GoogleLlm()
assert google_llm is not None
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
monkeypatch.setattr("importlib.import_module", lambda x: None)
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
google_llm.get_llm_model_answer("test prompt")
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
def mock_get_answer(prompt, config):
return "Generated Text"
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
google_llm = GoogleLlm(config=google_llm_config)
result = google_llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"

View File

@@ -0,0 +1,60 @@
import pytest
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
from embedchain.config import BaseLlmConfig
from embedchain.llm.gpt4all import GPT4ALLLlm
@pytest.fixture
def config():
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="orca-mini-3b-gguf2-q4_0.gguf",
)
yield config
@pytest.fixture
def gpt4all_with_config(config):
return GPT4ALLLlm(config=config)
@pytest.fixture
def gpt4all_without_config():
return GPT4ALLLlm()
def test_gpt4all_init_with_config(config, gpt4all_with_config):
assert gpt4all_with_config.config.temperature == config.temperature
assert gpt4all_with_config.config.max_tokens == config.max_tokens
assert gpt4all_with_config.config.top_p == config.top_p
assert gpt4all_with_config.config.stream == config.stream
assert gpt4all_with_config.config.system_prompt == config.system_prompt
assert gpt4all_with_config.config.model == config.model
assert isinstance(gpt4all_with_config.instance, LangchainGPT4All)
def test_gpt4all_init_without_config(gpt4all_without_config):
assert gpt4all_without_config.config.model == "orca-mini-3b-gguf2-q4_0.gguf"
assert isinstance(gpt4all_without_config.instance, LangchainGPT4All)
def test_get_llm_model_answer(mocker, gpt4all_with_config):
test_query = "Test query"
test_answer = "Test answer"
mocked_get_answer = mocker.patch("embedchain.llm.gpt4all.GPT4ALLLlm._get_answer", return_value=test_answer)
answer = gpt4all_with_config.get_llm_model_answer(test_query)
assert answer == test_answer
mocked_get_answer.assert_called_once_with(prompt=test_query, config=gpt4all_with_config.config)
def test_gpt4all_model_switching(gpt4all_with_config):
with pytest.raises(RuntimeError, match="GPT4ALLLlm does not support switching models at runtime."):
gpt4all_with_config._get_answer("Test prompt", BaseLlmConfig(model="new_model"))

View File

@@ -0,0 +1,83 @@
import importlib
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.huggingface import HuggingFaceLlm
@pytest.fixture
def huggingface_llm_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
@pytest.fixture
def huggingface_endpoint_config():
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
config = BaseLlmConfig(endpoint="https://api-inference.huggingface.co/models/gpt2", model_kwargs={"device": "cpu"})
yield config
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
HuggingFaceLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
llm = HuggingFaceLlm(huggingface_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
def test_top_p_value_within_range():
config = BaseLlmConfig(top_p=1.0)
with pytest.raises(ValueError):
HuggingFaceLlm._get_answer("test_prompt", config)
def test_dependency_is_imported():
importlib_installed = True
try:
importlib.import_module("huggingface_hub")
except ImportError:
importlib_installed = False
assert importlib_installed
def test_get_llm_model_answer(huggingface_llm_config, mocker):
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_hugging_face_mock(huggingface_llm_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mock_hf_hub = mocker.patch("embedchain.llm.huggingface.HuggingFaceHub")
mock_hf_hub.return_value.invoke = mock_llm_instance
llm = HuggingFaceLlm(huggingface_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mock_llm_instance.assert_called_once_with("Test query")
def test_custom_endpoint(huggingface_endpoint_config, mocker):
mock_llm_instance = mocker.Mock(return_value="Test answer")
mock_hf_endpoint = mocker.patch("embedchain.llm.huggingface.HuggingFaceEndpoint")
mock_hf_endpoint.return_value.invoke = mock_llm_instance
llm = HuggingFaceLlm(huggingface_endpoint_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mock_llm_instance.assert_called_once_with("Test query")

View File

@@ -0,0 +1,79 @@
import os
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.jina import JinaLlm
@pytest.fixture
def config():
os.environ["JINACHAT_API_KEY"] = "test_api_key"
config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
yield config
os.environ.pop("JINACHAT_API_KEY")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
JinaLlm()
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
llm = JinaLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
llm = JinaLlm(config)
llm.get_llm_model_answer("Test query")
mocked_jinachat.assert_called_once_with(
temperature=config.temperature,
max_tokens=config.max_tokens,
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
model_kwargs={"top_p": config.top_p},
)

View File

@@ -0,0 +1,40 @@
import os
import pytest
from embedchain.llm.llama2 import Llama2Llm
@pytest.fixture
def llama2_llm():
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
llm = Llama2Llm()
return llm
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
Llama2Llm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
llama2_llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llama2_llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(llama2_llm, mocker):
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
mocked_replicate_instance = mocker.MagicMock()
mocked_replicate.return_value = mocked_replicate_instance
mocked_replicate_instance.invoke.return_value = "Test answer"
llama2_llm.config.model = "test_model"
llama2_llm.config.max_tokens = 50
llama2_llm.config.temperature = 0.7
llama2_llm.config.top_p = 0.8
answer = llama2_llm.get_llm_model_answer("Test query")
assert answer == "Test answer"

View File

@@ -0,0 +1,87 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.mistralai import MistralAILlm
@pytest.fixture
def mistralai_llm_config(monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
def test_mistralai_llm_init_missing_api_key(monkeypatch):
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
MistralAILlm()
def test_mistralai_llm_init(monkeypatch):
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
llm = MistralAILlm()
assert llm is not None
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
def mock_get_answer(self, prompt, config):
return "Generated Text"
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = "Test system prompt"
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("")
assert result == "Generated Text"
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
mistralai_llm_config.system_prompt = None
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
llm = MistralAILlm(config=mistralai_llm_config)
result = llm.get_llm_model_answer("test prompt")
assert result == "Generated Text"
def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
test_config = BaseLlmConfig(
temperature=mistralai_llm_config.temperature,
max_tokens=mistralai_llm_config.max_tokens,
top_p=mistralai_llm_config.top_p,
model=mistralai_llm_config.model,
token_usage=True,
)
monkeypatch.setattr(
MistralAILlm,
"_get_answer",
lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = MistralAILlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Generated Text"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 7.5e-07,
"cost_currency": "USD",
}

View File

@@ -0,0 +1,52 @@
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.ollama import OllamaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture
def ollama_llm_config():
config = BaseLlmConfig(model="llama2", temperature=0.7, top_p=0.8, stream=True, system_prompt=None)
yield config
def test_get_llm_model_answer(ollama_llm_config, mocker):
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
llm = OllamaLlm(ollama_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_answer_mocked_ollama(ollama_llm_config, mocker):
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama")
mock_instance = mocked_ollama.return_value
mock_instance.invoke.return_value = "Mocked answer"
llm = OllamaLlm(ollama_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"
def test_get_llm_model_answer_with_streaming(ollama_llm_config, mocker):
ollama_llm_config.stream = True
ollama_llm_config.callbacks = [StreamingStdOutCallbackHandler()]
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
mocked_ollama_chat = mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
llm = OllamaLlm(ollama_llm_config)
llm.get_llm_model_answer("Test query")
mocked_ollama_chat.assert_called_once()
call_args = mocked_ollama_chat.call_args
config_arg = call_args[1]["config"]
callbacks = config_arg.callbacks
assert len(callbacks) == 1
assert isinstance(callbacks[0], StreamingStdOutCallbackHandler)

View File

@@ -0,0 +1,261 @@
import os
import httpx
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
@pytest.fixture()
def env_config():
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
yield
os.environ.pop("OPENAI_API_KEY")
@pytest.fixture
def config(env_config):
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_client_proxies=None,
http_async_client_proxies=None,
)
yield config
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_with_system_prompt(config, mocker):
config.system_prompt = "Custom system prompt"
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
llm = OpenAILlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_token_usage(config, mocker):
test_config = BaseLlmConfig(
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stream=config.stream,
system_prompt=config.system_prompt,
model=config.model,
token_usage=True,
)
mocked_get_answer = mocker.patch(
"embedchain.llm.openai.OpenAILlm._get_answer",
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = OpenAILlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 5.5e-06,
"cost_currency": "USD",
}
mocked_get_answer.assert_called_once_with("Test query", test_config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
def test_get_llm_model_answer_without_system_prompt(config, mocker):
config.system_prompt = None
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=None,
)
def test_get_llm_model_answer_with_special_headers(config, mocker):
config.default_headers = {"test": "test"}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
default_headers={"test": "test"},
http_client=None,
http_async_client=None,
)
def test_get_llm_model_answer_with_model_kwargs(config, mocker):
config.model_kwargs = {"response_format": {"type": "json_object"}}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=None,
)
@pytest.mark.parametrize(
"mock_return, expected",
[
([{"test": "test"}], '{"test": "test"}'),
([], "Input could not be mapped to the function!"),
],
)
def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
llm = OpenAILlm(config, tools={"test": "test"})
answer = llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=None,
)
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
mocked_json_output_tools_parser.assert_called_once()
assert answer == expected
def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
mock_http_client = mocker.Mock(spec=httpx.Client)
mock_http_client_instance = mocker.Mock(spec=httpx.Client)
mock_http_client.return_value = mock_http_client_instance
mocker.patch("httpx.Client", new=mock_http_client)
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_client_proxies="http://testproxy.mem0.net:8000",
)
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=mock_http_client_instance,
http_async_client=None,
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
mock_http_async_client.return_value = mock_http_async_client_instance
mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
config = BaseLlmConfig(
temperature=0.7,
max_tokens=50,
top_p=0.8,
stream=False,
system_prompt="System prompt",
model="gpt-3.5-turbo",
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
)
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
http_client=None,
http_async_client=mock_http_async_client_instance,
)
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})

View File

@@ -0,0 +1,79 @@
import os
from unittest.mock import MagicMock, patch
import pytest
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
app = App(config=AppConfig(collect_metrics=False))
return app
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(app):
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = app.query(input_query="Test query")
assert answer == "Test answer"
mock_retrieve.assert_called_once()
_, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
assert input_query_arg == "Test query"
mock_answer.assert_called_once()
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_query_config_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
llm = OpenAILlm(config=chat_config)
app = App(config=config, llm=llm)
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(app):
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = app.query("Test query", where={"attribute": "value"})
assert answer == "Test answer"
_, kwargs = mock_retrieve.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(app):
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
llm_config = BaseLlmConfig(where={"attribute": "value"})
answer = app.query("Test query", llm_config)
assert answer == "Test answer"
_, kwargs = mock_database_query.call_args
assert kwargs.get("input_query") == "Test query"
where = kwargs.get("where")
assert "app_id" in where
assert "attribute" in where
mock_answer.assert_called_once()

View File

@@ -0,0 +1,74 @@
import os
import pytest
from embedchain.config import BaseLlmConfig
from embedchain.llm.together import TogetherLlm
@pytest.fixture
def together_llm_config():
os.environ["TOGETHER_API_KEY"] = "test_api_key"
config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
yield config
os.environ.pop("TOGETHER_API_KEY")
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
TogetherLlm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(together_llm_config):
llm = TogetherLlm(together_llm_config)
llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(together_llm_config, mocker):
mocker.patch("embedchain.llm.together.TogetherLlm._get_answer", return_value="Test answer")
llm = TogetherLlm(together_llm_config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
test_config = BaseLlmConfig(
temperature=together_llm_config.temperature,
max_tokens=together_llm_config.max_tokens,
top_p=together_llm_config.top_p,
model=together_llm_config.model,
token_usage=True,
)
mocker.patch(
"embedchain.llm.together.TogetherLlm._get_answer",
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
)
llm = TogetherLlm(test_config)
answer, token_info = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3e-07,
"cost_currency": "USD",
}
def test_get_answer_mocked_together(together_llm_config, mocker):
mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
mock_instance = mocked_together.return_value
mock_instance.invoke.return_value.content = "Mocked answer"
llm = TogetherLlm(together_llm_config)
prompt = "Test query"
answer = llm.get_llm_model_answer(prompt)
assert answer == "Mocked answer"

View File

@@ -0,0 +1,76 @@
from unittest.mock import MagicMock, patch
import pytest
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.core.db.database import database_manager
from embedchain.llm.vertex_ai import VertexAILlm
@pytest.fixture(autouse=True)
def setup_database():
database_manager.setup_engine()
@pytest.fixture
def vertexai_llm():
config = BaseLlmConfig(temperature=0.6, model="chat-bison")
return VertexAILlm(config)
def test_get_llm_model_answer(vertexai_llm):
with patch.object(VertexAILlm, "_get_answer", return_value="Test Response") as mock_method:
prompt = "Test Prompt"
response = vertexai_llm.get_llm_model_answer(prompt)
assert response == "Test Response"
mock_method.assert_called_once_with(prompt, vertexai_llm.config)
def test_get_llm_model_answer_with_token_usage(vertexai_llm):
test_config = BaseLlmConfig(
temperature=vertexai_llm.config.temperature,
max_tokens=vertexai_llm.config.max_tokens,
top_p=vertexai_llm.config.top_p,
model=vertexai_llm.config.model,
token_usage=True,
)
vertexai_llm.config = test_config
with patch.object(
VertexAILlm,
"_get_answer",
return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
):
response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
assert response == "Test Response"
assert token_info == {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
"total_cost": 3.75e-07,
"cost_currency": "USD",
}
@patch("embedchain.llm.vertex_ai.ChatVertexAI")
def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog):
mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response")
config = vertexai_llm.config
prompt = "Test Prompt"
messages = vertexai_llm._get_messages(prompt)
response = vertexai_llm._get_answer(prompt, config)
mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages)
assert response == "Test Response" # Assertion corrected
assert "Config option `top_p` is not supported by this model." not in caplog.text
def test_get_messages(vertexai_llm):
prompt = "Test Prompt"
system_prompt = "Test System Prompt"
messages = vertexai_llm._get_messages(prompt, system_prompt)
assert messages == [
SystemMessage(content="Test System Prompt", additional_kwargs={}),
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
]

View File

@@ -0,0 +1,100 @@
import hashlib
import os
import sys
from unittest.mock import mock_open, patch
import pytest
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
from deepgram import PrerecordedOptions
from embedchain.loaders.audio import AudioLoader
@pytest.fixture
def setup_audio_loader(mocker):
mock_dropbox = mocker.patch("deepgram.DeepgramClient")
mock_dbx = mocker.MagicMock()
mock_dropbox.return_value = mock_dbx
os.environ["DEEPGRAM_API_KEY"] = "test_key"
loader = AudioLoader()
loader.client = mock_dbx
yield loader, mock_dbx
if "DEEPGRAM_API_KEY" in os.environ:
del os.environ["DEEPGRAM_API_KEY"]
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_initialization(setup_audio_loader):
"""Test initialization of AudioLoader."""
loader, _ = setup_audio_loader
assert loader is not None
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_load_data_from_url(setup_audio_loader):
loader, mock_dbx = setup_audio_loader
url = "https://example.com/audio.mp3"
expected_content = "This is a test audio transcript."
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
result = loader.load_data(url)
doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
expected_result = {
"doc_id": doc_id,
"data": [
{
"content": expected_content,
"meta_data": {"url": url},
}
],
}
assert result == expected_result
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
{"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
)
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_load_data_from_file(setup_audio_loader):
loader, mock_dbx = setup_audio_loader
file_path = "local_audio.mp3"
expected_content = "This is a test audio transcript."
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
# Mock the file reading functionality
with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
result = loader.load_data(file_path)
doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
expected_result = {
"doc_id": doc_id,
"data": [
{
"content": expected_content,
"meta_data": {"url": file_path},
}
],
}
assert result == expected_result
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
{"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
)

View File

@@ -0,0 +1,113 @@
import csv
import os
import pathlib
import tempfile
from unittest.mock import MagicMock, patch
import pytest
from embedchain.loaders.csv import CsvLoader
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
def test_load_data(delimiter):
"""
Test csv loader
Tests that file is loaded, metadata is correct and content is correct
"""
# Creating temporary CSV file
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
writer = csv.writer(tmpfile, delimiter=delimiter)
writer.writerow(["Name", "Age", "Occupation"])
writer.writerow(["Alice", "28", "Engineer"])
writer.writerow(["Bob", "35", "Doctor"])
writer.writerow(["Charlie", "22", "Student"])
tmpfile.seek(0)
filename = tmpfile.name
# Loading CSV using CsvLoader
loader = CsvLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(data) == 3
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert data[0]["meta_data"]["url"] == filename
assert data[0]["meta_data"]["row"] == 1
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert data[1]["meta_data"]["url"] == filename
assert data[1]["meta_data"]["row"] == 2
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert data[2]["meta_data"]["url"] == filename
assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file
os.unlink(filename)
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
def test_load_data_with_file_uri(delimiter):
"""
Test csv loader with file URI
Tests that file is loaded, metadata is correct and content is correct
"""
# Creating temporary CSV file
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
writer = csv.writer(tmpfile, delimiter=delimiter)
writer.writerow(["Name", "Age", "Occupation"])
writer.writerow(["Alice", "28", "Engineer"])
writer.writerow(["Bob", "35", "Doctor"])
writer.writerow(["Charlie", "22", "Student"])
tmpfile.seek(0)
filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI
# Loading CSV using CsvLoader
loader = CsvLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(data) == 3
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert data[0]["meta_data"]["url"] == filename
assert data[0]["meta_data"]["row"] == 1
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert data[1]["meta_data"]["url"] == filename
assert data[1]["meta_data"]["row"] == 2
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert data[2]["meta_data"]["url"] == filename
assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file
os.unlink(tmpfile.name)
@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"])
def test_get_file_content(content):
with pytest.raises(ValueError):
loader = CsvLoader()
loader._get_file_content(content)
@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"])
def test_get_file_content_http(content):
"""
Test _get_file_content method of CsvLoader for http and https URLs
"""
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student"
mock_get.return_value = mock_response
loader = CsvLoader()
file_content = loader._get_file_content(content)
mock_get.assert_called_once_with(content)
mock_response.raise_for_status.assert_called_once()
assert file_content.read() == mock_response.text

View File

@@ -0,0 +1,104 @@
import pytest
import requests
from embedchain.loaders.discourse import DiscourseLoader
@pytest.fixture
def discourse_loader_config():
return {
"domain": "https://example.com/",
}
@pytest.fixture
def discourse_loader(discourse_loader_config):
return DiscourseLoader(config=discourse_loader_config)
def test_discourse_loader_init_with_valid_config():
config = {"domain": "https://example.com/"}
loader = DiscourseLoader(config=config)
assert loader.domain == "https://example.com/"
def test_discourse_loader_init_with_missing_config():
with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
DiscourseLoader()
def test_discourse_loader_init_with_missing_domain():
config = {"another_key": "value"}
with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
DiscourseLoader(config=config)
def test_discourse_loader_check_query_with_valid_query(discourse_loader):
discourse_loader._check_query("sample query")
def test_discourse_loader_check_query_with_empty_query(discourse_loader):
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
discourse_loader._check_query("")
def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
discourse_loader._check_query(123)
def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def json(self):
return {"raw": "Sample post content"}
def raise_for_status(self):
pass
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
post_data = discourse_loader._load_post(123)
assert post_data["content"] == "Sample post content"
assert "meta_data" in post_data
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def json(self):
return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
def raise_for_status(self):
pass
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
def mock_load_post(*args, **kwargs):
return {
"content": "Sample post content",
"meta_data": {
"url": "https://example.com/posts/123.json",
"created_at": "2021-01-01",
"username": "test_user",
"topic_slug": "test_topic",
"score": 10,
},
}
monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
data = discourse_loader.load_data("sample query")
assert len(data["data"]) == 3
assert data["data"][0]["content"] == "Sample post content"
assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
assert data["data"][0]["meta_data"]["username"] == "test_user"
assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
assert data["data"][0]["meta_data"]["score"] == 10

View File

@@ -0,0 +1,130 @@
import hashlib
from unittest.mock import Mock, patch
import pytest
from requests import Response
from embedchain.loaders.docs_site_loader import DocsSiteLoader
@pytest.fixture
def mock_requests_get():
with patch("requests.get") as mock_get:
yield mock_get
@pytest.fixture
def docs_site_loader():
return DocsSiteLoader()
def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
</html>
"""
mock_requests_get.return_value = mock_response
docs_site_loader._get_child_links_recursive("https://example.com")
assert len(docs_site_loader.visited_links) == 2
assert "https://example.com/page1" in docs_site_loader.visited_links
assert "https://example.com/page2" in docs_site_loader.visited_links
def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
docs_site_loader._get_child_links_recursive("https://example.com")
assert len(docs_site_loader.visited_links) == 0
def test_get_all_urls(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
<a href="https://example.com/external">External</a>
</html>
"""
mock_requests_get.return_value = mock_response
all_urls = docs_site_loader._get_all_urls("https://example.com")
assert len(all_urls) == 3
assert "https://example.com/page1" in all_urls
assert "https://example.com/page2" in all_urls
assert "https://example.com/external" in all_urls
def test_load_data_from_url(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = """
<html>
<nav>
<h1>Navigation</h1>
</nav>
<article class="bd-article">
<p>Article Content</p>
</article>
</html>
""".encode()
mock_requests_get.return_value = mock_response
data = docs_site_loader._load_data_from_url("https://example.com/page1")
assert len(data) == 1
assert data[0]["content"] == "Article Content"
assert data[0]["meta_data"]["url"] == "https://example.com/page1"
def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
data = docs_site_loader._load_data_from_url("https://example.com/page1")
assert data == []
assert len(data) == 0
def test_load_data(mock_requests_get, docs_site_loader):
mock_response = Response()
mock_response.status_code = 200
mock_response._content = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
""".encode()
mock_requests_get.return_value = mock_response
url = "https://example.com"
data = docs_site_loader.load_data(url)
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
assert len(data["data"]) == 2
assert data["doc_id"] == expected_doc_id
def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Response()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
url = "https://example.com"
data = docs_site_loader.load_data(url)
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
assert len(data["data"]) == 0
assert data["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,218 @@
import pytest
import responses
from bs4 import BeautifulSoup
@pytest.mark.parametrize(
"ignored_tag",
[
"<nav>This is a navigation bar.</nav>",
"<aside>This is an aside.</aside>",
"<form>This is a form.</form>",
"<header>This is a header.</header>",
"<noscript>This is a noscript.</noscript>",
"<svg>This is an SVG.</svg>",
"<canvas>This is a canvas.</canvas>",
"<footer>This is a footer.</footer>",
"<script>This is a script.</script>",
"<style>This is a style.</style>",
],
ids=["nav", "aside", "form", "header", "noscript", "svg", "canvas", "footer", "script", "style"],
)
@pytest.mark.parametrize(
"selectee",
[
"""
<article class="bd-article">
<h2>Article Title</h2>
<p>Article content goes here.</p>
{ignored_tag}
</article>
""",
"""
<article role="main">
<h2>Main Article Title</h2>
<p>Main article content goes here.</p>
{ignored_tag}
</article>
""",
"""
<div class="md-content">
<h2>Markdown Content</h2>
<p>Markdown content goes here.</p>
{ignored_tag}
</div>
""",
"""
<div role="main">
<h2>Main Content</h2>
<p>Main content goes here.</p>
{ignored_tag}
</div>
""",
"""
<div class="container">
<h2>Container</h2>
<p>Container content goes here.</p>
{ignored_tag}
</div>
""",
"""
<div class="section">
<h2>Section</h2>
<p>Section content goes here.</p>
{ignored_tag}
</div>
""",
"""
<article>
<h2>Generic Article</h2>
<p>Generic article content goes here.</p>
{ignored_tag}
</article>
""",
"""
<main>
<h2>Main Content</h2>
<p>Main content goes here.</p>
{ignored_tag}
</main>
""",
],
ids=[
"article.bd-article",
'article[role="main"]',
"div.md-content",
'div[role="main"]',
"div.container",
"div.section",
"article",
"main",
],
)
def test_load_data_gets_by_selectors_and_ignored_tags(selectee, ignored_tag, loader, mocked_responses, mocker):
child_url = "https://docs.embedchain.ai/quickstart"
selectee = selectee.format(ignored_tag=ignored_tag)
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
{selectee}
</body>
</html>
"""
html_body = html_body.format(selectee=selectee)
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
url = "https://docs.embedchain.ai/"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/quickstart">Quickstart</a></li>
</body>
</html>
"""
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data(url)
selector_soup = BeautifulSoup(selectee, "html.parser")
expected_content = " ".join((selector_soup.select_one("h2").get_text(), selector_soup.select_one("p").get_text()))
assert result["doc_id"] == doc_id
assert result["data"] == [
{
"content": expected_content,
"meta_data": {"url": "https://docs.embedchain.ai/quickstart"},
}
]
def test_load_data_gets_child_links_recursively(loader, mocked_responses, mocker):
child_url = "https://docs.embedchain.ai/quickstart"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/">..</a></li>
<li><a href="/quickstart">.</a></li>
</body>
</html>
"""
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
child_url = "https://docs.embedchain.ai/introduction"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/">..</a></li>
<li><a href="/introduction">.</a></li>
</body>
</html>
"""
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
url = "https://docs.embedchain.ai/"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/quickstart">Quickstart</a></li>
<li><a href="/introduction">Introduction</a></li>
</body>
</html>
"""
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data(url)
assert result["doc_id"] == doc_id
expected_data = [
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/quickstart"}},
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/introduction"}},
]
assert all(item in expected_data for item in result["data"])
def test_load_data_fails_to_fetch_website(loader, mocked_responses, mocker):
child_url = "https://docs.embedchain.ai/introduction"
mocked_responses.get(child_url, status=404)
url = "https://docs.embedchain.ai/"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/introduction">Introduction</a></li>
</body>
</html>
"""
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data(url)
assert result["doc_id"] is doc_id
assert result["data"] == []
@pytest.fixture
def loader():
from embedchain.loaders.docs_site_loader import DocsSiteLoader
return DocsSiteLoader()
@pytest.fixture
def mocked_responses():
with responses.RequestsMock() as rsps:
yield rsps

View File

@@ -0,0 +1,39 @@
import hashlib
from unittest.mock import MagicMock, patch
import pytest
from embedchain.loaders.docx_file import DocxFileLoader
@pytest.fixture
def mock_docx2txt_loader():
with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader:
yield mock_loader
@pytest.fixture
def docx_file_loader():
return DocxFileLoader()
def test_load_data(mock_docx2txt_loader, docx_file_loader):
mock_url = "mock_docx_file.docx"
mock_loader = MagicMock()
mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})]
mock_docx2txt_loader.return_value = mock_loader
result = docx_file_loader.load_data(mock_url)
assert "doc_id" in result
assert "data" in result
expected_content = "Sample Docx Content"
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == "local"
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,85 @@
import os
from unittest.mock import MagicMock
import pytest
from dropbox.files import FileMetadata
from embedchain.loaders.dropbox import DropboxLoader
@pytest.fixture
def setup_dropbox_loader(mocker):
mock_dropbox = mocker.patch("dropbox.Dropbox")
mock_dbx = mocker.MagicMock()
mock_dropbox.return_value = mock_dbx
os.environ["DROPBOX_ACCESS_TOKEN"] = "test_token"
loader = DropboxLoader()
yield loader, mock_dbx
if "DROPBOX_ACCESS_TOKEN" in os.environ:
del os.environ["DROPBOX_ACCESS_TOKEN"]
def test_initialization(setup_dropbox_loader):
"""Test initialization of DropboxLoader."""
loader, _ = setup_dropbox_loader
assert loader is not None
def test_download_folder(setup_dropbox_loader, mocker):
"""Test downloading a folder."""
loader, mock_dbx = setup_dropbox_loader
mocker.patch("os.makedirs")
mocker.patch("os.path.join", return_value="mock/path")
mock_file_metadata = mocker.MagicMock(spec=FileMetadata)
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
entries = loader._download_folder("path/to/folder", "local_root")
assert entries is not None
def test_generate_dir_id_from_all_paths(setup_dropbox_loader, mocker):
"""Test directory ID generation."""
loader, mock_dbx = setup_dropbox_loader
mock_file_metadata = mocker.MagicMock(spec=FileMetadata, name="file.txt")
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
dir_id = loader._generate_dir_id_from_all_paths("path/to/folder")
assert dir_id is not None
assert len(dir_id) == 64
def test_clean_directory(setup_dropbox_loader, mocker):
"""Test cleaning up a directory."""
loader, _ = setup_dropbox_loader
mocker.patch("os.listdir", return_value=["file1", "file2"])
mocker.patch("os.remove")
mocker.patch("os.rmdir")
loader._clean_directory("path/to/folder")
def test_load_data(mocker, setup_dropbox_loader, tmp_path):
loader = setup_dropbox_loader[0]
mock_file_metadata = MagicMock(spec=FileMetadata, name="file.txt")
mocker.patch.object(loader.dbx, "files_list_folder", return_value=MagicMock(entries=[mock_file_metadata]))
mocker.patch.object(loader.dbx, "files_download_to_file")
# Mock DirectoryLoader
mock_data = {"data": "test_data"}
mocker.patch("embedchain.loaders.directory_loader.DirectoryLoader.load_data", return_value=mock_data)
test_dir = tmp_path / "dropbox_test"
test_dir.mkdir()
test_file = test_dir / "file.txt"
test_file.write_text("dummy content")
mocker.patch.object(loader, "_generate_dir_id_from_all_paths", return_value=str(test_dir))
result = loader.load_data("path/to/folder")
assert result == {"doc_id": mocker.ANY, "data": "test_data"}
loader.dbx.files_list_folder.assert_called_once_with("path/to/folder")

View File

@@ -0,0 +1,33 @@
import hashlib
from unittest.mock import patch
import pytest
from embedchain.loaders.excel_file import ExcelFileLoader
@pytest.fixture
def excel_file_loader():
return ExcelFileLoader()
def test_load_data(excel_file_loader):
mock_url = "mock_excel_file.xlsx"
expected_content = "Sample Excel Content"
# Mock the load_data method of the excel_file_loader instance
with patch.object(
excel_file_loader,
"load_data",
return_value={
"doc_id": hashlib.sha256((expected_content + mock_url).encode()).hexdigest(),
"data": [{"content": expected_content, "meta_data": {"url": mock_url}}],
},
):
result = excel_file_loader.load_data(mock_url)
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == mock_url
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,33 @@
import pytest
from embedchain.loaders.github import GithubLoader
@pytest.fixture
def mock_github_loader_config():
return {
"token": "your_mock_token",
}
@pytest.fixture
def mock_github_loader(mocker, mock_github_loader_config):
mock_github = mocker.patch("github.Github")
_ = mock_github.return_value
return GithubLoader(config=mock_github_loader_config)
def test_github_loader_init(mocker, mock_github_loader_config):
mock_github = mocker.patch("github.Github")
GithubLoader(config=mock_github_loader_config)
mock_github.assert_called_once_with("your_mock_token")
def test_github_loader_init_empty_config(mocker):
with pytest.raises(ValueError, match="requires a personal access token"):
GithubLoader()
def test_github_loader_init_missing_token():
with pytest.raises(ValueError, match="requires a personal access token"):
GithubLoader(config={})

View File

@@ -0,0 +1,43 @@
import pytest
from embedchain.loaders.gmail import GmailLoader
@pytest.fixture
def mock_beautifulsoup(mocker):
return mocker.patch("embedchain.loaders.gmail.BeautifulSoup", return_value=mocker.MagicMock())
@pytest.fixture
def gmail_loader(mock_beautifulsoup):
return GmailLoader()
def test_load_data_file_not_found(gmail_loader, mocker):
with pytest.raises(FileNotFoundError):
with mocker.patch("os.path.isfile", return_value=False):
gmail_loader.load_data("your_query")
@pytest.mark.skip(reason="TODO: Fix this test. Failing due to some googleapiclient import issue.")
def test_load_data(gmail_loader, mocker):
mock_gmail_reader_instance = mocker.MagicMock()
text = "your_test_email_text"
metadata = {
"id": "your_test_id",
"snippet": "your_test_snippet",
}
mock_gmail_reader_instance.load_data.return_value = [
{
"text": text,
"extra_info": metadata,
}
]
with mocker.patch("os.path.isfile", return_value=True):
response_data = gmail_loader.load_data("your_query")
assert "doc_id" in response_data
assert "data" in response_data
assert isinstance(response_data["doc_id"], str)
assert isinstance(response_data["data"], list)

View File

@@ -0,0 +1,37 @@
import pytest
from embedchain.loaders.google_drive import GoogleDriveLoader
@pytest.fixture
def google_drive_folder_loader():
return GoogleDriveLoader()
def test_load_data_invalid_drive_url(google_drive_folder_loader):
mock_invalid_drive_url = "https://example.com"
with pytest.raises(
ValueError,
match="The url provided https://example.com does not match a google drive folder url. Example "
"drive url: https://drive.google.com/drive/u/0/folders/xxxx",
):
google_drive_folder_loader.load_data(mock_invalid_drive_url)
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
def test_load_data_incorrect_drive_url(google_drive_folder_loader):
mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
with pytest.raises(
FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
):
google_drive_folder_loader.load_data(mock_invalid_drive_url)
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
def test_load_data(google_drive_folder_loader):
mock_valid_url = "YOUR_VALID_URL"
result = google_drive_folder_loader.load_data(mock_valid_url)
assert "doc_id" in result
assert "data" in result
assert "content" in result["data"][0]
assert "meta_data" in result["data"][0]

View File

@@ -0,0 +1,131 @@
import hashlib
import pytest
from embedchain.loaders.json import JSONLoader
def test_load_data(mocker):
content = "temp.json"
mock_document = {
"doc_id": hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest(),
"data": [
{"content": "content1", "meta_data": {"url": content}},
{"content": "content2", "meta_data": {"url": content}},
],
}
mocker.patch("embedchain.loaders.json.JSONLoader.load_data", return_value=mock_document)
json_loader = JSONLoader()
result = json_loader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content}},
{"content": "content2", "meta_data": {"url": content}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
def test_load_data_url(mocker):
content = "https://example.com/posts.json"
mocker.patch("os.path.isfile", return_value=False)
mocker.patch(
"embedchain.loaders.json.JSONReader.load_data",
return_value=[
{
"text": "content1",
},
{
"text": "content2",
},
],
)
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"document1": "content1", "document2": "content2"}
mocker.patch("requests.get", return_value=mock_response)
result = JSONLoader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content}},
{"content": "content2", "meta_data": {"url": content}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
def test_load_data_invalid_string_content(mocker):
mocker.patch("os.path.isfile", return_value=False)
mocker.patch("requests.get")
content = "123: 345}"
with pytest.raises(ValueError, match="Invalid content to load json data from"):
JSONLoader.load_data(content)
def test_load_data_invalid_url(mocker):
mocker.patch("os.path.isfile", return_value=False)
mock_response = mocker.Mock()
mock_response.status_code = 404
mocker.patch("requests.get", return_value=mock_response)
content = "http://invalid-url.com/"
with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
JSONLoader.load_data(content)
def test_load_data_from_json_string(mocker):
content = '{"foo": "bar"}'
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
mocker.patch("os.path.isfile", return_value=False)
mocker.patch(
"embedchain.loaders.json.JSONReader.load_data",
return_value=[
{
"text": "content1",
},
{
"text": "content2",
},
],
)
result = JSONLoader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content_url_str}},
{"content": "content2", "meta_data": {"url": content_url_str}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,32 @@
import hashlib
import pytest
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@pytest.fixture
def qna_pair_loader():
return LocalQnaPairLoader()
def test_load_data(qna_pair_loader):
question = "What is the capital of France?"
answer = "The capital of France is Paris."
content = (question, answer)
result = qna_pair_loader.load_data(content)
assert "doc_id" in result
assert "data" in result
url = "local"
expected_content = f"Q: {question}\nA: {answer}"
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == url
assert result["data"][0]["meta_data"]["question"] == question
expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,27 @@
import hashlib
import pytest
from embedchain.loaders.local_text import LocalTextLoader
@pytest.fixture
def text_loader():
return LocalTextLoader()
def test_load_data(text_loader):
mock_content = "This is a sample text content."
result = text_loader.load_data(mock_content)
assert "doc_id" in result
assert "data" in result
url = "local"
assert result["data"][0]["content"] == mock_content
assert result["data"][0]["meta_data"]["url"] == url
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,30 @@
import hashlib
from unittest.mock import mock_open, patch
import pytest
from embedchain.loaders.mdx import MdxLoader
@pytest.fixture
def mdx_loader():
return MdxLoader()
def test_load_data(mdx_loader):
mock_content = "Sample MDX Content"
# Mock open function to simulate file reading
with patch("builtins.open", mock_open(read_data=mock_content)):
url = "mock_file.mdx"
result = mdx_loader.load_data(url)
assert "doc_id" in result
assert "data" in result
assert result["data"][0]["content"] == mock_content
assert result["data"][0]["meta_data"]["url"] == url
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,77 @@
import hashlib
from unittest.mock import MagicMock
import pytest
from embedchain.loaders.mysql import MySQLLoader
@pytest.fixture
def mysql_loader(mocker):
with mocker.patch("mysql.connector.connection.MySQLConnection"):
config = {
"host": "localhost",
"port": "3306",
"user": "your_username",
"password": "your_password",
"database": "your_database",
}
loader = MySQLLoader(config=config)
yield loader
def test_mysql_loader_initialization(mysql_loader):
assert mysql_loader.config is not None
assert mysql_loader.connection is not None
assert mysql_loader.cursor is not None
def test_mysql_loader_invalid_config():
with pytest.raises(ValueError, match="Invalid sql config: None"):
MySQLLoader(config=None)
def test_mysql_loader_setup_loader_successful(mysql_loader):
assert mysql_loader.connection is not None
assert mysql_loader.cursor is not None
def test_mysql_loader_setup_loader_connection_error(mysql_loader, mocker):
mocker.patch("mysql.connector.connection.MySQLConnection", side_effect=IOError("Mocked connection error"))
with pytest.raises(ValueError, match="Unable to connect with the given config:"):
mysql_loader._setup_loader(config={})
def test_mysql_loader_check_query_successful(mysql_loader):
query = "SELECT * FROM table"
mysql_loader._check_query(query=query)
def test_mysql_loader_check_query_invalid(mysql_loader):
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
mysql_loader._check_query(query=123)
def test_mysql_loader_load_data_successful(mysql_loader, mocker):
mock_cursor = MagicMock()
mocker.patch.object(mysql_loader, "cursor", mock_cursor)
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
query = "SELECT * FROM table"
result = mysql_loader.load_data(query)
assert "doc_id" in result
assert "data" in result
assert len(result["data"]) == 2
assert result["data"][0]["meta_data"]["url"] == query
assert result["data"][1]["meta_data"]["url"] == query
doc_id = hashlib.sha256((query + ", ".join([d["content"] for d in result["data"]])).encode()).hexdigest()
assert result["doc_id"] == doc_id
assert mock_cursor.execute.called_with(query)
def test_mysql_loader_load_data_invalid_query(mysql_loader):
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
mysql_loader.load_data(query=123)

View File

@@ -0,0 +1,36 @@
import hashlib
import os
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.notion import NotionLoader
@pytest.fixture
def notion_loader():
with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}):
yield NotionLoader()
def test_load_data(notion_loader):
source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef"
mock_text = "This is a test page."
expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest()
expected_data = [
{
"content": mock_text,
"meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"}, # formatted_id
}
]
mock_page = Mock()
mock_page.text = mock_text
mock_documents = [mock_page]
with patch("embedchain.loaders.notion.NotionPageLoader") as mock_reader:
mock_reader.return_value.load_data.return_value = mock_documents
result = notion_loader.load_data(source)
assert result["doc_id"] == expected_doc_id
assert result["data"] == expected_data

View File

@@ -0,0 +1,26 @@
import pytest
from embedchain.loaders.openapi import OpenAPILoader
@pytest.fixture
def openapi_loader():
return OpenAPILoader()
def test_load_data(openapi_loader, mocker):
mocker.patch("builtins.open", mocker.mock_open(read_data="key1: value1\nkey2: value2"))
mocker.patch("hashlib.sha256", return_value=mocker.Mock(hexdigest=lambda: "mock_hash"))
file_path = "configs/openai_openapi.yaml"
result = openapi_loader.load_data(file_path)
expected_doc_id = "mock_hash"
expected_data = [
{"content": "key1: value1", "meta_data": {"url": file_path, "row": 1}},
{"content": "key2: value2", "meta_data": {"url": file_path, "row": 2}},
]
assert result["doc_id"] == expected_doc_id
assert result["data"] == expected_data

View File

@@ -0,0 +1,36 @@
import pytest
from langchain.schema import Document
def test_load_data(loader, mocker):
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
mocked_pypdfloader.return_value.load_and_split.return_value = [
Document(page_content="Page 0 Content", metadata={"source": "example.pdf", "page": 0}),
Document(page_content="Page 1 Content", metadata={"source": "example.pdf", "page": 1}),
]
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data("dummy_url")
assert result["doc_id"] is doc_id
assert result["data"] == [
{"content": "Page 0 Content", "meta_data": {"source": "example.pdf", "page": 0, "url": "dummy_url"}},
{"content": "Page 1 Content", "meta_data": {"source": "example.pdf", "page": 1, "url": "dummy_url"}},
]
def test_load_data_fails_to_find_data(loader, mocker):
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
mocked_pypdfloader.return_value.load_and_split.return_value = []
with pytest.raises(ValueError):
loader.load_data("dummy_url")
@pytest.fixture
def loader():
from embedchain.loaders.pdf_file import PdfFileLoader
return PdfFileLoader()

View File

@@ -0,0 +1,60 @@
from unittest.mock import MagicMock
import psycopg
import pytest
from embedchain.loaders.postgres import PostgresLoader
@pytest.fixture
def postgres_loader(mocker):
with mocker.patch.object(psycopg, "connect"):
config = {"url": "postgres://user:password@localhost:5432/database"}
loader = PostgresLoader(config=config)
yield loader
def test_postgres_loader_initialization(postgres_loader):
assert postgres_loader.connection is not None
assert postgres_loader.cursor is not None
def test_postgres_loader_invalid_config():
with pytest.raises(ValueError, match="Must provide the valid config. Received: None"):
PostgresLoader(config=None)
def test_load_data(postgres_loader, monkeypatch):
mock_cursor = MagicMock()
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
query = "SELECT * FROM table"
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
result = postgres_loader.load_data(query)
assert "doc_id" in result
assert "data" in result
assert len(result["data"]) == 2
assert result["data"][0]["meta_data"]["url"] == query
assert result["data"][1]["meta_data"]["url"] == query
assert mock_cursor.execute.called_with(query)
def test_load_data_exception(postgres_loader, monkeypatch):
mock_cursor = MagicMock()
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
_ = "SELECT * FROM table"
mock_cursor.execute.side_effect = Exception("Mocked exception")
with pytest.raises(
ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception"
):
postgres_loader.load_data("SELECT * FROM table")
def test_close_connection(postgres_loader):
postgres_loader.close_connection()
assert postgres_loader.cursor is None
assert postgres_loader.connection is None

View File

@@ -0,0 +1,47 @@
import pytest
from embedchain.loaders.slack import SlackLoader
@pytest.fixture
def slack_loader(mocker, monkeypatch):
# Mocking necessary dependencies
mocker.patch("slack_sdk.WebClient")
mocker.patch("ssl.create_default_context")
mocker.patch("certifi.where")
monkeypatch.setenv("SLACK_USER_TOKEN", "slack_user_token")
return SlackLoader()
def test_slack_loader_initialization(slack_loader):
assert slack_loader.client is not None
assert slack_loader.config == {"base_url": "https://www.slack.com/api/"}
def test_slack_loader_setup_loader(slack_loader):
slack_loader._setup_loader({"base_url": "https://custom.slack.api/"})
assert slack_loader.client is not None
def test_slack_loader_check_query(slack_loader):
valid_json_query = "test_query"
invalid_query = 123
slack_loader._check_query(valid_json_query)
with pytest.raises(ValueError):
slack_loader._check_query(invalid_query)
def test_slack_loader_load_data(slack_loader, mocker):
valid_json_query = "in:random"
mocker.patch.object(slack_loader.client, "search_messages", return_value={"messages": {}})
result = slack_loader.load_data(valid_json_query)
assert "doc_id" in result
assert "data" in result

View File

@@ -0,0 +1,117 @@
import hashlib
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.web_page import WebPageLoader
@pytest.fixture
def web_page_loader():
return WebPageLoader()
def test_load_data(web_page_loader):
page_url = "https://example.com/page"
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = """
<html>
<head>
<title>Test Page</title>
</head>
<body>
<div id="content">
<p>This is some test content.</p>
</div>
</body>
</html>
"""
with patch("embedchain.loaders.web_page.WebPageLoader._session.get", return_value=mock_response):
result = web_page_loader.load_data(page_url)
content = web_page_loader._get_clean_content(mock_response.content, page_url)
expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
expected_data = [
{
"content": content,
"meta_data": {
"url": page_url,
},
}
]
assert result["data"] == expected_data
def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
mock_html = """
<html>
<head>
<title>Sample HTML</title>
<style>
/* Stylesheet to be excluded */
.elementor-location-header {
background-color: #f0f0f0;
}
</style>
</head>
<body>
<header id="header">Header Content</header>
<nav class="nav">Nav Content</nav>
<aside>Aside Content</aside>
<form>Form Content</form>
<main>Main Content</main>
<footer class="footer">Footer Content</footer>
<script>Some Script</script>
<noscript>NoScript Content</noscript>
<svg>SVG Content</svg>
<canvas>Canvas Content</canvas>
<div id="sidebar">Sidebar Content</div>
<div id="main-navigation">Main Navigation Content</div>
<div id="menu-main-menu">Menu Main Menu Content</div>
<div class="header-sidebar-wrapper">Header Sidebar Wrapper Content</div>
<div class="blog-sidebar-wrapper">Blog Sidebar Wrapper Content</div>
<div class="related-posts">Related Posts Content</div>
</body>
</html>
"""
tags_to_exclude = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
classes_to_exclude = [
"elementor-location-header",
"navbar-header",
"nav",
"header-sidebar-wrapper",
"blog-sidebar-wrapper",
"related-posts",
]
content = web_page_loader._get_clean_content(mock_html, "https://example.com/page")
for tag in tags_to_exclude:
assert tag not in content
for id in ids_to_exclude:
assert id not in content
for class_name in classes_to_exclude:
assert class_name not in content
assert len(content) > 0

View File

@@ -0,0 +1,62 @@
import tempfile
import pytest
from embedchain.loaders.xml import XmlLoader
# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
<factbook>
<country>
<name>United States</name>
<capital>Washington, DC</capital>
<leader>Joe Biden</leader>
<sport>Baseball</sport>
</country>
<country>
<name>Canada</name>
<capital>Ottawa</capital>
<leader>Justin Trudeau</leader>
<sport>Hockey</sport>
</country>
<country>
<name>France</name>
<capital>Paris</capital>
<leader>Emmanuel Macron</leader>
<sport>Soccer</sport>
</country>
<country>
<name>Trinidad &amp; Tobado</name>
<capital>Port of Spain</capital>
<leader>Keith Rowley</leader>
<sport>Track &amp; Field</sport>
</country>
</factbook>"""
@pytest.mark.parametrize("xml", [SAMPLE_XML])
def test_load_data(xml: str):
"""
Test XML loader
Tests that XML file is loaded, metadata is correct and content is correct
"""
# Creating temporary XML file
with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
tmpfile.write(xml)
tmpfile.seek(0)
filename = tmpfile.name
# Loading CSV using XmlLoader
loader = XmlLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(data) == 1
assert "United States Washington, DC Joe Biden" in data[0]["content"]
assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
assert "France Paris Emmanuel Macron" in data[0]["content"]
assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
assert data[0]["meta_data"]["url"] == filename

View File

@@ -0,0 +1,53 @@
import hashlib
from unittest.mock import MagicMock, Mock, patch
import pytest
from embedchain.loaders.youtube_video import YoutubeVideoLoader
@pytest.fixture
def youtube_video_loader():
return YoutubeVideoLoader()
def test_load_data(youtube_video_loader):
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
mock_loader = Mock()
mock_page_content = "This is a YouTube video content."
mock_loader.load.return_value = [
MagicMock(
page_content=mock_page_content,
metadata={"url": video_url, "title": "Test Video"},
)
]
mock_transcript = [{"text": "sample text", "start": 0.0, "duration": 5.0}]
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader), patch(
"embedchain.loaders.youtube_video.YouTubeTranscriptApi.get_transcript", return_value=mock_transcript
):
result = youtube_video_loader.load_data(video_url)
expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
expected_data = [
{
"content": "This is a YouTube video content.",
"meta_data": {"url": video_url, "title": "Test Video", "transcript": "Unavailable"},
}
]
assert result["data"] == expected_data
def test_load_data_with_empty_doc(youtube_video_loader):
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
mock_loader = Mock()
mock_loader.load.return_value = []
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
with pytest.raises(ValueError):
youtube_video_loader.load_data(video_url)

View File

@@ -0,0 +1,91 @@
import pytest
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
# Fixture for creating an instance of ChatHistory
@pytest.fixture
def chat_memory_instance():
return ChatHistory()
def test_add_chat_memory(chat_memory_instance):
app_id = "test_app"
session_id = "test_session"
human_message = "Hello, how are you?"
ai_message = "I'm fine, thank you!"
chat_message = ChatMessage()
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, session_id, chat_message)
assert chat_memory_instance.count(app_id, session_id) == 1
chat_memory_instance.delete(app_id, session_id)
def test_get(chat_memory_instance):
app_id = "test_app"
session_id = "test_session"
for i in range(1, 7):
human_message = f"Question {i}"
ai_message = f"Answer {i}"
chat_message = ChatMessage()
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, session_id, chat_message)
recent_memories = chat_memory_instance.get(app_id, session_id, num_rounds=5)
assert len(recent_memories) == 5
all_memories = chat_memory_instance.get(app_id, fetch_all=True)
assert len(all_memories) == 6
def test_delete_chat_history(chat_memory_instance):
app_id = "test_app"
session_id = "test_session"
for i in range(1, 6):
human_message = f"Question {i}"
ai_message = f"Answer {i}"
chat_message = ChatMessage()
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, session_id, chat_message)
session_id_2 = "test_session_2"
for i in range(1, 6):
human_message = f"Question {i}"
ai_message = f"Answer {i}"
chat_message = ChatMessage()
chat_message.add_user_message(human_message)
chat_message.add_ai_message(ai_message)
chat_memory_instance.add(app_id, session_id_2, chat_message)
chat_memory_instance.delete(app_id, session_id)
assert chat_memory_instance.count(app_id, session_id) == 0
assert chat_memory_instance.count(app_id) == 5
chat_memory_instance.delete(app_id)
assert chat_memory_instance.count(app_id) == 0
@pytest.fixture
def close_connection(chat_memory_instance):
yield
chat_memory_instance.close_connection()

View File

@@ -0,0 +1,37 @@
from embedchain.memory.message import BaseMessage, ChatMessage
def test_ec_base_message():
content = "Hello, how are you?"
created_by = "human"
metadata = {"key": "value"}
message = BaseMessage(content=content, created_by=created_by, metadata=metadata)
assert message.content == content
assert message.created_by == created_by
assert message.metadata == metadata
assert message.type is None
assert message.is_lc_serializable() is True
assert str(message) == f"{created_by}: {content}"
def test_ec_base_chat_message():
human_message_content = "Hello, how are you?"
ai_message_content = "I'm fine, thank you!"
human_metadata = {"user": "John"}
ai_metadata = {"response_time": 0.5}
chat_message = ChatMessage()
chat_message.add_user_message(human_message_content, metadata=human_metadata)
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
assert chat_message.human_message.content == human_message_content
assert chat_message.human_message.created_by == "human"
assert chat_message.human_message.metadata == human_metadata
assert chat_message.ai_message.content == ai_message_content
assert chat_message.ai_message.created_by == "ai"
assert chat_message.ai_message.metadata == ai_metadata
assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"

View File

@@ -0,0 +1,30 @@
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
def test_subclass_types_in_data_type():
"""Test that all data type category subclasses are contained in the composite data type"""
# Check if DirectDataType values are in DataType
for data_type in DirectDataType:
assert data_type.value in DataType._value2member_map_
# Check if IndirectDataType values are in DataType
for data_type in IndirectDataType:
assert data_type.value in DataType._value2member_map_
# Check if SpecialDataType values are in DataType
for data_type in SpecialDataType:
assert data_type.value in DataType._value2member_map_
def test_data_type_in_subclasses():
"""Test that all data types in the composite data type are categorized in a subclass"""
for data_type in DataType:
if data_type.value in DirectDataType._value2member_map_:
assert data_type.value in DirectDataType._value2member_map_
elif data_type.value in IndirectDataType._value2member_map_:
assert data_type.value in IndirectDataType._value2member_map_
elif data_type.value in SpecialDataType._value2member_map_:
assert data_type.value in SpecialDataType._value2member_map_
else:
assert False, f"{data_type.value} not found in any subclass enums"

View File

@@ -0,0 +1,65 @@
import logging
import os
from embedchain.telemetry.posthog import AnonymousTelemetry
class TestAnonymousTelemetry:
def test_init(self, mocker):
# Enable telemetry specifically for this test
os.environ["EC_TELEMETRY"] = "true"
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
telemetry = AnonymousTelemetry()
assert telemetry.project_api_key == "phc_PHQDA5KwztijnSojsxJ2c1DuJd52QCzJzT2xnSGvjN2"
assert telemetry.host == "https://app.posthog.com"
assert telemetry.enabled is True
assert telemetry.user_id
mock_posthog.assert_called_once_with(project_api_key=telemetry.project_api_key, host=telemetry.host)
def test_init_with_disabled_telemetry(self, mocker):
mocker.patch("embedchain.telemetry.posthog.Posthog")
telemetry = AnonymousTelemetry()
assert telemetry.enabled is False
assert telemetry.posthog.disabled is True
def test_get_user_id(self, mocker, tmpdir):
mock_uuid = mocker.patch("embedchain.telemetry.posthog.uuid.uuid4")
mock_uuid.return_value = "unique_user_id"
config_file = tmpdir.join("config.json")
mocker.patch("embedchain.telemetry.posthog.CONFIG_FILE", str(config_file))
telemetry = AnonymousTelemetry()
user_id = telemetry._get_user_id()
assert user_id == "unique_user_id"
assert config_file.read() == '{"user_id": "unique_user_id"}'
def test_capture(self, mocker):
# Enable telemetry specifically for this test
os.environ["EC_TELEMETRY"] = "true"
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
telemetry = AnonymousTelemetry()
event_name = "test_event"
properties = {"key": "value"}
telemetry.capture(event_name, properties)
mock_posthog.assert_called_once_with(
project_api_key=telemetry.project_api_key,
host=telemetry.host,
)
mock_posthog.return_value.capture.assert_called_once_with(
telemetry.user_id,
event_name,
properties,
)
def test_capture_with_exception(self, mocker, caplog):
os.environ["EC_TELEMETRY"] = "true"
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
mock_posthog.return_value.capture.side_effect = Exception("Test Exception")
telemetry = AnonymousTelemetry()
event_name = "test_event"
properties = {"key": "value"}
with caplog.at_level(logging.ERROR):
telemetry.capture(event_name, properties)
assert "Failed to send telemetry event" in caplog.text
caplog.clear()

View File

@@ -0,0 +1,111 @@
import os
import pytest
import yaml
from embedchain import App
from embedchain.config import ChromaDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_BASE"] = "test_api_base"
return App()
def test_app(app):
assert isinstance(app.llm, BaseLlm)
assert isinstance(app.db, BaseVectorDB)
assert isinstance(app.embedding_model, BaseEmbedder)
class TestConfigForAppComponents:
def test_constructor_config(self):
collection_name = "my-test-collection"
db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
app = App(db=db)
assert app.db.config.collection_name == collection_name
def test_component_config(self):
collection_name = "my-test-collection"
database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
app = App(db=database)
assert app.db.config.collection_name == collection_name
class TestAppFromConfig:
def load_config_data(self, yaml_path):
with open(yaml_path, "r") as file:
return yaml.safe_load(file)
def test_from_chroma_config(self, mocker):
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
yaml_path = "configs/chroma.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(config_path=yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
# Even though not present in the config, the default value is used
assert app.config.collect_metrics is True
# Validate the LLM config values
llm_config = config_data["llm"]["config"]
assert app.llm.config.temperature == llm_config["temperature"]
assert app.llm.config.max_tokens == llm_config["max_tokens"]
assert app.llm.config.top_p == llm_config["top_p"]
assert app.llm.config.stream == llm_config["stream"]
# Validate the VectorDB config values
db_config = config_data["vectordb"]["config"]
assert app.db.config.collection_name == db_config["collection_name"]
assert app.db.config.dir == db_config["dir"]
assert app.db.config.allow_reset == db_config["allow_reset"]
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedding_model.config.model == embedder_config["model"]
assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name")
def test_from_opensource_config(self, mocker):
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
yaml_path = "configs/opensource.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
# Validate the LLM config values
llm_config = config_data["llm"]["config"]
assert app.llm.config.model == llm_config["model"]
assert app.llm.config.temperature == llm_config["temperature"]
assert app.llm.config.max_tokens == llm_config["max_tokens"]
assert app.llm.config.top_p == llm_config["top_p"]
assert app.llm.config.stream == llm_config["stream"]
# Validate the VectorDB config values
db_config = config_data["vectordb"]["config"]
assert app.db.config.collection_name == db_config["collection_name"]
assert app.db.config.dir == db_config["dir"]
assert app.db.config.allow_reset == db_config["allow_reset"]
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"]

View File

@@ -0,0 +1,53 @@
import pytest
from embedchain import Client
class TestClient:
@pytest.fixture
def mock_requests_post(self, mocker):
return mocker.patch("embedchain.client.requests.post")
def test_valid_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
client = Client(api_key="valid_api_key")
assert client.check("valid_api_key") is True
def test_invalid_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 401
with pytest.raises(ValueError):
Client(api_key="invalid_api_key")
def test_update_valid_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
client = Client(api_key="valid_api_key")
client.update("new_valid_api_key")
assert client.get() == "new_valid_api_key"
def test_clear_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
client = Client(api_key="valid_api_key")
client.clear()
assert client.get() is None
def test_save_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
api_key_to_save = "valid_api_key"
client = Client(api_key=api_key_to_save)
client.save()
assert client.get() == api_key_to_save
def test_load_api_key_from_config(self, mocker):
mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"})
client = Client()
assert client.get() == "test_api_key"
def test_load_invalid_api_key_from_config(self, mocker):
mocker.patch("embedchain.Client.load_config", return_value={})
with pytest.raises(ValueError):
Client()
def test_load_missing_api_key_from_config(self, mocker):
mocker.patch("embedchain.Client.load_config", return_value={})
with pytest.raises(ValueError):
Client()

View File

@@ -0,0 +1,66 @@
import os
import pytest
import embedchain
import embedchain.embedder.gpt4all
import embedchain.embedder.huggingface
import embedchain.embedder.openai
import embedchain.embedder.vertexai
import embedchain.llm.anthropic
import embedchain.llm.openai
import embedchain.vectordb.chroma
import embedchain.vectordb.elasticsearch
import embedchain.vectordb.opensearch
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
class TestFactories:
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("openai", {}, embedchain.llm.openai.OpenAILlm),
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
],
)
def test_llm_factory_create(self, provider_name, config_data, expected_class):
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["OPENAI_API_BASE"] = "test_api_base"
llm_instance = LlmFactory.create(provider_name, config_data)
assert isinstance(llm_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
(
"huggingface",
{"model": "sentence-transformers/all-mpnet-base-v2", "vector_dimension": 768},
embedchain.embedder.huggingface.HuggingFaceEmbedder,
),
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAIEmbedder),
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
],
)
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.embedder.vertexai.VertexAIEmbedder", autospec=True)
embedder_instance = EmbedderFactory.create(provider_name, config_data)
assert isinstance(embedder_instance, expected_class)
@pytest.mark.parametrize(
"provider_name, config_data, expected_class",
[
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
(
"opensearch",
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
embedchain.vectordb.opensearch.OpenSearchDB,
),
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
],
)
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
assert isinstance(vectordb_instance, expected_class)

View File

@@ -0,0 +1,38 @@
import yaml
from embedchain.utils.misc import validate_config
CONFIG_YAMLS = [
"configs/anthropic.yaml",
"configs/azure_openai.yaml",
"configs/chroma.yaml",
"configs/chunker.yaml",
"configs/cohere.yaml",
"configs/together.yaml",
"configs/ollama.yaml",
"configs/full-stack.yaml",
"configs/gpt4.yaml",
"configs/gpt4all.yaml",
"configs/huggingface.yaml",
"configs/jina.yaml",
"configs/llama2.yaml",
"configs/opensearch.yaml",
"configs/opensource.yaml",
"configs/pinecone.yaml",
"configs/vertexai.yaml",
"configs/weaviate.yaml",
]
def test_all_config_yamls():
"""Test that all config yamls are valid."""
for config_yaml in CONFIG_YAMLS:
with open(config_yaml, "r") as f:
config = yaml.safe_load(f)
assert config is not None
try:
validate_config(config)
except Exception as e:
print(f"Error in {config_yaml}: {e}")
raise e

View File

@@ -0,0 +1,253 @@
import os
import shutil
from unittest.mock import patch
import pytest
from chromadb.config import Settings
from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.vectordb.chroma import ChromaDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
@pytest.fixture
def chroma_db():
return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
@pytest.fixture
def app_with_settings():
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
chroma_db = ChromaDB(config=chroma_config)
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db=chroma_db)
@pytest.fixture(scope="session", autouse=True)
def cleanup_db():
yield
try:
shutil.rmtree("test-db")
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_chroma_db_init_with_host_and_port(mock_client):
chroma_db = ChromaDB(config=ChromaDbConfig(host="test-host", port="1234")) # noqa
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == "test-host"
assert called_settings.chroma_server_http_port == "1234"
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_chroma_db_init_with_basic_auth(mock_client):
chroma_config = {
"host": "test-host",
"port": "1234",
"chroma_settings": {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
},
}
ChromaDB(config=ChromaDbConfig(**chroma_config))
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == "test-host"
assert called_settings.chroma_server_http_port == "1234"
assert (
called_settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
)
assert (
called_settings.chroma_client_auth_credentials
== chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
)
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port(mock_client):
host = "test-host"
port = "1234"
config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port)
db = ChromaDB(config=db_config)
_app = App(config=config, db=db)
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == host
assert called_settings.chroma_server_http_port == port
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port_none(mock_client):
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
_app = App(config=AppConfig(collect_metrics=False), db=db)
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host is None
assert called_settings.chroma_server_http_port is None
def test_chroma_db_duplicates_throw_warning(caplog):
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text
app.db.reset()
def test_chroma_db_duplicates_collections_no_warning(caplog):
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text
assert "Add of existing embedding ID: 0" not in caplog.text
app.db.reset()
app.set_collection_name("test_collection_1")
app.db.reset()
def test_chroma_db_collection_init_with_default_collection():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
assert app.db.collection.name == "embedchain_store"
def test_chroma_db_collection_init_with_custom_collection():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_set_collection_name():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_changes_encapsulated():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
assert app.db.count() == 1
app.set_collection_name("test_collection_2")
assert app.db.count() == 0
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
app.set_collection_name("test_collection_2")
app.db.reset()
def test_chroma_db_collection_collections_are_persistent():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
def test_chroma_db_collection_parallel_collections():
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
app1 = App(
config=AppConfig(collect_metrics=False),
db=db1,
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
app2 = App(
config=AppConfig(collect_metrics=False),
db=db2,
)
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
assert app1.db.count() == 1
assert app2.db.count() == 0
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
app1.set_collection_name("test_collection_2")
assert app1.db.count() == 1
app2.set_collection_name("test_collection_1")
assert app2.db.count() == 3
# cleanup
app1.db.reset()
app2.db.reset()
def test_chroma_db_collection_ids_share_collections():
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("one_collection")
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
assert app1.db.count() == 3
assert app2.db.count() == 3
# cleanup
app1.db.reset()
app2.db.reset()
def test_chroma_db_collection_reset():
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("two_collection")
db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
app3.set_collection_name("three_collection")
db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
app4.set_collection_name("four_collection")
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
app1.db.reset()
assert app1.db.count() == 0
assert app2.db.count() == 1
assert app3.db.count() == 1
assert app4.db.count() == 1
# cleanup
app2.db.reset()
app3.db.reset()
app4.db.reset()

View File

@@ -0,0 +1,86 @@
import os
import unittest
from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig, ElasticsearchDBConfig
from embedchain.embedder.gpt4all import GPT4AllEmbedder
from embedchain.vectordb.elasticsearch import ElasticsearchDB
class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
self.vector_dim = 384
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
# Create some dummy data
documents = ["This is a document.", "This is another document."]
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
ids = ["doc_1", "doc_2"]
# Add the data to the database.
self.db.add(documents, metadatas, ids)
search_response = {
"hits": {
"hits": [
{
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
"_score": 0.9,
},
{
"_source": {
"text": "This is another document.",
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
},
"_score": 0.8,
},
]
}
}
# Configure the mock client to return the mocked response.
mock_client.return_value.search.return_value = search_response
# Query the database for the documents that are most similar to the query "This is a document".
query = "This is a document"
results_without_citations = self.db.query(query, n_results=2, where={})
expected_results_without_citations = ["This is a document.", "This is another document."]
self.assertEqual(results_without_citations, expected_results_without_citations)
results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
expected_results_with_citations = [
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
]
self.assertEqual(results_with_citations, expected_results_with_citations)
def test_init_without_url(self):
# Make sure it's not loaded from env
try:
del os.environ["ELASTICSEARCH_URL"]
except KeyError:
pass
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(AttributeError):
ElasticsearchDB()
def test_init_with_invalid_es_config(self):
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(TypeError):
ElasticsearchDB(es_config={"ES_URL": "some_url", "valid es_config": False})

View File

@@ -0,0 +1,215 @@
import os
import shutil
import pytest
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vector_db.lancedb import LanceDBConfig
from embedchain.vectordb.lancedb import LanceDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
@pytest.fixture
def lancedb():
return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
@pytest.fixture
def app_with_settings():
lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
lancedb = LanceDB(config=lancedb_config)
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db=lancedb)
@pytest.fixture(scope="session", autouse=True)
def cleanup_db():
yield
try:
shutil.rmtree("test-db.lance")
shutil.rmtree("test-db-reset.lance")
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
def test_lancedb_duplicates_throw_warning(caplog):
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert "Insert of existing doc ID: 0" not in caplog.text
assert "Add of existing doc ID: 0" not in caplog.text
app.db.reset()
def test_lancedb_duplicates_collections_no_warning(caplog):
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app.set_collection_name("test_collection_2")
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert "Insert of existing doc ID: 0" not in caplog.text
assert "Add of existing doc ID: 0" not in caplog.text
app.db.reset()
app.set_collection_name("test_collection_1")
app.db.reset()
def test_lancedb_collection_init_with_default_collection():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
assert app.db.collection.name == "embedchain_store"
def test_lancedb_collection_init_with_custom_collection():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_lancedb_collection_set_collection_name():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_lancedb_collection_changes_encapsulated():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert app.db.count() == 1
app.set_collection_name("test_collection_2")
assert app.db.count() == 0
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
app.set_collection_name("test_collection_2")
app.db.reset()
def test_lancedb_collection_collections_are_persistent():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
del app
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
def test_lancedb_collection_parallel_collections():
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
app1 = App(
config=AppConfig(collect_metrics=False),
db=db1,
)
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
app2 = App(
config=AppConfig(collect_metrics=False),
db=db2,
)
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert app1.db.count() == 1
assert app2.db.count() == 0
app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app1.set_collection_name("test_collection_2")
assert app1.db.count() == 1
app2.set_collection_name("test_collection_1")
assert app2.db.count() == 3
# cleanup
app1.db.reset()
app2.db.reset()
def test_lancedb_collection_ids_share_collections():
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("one_collection")
# cleanup
app1.db.reset()
app2.db.reset()
app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
assert app1.db.count() == 2
assert app2.db.count() == 3
# cleanup
app1.db.reset()
app2.db.reset()
def test_lancedb_collection_reset():
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("two_collection")
db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
app3.set_collection_name("three_collection")
db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
app4.set_collection_name("four_collection")
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
app3.db.reset()
app4.db.reset()
app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
app1.db.reset()
assert app1.db.count() == 0
assert app2.db.count() == 1
assert app3.db.count() == 1
assert app4.db.count() == 1
# cleanup
app2.db.reset()
app3.db.reset()
app4.db.reset()
def generate_embeddings(dummy_embed, embed_size):
generated_embedding = []
for i in range(embed_size):
generated_embedding.append(dummy_embed)
return generated_embedding

View File

@@ -0,0 +1,225 @@
import pytest
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.vectordb.pinecone import PineconeDB
@pytest.fixture
def pinecone_pod_config():
return PineconeDBConfig(
index_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
)
@pytest.fixture
def pinecone_serverless_config():
return PineconeDBConfig(
index_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
serverless_config={
"cloud": "test_cloud",
"region": "test_region",
},
)
def test_pinecone_init_without_config(monkeypatch):
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB()
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
monkeypatch.delenv("PINECONE_API_KEY")
def test_pinecone_init_with_config(pinecone_pod_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
class MockListIndexes:
def names(self):
return ["test_collection"]
class MockPineconeIndex:
db = []
def __init__(*args, **kwargs):
pass
def upsert(self, chunk, **kwargs):
self.db.extend([c for c in chunk])
return
def delete(self, *args, **kwargs):
pass
def query(self, *args, **kwargs):
return {
"matches": [
{
"metadata": {
"key": "value",
"text": "text_1",
},
"score": 0.1,
},
{
"metadata": {
"key": "value",
"text": "text_2",
},
"score": 0.2,
},
]
}
def fetch(self, *args, **kwargs):
return {
"vectors": {
"key_1": {
"metadata": {
"source": "1",
}
},
"key_2": {
"metadata": {
"source": "2",
}
},
}
}
def describe_index_stats(self, *args, **kwargs):
return {"total_vector_count": len(self.db)}
class MockPineconeClient:
def __init__(*args, **kwargs):
pass
def list_indexes(self):
return MockListIndexes()
def create_index(self, *args, **kwargs):
pass
def Index(self, *args, **kwargs):
return MockPineconeIndex()
def delete_index(self, *args, **kwargs):
pass
class MockPinecone:
def __init__(*args, **kwargs):
pass
def Pinecone(*args, **kwargs):
return MockPineconeClient()
def PodSpec(*args, **kwargs):
pass
def ServerlessSpec(*args, **kwargs):
pass
class MockEmbedder:
def embedding_fn(self, documents):
return [[1, 1, 1] for d in documents]
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
pinecone_db = PineconeDB(config=pinecone_pod_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test_collection"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
pinecone_db = PineconeDB(config=pinecone_serverless_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test_collection"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
def test_get(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
db.pinecone_index = MockPineconeIndex()
return db
pinecone_db = mock_pinecone_db()
ids = pinecone_db.get(["key_1", "key_2"])
assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
def test_add(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
db.pinecone_index = MockPineconeIndex()
db._set_embedder(MockEmbedder())
return db
pinecone_db = mock_pinecone_db()
pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
assert pinecone_db.count() == 2
pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
assert pinecone_db.count() == 4
def test_query(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
db.pinecone_index = MockPineconeIndex()
db._set_embedder(MockEmbedder())
return db
pinecone_db = mock_pinecone_db()
# without citations
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
assert results == ["text_1", "text_2"]
# with citations
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
assert results == [
("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
]

View File

@@ -0,0 +1,167 @@
import unittest
import uuid
from mock import patch
from qdrant_client.http import models
from qdrant_client.http.models import Batch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.qdrant import QdrantDB
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
"""A mock embedding function."""
return [[1, 2, 3], [4, 5, 6]]
class TestQdrantDB(unittest.TestCase):
TEST_UUIDS = ["abc", "def", "ghi"]
def test_incorrect_config_throws_error(self):
"""Test the init method of the Qdrant class throws error for incorrect config"""
with self.assertRaises(TypeError):
QdrantDB(config=PineconeDBConfig())
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_initialize(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
self.assertEqual(db.collection_name, "embedchain-store-1536")
self.assertEqual(db.client, qdrant_client_mock.return_value)
qdrant_client_mock.return_value.get_collections.assert_called_once()
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_get(self, qdrant_client_mock):
qdrant_client_mock.return_value.scroll.return_value = ([], None)
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
resp = db.get(ids=[], where={})
self.assertEqual(resp, {"ids": [], "metadatas": []})
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
self.assertEqual(resp2, {"ids": [], "metadatas": []})
@patch("embedchain.vectordb.qdrant.QdrantClient")
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
def test_add(self, uuid_mock, qdrant_client_mock):
qdrant_client_mock.return_value.scroll.return_value = ([], None)
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
documents = ["This is a test document.", "This is another test document."]
metadatas = [{}, {}]
ids = ["123", "456"]
db.add(documents, metadatas, ids)
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1536",
points=Batch(
ids=["123", "456"],
payloads=[
{
"identifier": "123",
"text": "This is a test document.",
"metadata": {"text": "This is a test document."},
},
{
"identifier": "456",
"text": "This is another test document.",
"metadata": {"text": "This is another test document."},
},
],
vectors=[[1, 2, 3], [4, 5, 6]],
),
)
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_query(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
qdrant_client_mock.return_value.search.assert_called_once_with(
collection_name="embedchain-store-1536",
query_filter=models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(
value="123",
),
)
]
),
query_vector=[1, 2, 3],
limit=1,
)
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_count(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
db.count()
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_reset(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
db.reset()
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
collection_name="embedchain-store-1536"
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,237 @@
import unittest
from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.weaviate import WeaviateDB
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
"""A mock embedding function."""
return [[1, 2, 3], [4, 5, 6]]
class TestWeaviateDb(unittest.TestCase):
def test_incorrect_config_throws_error(self):
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
with self.assertRaises(TypeError):
WeaviateDB(config=PineconeDBConfig())
@patch("embedchain.vectordb.weaviate.weaviate")
def test_initialize(self, weaviate_mock):
"""Test the init method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
weaviate_client_schema_mock = weaviate_client_mock.schema
# Mock that schema doesn't already exist so that a new schema is created
weaviate_client_schema_mock.exists.return_value = False
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
expected_class_obj = {
"classes": [
{
"class": "Embedchain_store_1536",
"vectorizer": "none",
"properties": [
{
"name": "identifier",
"dataType": ["text"],
},
{
"name": "text",
"dataType": ["text"],
},
{
"name": "metadata",
"dataType": ["Embedchain_store_1536_metadata"],
},
],
},
{
"class": "Embedchain_store_1536_metadata",
"vectorizer": "none",
"properties": [
{
"name": "data_type",
"dataType": ["text"],
},
{
"name": "doc_id",
"dataType": ["text"],
},
{
"name": "url",
"dataType": ["text"],
},
{
"name": "hash",
"dataType": ["text"],
},
{
"name": "app_id",
"dataType": ["text"],
},
],
},
]
}
# Assert that the Weaviate client was initialized
weaviate_mock.Client.assert_called_once()
self.assertEqual(db.index_name, "Embedchain_store_1536")
weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
@patch("embedchain.vectordb.weaviate.weaviate")
def test_get_or_create_db(self, weaviate_mock):
"""Test the _get_or_create_db method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
expected_client = db._get_or_create_db()
self.assertEqual(expected_client, weaviate_client_mock)
@patch("embedchain.vectordb.weaviate.weaviate")
def test_add(self, weaviate_mock):
"""Test the add method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
weaviate_client_batch_mock = weaviate_client_mock.batch
weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
documents = ["This is test document"]
metadatas = [None]
ids = ["id_1"]
db.add(documents, metadatas, ids)
# Check if the document was added to the database.
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
)
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
data_object={"text": documents[0]},
class_name="Embedchain_store_1536_metadata",
vector=[1, 2, 3],
)
@patch("embedchain.vectordb.weaviate.weaviate")
def test_query_without_where(self, weaviate_mock):
"""Test the query method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
weaviate_client_query_mock = weaviate_client_mock.query
weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query="This is a test document.", n_results=1, where={})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
def test_query_with_where(self, weaviate_mock):
"""Test the query method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
weaviate_client_query_mock = weaviate_client_mock.query
weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
weaviate_client_query_get_mock.with_where.assert_called_once_with(
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
)
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
def test_reset(self, weaviate_mock):
"""Test the reset method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
weaviate_client_batch_mock = weaviate_client_mock.batch
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
# Reset the database.
db.reset()
weaviate_client_batch_mock.delete_objects.assert_called_once_with(
"Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
)
@patch("embedchain.vectordb.weaviate.weaviate")
def test_count(self, weaviate_mock):
"""Test the reset method of the WeaviateDb class."""
weaviate_client_mock = weaviate_mock.Client.return_value
weaviate_client_query = weaviate_client_mock.query
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
# Reset the database.
db.count()
weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")

View File

@@ -0,0 +1,168 @@
# ruff: noqa: E501
import os
from unittest import mock
from unittest.mock import Mock, patch
import pytest
from embedchain.config import ZillizDBConfig
from embedchain.vectordb.zilliz import ZillizVectorDB
# to run tests, provide the URI and TOKEN in .env file
class TestZillizVectorDBConfig:
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_with_uri_and_token(self):
"""
Test if the `ZillizVectorDBConfig` instance is initialized with the correct uri and token values.
"""
# Create a ZillizDBConfig instance with mocked values
expected_uri = "mocked_uri"
expected_token = "mocked_token"
db_config = ZillizDBConfig()
# Assert that the values in the ZillizVectorDB instance match the mocked values
assert db_config.uri == expected_uri
assert db_config.token == expected_token
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_without_uri(self):
"""
Test if the `ZillizVectorDBConfig` instance throws an error when no URI found.
"""
try:
del os.environ["ZILLIZ_CLOUD_URI"]
except KeyError:
pass
with pytest.raises(AttributeError):
ZillizDBConfig()
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_without_token(self):
"""
Test if the `ZillizVectorDBConfig` instance throws an error when no Token found.
"""
try:
del os.environ["ZILLIZ_CLOUD_TOKEN"]
except KeyError:
pass
# Test if an exception is raised when ZILLIZ_CLOUD_TOKEN is missing
with pytest.raises(AttributeError):
ZillizDBConfig()
class TestZillizVectorDB:
@pytest.fixture
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def mock_config(self, mocker):
return mocker.Mock(spec=ZillizDBConfig())
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections.connect", autospec=True)
def test_zilliz_vector_db_setup(self, mock_connect, mock_client, mock_config):
"""
Test if the `ZillizVectorDB` instance is initialized with the correct uri and token values.
"""
# Create an instance of ZillizVectorDB with the mock config
# zilliz_db = ZillizVectorDB(config=mock_config)
ZillizVectorDB(config=mock_config)
# Assert that the MilvusClient and connections.connect were called
mock_client.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
mock_connect.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
class TestZillizDBCollection:
@pytest.fixture
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def mock_config(self, mocker):
return mocker.Mock(spec=ZillizDBConfig())
@pytest.fixture
def mock_embedder(self, mocker):
return mocker.Mock()
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_with_default_collection(self):
"""
Test if the `ZillizVectorDB` instance is initialized with the correct default collection name.
"""
# Create a ZillizDBConfig instance
db_config = ZillizDBConfig()
assert db_config.collection_name == "embedchain_store"
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
def test_init_with_custom_collection(self):
"""
Test if the `ZillizVectorDB` instance is initialized with the correct custom collection name.
"""
# Create a ZillizDBConfig instance with mocked values
expected_collection = "test_collection"
db_config = ZillizDBConfig(collection_name="test_collection")
assert db_config.collection_name == expected_collection
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
def test_query(self, mock_connect, mock_client, mock_embedder, mock_config):
# Create an instance of ZillizVectorDB with mock config
zilliz_db = ZillizVectorDB(config=mock_config)
# Add a 'embedder' attribute to the ZillizVectorDB instance for testing
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
assert zilliz_db.client == mock_client()
# Mock the MilvusClient search method
with patch.object(zilliz_db.client, "search") as mock_search:
# Mock the embedding function
mock_embedder.embedding_fn.return_value = ["query_vector"]
# Mock the search result
mock_search.return_value = [
[
{
"distance": 0.0,
"entity": {
"text": "result_doc",
"embeddings": [1, 2, 3],
"metadata": {"url": "url_1", "doc_id": "doc_id_1"},
},
}
]
]
query_result = zilliz_db.query(input_query="query_text", n_results=1, where={})
# Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_with(
collection_name=mock_config.collection_name,
data=["query_vector"],
filter="",
limit=1,
output_fields=["*"],
)
# Assert that the query result matches the expected result
assert query_result == ["result_doc"]
query_result_with_citations = zilliz_db.query(
input_query="query_text", n_results=1, where={}, citations=True
)
mock_search.assert_called_with(
collection_name=mock_config.collection_name,
data=["query_vector"],
filter="",
limit=1,
output_fields=["*"],
)
assert query_result_with_citations == [("result_doc", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.0})]