Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
100
embedchain/tests/loaders/test_audio.py
Normal file
100
embedchain/tests/loaders/test_audio.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
|
||||
from deepgram import PrerecordedOptions
|
||||
|
||||
from embedchain.loaders.audio import AudioLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_audio_loader(mocker):
|
||||
mock_dropbox = mocker.patch("deepgram.DeepgramClient")
|
||||
mock_dbx = mocker.MagicMock()
|
||||
mock_dropbox.return_value = mock_dbx
|
||||
|
||||
os.environ["DEEPGRAM_API_KEY"] = "test_key"
|
||||
loader = AudioLoader()
|
||||
loader.client = mock_dbx
|
||||
|
||||
yield loader, mock_dbx
|
||||
|
||||
if "DEEPGRAM_API_KEY" in os.environ:
|
||||
del os.environ["DEEPGRAM_API_KEY"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||
) # as `match` statement was introduced in python 3.10
|
||||
def test_initialization(setup_audio_loader):
|
||||
"""Test initialization of AudioLoader."""
|
||||
loader, _ = setup_audio_loader
|
||||
assert loader is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||
) # as `match` statement was introduced in python 3.10
|
||||
def test_load_data_from_url(setup_audio_loader):
|
||||
loader, mock_dbx = setup_audio_loader
|
||||
url = "https://example.com/audio.mp3"
|
||||
expected_content = "This is a test audio transcript."
|
||||
|
||||
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
|
||||
|
||||
result = loader.load_data(url)
|
||||
|
||||
doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
|
||||
expected_result = {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": expected_content,
|
||||
"meta_data": {"url": url},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
|
||||
{"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||
) # as `match` statement was introduced in python 3.10
|
||||
def test_load_data_from_file(setup_audio_loader):
|
||||
loader, mock_dbx = setup_audio_loader
|
||||
file_path = "local_audio.mp3"
|
||||
expected_content = "This is a test audio transcript."
|
||||
|
||||
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
|
||||
|
||||
# Mock the file reading functionality
|
||||
with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
|
||||
result = loader.load_data(file_path)
|
||||
|
||||
doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
|
||||
expected_result = {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": expected_content,
|
||||
"meta_data": {"url": file_path},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
|
||||
{"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
|
||||
)
|
||||
113
embedchain/tests/loaders/test_csv.py
Normal file
113
embedchain/tests/loaders/test_csv.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import csv
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.csv import CsvLoader
|
||||
|
||||
|
||||
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
|
||||
def test_load_data(delimiter):
|
||||
"""
|
||||
Test csv loader
|
||||
|
||||
Tests that file is loaded, metadata is correct and content is correct
|
||||
"""
|
||||
# Creating temporary CSV file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
|
||||
writer = csv.writer(tmpfile, delimiter=delimiter)
|
||||
writer.writerow(["Name", "Age", "Occupation"])
|
||||
writer.writerow(["Alice", "28", "Engineer"])
|
||||
writer.writerow(["Bob", "35", "Doctor"])
|
||||
writer.writerow(["Charlie", "22", "Student"])
|
||||
|
||||
tmpfile.seek(0)
|
||||
filename = tmpfile.name
|
||||
|
||||
# Loading CSV using CsvLoader
|
||||
loader = CsvLoader()
|
||||
result = loader.load_data(filename)
|
||||
data = result["data"]
|
||||
|
||||
# Assertions
|
||||
assert len(data) == 3
|
||||
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||
assert data[0]["meta_data"]["url"] == filename
|
||||
assert data[0]["meta_data"]["row"] == 1
|
||||
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||
assert data[1]["meta_data"]["url"] == filename
|
||||
assert data[1]["meta_data"]["row"] == 2
|
||||
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||
assert data[2]["meta_data"]["url"] == filename
|
||||
assert data[2]["meta_data"]["row"] == 3
|
||||
|
||||
# Cleaning up the temporary file
|
||||
os.unlink(filename)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
|
||||
def test_load_data_with_file_uri(delimiter):
|
||||
"""
|
||||
Test csv loader with file URI
|
||||
|
||||
Tests that file is loaded, metadata is correct and content is correct
|
||||
"""
|
||||
# Creating temporary CSV file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
|
||||
writer = csv.writer(tmpfile, delimiter=delimiter)
|
||||
writer.writerow(["Name", "Age", "Occupation"])
|
||||
writer.writerow(["Alice", "28", "Engineer"])
|
||||
writer.writerow(["Bob", "35", "Doctor"])
|
||||
writer.writerow(["Charlie", "22", "Student"])
|
||||
|
||||
tmpfile.seek(0)
|
||||
filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI
|
||||
|
||||
# Loading CSV using CsvLoader
|
||||
loader = CsvLoader()
|
||||
result = loader.load_data(filename)
|
||||
data = result["data"]
|
||||
|
||||
# Assertions
|
||||
assert len(data) == 3
|
||||
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||
assert data[0]["meta_data"]["url"] == filename
|
||||
assert data[0]["meta_data"]["row"] == 1
|
||||
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||
assert data[1]["meta_data"]["url"] == filename
|
||||
assert data[1]["meta_data"]["row"] == 2
|
||||
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||
assert data[2]["meta_data"]["url"] == filename
|
||||
assert data[2]["meta_data"]["row"] == 3
|
||||
|
||||
# Cleaning up the temporary file
|
||||
os.unlink(tmpfile.name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"])
|
||||
def test_get_file_content(content):
|
||||
with pytest.raises(ValueError):
|
||||
loader = CsvLoader()
|
||||
loader._get_file_content(content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"])
|
||||
def test_get_file_content_http(content):
|
||||
"""
|
||||
Test _get_file_content method of CsvLoader for http and https URLs
|
||||
"""
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
loader = CsvLoader()
|
||||
file_content = loader._get_file_content(content)
|
||||
|
||||
mock_get.assert_called_once_with(content)
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
assert file_content.read() == mock_response.text
|
||||
104
embedchain/tests/loaders/test_discourse.py
Normal file
104
embedchain/tests/loaders/test_discourse.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.discourse import DiscourseLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discourse_loader_config():
|
||||
return {
|
||||
"domain": "https://example.com/",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discourse_loader(discourse_loader_config):
|
||||
return DiscourseLoader(config=discourse_loader_config)
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_valid_config():
|
||||
config = {"domain": "https://example.com/"}
|
||||
loader = DiscourseLoader(config=config)
|
||||
assert loader.domain == "https://example.com/"
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_missing_config():
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
|
||||
DiscourseLoader()
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_missing_domain():
|
||||
config = {"another_key": "value"}
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
|
||||
DiscourseLoader(config=config)
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_valid_query(discourse_loader):
|
||||
discourse_loader._check_query("sample query")
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_empty_query(discourse_loader):
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
|
||||
discourse_loader._check_query("")
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
|
||||
discourse_loader._check_query(123)
|
||||
|
||||
|
||||
def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def json(self):
|
||||
return {"raw": "Sample post content"}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
post_data = discourse_loader._load_post(123)
|
||||
|
||||
assert post_data["content"] == "Sample post content"
|
||||
assert "meta_data" in post_data
|
||||
|
||||
|
||||
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def json(self):
|
||||
return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
def mock_load_post(*args, **kwargs):
|
||||
return {
|
||||
"content": "Sample post content",
|
||||
"meta_data": {
|
||||
"url": "https://example.com/posts/123.json",
|
||||
"created_at": "2021-01-01",
|
||||
"username": "test_user",
|
||||
"topic_slug": "test_topic",
|
||||
"score": 10,
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
|
||||
|
||||
data = discourse_loader.load_data("sample query")
|
||||
|
||||
assert len(data["data"]) == 3
|
||||
assert data["data"][0]["content"] == "Sample post content"
|
||||
assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
|
||||
assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
|
||||
assert data["data"][0]["meta_data"]["username"] == "test_user"
|
||||
assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
|
||||
assert data["data"][0]["meta_data"]["score"] == 10
|
||||
130
embedchain/tests/loaders/test_docs_site.py
Normal file
130
embedchain/tests/loaders/test_docs_site.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import hashlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests_get():
|
||||
with patch("requests.get") as mock_get:
|
||||
yield mock_get
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docs_site_loader():
|
||||
return DocsSiteLoader()
|
||||
|
||||
|
||||
def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = """
|
||||
<html>
|
||||
<a href="/page1">Page 1</a>
|
||||
<a href="/page2">Page 2</a>
|
||||
</html>
|
||||
"""
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
docs_site_loader._get_child_links_recursive("https://example.com")
|
||||
|
||||
assert len(docs_site_loader.visited_links) == 2
|
||||
assert "https://example.com/page1" in docs_site_loader.visited_links
|
||||
assert "https://example.com/page2" in docs_site_loader.visited_links
|
||||
|
||||
|
||||
def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
docs_site_loader._get_child_links_recursive("https://example.com")
|
||||
|
||||
assert len(docs_site_loader.visited_links) == 0
|
||||
|
||||
|
||||
def test_get_all_urls(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = """
|
||||
<html>
|
||||
<a href="/page1">Page 1</a>
|
||||
<a href="/page2">Page 2</a>
|
||||
<a href="https://example.com/external">External</a>
|
||||
</html>
|
||||
"""
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
all_urls = docs_site_loader._get_all_urls("https://example.com")
|
||||
|
||||
assert len(all_urls) == 3
|
||||
assert "https://example.com/page1" in all_urls
|
||||
assert "https://example.com/page2" in all_urls
|
||||
assert "https://example.com/external" in all_urls
|
||||
|
||||
|
||||
def test_load_data_from_url(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = """
|
||||
<html>
|
||||
<nav>
|
||||
<h1>Navigation</h1>
|
||||
</nav>
|
||||
<article class="bd-article">
|
||||
<p>Article Content</p>
|
||||
</article>
|
||||
</html>
|
||||
""".encode()
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
data = docs_site_loader._load_data_from_url("https://example.com/page1")
|
||||
|
||||
assert len(data) == 1
|
||||
assert data[0]["content"] == "Article Content"
|
||||
assert data[0]["meta_data"]["url"] == "https://example.com/page1"
|
||||
|
||||
|
||||
def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
data = docs_site_loader._load_data_from_url("https://example.com/page1")
|
||||
|
||||
assert data == []
|
||||
assert len(data) == 0
|
||||
|
||||
|
||||
def test_load_data(mock_requests_get, docs_site_loader):
|
||||
mock_response = Response()
|
||||
mock_response.status_code = 200
|
||||
mock_response._content = """
|
||||
<html>
|
||||
<a href="/page1">Page 1</a>
|
||||
<a href="/page2">Page 2</a>
|
||||
""".encode()
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
url = "https://example.com"
|
||||
data = docs_site_loader.load_data(url)
|
||||
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
|
||||
|
||||
assert len(data["data"]) == 2
|
||||
assert data["doc_id"] == expected_doc_id
|
||||
|
||||
|
||||
def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
|
||||
mock_response = Response()
|
||||
mock_response.status_code = 404
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
url = "https://example.com"
|
||||
data = docs_site_loader.load_data(url)
|
||||
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
|
||||
|
||||
assert len(data["data"]) == 0
|
||||
assert data["doc_id"] == expected_doc_id
|
||||
218
embedchain/tests/loaders/test_docs_site_loader.py
Normal file
218
embedchain/tests/loaders/test_docs_site_loader.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import pytest
|
||||
import responses
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ignored_tag",
|
||||
[
|
||||
"<nav>This is a navigation bar.</nav>",
|
||||
"<aside>This is an aside.</aside>",
|
||||
"<form>This is a form.</form>",
|
||||
"<header>This is a header.</header>",
|
||||
"<noscript>This is a noscript.</noscript>",
|
||||
"<svg>This is an SVG.</svg>",
|
||||
"<canvas>This is a canvas.</canvas>",
|
||||
"<footer>This is a footer.</footer>",
|
||||
"<script>This is a script.</script>",
|
||||
"<style>This is a style.</style>",
|
||||
],
|
||||
ids=["nav", "aside", "form", "header", "noscript", "svg", "canvas", "footer", "script", "style"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"selectee",
|
||||
[
|
||||
"""
|
||||
<article class="bd-article">
|
||||
<h2>Article Title</h2>
|
||||
<p>Article content goes here.</p>
|
||||
{ignored_tag}
|
||||
</article>
|
||||
""",
|
||||
"""
|
||||
<article role="main">
|
||||
<h2>Main Article Title</h2>
|
||||
<p>Main article content goes here.</p>
|
||||
{ignored_tag}
|
||||
</article>
|
||||
""",
|
||||
"""
|
||||
<div class="md-content">
|
||||
<h2>Markdown Content</h2>
|
||||
<p>Markdown content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<div role="main">
|
||||
<h2>Main Content</h2>
|
||||
<p>Main content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<div class="container">
|
||||
<h2>Container</h2>
|
||||
<p>Container content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<div class="section">
|
||||
<h2>Section</h2>
|
||||
<p>Section content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<article>
|
||||
<h2>Generic Article</h2>
|
||||
<p>Generic article content goes here.</p>
|
||||
{ignored_tag}
|
||||
</article>
|
||||
""",
|
||||
"""
|
||||
<main>
|
||||
<h2>Main Content</h2>
|
||||
<p>Main content goes here.</p>
|
||||
{ignored_tag}
|
||||
</main>
|
||||
""",
|
||||
],
|
||||
ids=[
|
||||
"article.bd-article",
|
||||
'article[role="main"]',
|
||||
"div.md-content",
|
||||
'div[role="main"]',
|
||||
"div.container",
|
||||
"div.section",
|
||||
"article",
|
||||
"main",
|
||||
],
|
||||
)
|
||||
def test_load_data_gets_by_selectors_and_ignored_tags(selectee, ignored_tag, loader, mocked_responses, mocker):
|
||||
child_url = "https://docs.embedchain.ai/quickstart"
|
||||
selectee = selectee.format(ignored_tag=ignored_tag)
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
{selectee}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
html_body = html_body.format(selectee=selectee)
|
||||
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
url = "https://docs.embedchain.ai/"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/quickstart">Quickstart</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data(url)
|
||||
selector_soup = BeautifulSoup(selectee, "html.parser")
|
||||
expected_content = " ".join((selector_soup.select_one("h2").get_text(), selector_soup.select_one("p").get_text()))
|
||||
assert result["doc_id"] == doc_id
|
||||
assert result["data"] == [
|
||||
{
|
||||
"content": expected_content,
|
||||
"meta_data": {"url": "https://docs.embedchain.ai/quickstart"},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_load_data_gets_child_links_recursively(loader, mocked_responses, mocker):
|
||||
child_url = "https://docs.embedchain.ai/quickstart"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/">..</a></li>
|
||||
<li><a href="/quickstart">.</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
child_url = "https://docs.embedchain.ai/introduction"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/">..</a></li>
|
||||
<li><a href="/introduction">.</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
url = "https://docs.embedchain.ai/"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/quickstart">Quickstart</a></li>
|
||||
<li><a href="/introduction">Introduction</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data(url)
|
||||
assert result["doc_id"] == doc_id
|
||||
expected_data = [
|
||||
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/quickstart"}},
|
||||
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/introduction"}},
|
||||
]
|
||||
assert all(item in expected_data for item in result["data"])
|
||||
|
||||
|
||||
def test_load_data_fails_to_fetch_website(loader, mocked_responses, mocker):
|
||||
child_url = "https://docs.embedchain.ai/introduction"
|
||||
mocked_responses.get(child_url, status=404)
|
||||
|
||||
url = "https://docs.embedchain.ai/"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/introduction">Introduction</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data(url)
|
||||
assert result["doc_id"] is doc_id
|
||||
assert result["data"] == []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
|
||||
return DocsSiteLoader()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_responses():
|
||||
with responses.RequestsMock() as rsps:
|
||||
yield rsps
|
||||
39
embedchain/tests/loaders/test_docx_file.py
Normal file
39
embedchain/tests/loaders/test_docx_file.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.docx_file import DocxFileLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docx2txt_loader():
|
||||
with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader:
|
||||
yield mock_loader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docx_file_loader():
|
||||
return DocxFileLoader()
|
||||
|
||||
|
||||
def test_load_data(mock_docx2txt_loader, docx_file_loader):
|
||||
mock_url = "mock_docx_file.docx"
|
||||
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})]
|
||||
|
||||
mock_docx2txt_loader.return_value = mock_loader
|
||||
|
||||
result = docx_file_loader.load_data(mock_url)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_content = "Sample Docx Content"
|
||||
assert result["data"][0]["content"] == expected_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == "local"
|
||||
|
||||
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
85
embedchain/tests/loaders/test_dropbox.py
Normal file
85
embedchain/tests/loaders/test_dropbox.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dropbox.files import FileMetadata
|
||||
|
||||
from embedchain.loaders.dropbox import DropboxLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_dropbox_loader(mocker):
|
||||
mock_dropbox = mocker.patch("dropbox.Dropbox")
|
||||
mock_dbx = mocker.MagicMock()
|
||||
mock_dropbox.return_value = mock_dbx
|
||||
|
||||
os.environ["DROPBOX_ACCESS_TOKEN"] = "test_token"
|
||||
loader = DropboxLoader()
|
||||
|
||||
yield loader, mock_dbx
|
||||
|
||||
if "DROPBOX_ACCESS_TOKEN" in os.environ:
|
||||
del os.environ["DROPBOX_ACCESS_TOKEN"]
|
||||
|
||||
|
||||
def test_initialization(setup_dropbox_loader):
|
||||
"""Test initialization of DropboxLoader."""
|
||||
loader, _ = setup_dropbox_loader
|
||||
assert loader is not None
|
||||
|
||||
|
||||
def test_download_folder(setup_dropbox_loader, mocker):
|
||||
"""Test downloading a folder."""
|
||||
loader, mock_dbx = setup_dropbox_loader
|
||||
mocker.patch("os.makedirs")
|
||||
mocker.patch("os.path.join", return_value="mock/path")
|
||||
|
||||
mock_file_metadata = mocker.MagicMock(spec=FileMetadata)
|
||||
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
|
||||
|
||||
entries = loader._download_folder("path/to/folder", "local_root")
|
||||
assert entries is not None
|
||||
|
||||
|
||||
def test_generate_dir_id_from_all_paths(setup_dropbox_loader, mocker):
|
||||
"""Test directory ID generation."""
|
||||
loader, mock_dbx = setup_dropbox_loader
|
||||
mock_file_metadata = mocker.MagicMock(spec=FileMetadata, name="file.txt")
|
||||
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
|
||||
|
||||
dir_id = loader._generate_dir_id_from_all_paths("path/to/folder")
|
||||
assert dir_id is not None
|
||||
assert len(dir_id) == 64
|
||||
|
||||
|
||||
def test_clean_directory(setup_dropbox_loader, mocker):
|
||||
"""Test cleaning up a directory."""
|
||||
loader, _ = setup_dropbox_loader
|
||||
mocker.patch("os.listdir", return_value=["file1", "file2"])
|
||||
mocker.patch("os.remove")
|
||||
mocker.patch("os.rmdir")
|
||||
|
||||
loader._clean_directory("path/to/folder")
|
||||
|
||||
|
||||
def test_load_data(mocker, setup_dropbox_loader, tmp_path):
|
||||
loader = setup_dropbox_loader[0]
|
||||
|
||||
mock_file_metadata = MagicMock(spec=FileMetadata, name="file.txt")
|
||||
mocker.patch.object(loader.dbx, "files_list_folder", return_value=MagicMock(entries=[mock_file_metadata]))
|
||||
mocker.patch.object(loader.dbx, "files_download_to_file")
|
||||
|
||||
# Mock DirectoryLoader
|
||||
mock_data = {"data": "test_data"}
|
||||
mocker.patch("embedchain.loaders.directory_loader.DirectoryLoader.load_data", return_value=mock_data)
|
||||
|
||||
test_dir = tmp_path / "dropbox_test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "file.txt"
|
||||
test_file.write_text("dummy content")
|
||||
mocker.patch.object(loader, "_generate_dir_id_from_all_paths", return_value=str(test_dir))
|
||||
|
||||
result = loader.load_data("path/to/folder")
|
||||
|
||||
assert result == {"doc_id": mocker.ANY, "data": "test_data"}
|
||||
loader.dbx.files_list_folder.assert_called_once_with("path/to/folder")
|
||||
33
embedchain/tests/loaders/test_excel_file.py
Normal file
33
embedchain/tests/loaders/test_excel_file.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import hashlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.excel_file import ExcelFileLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def excel_file_loader():
|
||||
return ExcelFileLoader()
|
||||
|
||||
|
||||
def test_load_data(excel_file_loader):
|
||||
mock_url = "mock_excel_file.xlsx"
|
||||
expected_content = "Sample Excel Content"
|
||||
|
||||
# Mock the load_data method of the excel_file_loader instance
|
||||
with patch.object(
|
||||
excel_file_loader,
|
||||
"load_data",
|
||||
return_value={
|
||||
"doc_id": hashlib.sha256((expected_content + mock_url).encode()).hexdigest(),
|
||||
"data": [{"content": expected_content, "meta_data": {"url": mock_url}}],
|
||||
},
|
||||
):
|
||||
result = excel_file_loader.load_data(mock_url)
|
||||
|
||||
assert result["data"][0]["content"] == expected_content
|
||||
assert result["data"][0]["meta_data"]["url"] == mock_url
|
||||
|
||||
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
33
embedchain/tests/loaders/test_github.py
Normal file
33
embedchain/tests/loaders/test_github.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_loader_config():
|
||||
return {
|
||||
"token": "your_mock_token",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_loader(mocker, mock_github_loader_config):
|
||||
mock_github = mocker.patch("github.Github")
|
||||
_ = mock_github.return_value
|
||||
return GithubLoader(config=mock_github_loader_config)
|
||||
|
||||
|
||||
def test_github_loader_init(mocker, mock_github_loader_config):
|
||||
mock_github = mocker.patch("github.Github")
|
||||
GithubLoader(config=mock_github_loader_config)
|
||||
mock_github.assert_called_once_with("your_mock_token")
|
||||
|
||||
|
||||
def test_github_loader_init_empty_config(mocker):
|
||||
with pytest.raises(ValueError, match="requires a personal access token"):
|
||||
GithubLoader()
|
||||
|
||||
|
||||
def test_github_loader_init_missing_token():
|
||||
with pytest.raises(ValueError, match="requires a personal access token"):
|
||||
GithubLoader(config={})
|
||||
43
embedchain/tests/loaders/test_gmail.py
Normal file
43
embedchain/tests/loaders/test_gmail.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.gmail import GmailLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_beautifulsoup(mocker):
|
||||
return mocker.patch("embedchain.loaders.gmail.BeautifulSoup", return_value=mocker.MagicMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gmail_loader(mock_beautifulsoup):
|
||||
return GmailLoader()
|
||||
|
||||
|
||||
def test_load_data_file_not_found(gmail_loader, mocker):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
with mocker.patch("os.path.isfile", return_value=False):
|
||||
gmail_loader.load_data("your_query")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO: Fix this test. Failing due to some googleapiclient import issue.")
|
||||
def test_load_data(gmail_loader, mocker):
|
||||
mock_gmail_reader_instance = mocker.MagicMock()
|
||||
text = "your_test_email_text"
|
||||
metadata = {
|
||||
"id": "your_test_id",
|
||||
"snippet": "your_test_snippet",
|
||||
}
|
||||
mock_gmail_reader_instance.load_data.return_value = [
|
||||
{
|
||||
"text": text,
|
||||
"extra_info": metadata,
|
||||
}
|
||||
]
|
||||
|
||||
with mocker.patch("os.path.isfile", return_value=True):
|
||||
response_data = gmail_loader.load_data("your_query")
|
||||
|
||||
assert "doc_id" in response_data
|
||||
assert "data" in response_data
|
||||
assert isinstance(response_data["doc_id"], str)
|
||||
assert isinstance(response_data["data"], list)
|
||||
37
embedchain/tests/loaders/test_google_drive.py
Normal file
37
embedchain/tests/loaders/test_google_drive.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.google_drive import GoogleDriveLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_drive_folder_loader():
|
||||
return GoogleDriveLoader()
|
||||
|
||||
|
||||
def test_load_data_invalid_drive_url(google_drive_folder_loader):
|
||||
mock_invalid_drive_url = "https://example.com"
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The url provided https://example.com does not match a google drive folder url. Example "
|
||||
"drive url: https://drive.google.com/drive/u/0/folders/xxxx",
|
||||
):
|
||||
google_drive_folder_loader.load_data(mock_invalid_drive_url)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
|
||||
def test_load_data_incorrect_drive_url(google_drive_folder_loader):
|
||||
mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
|
||||
with pytest.raises(
|
||||
FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
|
||||
):
|
||||
google_drive_folder_loader.load_data(mock_invalid_drive_url)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
|
||||
def test_load_data(google_drive_folder_loader):
|
||||
mock_valid_url = "YOUR_VALID_URL"
|
||||
result = google_drive_folder_loader.load_data(mock_valid_url)
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert "content" in result["data"][0]
|
||||
assert "meta_data" in result["data"][0]
|
||||
131
embedchain/tests/loaders/test_json.py
Normal file
131
embedchain/tests/loaders/test_json.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
|
||||
|
||||
def test_load_data(mocker):
|
||||
content = "temp.json"
|
||||
|
||||
mock_document = {
|
||||
"doc_id": hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest(),
|
||||
"data": [
|
||||
{"content": "content1", "meta_data": {"url": content}},
|
||||
{"content": "content2", "meta_data": {"url": content}},
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch("embedchain.loaders.json.JSONLoader.load_data", return_value=mock_document)
|
||||
|
||||
json_loader = JSONLoader()
|
||||
|
||||
result = json_loader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_data = [
|
||||
{"content": "content1", "meta_data": {"url": content}},
|
||||
{"content": "content2", "meta_data": {"url": content}},
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
|
||||
def test_load_data_url(mocker):
|
||||
content = "https://example.com/posts.json"
|
||||
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
mocker.patch(
|
||||
"embedchain.loaders.json.JSONReader.load_data",
|
||||
return_value=[
|
||||
{
|
||||
"text": "content1",
|
||||
},
|
||||
{
|
||||
"text": "content2",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"document1": "content1", "document2": "content2"}
|
||||
|
||||
mocker.patch("requests.get", return_value=mock_response)
|
||||
|
||||
result = JSONLoader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_data = [
|
||||
{"content": "content1", "meta_data": {"url": content}},
|
||||
{"content": "content2", "meta_data": {"url": content}},
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
|
||||
def test_load_data_invalid_string_content(mocker):
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
mocker.patch("requests.get")
|
||||
|
||||
content = "123: 345}"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid content to load json data from"):
|
||||
JSONLoader.load_data(content)
|
||||
|
||||
|
||||
def test_load_data_invalid_url(mocker):
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 404
|
||||
mocker.patch("requests.get", return_value=mock_response)
|
||||
|
||||
content = "http://invalid-url.com/"
|
||||
|
||||
with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
|
||||
JSONLoader.load_data(content)
|
||||
|
||||
|
||||
def test_load_data_from_json_string(mocker):
|
||||
content = '{"foo": "bar"}'
|
||||
|
||||
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
|
||||
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
mocker.patch(
|
||||
"embedchain.loaders.json.JSONReader.load_data",
|
||||
return_value=[
|
||||
{
|
||||
"text": "content1",
|
||||
},
|
||||
{
|
||||
"text": "content2",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result = JSONLoader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_data = [
|
||||
{"content": "content1", "meta_data": {"url": content_url_str}},
|
||||
{"content": "content2", "meta_data": {"url": content_url_str}},
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
32
embedchain/tests/loaders/test_local_qna_pair.py
Normal file
32
embedchain/tests/loaders/test_local_qna_pair.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qna_pair_loader():
|
||||
return LocalQnaPairLoader()
|
||||
|
||||
|
||||
def test_load_data(qna_pair_loader):
|
||||
question = "What is the capital of France?"
|
||||
answer = "The capital of France is Paris."
|
||||
|
||||
content = (question, answer)
|
||||
result = qna_pair_loader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
url = "local"
|
||||
|
||||
expected_content = f"Q: {question}\nA: {answer}"
|
||||
assert result["data"][0]["content"] == expected_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == url
|
||||
|
||||
assert result["data"][0]["meta_data"]["question"] == question
|
||||
|
||||
expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
27
embedchain/tests/loaders/test_local_text.py
Normal file
27
embedchain/tests/loaders/test_local_text.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_text import LocalTextLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_loader():
|
||||
return LocalTextLoader()
|
||||
|
||||
|
||||
def test_load_data(text_loader):
|
||||
mock_content = "This is a sample text content."
|
||||
|
||||
result = text_loader.load_data(mock_content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
url = "local"
|
||||
assert result["data"][0]["content"] == mock_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == url
|
||||
|
||||
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
30
embedchain/tests/loaders/test_mdx.py
Normal file
30
embedchain/tests/loaders/test_mdx.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import hashlib
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mdx_loader():
|
||||
return MdxLoader()
|
||||
|
||||
|
||||
def test_load_data(mdx_loader):
|
||||
mock_content = "Sample MDX Content"
|
||||
|
||||
# Mock open function to simulate file reading
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
url = "mock_file.mdx"
|
||||
result = mdx_loader.load_data(url)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
assert result["data"][0]["content"] == mock_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == url
|
||||
|
||||
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
77
embedchain/tests/loaders/test_mysql.py
Normal file
77
embedchain/tests/loaders/test_mysql.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.mysql import MySQLLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mysql_loader(mocker):
|
||||
with mocker.patch("mysql.connector.connection.MySQLConnection"):
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": "3306",
|
||||
"user": "your_username",
|
||||
"password": "your_password",
|
||||
"database": "your_database",
|
||||
}
|
||||
loader = MySQLLoader(config=config)
|
||||
yield loader
|
||||
|
||||
|
||||
def test_mysql_loader_initialization(mysql_loader):
|
||||
assert mysql_loader.config is not None
|
||||
assert mysql_loader.connection is not None
|
||||
assert mysql_loader.cursor is not None
|
||||
|
||||
|
||||
def test_mysql_loader_invalid_config():
|
||||
with pytest.raises(ValueError, match="Invalid sql config: None"):
|
||||
MySQLLoader(config=None)
|
||||
|
||||
|
||||
def test_mysql_loader_setup_loader_successful(mysql_loader):
|
||||
assert mysql_loader.connection is not None
|
||||
assert mysql_loader.cursor is not None
|
||||
|
||||
|
||||
def test_mysql_loader_setup_loader_connection_error(mysql_loader, mocker):
|
||||
mocker.patch("mysql.connector.connection.MySQLConnection", side_effect=IOError("Mocked connection error"))
|
||||
with pytest.raises(ValueError, match="Unable to connect with the given config:"):
|
||||
mysql_loader._setup_loader(config={})
|
||||
|
||||
|
||||
def test_mysql_loader_check_query_successful(mysql_loader):
|
||||
query = "SELECT * FROM table"
|
||||
mysql_loader._check_query(query=query)
|
||||
|
||||
|
||||
def test_mysql_loader_check_query_invalid(mysql_loader):
|
||||
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
|
||||
mysql_loader._check_query(query=123)
|
||||
|
||||
|
||||
def test_mysql_loader_load_data_successful(mysql_loader, mocker):
|
||||
mock_cursor = MagicMock()
|
||||
mocker.patch.object(mysql_loader, "cursor", mock_cursor)
|
||||
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
|
||||
|
||||
query = "SELECT * FROM table"
|
||||
result = mysql_loader.load_data(query)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
assert result["data"][0]["meta_data"]["url"] == query
|
||||
assert result["data"][1]["meta_data"]["url"] == query
|
||||
|
||||
doc_id = hashlib.sha256((query + ", ".join([d["content"] for d in result["data"]])).encode()).hexdigest()
|
||||
|
||||
assert result["doc_id"] == doc_id
|
||||
assert mock_cursor.execute.called_with(query)
|
||||
|
||||
|
||||
def test_mysql_loader_load_data_invalid_query(mysql_loader):
|
||||
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
|
||||
mysql_loader.load_data(query=123)
|
||||
36
embedchain/tests/loaders/test_notion.py
Normal file
36
embedchain/tests/loaders/test_notion.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import hashlib
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.notion import NotionLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_loader():
|
||||
with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}):
|
||||
yield NotionLoader()
|
||||
|
||||
|
||||
def test_load_data(notion_loader):
|
||||
source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef"
|
||||
mock_text = "This is a test page."
|
||||
expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest()
|
||||
expected_data = [
|
||||
{
|
||||
"content": mock_text,
|
||||
"meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"}, # formatted_id
|
||||
}
|
||||
]
|
||||
|
||||
mock_page = Mock()
|
||||
mock_page.text = mock_text
|
||||
mock_documents = [mock_page]
|
||||
|
||||
with patch("embedchain.loaders.notion.NotionPageLoader") as mock_reader:
|
||||
mock_reader.return_value.load_data.return_value = mock_documents
|
||||
result = notion_loader.load_data(source)
|
||||
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
assert result["data"] == expected_data
|
||||
26
embedchain/tests/loaders/test_openapi.py
Normal file
26
embedchain/tests/loaders/test_openapi.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.openapi import OpenAPILoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_loader():
|
||||
return OpenAPILoader()
|
||||
|
||||
|
||||
def test_load_data(openapi_loader, mocker):
|
||||
mocker.patch("builtins.open", mocker.mock_open(read_data="key1: value1\nkey2: value2"))
|
||||
|
||||
mocker.patch("hashlib.sha256", return_value=mocker.Mock(hexdigest=lambda: "mock_hash"))
|
||||
|
||||
file_path = "configs/openai_openapi.yaml"
|
||||
result = openapi_loader.load_data(file_path)
|
||||
|
||||
expected_doc_id = "mock_hash"
|
||||
expected_data = [
|
||||
{"content": "key1: value1", "meta_data": {"url": file_path, "row": 1}},
|
||||
{"content": "key2: value2", "meta_data": {"url": file_path, "row": 2}},
|
||||
]
|
||||
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
assert result["data"] == expected_data
|
||||
36
embedchain/tests/loaders/test_pdf_file.py
Normal file
36
embedchain/tests/loaders/test_pdf_file.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_load_data(loader, mocker):
|
||||
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
|
||||
mocked_pypdfloader.return_value.load_and_split.return_value = [
|
||||
Document(page_content="Page 0 Content", metadata={"source": "example.pdf", "page": 0}),
|
||||
Document(page_content="Page 1 Content", metadata={"source": "example.pdf", "page": 1}),
|
||||
]
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data("dummy_url")
|
||||
assert result["doc_id"] is doc_id
|
||||
assert result["data"] == [
|
||||
{"content": "Page 0 Content", "meta_data": {"source": "example.pdf", "page": 0, "url": "dummy_url"}},
|
||||
{"content": "Page 1 Content", "meta_data": {"source": "example.pdf", "page": 1, "url": "dummy_url"}},
|
||||
]
|
||||
|
||||
|
||||
def test_load_data_fails_to_find_data(loader, mocker):
|
||||
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
|
||||
mocked_pypdfloader.return_value.load_and_split.return_value = []
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
loader.load_data("dummy_url")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
|
||||
return PdfFileLoader()
|
||||
60
embedchain/tests/loaders/test_postgres.py
Normal file
60
embedchain/tests/loaders/test_postgres.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import psycopg
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def postgres_loader(mocker):
|
||||
with mocker.patch.object(psycopg, "connect"):
|
||||
config = {"url": "postgres://user:password@localhost:5432/database"}
|
||||
loader = PostgresLoader(config=config)
|
||||
yield loader
|
||||
|
||||
|
||||
def test_postgres_loader_initialization(postgres_loader):
|
||||
assert postgres_loader.connection is not None
|
||||
assert postgres_loader.cursor is not None
|
||||
|
||||
|
||||
def test_postgres_loader_invalid_config():
|
||||
with pytest.raises(ValueError, match="Must provide the valid config. Received: None"):
|
||||
PostgresLoader(config=None)
|
||||
|
||||
|
||||
def test_load_data(postgres_loader, monkeypatch):
|
||||
mock_cursor = MagicMock()
|
||||
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
|
||||
|
||||
query = "SELECT * FROM table"
|
||||
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
|
||||
|
||||
result = postgres_loader.load_data(query)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
assert result["data"][0]["meta_data"]["url"] == query
|
||||
assert result["data"][1]["meta_data"]["url"] == query
|
||||
assert mock_cursor.execute.called_with(query)
|
||||
|
||||
|
||||
def test_load_data_exception(postgres_loader, monkeypatch):
|
||||
mock_cursor = MagicMock()
|
||||
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
|
||||
|
||||
_ = "SELECT * FROM table"
|
||||
mock_cursor.execute.side_effect = Exception("Mocked exception")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception"
|
||||
):
|
||||
postgres_loader.load_data("SELECT * FROM table")
|
||||
|
||||
|
||||
def test_close_connection(postgres_loader):
|
||||
postgres_loader.close_connection()
|
||||
assert postgres_loader.cursor is None
|
||||
assert postgres_loader.connection is None
|
||||
47
embedchain/tests/loaders/test_slack.py
Normal file
47
embedchain/tests/loaders/test_slack.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.slack import SlackLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def slack_loader(mocker, monkeypatch):
|
||||
# Mocking necessary dependencies
|
||||
mocker.patch("slack_sdk.WebClient")
|
||||
mocker.patch("ssl.create_default_context")
|
||||
mocker.patch("certifi.where")
|
||||
|
||||
monkeypatch.setenv("SLACK_USER_TOKEN", "slack_user_token")
|
||||
|
||||
return SlackLoader()
|
||||
|
||||
|
||||
def test_slack_loader_initialization(slack_loader):
|
||||
assert slack_loader.client is not None
|
||||
assert slack_loader.config == {"base_url": "https://www.slack.com/api/"}
|
||||
|
||||
|
||||
def test_slack_loader_setup_loader(slack_loader):
|
||||
slack_loader._setup_loader({"base_url": "https://custom.slack.api/"})
|
||||
|
||||
assert slack_loader.client is not None
|
||||
|
||||
|
||||
def test_slack_loader_check_query(slack_loader):
|
||||
valid_json_query = "test_query"
|
||||
invalid_query = 123
|
||||
|
||||
slack_loader._check_query(valid_json_query)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
slack_loader._check_query(invalid_query)
|
||||
|
||||
|
||||
def test_slack_loader_load_data(slack_loader, mocker):
|
||||
valid_json_query = "in:random"
|
||||
|
||||
mocker.patch.object(slack_loader.client, "search_messages", return_value={"messages": {}})
|
||||
|
||||
result = slack_loader.load_data(valid_json_query)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
117
embedchain/tests/loaders/test_web_page.py
Normal file
117
embedchain/tests/loaders/test_web_page.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import hashlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_page_loader():
|
||||
return WebPageLoader()
|
||||
|
||||
|
||||
def test_load_data(web_page_loader):
|
||||
page_url = "https://example.com/page"
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Test Page</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="content">
|
||||
<p>This is some test content.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
with patch("embedchain.loaders.web_page.WebPageLoader._session.get", return_value=mock_response):
|
||||
result = web_page_loader.load_data(page_url)
|
||||
|
||||
content = web_page_loader._get_clean_content(mock_response.content, page_url)
|
||||
expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
expected_data = [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": {
|
||||
"url": page_url,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
|
||||
def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
|
||||
mock_html = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Sample HTML</title>
|
||||
<style>
|
||||
/* Stylesheet to be excluded */
|
||||
.elementor-location-header {
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header id="header">Header Content</header>
|
||||
<nav class="nav">Nav Content</nav>
|
||||
<aside>Aside Content</aside>
|
||||
<form>Form Content</form>
|
||||
<main>Main Content</main>
|
||||
<footer class="footer">Footer Content</footer>
|
||||
<script>Some Script</script>
|
||||
<noscript>NoScript Content</noscript>
|
||||
<svg>SVG Content</svg>
|
||||
<canvas>Canvas Content</canvas>
|
||||
|
||||
<div id="sidebar">Sidebar Content</div>
|
||||
<div id="main-navigation">Main Navigation Content</div>
|
||||
<div id="menu-main-menu">Menu Main Menu Content</div>
|
||||
|
||||
<div class="header-sidebar-wrapper">Header Sidebar Wrapper Content</div>
|
||||
<div class="blog-sidebar-wrapper">Blog Sidebar Wrapper Content</div>
|
||||
<div class="related-posts">Related Posts Content</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
tags_to_exclude = [
|
||||
"nav",
|
||||
"aside",
|
||||
"form",
|
||||
"header",
|
||||
"noscript",
|
||||
"svg",
|
||||
"canvas",
|
||||
"footer",
|
||||
"script",
|
||||
"style",
|
||||
]
|
||||
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
||||
classes_to_exclude = [
|
||||
"elementor-location-header",
|
||||
"navbar-header",
|
||||
"nav",
|
||||
"header-sidebar-wrapper",
|
||||
"blog-sidebar-wrapper",
|
||||
"related-posts",
|
||||
]
|
||||
|
||||
content = web_page_loader._get_clean_content(mock_html, "https://example.com/page")
|
||||
|
||||
for tag in tags_to_exclude:
|
||||
assert tag not in content
|
||||
|
||||
for id in ids_to_exclude:
|
||||
assert id not in content
|
||||
|
||||
for class_name in classes_to_exclude:
|
||||
assert class_name not in content
|
||||
|
||||
assert len(content) > 0
|
||||
62
embedchain/tests/loaders/test_xml.py
Normal file
62
embedchain/tests/loaders/test_xml.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.xml import XmlLoader
|
||||
|
||||
# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
|
||||
SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<factbook>
|
||||
<country>
|
||||
<name>United States</name>
|
||||
<capital>Washington, DC</capital>
|
||||
<leader>Joe Biden</leader>
|
||||
<sport>Baseball</sport>
|
||||
</country>
|
||||
<country>
|
||||
<name>Canada</name>
|
||||
<capital>Ottawa</capital>
|
||||
<leader>Justin Trudeau</leader>
|
||||
<sport>Hockey</sport>
|
||||
</country>
|
||||
<country>
|
||||
<name>France</name>
|
||||
<capital>Paris</capital>
|
||||
<leader>Emmanuel Macron</leader>
|
||||
<sport>Soccer</sport>
|
||||
</country>
|
||||
<country>
|
||||
<name>Trinidad & Tobado</name>
|
||||
<capital>Port of Spain</capital>
|
||||
<leader>Keith Rowley</leader>
|
||||
<sport>Track & Field</sport>
|
||||
</country>
|
||||
</factbook>"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("xml", [SAMPLE_XML])
|
||||
def test_load_data(xml: str):
|
||||
"""
|
||||
Test XML loader
|
||||
|
||||
Tests that XML file is loaded, metadata is correct and content is correct
|
||||
"""
|
||||
# Creating temporary XML file
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
|
||||
tmpfile.write(xml)
|
||||
|
||||
tmpfile.seek(0)
|
||||
filename = tmpfile.name
|
||||
|
||||
# Loading CSV using XmlLoader
|
||||
loader = XmlLoader()
|
||||
result = loader.load_data(filename)
|
||||
data = result["data"]
|
||||
|
||||
# Assertions
|
||||
assert len(data) == 1
|
||||
assert "United States Washington, DC Joe Biden" in data[0]["content"]
|
||||
assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
|
||||
assert "France Paris Emmanuel Macron" in data[0]["content"]
|
||||
assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
|
||||
assert data[0]["meta_data"]["url"] == filename
|
||||
53
embedchain/tests/loaders/test_youtube_video.py
Normal file
53
embedchain/tests/loaders/test_youtube_video.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def youtube_video_loader():
|
||||
return YoutubeVideoLoader()
|
||||
|
||||
|
||||
def test_load_data(youtube_video_loader):
|
||||
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
|
||||
mock_loader = Mock()
|
||||
mock_page_content = "This is a YouTube video content."
|
||||
mock_loader.load.return_value = [
|
||||
MagicMock(
|
||||
page_content=mock_page_content,
|
||||
metadata={"url": video_url, "title": "Test Video"},
|
||||
)
|
||||
]
|
||||
|
||||
mock_transcript = [{"text": "sample text", "start": 0.0, "duration": 5.0}]
|
||||
|
||||
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader), patch(
|
||||
"embedchain.loaders.youtube_video.YouTubeTranscriptApi.get_transcript", return_value=mock_transcript
|
||||
):
|
||||
result = youtube_video_loader.load_data(video_url)
|
||||
|
||||
expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest()
|
||||
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
expected_data = [
|
||||
{
|
||||
"content": "This is a YouTube video content.",
|
||||
"meta_data": {"url": video_url, "title": "Test Video", "transcript": "Unavailable"},
|
||||
}
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
|
||||
def test_load_data_with_empty_doc(youtube_video_loader):
|
||||
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
|
||||
mock_loader = Mock()
|
||||
mock_loader.load.return_value = []
|
||||
|
||||
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
|
||||
with pytest.raises(ValueError):
|
||||
youtube_video_loader.load_data(video_url)
|
||||
Reference in New Issue
Block a user