Rename embedchain to mem0 and open sourcing code for long term memory (#1474)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Taranjeet Singh
2024-07-12 07:51:33 -07:00
committed by GitHub
parent 83e8c97295
commit f842a92e25
665 changed files with 9427 additions and 6592 deletions

View File

@@ -0,0 +1,100 @@
import hashlib
import os
import sys
from unittest.mock import mock_open, patch
import pytest
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
from deepgram import PrerecordedOptions
from embedchain.loaders.audio import AudioLoader
@pytest.fixture
def setup_audio_loader(mocker):
mock_dropbox = mocker.patch("deepgram.DeepgramClient")
mock_dbx = mocker.MagicMock()
mock_dropbox.return_value = mock_dbx
os.environ["DEEPGRAM_API_KEY"] = "test_key"
loader = AudioLoader()
loader.client = mock_dbx
yield loader, mock_dbx
if "DEEPGRAM_API_KEY" in os.environ:
del os.environ["DEEPGRAM_API_KEY"]
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_initialization(setup_audio_loader):
"""Test initialization of AudioLoader."""
loader, _ = setup_audio_loader
assert loader is not None
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_load_data_from_url(setup_audio_loader):
loader, mock_dbx = setup_audio_loader
url = "https://example.com/audio.mp3"
expected_content = "This is a test audio transcript."
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
result = loader.load_data(url)
doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
expected_result = {
"doc_id": doc_id,
"data": [
{
"content": expected_content,
"meta_data": {"url": url},
}
],
}
assert result == expected_result
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
{"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
)
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
) # as `match` statement was introduced in python 3.10
def test_load_data_from_file(setup_audio_loader):
loader, mock_dbx = setup_audio_loader
file_path = "local_audio.mp3"
expected_content = "This is a test audio transcript."
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
# Mock the file reading functionality
with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
result = loader.load_data(file_path)
doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
expected_result = {
"doc_id": doc_id,
"data": [
{
"content": expected_content,
"meta_data": {"url": file_path},
}
],
}
assert result == expected_result
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
{"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
)

View File

@@ -0,0 +1,113 @@
import csv
import os
import pathlib
import tempfile
from unittest.mock import MagicMock, patch
import pytest
from embedchain.loaders.csv import CsvLoader
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
def test_load_data(delimiter):
"""
Test csv loader
Tests that file is loaded, metadata is correct and content is correct
"""
# Creating temporary CSV file
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
writer = csv.writer(tmpfile, delimiter=delimiter)
writer.writerow(["Name", "Age", "Occupation"])
writer.writerow(["Alice", "28", "Engineer"])
writer.writerow(["Bob", "35", "Doctor"])
writer.writerow(["Charlie", "22", "Student"])
tmpfile.seek(0)
filename = tmpfile.name
# Loading CSV using CsvLoader
loader = CsvLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(data) == 3
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert data[0]["meta_data"]["url"] == filename
assert data[0]["meta_data"]["row"] == 1
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert data[1]["meta_data"]["url"] == filename
assert data[1]["meta_data"]["row"] == 2
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert data[2]["meta_data"]["url"] == filename
assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file
os.unlink(filename)
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
def test_load_data_with_file_uri(delimiter):
"""
Test csv loader with file URI
Tests that file is loaded, metadata is correct and content is correct
"""
# Creating temporary CSV file
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
writer = csv.writer(tmpfile, delimiter=delimiter)
writer.writerow(["Name", "Age", "Occupation"])
writer.writerow(["Alice", "28", "Engineer"])
writer.writerow(["Bob", "35", "Doctor"])
writer.writerow(["Charlie", "22", "Student"])
tmpfile.seek(0)
filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI
# Loading CSV using CsvLoader
loader = CsvLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(data) == 3
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert data[0]["meta_data"]["url"] == filename
assert data[0]["meta_data"]["row"] == 1
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert data[1]["meta_data"]["url"] == filename
assert data[1]["meta_data"]["row"] == 2
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert data[2]["meta_data"]["url"] == filename
assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file
os.unlink(tmpfile.name)
@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"])
def test_get_file_content(content):
with pytest.raises(ValueError):
loader = CsvLoader()
loader._get_file_content(content)
@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"])
def test_get_file_content_http(content):
"""
Test _get_file_content method of CsvLoader for http and https URLs
"""
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student"
mock_get.return_value = mock_response
loader = CsvLoader()
file_content = loader._get_file_content(content)
mock_get.assert_called_once_with(content)
mock_response.raise_for_status.assert_called_once()
assert file_content.read() == mock_response.text

View File

@@ -0,0 +1,104 @@
import pytest
import requests
from embedchain.loaders.discourse import DiscourseLoader
@pytest.fixture
def discourse_loader_config():
return {
"domain": "https://example.com/",
}
@pytest.fixture
def discourse_loader(discourse_loader_config):
return DiscourseLoader(config=discourse_loader_config)
def test_discourse_loader_init_with_valid_config():
config = {"domain": "https://example.com/"}
loader = DiscourseLoader(config=config)
assert loader.domain == "https://example.com/"
def test_discourse_loader_init_with_missing_config():
with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
DiscourseLoader()
def test_discourse_loader_init_with_missing_domain():
config = {"another_key": "value"}
with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
DiscourseLoader(config=config)
def test_discourse_loader_check_query_with_valid_query(discourse_loader):
discourse_loader._check_query("sample query")
def test_discourse_loader_check_query_with_empty_query(discourse_loader):
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
discourse_loader._check_query("")
def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
discourse_loader._check_query(123)
def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def json(self):
return {"raw": "Sample post content"}
def raise_for_status(self):
pass
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
post_data = discourse_loader._load_post(123)
assert post_data["content"] == "Sample post content"
assert "meta_data" in post_data
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
def mock_get(*args, **kwargs):
class MockResponse:
def json(self):
return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
def raise_for_status(self):
pass
return MockResponse()
monkeypatch.setattr(requests, "get", mock_get)
def mock_load_post(*args, **kwargs):
return {
"content": "Sample post content",
"meta_data": {
"url": "https://example.com/posts/123.json",
"created_at": "2021-01-01",
"username": "test_user",
"topic_slug": "test_topic",
"score": 10,
},
}
monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
data = discourse_loader.load_data("sample query")
assert len(data["data"]) == 3
assert data["data"][0]["content"] == "Sample post content"
assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
assert data["data"][0]["meta_data"]["username"] == "test_user"
assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
assert data["data"][0]["meta_data"]["score"] == 10

View File

@@ -0,0 +1,130 @@
import hashlib
from unittest.mock import Mock, patch
import pytest
from requests import Response
from embedchain.loaders.docs_site_loader import DocsSiteLoader
@pytest.fixture
def mock_requests_get():
with patch("requests.get") as mock_get:
yield mock_get
@pytest.fixture
def docs_site_loader():
return DocsSiteLoader()
def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
</html>
"""
mock_requests_get.return_value = mock_response
docs_site_loader._get_child_links_recursive("https://example.com")
assert len(docs_site_loader.visited_links) == 2
assert "https://example.com/page1" in docs_site_loader.visited_links
assert "https://example.com/page2" in docs_site_loader.visited_links
def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
docs_site_loader._get_child_links_recursive("https://example.com")
assert len(docs_site_loader.visited_links) == 0
def test_get_all_urls(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
<a href="https://example.com/external">External</a>
</html>
"""
mock_requests_get.return_value = mock_response
all_urls = docs_site_loader._get_all_urls("https://example.com")
assert len(all_urls) == 3
assert "https://example.com/page1" in all_urls
assert "https://example.com/page2" in all_urls
assert "https://example.com/external" in all_urls
def test_load_data_from_url(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = """
<html>
<nav>
<h1>Navigation</h1>
</nav>
<article class="bd-article">
<p>Article Content</p>
</article>
</html>
""".encode()
mock_requests_get.return_value = mock_response
data = docs_site_loader._load_data_from_url("https://example.com/page1")
assert len(data) == 1
assert data[0]["content"] == "Article Content"
assert data[0]["meta_data"]["url"] == "https://example.com/page1"
def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
data = docs_site_loader._load_data_from_url("https://example.com/page1")
assert data == []
assert len(data) == 0
def test_load_data(mock_requests_get, docs_site_loader):
mock_response = Response()
mock_response.status_code = 200
mock_response._content = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
""".encode()
mock_requests_get.return_value = mock_response
url = "https://example.com"
data = docs_site_loader.load_data(url)
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
assert len(data["data"]) == 2
assert data["doc_id"] == expected_doc_id
def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Response()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
url = "https://example.com"
data = docs_site_loader.load_data(url)
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
assert len(data["data"]) == 0
assert data["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,218 @@
import pytest
import responses
from bs4 import BeautifulSoup
@pytest.mark.parametrize(
"ignored_tag",
[
"<nav>This is a navigation bar.</nav>",
"<aside>This is an aside.</aside>",
"<form>This is a form.</form>",
"<header>This is a header.</header>",
"<noscript>This is a noscript.</noscript>",
"<svg>This is an SVG.</svg>",
"<canvas>This is a canvas.</canvas>",
"<footer>This is a footer.</footer>",
"<script>This is a script.</script>",
"<style>This is a style.</style>",
],
ids=["nav", "aside", "form", "header", "noscript", "svg", "canvas", "footer", "script", "style"],
)
@pytest.mark.parametrize(
"selectee",
[
"""
<article class="bd-article">
<h2>Article Title</h2>
<p>Article content goes here.</p>
{ignored_tag}
</article>
""",
"""
<article role="main">
<h2>Main Article Title</h2>
<p>Main article content goes here.</p>
{ignored_tag}
</article>
""",
"""
<div class="md-content">
<h2>Markdown Content</h2>
<p>Markdown content goes here.</p>
{ignored_tag}
</div>
""",
"""
<div role="main">
<h2>Main Content</h2>
<p>Main content goes here.</p>
{ignored_tag}
</div>
""",
"""
<div class="container">
<h2>Container</h2>
<p>Container content goes here.</p>
{ignored_tag}
</div>
""",
"""
<div class="section">
<h2>Section</h2>
<p>Section content goes here.</p>
{ignored_tag}
</div>
""",
"""
<article>
<h2>Generic Article</h2>
<p>Generic article content goes here.</p>
{ignored_tag}
</article>
""",
"""
<main>
<h2>Main Content</h2>
<p>Main content goes here.</p>
{ignored_tag}
</main>
""",
],
ids=[
"article.bd-article",
'article[role="main"]',
"div.md-content",
'div[role="main"]',
"div.container",
"div.section",
"article",
"main",
],
)
def test_load_data_gets_by_selectors_and_ignored_tags(selectee, ignored_tag, loader, mocked_responses, mocker):
child_url = "https://docs.embedchain.ai/quickstart"
selectee = selectee.format(ignored_tag=ignored_tag)
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
{selectee}
</body>
</html>
"""
html_body = html_body.format(selectee=selectee)
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
url = "https://docs.embedchain.ai/"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/quickstart">Quickstart</a></li>
</body>
</html>
"""
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data(url)
selector_soup = BeautifulSoup(selectee, "html.parser")
expected_content = " ".join((selector_soup.select_one("h2").get_text(), selector_soup.select_one("p").get_text()))
assert result["doc_id"] == doc_id
assert result["data"] == [
{
"content": expected_content,
"meta_data": {"url": "https://docs.embedchain.ai/quickstart"},
}
]
def test_load_data_gets_child_links_recursively(loader, mocked_responses, mocker):
child_url = "https://docs.embedchain.ai/quickstart"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/">..</a></li>
<li><a href="/quickstart">.</a></li>
</body>
</html>
"""
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
child_url = "https://docs.embedchain.ai/introduction"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/">..</a></li>
<li><a href="/introduction">.</a></li>
</body>
</html>
"""
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
url = "https://docs.embedchain.ai/"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/quickstart">Quickstart</a></li>
<li><a href="/introduction">Introduction</a></li>
</body>
</html>
"""
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data(url)
assert result["doc_id"] == doc_id
expected_data = [
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/quickstart"}},
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/introduction"}},
]
assert all(item in expected_data for item in result["data"])
def test_load_data_fails_to_fetch_website(loader, mocked_responses, mocker):
child_url = "https://docs.embedchain.ai/introduction"
mocked_responses.get(child_url, status=404)
url = "https://docs.embedchain.ai/"
html_body = """
<!DOCTYPE html>
<html lang="en">
<body>
<li><a href="/introduction">Introduction</a></li>
</body>
</html>
"""
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data(url)
assert result["doc_id"] is doc_id
assert result["data"] == []
@pytest.fixture
def loader():
from embedchain.loaders.docs_site_loader import DocsSiteLoader
return DocsSiteLoader()
@pytest.fixture
def mocked_responses():
with responses.RequestsMock() as rsps:
yield rsps

View File

@@ -0,0 +1,39 @@
import hashlib
from unittest.mock import MagicMock, patch
import pytest
from embedchain.loaders.docx_file import DocxFileLoader
@pytest.fixture
def mock_docx2txt_loader():
with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader:
yield mock_loader
@pytest.fixture
def docx_file_loader():
return DocxFileLoader()
def test_load_data(mock_docx2txt_loader, docx_file_loader):
mock_url = "mock_docx_file.docx"
mock_loader = MagicMock()
mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})]
mock_docx2txt_loader.return_value = mock_loader
result = docx_file_loader.load_data(mock_url)
assert "doc_id" in result
assert "data" in result
expected_content = "Sample Docx Content"
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == "local"
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,85 @@
import os
from unittest.mock import MagicMock
import pytest
from dropbox.files import FileMetadata
from embedchain.loaders.dropbox import DropboxLoader
@pytest.fixture
def setup_dropbox_loader(mocker):
mock_dropbox = mocker.patch("dropbox.Dropbox")
mock_dbx = mocker.MagicMock()
mock_dropbox.return_value = mock_dbx
os.environ["DROPBOX_ACCESS_TOKEN"] = "test_token"
loader = DropboxLoader()
yield loader, mock_dbx
if "DROPBOX_ACCESS_TOKEN" in os.environ:
del os.environ["DROPBOX_ACCESS_TOKEN"]
def test_initialization(setup_dropbox_loader):
"""Test initialization of DropboxLoader."""
loader, _ = setup_dropbox_loader
assert loader is not None
def test_download_folder(setup_dropbox_loader, mocker):
"""Test downloading a folder."""
loader, mock_dbx = setup_dropbox_loader
mocker.patch("os.makedirs")
mocker.patch("os.path.join", return_value="mock/path")
mock_file_metadata = mocker.MagicMock(spec=FileMetadata)
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
entries = loader._download_folder("path/to/folder", "local_root")
assert entries is not None
def test_generate_dir_id_from_all_paths(setup_dropbox_loader, mocker):
"""Test directory ID generation."""
loader, mock_dbx = setup_dropbox_loader
mock_file_metadata = mocker.MagicMock(spec=FileMetadata, name="file.txt")
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
dir_id = loader._generate_dir_id_from_all_paths("path/to/folder")
assert dir_id is not None
assert len(dir_id) == 64
def test_clean_directory(setup_dropbox_loader, mocker):
"""Test cleaning up a directory."""
loader, _ = setup_dropbox_loader
mocker.patch("os.listdir", return_value=["file1", "file2"])
mocker.patch("os.remove")
mocker.patch("os.rmdir")
loader._clean_directory("path/to/folder")
def test_load_data(mocker, setup_dropbox_loader, tmp_path):
loader = setup_dropbox_loader[0]
mock_file_metadata = MagicMock(spec=FileMetadata, name="file.txt")
mocker.patch.object(loader.dbx, "files_list_folder", return_value=MagicMock(entries=[mock_file_metadata]))
mocker.patch.object(loader.dbx, "files_download_to_file")
# Mock DirectoryLoader
mock_data = {"data": "test_data"}
mocker.patch("embedchain.loaders.directory_loader.DirectoryLoader.load_data", return_value=mock_data)
test_dir = tmp_path / "dropbox_test"
test_dir.mkdir()
test_file = test_dir / "file.txt"
test_file.write_text("dummy content")
mocker.patch.object(loader, "_generate_dir_id_from_all_paths", return_value=str(test_dir))
result = loader.load_data("path/to/folder")
assert result == {"doc_id": mocker.ANY, "data": "test_data"}
loader.dbx.files_list_folder.assert_called_once_with("path/to/folder")

View File

@@ -0,0 +1,33 @@
import hashlib
from unittest.mock import patch
import pytest
from embedchain.loaders.excel_file import ExcelFileLoader
@pytest.fixture
def excel_file_loader():
return ExcelFileLoader()
def test_load_data(excel_file_loader):
mock_url = "mock_excel_file.xlsx"
expected_content = "Sample Excel Content"
# Mock the load_data method of the excel_file_loader instance
with patch.object(
excel_file_loader,
"load_data",
return_value={
"doc_id": hashlib.sha256((expected_content + mock_url).encode()).hexdigest(),
"data": [{"content": expected_content, "meta_data": {"url": mock_url}}],
},
):
result = excel_file_loader.load_data(mock_url)
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == mock_url
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,33 @@
import pytest
from embedchain.loaders.github import GithubLoader
@pytest.fixture
def mock_github_loader_config():
return {
"token": "your_mock_token",
}
@pytest.fixture
def mock_github_loader(mocker, mock_github_loader_config):
mock_github = mocker.patch("github.Github")
_ = mock_github.return_value
return GithubLoader(config=mock_github_loader_config)
def test_github_loader_init(mocker, mock_github_loader_config):
mock_github = mocker.patch("github.Github")
GithubLoader(config=mock_github_loader_config)
mock_github.assert_called_once_with("your_mock_token")
def test_github_loader_init_empty_config(mocker):
with pytest.raises(ValueError, match="requires a personal access token"):
GithubLoader()
def test_github_loader_init_missing_token():
with pytest.raises(ValueError, match="requires a personal access token"):
GithubLoader(config={})

View File

@@ -0,0 +1,43 @@
import pytest
from embedchain.loaders.gmail import GmailLoader
@pytest.fixture
def mock_beautifulsoup(mocker):
return mocker.patch("embedchain.loaders.gmail.BeautifulSoup", return_value=mocker.MagicMock())
@pytest.fixture
def gmail_loader(mock_beautifulsoup):
return GmailLoader()
def test_load_data_file_not_found(gmail_loader, mocker):
with pytest.raises(FileNotFoundError):
with mocker.patch("os.path.isfile", return_value=False):
gmail_loader.load_data("your_query")
@pytest.mark.skip(reason="TODO: Fix this test. Failing due to some googleapiclient import issue.")
def test_load_data(gmail_loader, mocker):
mock_gmail_reader_instance = mocker.MagicMock()
text = "your_test_email_text"
metadata = {
"id": "your_test_id",
"snippet": "your_test_snippet",
}
mock_gmail_reader_instance.load_data.return_value = [
{
"text": text,
"extra_info": metadata,
}
]
with mocker.patch("os.path.isfile", return_value=True):
response_data = gmail_loader.load_data("your_query")
assert "doc_id" in response_data
assert "data" in response_data
assert isinstance(response_data["doc_id"], str)
assert isinstance(response_data["data"], list)

View File

@@ -0,0 +1,37 @@
import pytest
from embedchain.loaders.google_drive import GoogleDriveLoader
@pytest.fixture
def google_drive_folder_loader():
return GoogleDriveLoader()
def test_load_data_invalid_drive_url(google_drive_folder_loader):
mock_invalid_drive_url = "https://example.com"
with pytest.raises(
ValueError,
match="The url provided https://example.com does not match a google drive folder url. Example "
"drive url: https://drive.google.com/drive/u/0/folders/xxxx",
):
google_drive_folder_loader.load_data(mock_invalid_drive_url)
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
def test_load_data_incorrect_drive_url(google_drive_folder_loader):
mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
with pytest.raises(
FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
):
google_drive_folder_loader.load_data(mock_invalid_drive_url)
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
def test_load_data(google_drive_folder_loader):
mock_valid_url = "YOUR_VALID_URL"
result = google_drive_folder_loader.load_data(mock_valid_url)
assert "doc_id" in result
assert "data" in result
assert "content" in result["data"][0]
assert "meta_data" in result["data"][0]

View File

@@ -0,0 +1,131 @@
import hashlib
import pytest
from embedchain.loaders.json import JSONLoader
def test_load_data(mocker):
content = "temp.json"
mock_document = {
"doc_id": hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest(),
"data": [
{"content": "content1", "meta_data": {"url": content}},
{"content": "content2", "meta_data": {"url": content}},
],
}
mocker.patch("embedchain.loaders.json.JSONLoader.load_data", return_value=mock_document)
json_loader = JSONLoader()
result = json_loader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content}},
{"content": "content2", "meta_data": {"url": content}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
def test_load_data_url(mocker):
content = "https://example.com/posts.json"
mocker.patch("os.path.isfile", return_value=False)
mocker.patch(
"embedchain.loaders.json.JSONReader.load_data",
return_value=[
{
"text": "content1",
},
{
"text": "content2",
},
],
)
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"document1": "content1", "document2": "content2"}
mocker.patch("requests.get", return_value=mock_response)
result = JSONLoader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content}},
{"content": "content2", "meta_data": {"url": content}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
def test_load_data_invalid_string_content(mocker):
mocker.patch("os.path.isfile", return_value=False)
mocker.patch("requests.get")
content = "123: 345}"
with pytest.raises(ValueError, match="Invalid content to load json data from"):
JSONLoader.load_data(content)
def test_load_data_invalid_url(mocker):
mocker.patch("os.path.isfile", return_value=False)
mock_response = mocker.Mock()
mock_response.status_code = 404
mocker.patch("requests.get", return_value=mock_response)
content = "http://invalid-url.com/"
with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
JSONLoader.load_data(content)
def test_load_data_from_json_string(mocker):
content = '{"foo": "bar"}'
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
mocker.patch("os.path.isfile", return_value=False)
mocker.patch(
"embedchain.loaders.json.JSONReader.load_data",
return_value=[
{
"text": "content1",
},
{
"text": "content2",
},
],
)
result = JSONLoader.load_data(content)
assert "doc_id" in result
assert "data" in result
expected_data = [
{"content": "content1", "meta_data": {"url": content_url_str}},
{"content": "content2", "meta_data": {"url": content_url_str}},
]
assert result["data"] == expected_data
expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,32 @@
import hashlib
import pytest
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@pytest.fixture
def qna_pair_loader():
return LocalQnaPairLoader()
def test_load_data(qna_pair_loader):
question = "What is the capital of France?"
answer = "The capital of France is Paris."
content = (question, answer)
result = qna_pair_loader.load_data(content)
assert "doc_id" in result
assert "data" in result
url = "local"
expected_content = f"Q: {question}\nA: {answer}"
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == url
assert result["data"][0]["meta_data"]["question"] == question
expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,27 @@
import hashlib
import pytest
from embedchain.loaders.local_text import LocalTextLoader
@pytest.fixture
def text_loader():
return LocalTextLoader()
def test_load_data(text_loader):
mock_content = "This is a sample text content."
result = text_loader.load_data(mock_content)
assert "doc_id" in result
assert "data" in result
url = "local"
assert result["data"][0]["content"] == mock_content
assert result["data"][0]["meta_data"]["url"] == url
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,30 @@
import hashlib
from unittest.mock import mock_open, patch
import pytest
from embedchain.loaders.mdx import MdxLoader
@pytest.fixture
def mdx_loader():
return MdxLoader()
def test_load_data(mdx_loader):
mock_content = "Sample MDX Content"
# Mock open function to simulate file reading
with patch("builtins.open", mock_open(read_data=mock_content)):
url = "mock_file.mdx"
result = mdx_loader.load_data(url)
assert "doc_id" in result
assert "data" in result
assert result["data"][0]["content"] == mock_content
assert result["data"][0]["meta_data"]["url"] == url
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,77 @@
import hashlib
from unittest.mock import MagicMock
import pytest
from embedchain.loaders.mysql import MySQLLoader
@pytest.fixture
def mysql_loader(mocker):
with mocker.patch("mysql.connector.connection.MySQLConnection"):
config = {
"host": "localhost",
"port": "3306",
"user": "your_username",
"password": "your_password",
"database": "your_database",
}
loader = MySQLLoader(config=config)
yield loader
def test_mysql_loader_initialization(mysql_loader):
assert mysql_loader.config is not None
assert mysql_loader.connection is not None
assert mysql_loader.cursor is not None
def test_mysql_loader_invalid_config():
with pytest.raises(ValueError, match="Invalid sql config: None"):
MySQLLoader(config=None)
def test_mysql_loader_setup_loader_successful(mysql_loader):
assert mysql_loader.connection is not None
assert mysql_loader.cursor is not None
def test_mysql_loader_setup_loader_connection_error(mysql_loader, mocker):
mocker.patch("mysql.connector.connection.MySQLConnection", side_effect=IOError("Mocked connection error"))
with pytest.raises(ValueError, match="Unable to connect with the given config:"):
mysql_loader._setup_loader(config={})
def test_mysql_loader_check_query_successful(mysql_loader):
query = "SELECT * FROM table"
mysql_loader._check_query(query=query)
def test_mysql_loader_check_query_invalid(mysql_loader):
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
mysql_loader._check_query(query=123)
def test_mysql_loader_load_data_successful(mysql_loader, mocker):
mock_cursor = MagicMock()
mocker.patch.object(mysql_loader, "cursor", mock_cursor)
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
query = "SELECT * FROM table"
result = mysql_loader.load_data(query)
assert "doc_id" in result
assert "data" in result
assert len(result["data"]) == 2
assert result["data"][0]["meta_data"]["url"] == query
assert result["data"][1]["meta_data"]["url"] == query
doc_id = hashlib.sha256((query + ", ".join([d["content"] for d in result["data"]])).encode()).hexdigest()
assert result["doc_id"] == doc_id
assert mock_cursor.execute.called_with(query)
def test_mysql_loader_load_data_invalid_query(mysql_loader):
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
mysql_loader.load_data(query=123)

View File

@@ -0,0 +1,36 @@
import hashlib
import os
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.notion import NotionLoader
@pytest.fixture
def notion_loader():
with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}):
yield NotionLoader()
def test_load_data(notion_loader):
source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef"
mock_text = "This is a test page."
expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest()
expected_data = [
{
"content": mock_text,
"meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"}, # formatted_id
}
]
mock_page = Mock()
mock_page.text = mock_text
mock_documents = [mock_page]
with patch("embedchain.loaders.notion.NotionPageLoader") as mock_reader:
mock_reader.return_value.load_data.return_value = mock_documents
result = notion_loader.load_data(source)
assert result["doc_id"] == expected_doc_id
assert result["data"] == expected_data

View File

@@ -0,0 +1,26 @@
import pytest
from embedchain.loaders.openapi import OpenAPILoader
@pytest.fixture
def openapi_loader():
return OpenAPILoader()
def test_load_data(openapi_loader, mocker):
mocker.patch("builtins.open", mocker.mock_open(read_data="key1: value1\nkey2: value2"))
mocker.patch("hashlib.sha256", return_value=mocker.Mock(hexdigest=lambda: "mock_hash"))
file_path = "configs/openai_openapi.yaml"
result = openapi_loader.load_data(file_path)
expected_doc_id = "mock_hash"
expected_data = [
{"content": "key1: value1", "meta_data": {"url": file_path, "row": 1}},
{"content": "key2: value2", "meta_data": {"url": file_path, "row": 2}},
]
assert result["doc_id"] == expected_doc_id
assert result["data"] == expected_data

View File

@@ -0,0 +1,36 @@
import pytest
from langchain.schema import Document
def test_load_data(loader, mocker):
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
mocked_pypdfloader.return_value.load_and_split.return_value = [
Document(page_content="Page 0 Content", metadata={"source": "example.pdf", "page": 0}),
Document(page_content="Page 1 Content", metadata={"source": "example.pdf", "page": 1}),
]
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
doc_id = "mocked_hash"
mock_sha256.return_value.hexdigest.return_value = doc_id
result = loader.load_data("dummy_url")
assert result["doc_id"] is doc_id
assert result["data"] == [
{"content": "Page 0 Content", "meta_data": {"source": "example.pdf", "page": 0, "url": "dummy_url"}},
{"content": "Page 1 Content", "meta_data": {"source": "example.pdf", "page": 1, "url": "dummy_url"}},
]
def test_load_data_fails_to_find_data(loader, mocker):
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
mocked_pypdfloader.return_value.load_and_split.return_value = []
with pytest.raises(ValueError):
loader.load_data("dummy_url")
@pytest.fixture
def loader():
from embedchain.loaders.pdf_file import PdfFileLoader
return PdfFileLoader()

View File

@@ -0,0 +1,60 @@
from unittest.mock import MagicMock
import psycopg
import pytest
from embedchain.loaders.postgres import PostgresLoader
@pytest.fixture
def postgres_loader(mocker):
with mocker.patch.object(psycopg, "connect"):
config = {"url": "postgres://user:password@localhost:5432/database"}
loader = PostgresLoader(config=config)
yield loader
def test_postgres_loader_initialization(postgres_loader):
assert postgres_loader.connection is not None
assert postgres_loader.cursor is not None
def test_postgres_loader_invalid_config():
with pytest.raises(ValueError, match="Must provide the valid config. Received: None"):
PostgresLoader(config=None)
def test_load_data(postgres_loader, monkeypatch):
mock_cursor = MagicMock()
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
query = "SELECT * FROM table"
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
result = postgres_loader.load_data(query)
assert "doc_id" in result
assert "data" in result
assert len(result["data"]) == 2
assert result["data"][0]["meta_data"]["url"] == query
assert result["data"][1]["meta_data"]["url"] == query
assert mock_cursor.execute.called_with(query)
def test_load_data_exception(postgres_loader, monkeypatch):
mock_cursor = MagicMock()
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
_ = "SELECT * FROM table"
mock_cursor.execute.side_effect = Exception("Mocked exception")
with pytest.raises(
ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception"
):
postgres_loader.load_data("SELECT * FROM table")
def test_close_connection(postgres_loader):
postgres_loader.close_connection()
assert postgres_loader.cursor is None
assert postgres_loader.connection is None

View File

@@ -0,0 +1,47 @@
import pytest
from embedchain.loaders.slack import SlackLoader
@pytest.fixture
def slack_loader(mocker, monkeypatch):
# Mocking necessary dependencies
mocker.patch("slack_sdk.WebClient")
mocker.patch("ssl.create_default_context")
mocker.patch("certifi.where")
monkeypatch.setenv("SLACK_USER_TOKEN", "slack_user_token")
return SlackLoader()
def test_slack_loader_initialization(slack_loader):
assert slack_loader.client is not None
assert slack_loader.config == {"base_url": "https://www.slack.com/api/"}
def test_slack_loader_setup_loader(slack_loader):
slack_loader._setup_loader({"base_url": "https://custom.slack.api/"})
assert slack_loader.client is not None
def test_slack_loader_check_query(slack_loader):
valid_json_query = "test_query"
invalid_query = 123
slack_loader._check_query(valid_json_query)
with pytest.raises(ValueError):
slack_loader._check_query(invalid_query)
def test_slack_loader_load_data(slack_loader, mocker):
valid_json_query = "in:random"
mocker.patch.object(slack_loader.client, "search_messages", return_value={"messages": {}})
result = slack_loader.load_data(valid_json_query)
assert "doc_id" in result
assert "data" in result

View File

@@ -0,0 +1,117 @@
import hashlib
from unittest.mock import Mock, patch
import pytest
from embedchain.loaders.web_page import WebPageLoader
@pytest.fixture
def web_page_loader():
return WebPageLoader()
def test_load_data(web_page_loader):
page_url = "https://example.com/page"
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = """
<html>
<head>
<title>Test Page</title>
</head>
<body>
<div id="content">
<p>This is some test content.</p>
</div>
</body>
</html>
"""
with patch("embedchain.loaders.web_page.WebPageLoader._session.get", return_value=mock_response):
result = web_page_loader.load_data(page_url)
content = web_page_loader._get_clean_content(mock_response.content, page_url)
expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
expected_data = [
{
"content": content,
"meta_data": {
"url": page_url,
},
}
]
assert result["data"] == expected_data
def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
mock_html = """
<html>
<head>
<title>Sample HTML</title>
<style>
/* Stylesheet to be excluded */
.elementor-location-header {
background-color: #f0f0f0;
}
</style>
</head>
<body>
<header id="header">Header Content</header>
<nav class="nav">Nav Content</nav>
<aside>Aside Content</aside>
<form>Form Content</form>
<main>Main Content</main>
<footer class="footer">Footer Content</footer>
<script>Some Script</script>
<noscript>NoScript Content</noscript>
<svg>SVG Content</svg>
<canvas>Canvas Content</canvas>
<div id="sidebar">Sidebar Content</div>
<div id="main-navigation">Main Navigation Content</div>
<div id="menu-main-menu">Menu Main Menu Content</div>
<div class="header-sidebar-wrapper">Header Sidebar Wrapper Content</div>
<div class="blog-sidebar-wrapper">Blog Sidebar Wrapper Content</div>
<div class="related-posts">Related Posts Content</div>
</body>
</html>
"""
tags_to_exclude = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
classes_to_exclude = [
"elementor-location-header",
"navbar-header",
"nav",
"header-sidebar-wrapper",
"blog-sidebar-wrapper",
"related-posts",
]
content = web_page_loader._get_clean_content(mock_html, "https://example.com/page")
for tag in tags_to_exclude:
assert tag not in content
for id in ids_to_exclude:
assert id not in content
for class_name in classes_to_exclude:
assert class_name not in content
assert len(content) > 0

View File

@@ -0,0 +1,62 @@
import tempfile
import pytest
from embedchain.loaders.xml import XmlLoader
# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
<factbook>
<country>
<name>United States</name>
<capital>Washington, DC</capital>
<leader>Joe Biden</leader>
<sport>Baseball</sport>
</country>
<country>
<name>Canada</name>
<capital>Ottawa</capital>
<leader>Justin Trudeau</leader>
<sport>Hockey</sport>
</country>
<country>
<name>France</name>
<capital>Paris</capital>
<leader>Emmanuel Macron</leader>
<sport>Soccer</sport>
</country>
<country>
<name>Trinidad &amp; Tobado</name>
<capital>Port of Spain</capital>
<leader>Keith Rowley</leader>
<sport>Track &amp; Field</sport>
</country>
</factbook>"""
@pytest.mark.parametrize("xml", [SAMPLE_XML])
def test_load_data(xml: str):
"""
Test XML loader
Tests that XML file is loaded, metadata is correct and content is correct
"""
# Creating temporary XML file
with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
tmpfile.write(xml)
tmpfile.seek(0)
filename = tmpfile.name
# Loading CSV using XmlLoader
loader = XmlLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(data) == 1
assert "United States Washington, DC Joe Biden" in data[0]["content"]
assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
assert "France Paris Emmanuel Macron" in data[0]["content"]
assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
assert data[0]["meta_data"]["url"] == filename

View File

@@ -0,0 +1,53 @@
import hashlib
from unittest.mock import MagicMock, Mock, patch
import pytest
from embedchain.loaders.youtube_video import YoutubeVideoLoader
@pytest.fixture
def youtube_video_loader():
return YoutubeVideoLoader()
def test_load_data(youtube_video_loader):
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
mock_loader = Mock()
mock_page_content = "This is a YouTube video content."
mock_loader.load.return_value = [
MagicMock(
page_content=mock_page_content,
metadata={"url": video_url, "title": "Test Video"},
)
]
mock_transcript = [{"text": "sample text", "start": 0.0, "duration": 5.0}]
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader), patch(
"embedchain.loaders.youtube_video.YouTubeTranscriptApi.get_transcript", return_value=mock_transcript
):
result = youtube_video_loader.load_data(video_url)
expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
expected_data = [
{
"content": "This is a YouTube video content.",
"meta_data": {"url": video_url, "title": "Test Video", "transcript": "Unavailable"},
}
]
assert result["data"] == expected_data
def test_load_data_with_empty_doc(youtube_video_loader):
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
mock_loader = Mock()
mock_loader.load.return_value = []
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
with pytest.raises(ValueError):
youtube_video_loader.load_data(video_url)