Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
52
embedchain/tests/embedchain/test_add.py
Normal file
52
embedchain/tests/embedchain/test_add.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mocker):
|
||||
mocker.patch("chromadb.api.models.Collection.Collection.add")
|
||||
return App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
|
||||
def test_add(app):
|
||||
app.add("https://example.com", metadata={"foo": "bar"})
|
||||
assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
|
||||
|
||||
|
||||
# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
|
||||
# def test_add_sitemap(app):
|
||||
# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
|
||||
# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_add_forced_type(app):
|
||||
data_type = "text"
|
||||
app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
|
||||
assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_dry_run(app):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
|
||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
|
||||
|
||||
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
|
||||
|
||||
chunks = result["chunks"]
|
||||
metadata = result["metadata"]
|
||||
count = result["count"]
|
||||
data_type = result["type"]
|
||||
|
||||
assert len(chunks) == len(text)
|
||||
assert count == len(text)
|
||||
assert data_type == DataType.TEXT
|
||||
for item in metadata:
|
||||
assert isinstance(item, dict)
|
||||
assert "local" in item["url"]
|
||||
assert "text" in item["data_type"]
|
||||
75
embedchain/tests/embedchain/test_embedchain.py
Normal file
75
embedchain/tests/embedchain/test_embedchain.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_instance():
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
return App(config=config)
|
||||
|
||||
|
||||
def test_whole_app(app_instance, mocker):
|
||||
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
|
||||
|
||||
mocker.patch.object(EmbedChain, "add")
|
||||
mocker.patch.object(EmbedChain, "_retrieve_from_database")
|
||||
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "generate_prompt")
|
||||
mocker.patch.object(BaseLlm, "add_history")
|
||||
mocker.patch.object(ChatHistory, "delete", autospec=True)
|
||||
|
||||
app_instance.add(knowledge, data_type="text")
|
||||
app_instance.query("What text did I give you?")
|
||||
app_instance.chat("What text did I give you?")
|
||||
|
||||
assert BaseLlm.generate_prompt.call_count == 2
|
||||
app_instance.reset()
|
||||
|
||||
|
||||
def test_add_after_reset(app_instance, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
db = ChromaDB(config=chroma_config)
|
||||
app_instance = App(config=config, db=db)
|
||||
|
||||
# mock delete chat history
|
||||
mocker.patch.object(ChatHistory, "delete", autospec=True)
|
||||
|
||||
app_instance.reset()
|
||||
|
||||
app_instance.db.client.heartbeat()
|
||||
|
||||
mocker.patch.object(Collection, "add")
|
||||
|
||||
app_instance.db.collection.add(
|
||||
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
|
||||
metadatas=[
|
||||
{"chapter": "3", "verse": "16"},
|
||||
{"chapter": "3", "verse": "5"},
|
||||
{"chapter": "29", "verse": "11"},
|
||||
],
|
||||
ids=["id1", "id2", "id3"],
|
||||
)
|
||||
|
||||
app_instance.reset()
|
||||
|
||||
|
||||
def test_add_with_incorrect_content(app_instance, mocker):
|
||||
content = [{"foo": "bar"}]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
app_instance.add(content, data_type="json")
|
||||
133
embedchain/tests/embedchain/test_utils.py
Normal file
133
embedchain/tests/embedchain/test_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils.misc import detect_datatype
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
"""Test that the datatype detection is working, based on the input."""
|
||||
|
||||
def test_detect_datatype_youtube(self):
|
||||
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(detect_datatype("https://m.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(
|
||||
detect_datatype("https://www.youtube-nocookie.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
self.assertEqual(detect_datatype("https://vid.plus/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(detect_datatype("https://youtu.be/dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def test_detect_datatype_local_file(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/file.txt"), DataType.WEB_PAGE)
|
||||
|
||||
def test_detect_datatype_pdf(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/document.pdf"), DataType.PDF_FILE)
|
||||
|
||||
def test_detect_datatype_local_pdf(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/document.pdf"), DataType.PDF_FILE)
|
||||
|
||||
def test_detect_datatype_xml(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/sitemap.xml"), DataType.SITEMAP)
|
||||
|
||||
def test_detect_datatype_local_xml(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/sitemap.xml"), DataType.SITEMAP)
|
||||
|
||||
def test_detect_datatype_docx(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/document.docx"), DataType.DOCX)
|
||||
|
||||
def test_detect_datatype_local_docx(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
|
||||
|
||||
def test_detect_data_type_json(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/data.json"), DataType.JSON)
|
||||
|
||||
def test_detect_data_type_local_json(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/data.json"), DataType.JSON)
|
||||
|
||||
@patch("os.path.isfile")
|
||||
def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
|
||||
mock_isfile.return_value = True
|
||||
self.assertEqual(detect_datatype(tmp.name), DataType.DOCX)
|
||||
|
||||
def test_detect_datatype_docs_site(self):
|
||||
self.assertEqual(detect_datatype("https://docs.example.com"), DataType.DOCS_SITE)
|
||||
|
||||
def test_detect_datatype_docs_sitein_path(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/docs/index.html"), DataType.DOCS_SITE)
|
||||
self.assertNotEqual(detect_datatype("file:///var/www/docs/index.html"), DataType.DOCS_SITE) # NOT equal
|
||||
|
||||
def test_detect_datatype_web_page(self):
|
||||
self.assertEqual(detect_datatype("https://nav.al/agi"), DataType.WEB_PAGE)
|
||||
|
||||
def test_detect_datatype_invalid_url(self):
|
||||
self.assertEqual(detect_datatype("not a url"), DataType.TEXT)
|
||||
|
||||
def test_detect_datatype_qna_pair(self):
|
||||
self.assertEqual(
|
||||
detect_datatype(("Question?", "Answer. Content of the string is irrelevant.")), DataType.QNA_PAIR
|
||||
) #
|
||||
|
||||
def test_detect_datatype_qna_pair_types(self):
|
||||
"""Test that a QnA pair needs to be a tuple of length two, and both items have to be strings."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.assertNotEqual(
|
||||
detect_datatype(("How many planets are in our solar system?", 8)), DataType.QNA_PAIR
|
||||
) # NOT equal
|
||||
|
||||
def test_detect_datatype_text(self):
|
||||
self.assertEqual(detect_datatype("Just some text."), DataType.TEXT)
|
||||
|
||||
def test_detect_datatype_non_string_error(self):
|
||||
"""Test type error if the value passed is not a string, and not a valid non-string data_type"""
|
||||
with self.assertRaises(TypeError):
|
||||
detect_datatype(["foo", "bar"])
|
||||
|
||||
@patch("os.path.isfile")
|
||||
def test_detect_datatype_regular_filesystem_file_txt(self, mock_isfile):
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
|
||||
mock_isfile.return_value = True
|
||||
self.assertEqual(detect_datatype(tmp.name), DataType.TEXT_FILE)
|
||||
|
||||
def test_detect_datatype_regular_filesystem_no_file(self):
|
||||
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""
|
||||
self.assertEqual(detect_datatype("/var/not-an-existing-file.txt"), DataType.TEXT)
|
||||
|
||||
def test_doc_examples_quickstart(self):
|
||||
"""Test examples used in the documentation."""
|
||||
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Elon_Musk"), DataType.WEB_PAGE)
|
||||
self.assertEqual(detect_datatype("https://www.tesla.com/elon-musk"), DataType.WEB_PAGE)
|
||||
|
||||
def test_doc_examples_introduction(self):
|
||||
"""Test examples used in the documentation."""
|
||||
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=3qHkcs3kG44"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(
|
||||
detect_datatype(
|
||||
"https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf"
|
||||
),
|
||||
DataType.PDF_FILE,
|
||||
)
|
||||
self.assertEqual(detect_datatype("https://nav.al/feedback"), DataType.WEB_PAGE)
|
||||
|
||||
def test_doc_examples_app_types(self):
|
||||
"""Test examples used in the documentation."""
|
||||
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=Ff4fRgnuFgQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Mark_Zuckerberg"), DataType.WEB_PAGE)
|
||||
|
||||
def test_doc_examples_configuration(self):
|
||||
"""Test examples used in the documentation."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "wikipedia"])
|
||||
import wikipedia
|
||||
|
||||
page = wikipedia.page("Albert Einstein")
|
||||
# TODO: Add a wikipedia type, so wikipedia is a dependency and we don't need this slow test.
|
||||
# (timings: import: 1.4s, fetch wiki: 0.7s)
|
||||
self.assertEqual(detect_datatype(page.content), DataType.TEXT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user