Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
0
embedchain/tests/__init__.py
Normal file
0
embedchain/tests/__init__.py
Normal file
99
embedchain/tests/chunkers/test_base_chunker.py
Normal file
99
embedchain/tests/chunkers/test_base_chunker.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_splitter_mock():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader_mock():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_id():
|
||||
return "test_app"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_type():
|
||||
return DataType.TEXT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chunker(text_splitter_mock, data_type):
|
||||
text_splitter = text_splitter_mock
|
||||
chunker = BaseChunker(text_splitter)
|
||||
chunker.set_data_type(data_type)
|
||||
return chunker
|
||||
|
||||
|
||||
def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
|
||||
text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
|
||||
loader_mock.load_data.return_value = {
|
||||
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
|
||||
"doc_id": "DocID",
|
||||
}
|
||||
config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
|
||||
result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
|
||||
|
||||
assert result["documents"] == ["long chunk"]
|
||||
|
||||
|
||||
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
|
||||
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
|
||||
loader_mock.load_data.return_value = {
|
||||
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
|
||||
"doc_id": "DocID",
|
||||
}
|
||||
|
||||
result = chunker.create_chunks(loader_mock, "test_src", app_id)
|
||||
expected_ids = [
|
||||
f"{app_id}--" + hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
|
||||
f"{app_id}--" + hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
|
||||
]
|
||||
|
||||
assert result["documents"] == ["Chunk 1", "Chunk 2"]
|
||||
assert result["ids"] == expected_ids
|
||||
assert result["metadatas"] == [
|
||||
{
|
||||
"url": "URL 1",
|
||||
"data_type": data_type.value,
|
||||
"doc_id": f"{app_id}--DocID",
|
||||
},
|
||||
{
|
||||
"url": "URL 1",
|
||||
"data_type": data_type.value,
|
||||
"doc_id": f"{app_id}--DocID",
|
||||
},
|
||||
]
|
||||
assert result["doc_id"] == f"{app_id}--DocID"
|
||||
|
||||
|
||||
def test_get_chunks(chunker, text_splitter_mock):
|
||||
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
|
||||
|
||||
content = "This is a test content."
|
||||
result = chunker.get_chunks(content)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == ["Chunk 1", "Chunk 2"]
|
||||
|
||||
|
||||
def test_set_data_type(chunker):
|
||||
chunker.set_data_type(DataType.MDX)
|
||||
assert chunker.data_type == DataType.MDX
|
||||
|
||||
|
||||
def test_get_word_count(chunker):
|
||||
documents = ["This is a test.", "Another test."]
|
||||
result = chunker.get_word_count(documents)
|
||||
assert result == 6
|
||||
66
embedchain/tests/chunkers/test_chunkers.py
Normal file
66
embedchain/tests/chunkers/test_chunkers.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from embedchain.chunkers.audio import AudioChunker
|
||||
from embedchain.chunkers.common_chunker import CommonChunker
|
||||
from embedchain.chunkers.discourse import DiscourseChunker
|
||||
from embedchain.chunkers.docs_site import DocsSiteChunker
|
||||
from embedchain.chunkers.docx_file import DocxFileChunker
|
||||
from embedchain.chunkers.excel_file import ExcelFileChunker
|
||||
from embedchain.chunkers.gmail import GmailChunker
|
||||
from embedchain.chunkers.google_drive import GoogleDriveChunker
|
||||
from embedchain.chunkers.json import JSONChunker
|
||||
from embedchain.chunkers.mdx import MdxChunker
|
||||
from embedchain.chunkers.notion import NotionChunker
|
||||
from embedchain.chunkers.openapi import OpenAPIChunker
|
||||
from embedchain.chunkers.pdf_file import PdfFileChunker
|
||||
from embedchain.chunkers.postgres import PostgresChunker
|
||||
from embedchain.chunkers.qna_pair import QnaPairChunker
|
||||
from embedchain.chunkers.sitemap import SitemapChunker
|
||||
from embedchain.chunkers.slack import SlackChunker
|
||||
from embedchain.chunkers.table import TableChunker
|
||||
from embedchain.chunkers.text import TextChunker
|
||||
from embedchain.chunkers.web_page import WebPageChunker
|
||||
from embedchain.chunkers.xml import XmlChunker
|
||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
||||
|
||||
chunker_common_config = {
|
||||
DocsSiteChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
|
||||
DocxFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
PdfFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
TextChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
|
||||
MdxChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
NotionChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
|
||||
QnaPairChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
|
||||
TableChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
|
||||
SitemapChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
|
||||
WebPageChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
|
||||
XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
|
||||
YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
|
||||
JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
|
||||
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
ExcelFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
AudioChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
}
|
||||
|
||||
|
||||
def test_default_config_values():
|
||||
for chunker_class, config in chunker_common_config.items():
|
||||
chunker = chunker_class()
|
||||
assert chunker.text_splitter._chunk_size == config["chunk_size"]
|
||||
assert chunker.text_splitter._chunk_overlap == config["chunk_overlap"]
|
||||
assert chunker.text_splitter._length_function == config["length_function"]
|
||||
|
||||
|
||||
def test_custom_config_values():
|
||||
for chunker_class, _ in chunker_common_config.items():
|
||||
chunker = chunker_class(config=chunker_config)
|
||||
assert chunker.text_splitter._chunk_size == 500
|
||||
assert chunker.text_splitter._chunk_overlap == 0
|
||||
assert chunker.text_splitter._length_function == len
|
||||
86
embedchain/tests/chunkers/test_text.py
Normal file
86
embedchain/tests/chunkers/test_text.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
from embedchain.chunkers.text import TextChunker
|
||||
from embedchain.config import ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class TestTextChunker:
|
||||
def test_chunks_without_app_id(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||
documents = result["documents"]
|
||||
assert len(documents) > 5
|
||||
|
||||
def test_chunks_with_app_id(self):
|
||||
"""
|
||||
Test the chunks generated by TextChunker with app_id
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||
documents = result["documents"]
|
||||
assert len(documents) > 5
|
||||
|
||||
def test_big_chunksize(self):
|
||||
"""
|
||||
Test that if an infinitely high chunk size is used, only one chunk is returned.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||
documents = result["documents"]
|
||||
assert len(documents) == 1
|
||||
|
||||
def test_small_chunksize(self):
|
||||
"""
|
||||
Test that if a chunk size of one is used, every character is a chunk.
|
||||
"""
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
|
||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
|
||||
# Data type must be set manually in the test
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||
documents = result["documents"]
|
||||
assert len(documents) == len(text)
|
||||
|
||||
def test_word_count(self):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||
chunker = TextChunker(config=chunker_config)
|
||||
chunker.set_data_type(DataType.TEXT)
|
||||
|
||||
document = ["ab cd", "ef gh"]
|
||||
result = chunker.get_word_count(document)
|
||||
assert result == 4
|
||||
|
||||
|
||||
class MockLoader:
|
||||
@staticmethod
|
||||
def load_data(src) -> dict:
|
||||
"""
|
||||
Mock loader that returns a list of data dictionaries.
|
||||
Adjust this method to return different data for testing.
|
||||
"""
|
||||
return {
|
||||
"doc_id": "123",
|
||||
"data": [
|
||||
{
|
||||
"content": src,
|
||||
"meta_data": {"url": "none"},
|
||||
}
|
||||
],
|
||||
}
|
||||
35
embedchain/tests/conftest.py
Normal file
35
embedchain/tests/conftest.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_db():
|
||||
db_path = os.path.expanduser("~/.embedchain/embedchain.db")
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
engine = create_engine(db_url)
|
||||
metadata = MetaData()
|
||||
metadata.reflect(bind=engine) # Reflect schema from the engine
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
try:
|
||||
# Iterate over all tables in reversed order to respect foreign keys
|
||||
for table in reversed(metadata.sorted_tables):
|
||||
if table.name != "alembic_version": # Skip the Alembic version table
|
||||
session.execute(table.delete())
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
print(f"Error cleaning database: {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry():
|
||||
os.environ["EC_TELEMETRY"] = "false"
|
||||
yield
|
||||
del os.environ["EC_TELEMETRY"]
|
||||
52
embedchain/tests/embedchain/test_add.py
Normal file
52
embedchain/tests/embedchain/test_add.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mocker):
|
||||
mocker.patch("chromadb.api.models.Collection.Collection.add")
|
||||
return App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
|
||||
def test_add(app):
|
||||
app.add("https://example.com", metadata={"foo": "bar"})
|
||||
assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
|
||||
|
||||
|
||||
# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
|
||||
# def test_add_sitemap(app):
|
||||
# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
|
||||
# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_add_forced_type(app):
|
||||
data_type = "text"
|
||||
app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
|
||||
assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_dry_run(app):
|
||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
|
||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
|
||||
|
||||
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
|
||||
|
||||
chunks = result["chunks"]
|
||||
metadata = result["metadata"]
|
||||
count = result["count"]
|
||||
data_type = result["type"]
|
||||
|
||||
assert len(chunks) == len(text)
|
||||
assert count == len(text)
|
||||
assert data_type == DataType.TEXT
|
||||
for item in metadata:
|
||||
assert isinstance(item, dict)
|
||||
assert "local" in item["url"]
|
||||
assert "text" in item["data_type"]
|
||||
75
embedchain/tests/embedchain/test_embedchain.py
Normal file
75
embedchain/tests/embedchain/test_embedchain.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_instance():
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
return App(config=config)
|
||||
|
||||
|
||||
def test_whole_app(app_instance, mocker):
|
||||
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
|
||||
|
||||
mocker.patch.object(EmbedChain, "add")
|
||||
mocker.patch.object(EmbedChain, "_retrieve_from_database")
|
||||
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
||||
mocker.patch.object(BaseLlm, "generate_prompt")
|
||||
mocker.patch.object(BaseLlm, "add_history")
|
||||
mocker.patch.object(ChatHistory, "delete", autospec=True)
|
||||
|
||||
app_instance.add(knowledge, data_type="text")
|
||||
app_instance.query("What text did I give you?")
|
||||
app_instance.chat("What text did I give you?")
|
||||
|
||||
assert BaseLlm.generate_prompt.call_count == 2
|
||||
app_instance.reset()
|
||||
|
||||
|
||||
def test_add_after_reset(app_instance, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
db = ChromaDB(config=chroma_config)
|
||||
app_instance = App(config=config, db=db)
|
||||
|
||||
# mock delete chat history
|
||||
mocker.patch.object(ChatHistory, "delete", autospec=True)
|
||||
|
||||
app_instance.reset()
|
||||
|
||||
app_instance.db.client.heartbeat()
|
||||
|
||||
mocker.patch.object(Collection, "add")
|
||||
|
||||
app_instance.db.collection.add(
|
||||
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
|
||||
metadatas=[
|
||||
{"chapter": "3", "verse": "16"},
|
||||
{"chapter": "3", "verse": "5"},
|
||||
{"chapter": "29", "verse": "11"},
|
||||
],
|
||||
ids=["id1", "id2", "id3"],
|
||||
)
|
||||
|
||||
app_instance.reset()
|
||||
|
||||
|
||||
def test_add_with_incorrect_content(app_instance, mocker):
|
||||
content = [{"foo": "bar"}]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
app_instance.add(content, data_type="json")
|
||||
133
embedchain/tests/embedchain/test_utils.py
Normal file
133
embedchain/tests/embedchain/test_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils.misc import detect_datatype
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
"""Test that the datatype detection is working, based on the input."""
|
||||
|
||||
def test_detect_datatype_youtube(self):
|
||||
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(detect_datatype("https://m.youtube.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(
|
||||
detect_datatype("https://www.youtube-nocookie.com/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
self.assertEqual(detect_datatype("https://vid.plus/watch?v=dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(detect_datatype("https://youtu.be/dQw4w9WgXcQ"), DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def test_detect_datatype_local_file(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/file.txt"), DataType.WEB_PAGE)
|
||||
|
||||
def test_detect_datatype_pdf(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/document.pdf"), DataType.PDF_FILE)
|
||||
|
||||
def test_detect_datatype_local_pdf(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/document.pdf"), DataType.PDF_FILE)
|
||||
|
||||
def test_detect_datatype_xml(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/sitemap.xml"), DataType.SITEMAP)
|
||||
|
||||
def test_detect_datatype_local_xml(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/sitemap.xml"), DataType.SITEMAP)
|
||||
|
||||
def test_detect_datatype_docx(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/document.docx"), DataType.DOCX)
|
||||
|
||||
def test_detect_datatype_local_docx(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/document.docx"), DataType.DOCX)
|
||||
|
||||
def test_detect_data_type_json(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/data.json"), DataType.JSON)
|
||||
|
||||
def test_detect_data_type_local_json(self):
|
||||
self.assertEqual(detect_datatype("file:///home/user/data.json"), DataType.JSON)
|
||||
|
||||
@patch("os.path.isfile")
|
||||
def test_detect_datatype_regular_filesystem_docx(self, mock_isfile):
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
|
||||
mock_isfile.return_value = True
|
||||
self.assertEqual(detect_datatype(tmp.name), DataType.DOCX)
|
||||
|
||||
def test_detect_datatype_docs_site(self):
|
||||
self.assertEqual(detect_datatype("https://docs.example.com"), DataType.DOCS_SITE)
|
||||
|
||||
def test_detect_datatype_docs_sitein_path(self):
|
||||
self.assertEqual(detect_datatype("https://www.example.com/docs/index.html"), DataType.DOCS_SITE)
|
||||
self.assertNotEqual(detect_datatype("file:///var/www/docs/index.html"), DataType.DOCS_SITE) # NOT equal
|
||||
|
||||
def test_detect_datatype_web_page(self):
|
||||
self.assertEqual(detect_datatype("https://nav.al/agi"), DataType.WEB_PAGE)
|
||||
|
||||
def test_detect_datatype_invalid_url(self):
|
||||
self.assertEqual(detect_datatype("not a url"), DataType.TEXT)
|
||||
|
||||
def test_detect_datatype_qna_pair(self):
|
||||
self.assertEqual(
|
||||
detect_datatype(("Question?", "Answer. Content of the string is irrelevant.")), DataType.QNA_PAIR
|
||||
) #
|
||||
|
||||
def test_detect_datatype_qna_pair_types(self):
|
||||
"""Test that a QnA pair needs to be a tuple of length two, and both items have to be strings."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.assertNotEqual(
|
||||
detect_datatype(("How many planets are in our solar system?", 8)), DataType.QNA_PAIR
|
||||
) # NOT equal
|
||||
|
||||
def test_detect_datatype_text(self):
|
||||
self.assertEqual(detect_datatype("Just some text."), DataType.TEXT)
|
||||
|
||||
def test_detect_datatype_non_string_error(self):
|
||||
"""Test type error if the value passed is not a string, and not a valid non-string data_type"""
|
||||
with self.assertRaises(TypeError):
|
||||
detect_datatype(["foo", "bar"])
|
||||
|
||||
@patch("os.path.isfile")
|
||||
def test_detect_datatype_regular_filesystem_file_txt(self, mock_isfile):
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
|
||||
mock_isfile.return_value = True
|
||||
self.assertEqual(detect_datatype(tmp.name), DataType.TEXT_FILE)
|
||||
|
||||
def test_detect_datatype_regular_filesystem_no_file(self):
|
||||
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""
|
||||
self.assertEqual(detect_datatype("/var/not-an-existing-file.txt"), DataType.TEXT)
|
||||
|
||||
def test_doc_examples_quickstart(self):
|
||||
"""Test examples used in the documentation."""
|
||||
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Elon_Musk"), DataType.WEB_PAGE)
|
||||
self.assertEqual(detect_datatype("https://www.tesla.com/elon-musk"), DataType.WEB_PAGE)
|
||||
|
||||
def test_doc_examples_introduction(self):
|
||||
"""Test examples used in the documentation."""
|
||||
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=3qHkcs3kG44"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(
|
||||
detect_datatype(
|
||||
"https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf"
|
||||
),
|
||||
DataType.PDF_FILE,
|
||||
)
|
||||
self.assertEqual(detect_datatype("https://nav.al/feedback"), DataType.WEB_PAGE)
|
||||
|
||||
def test_doc_examples_app_types(self):
|
||||
"""Test examples used in the documentation."""
|
||||
self.assertEqual(detect_datatype("https://www.youtube.com/watch?v=Ff4fRgnuFgQ"), DataType.YOUTUBE_VIDEO)
|
||||
self.assertEqual(detect_datatype("https://en.wikipedia.org/wiki/Mark_Zuckerberg"), DataType.WEB_PAGE)
|
||||
|
||||
def test_doc_examples_configuration(self):
|
||||
"""Test examples used in the documentation."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "wikipedia"])
|
||||
import wikipedia
|
||||
|
||||
page = wikipedia.page("Albert Einstein")
|
||||
# TODO: Add a wikipedia type, so wikipedia is a dependency and we don't need this slow test.
|
||||
# (timings: import: 1.4s, fetch wiki: 0.7s)
|
||||
self.assertEqual(detect_datatype(page.content), DataType.TEXT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
49
embedchain/tests/embedder/test_embedder.py
Normal file
49
embedchain/tests/embedder/test_embedder.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_embedder():
|
||||
return BaseEmbedder()
|
||||
|
||||
|
||||
def test_initialization(base_embedder):
|
||||
assert isinstance(base_embedder.config, BaseEmbedderConfig)
|
||||
# not initialized
|
||||
assert not hasattr(base_embedder, "embedding_fn")
|
||||
assert not hasattr(base_embedder, "vector_dimension")
|
||||
|
||||
|
||||
def test_set_embedding_fn(base_embedder):
|
||||
def embedding_function(texts: Documents) -> Embeddings:
|
||||
return [f"Embedding for {text}" for text in texts]
|
||||
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
assert hasattr(base_embedder, "embedding_fn")
|
||||
assert callable(base_embedder.embedding_fn)
|
||||
embeddings = base_embedder.embedding_fn(["text1", "text2"])
|
||||
assert embeddings == ["Embedding for text1", "Embedding for text2"]
|
||||
|
||||
|
||||
def test_set_embedding_fn_when_not_a_function(base_embedder):
|
||||
with pytest.raises(ValueError):
|
||||
base_embedder.set_embedding_fn(None)
|
||||
|
||||
|
||||
def test_set_vector_dimension(base_embedder):
|
||||
base_embedder.set_vector_dimension(256)
|
||||
assert hasattr(base_embedder, "vector_dimension")
|
||||
assert base_embedder.vector_dimension == 256
|
||||
|
||||
|
||||
def test_set_vector_dimension_type_error(base_embedder):
|
||||
with pytest.raises(TypeError):
|
||||
base_embedder.set_vector_dimension(None)
|
||||
|
||||
|
||||
def test_embedder_with_config():
|
||||
embedder = BaseEmbedder(BaseEmbedderConfig())
|
||||
assert isinstance(embedder.config, BaseEmbedderConfig)
|
||||
18
embedchain/tests/embedder/test_huggingface_embedder.py
Normal file
18
embedchain/tests/embedder/test_huggingface_embedder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
||||
|
||||
|
||||
def test_huggingface_embedder_with_model(monkeypatch):
|
||||
config = BaseEmbedderConfig(model="test-model", model_kwargs={"param": "value"})
|
||||
with patch('embedchain.embedder.huggingface.HuggingFaceEmbeddings') as mock_embeddings:
|
||||
embedder = HuggingFaceEmbedder(config=config)
|
||||
assert embedder.config.model == "test-model"
|
||||
assert embedder.config.model_kwargs == {"param": "value"}
|
||||
mock_embeddings.assert_called_once_with(
|
||||
model_name="test-model",
|
||||
model_kwargs={"param": "value"}
|
||||
)
|
||||
|
||||
|
||||
224
embedchain/tests/evaluation/test_answer_relevancy_metric.py
Normal file
224
embedchain/tests/evaluation/test_answer_relevancy_metric.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from embedchain.config.evaluation.base import AnswerRelevanceConfig
|
||||
from embedchain.evaluation.metrics import AnswerRelevance
|
||||
from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data():
|
||||
return [
|
||||
EvalData(
|
||||
contexts=[
|
||||
"This is a test context 1.",
|
||||
],
|
||||
question="This is a test question 1.",
|
||||
answer="This is a test answer 1.",
|
||||
),
|
||||
EvalData(
|
||||
contexts=[
|
||||
"This is a test context 2-1.",
|
||||
"This is a test context 2-2.",
|
||||
],
|
||||
question="This is a test question 2.",
|
||||
answer="This is a test answer 2.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_answer_relevance_metric(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
monkeypatch.setenv("OPENAI_API_BASE", "test_api_base")
|
||||
metric = AnswerRelevance()
|
||||
return metric
|
||||
|
||||
|
||||
def test_answer_relevance_init(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
metric = AnswerRelevance()
|
||||
assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
|
||||
assert metric.config.model == "gpt-4"
|
||||
assert metric.config.embedder == "text-embedding-ada-002"
|
||||
assert metric.config.api_key is None
|
||||
assert metric.config.num_gen_questions == 1
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_answer_relevance_init_with_config():
|
||||
metric = AnswerRelevance(config=AnswerRelevanceConfig(api_key="test_api_key"))
|
||||
assert metric.name == EvalMetric.ANSWER_RELEVANCY.value
|
||||
assert metric.config.model == "gpt-4"
|
||||
assert metric.config.embedder == "text-embedding-ada-002"
|
||||
assert metric.config.api_key == "test_api_key"
|
||||
assert metric.config.num_gen_questions == 1
|
||||
|
||||
|
||||
def test_answer_relevance_init_without_api_key(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
AnswerRelevance()
|
||||
|
||||
|
||||
def test_generate_prompt(mock_answer_relevance_metric, mock_data):
|
||||
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
|
||||
assert "This is a test answer 1." in prompt
|
||||
|
||||
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
|
||||
assert "This is a test answer 2." in prompt
|
||||
|
||||
|
||||
def test_generate_questions(mock_answer_relevance_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
|
||||
)
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[0])
|
||||
questions = mock_answer_relevance_metric._generate_questions(prompt)
|
||||
assert len(questions) == 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
prompt = mock_answer_relevance_metric._generate_prompt(mock_data[1])
|
||||
questions = mock_answer_relevance_metric._generate_questions(prompt)
|
||||
assert len(questions) == 2
|
||||
|
||||
|
||||
def test_generate_embedding(mock_answer_relevance_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.embeddings,
|
||||
"create",
|
||||
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
|
||||
)
|
||||
embedding = mock_answer_relevance_metric._generate_embedding("This is a test question.")
|
||||
assert len(embedding) == 3
|
||||
|
||||
|
||||
def test_compute_similarity(mock_answer_relevance_metric, mock_data):
|
||||
original = np.array([1, 2, 3])
|
||||
generated = np.array([[1, 2, 3], [1, 2, 3]])
|
||||
similarity = mock_answer_relevance_metric._compute_similarity(original, generated)
|
||||
assert len(similarity) == 2
|
||||
assert similarity[0] == 1.0
|
||||
assert similarity[1] == 1.0
|
||||
|
||||
|
||||
def test_compute_score(mock_answer_relevance_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
|
||||
)
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.embeddings,
|
||||
"create",
|
||||
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
|
||||
)
|
||||
score = mock_answer_relevance_metric._compute_score(mock_data[0])
|
||||
assert score == 1.0
|
||||
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.embeddings,
|
||||
"create",
|
||||
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
|
||||
)
|
||||
score = mock_answer_relevance_metric._compute_score(mock_data[1])
|
||||
assert score == 1.0
|
||||
|
||||
|
||||
def test_evaluate(mock_answer_relevance_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{"message": type("obj", (object,), {"content": "This is a test question response.\n"})},
|
||||
)
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.embeddings,
|
||||
"create",
|
||||
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
|
||||
)
|
||||
score = mock_answer_relevance_metric.evaluate(mock_data)
|
||||
assert score == 1.0
|
||||
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type("obj", (object,), {"message": type("obj", (object,), {"content": "question 1?\nquestion2?"})})
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mock_answer_relevance_metric.client.embeddings,
|
||||
"create",
|
||||
lambda input, model: type("obj", (object,), {"data": [type("obj", (object,), {"embedding": [1, 2, 3]})]})(),
|
||||
)
|
||||
score = mock_answer_relevance_metric.evaluate(mock_data)
|
||||
assert score == 1.0
|
||||
100
embedchain/tests/evaluation/test_context_relevancy_metric.py
Normal file
100
embedchain/tests/evaluation/test_context_relevancy_metric.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config.evaluation.base import ContextRelevanceConfig
|
||||
from embedchain.evaluation.metrics import ContextRelevance
|
||||
from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data():
|
||||
return [
|
||||
EvalData(
|
||||
contexts=[
|
||||
"This is a test context 1.",
|
||||
],
|
||||
question="This is a test question 1.",
|
||||
answer="This is a test answer 1.",
|
||||
),
|
||||
EvalData(
|
||||
contexts=[
|
||||
"This is a test context 2-1.",
|
||||
"This is a test context 2-2.",
|
||||
],
|
||||
question="This is a test question 2.",
|
||||
answer="This is a test answer 2.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_relevance_metric(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
metric = ContextRelevance()
|
||||
return metric
|
||||
|
||||
|
||||
def test_context_relevance_init(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
metric = ContextRelevance()
|
||||
assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
|
||||
assert metric.config.model == "gpt-4"
|
||||
assert metric.config.api_key is None
|
||||
assert metric.config.language == "en"
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_context_relevance_init_with_config():
|
||||
metric = ContextRelevance(config=ContextRelevanceConfig(api_key="test_api_key"))
|
||||
assert metric.name == EvalMetric.CONTEXT_RELEVANCY.value
|
||||
assert metric.config.model == "gpt-4"
|
||||
assert metric.config.api_key == "test_api_key"
|
||||
assert metric.config.language == "en"
|
||||
|
||||
|
||||
def test_context_relevance_init_without_api_key(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
ContextRelevance()
|
||||
|
||||
|
||||
def test_sentence_segmenter(mock_context_relevance_metric):
|
||||
text = "This is a test sentence. This is another sentence."
|
||||
assert mock_context_relevance_metric._sentence_segmenter(text) == [
|
||||
"This is a test sentence. ",
|
||||
"This is another sentence.",
|
||||
]
|
||||
|
||||
|
||||
def test_compute_score(mock_context_relevance_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_context_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
assert mock_context_relevance_metric._compute_score(mock_data[0]) == 1.0
|
||||
assert mock_context_relevance_metric._compute_score(mock_data[1]) == 0.5
|
||||
|
||||
|
||||
def test_evaluate(mock_context_relevance_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_context_relevance_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda model, messages: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type("obj", (object,), {"message": type("obj", (object,), {"content": "This is a test reponse."})})
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
assert mock_context_relevance_metric.evaluate(mock_data) == 0.75
|
||||
152
embedchain/tests/evaluation/test_groundedness_metric.py
Normal file
152
embedchain/tests/evaluation/test_groundedness_metric.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from embedchain.config.evaluation.base import GroundednessConfig
|
||||
from embedchain.evaluation.metrics import Groundedness
|
||||
from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data():
|
||||
return [
|
||||
EvalData(
|
||||
contexts=[
|
||||
"This is a test context 1.",
|
||||
],
|
||||
question="This is a test question 1.",
|
||||
answer="This is a test answer 1.",
|
||||
),
|
||||
EvalData(
|
||||
contexts=[
|
||||
"This is a test context 2-1.",
|
||||
"This is a test context 2-2.",
|
||||
],
|
||||
question="This is a test question 2.",
|
||||
answer="This is a test answer 2.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_groundedness_metric(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
metric = Groundedness()
|
||||
return metric
|
||||
|
||||
|
||||
def test_groundedness_init(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
metric = Groundedness()
|
||||
assert metric.name == EvalMetric.GROUNDEDNESS.value
|
||||
assert metric.config.model == "gpt-4"
|
||||
assert metric.config.api_key is None
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_groundedness_init_with_config():
|
||||
metric = Groundedness(config=GroundednessConfig(api_key="test_api_key"))
|
||||
assert metric.name == EvalMetric.GROUNDEDNESS.value
|
||||
assert metric.config.model == "gpt-4"
|
||||
assert metric.config.api_key == "test_api_key"
|
||||
|
||||
|
||||
def test_groundedness_init_without_api_key(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
Groundedness()
|
||||
|
||||
|
||||
def test_generate_answer_claim_prompt(mock_groundedness_metric, mock_data):
|
||||
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
|
||||
assert "This is a test question 1." in prompt
|
||||
assert "This is a test answer 1." in prompt
|
||||
|
||||
|
||||
def test_get_claim_statements(mock_groundedness_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_groundedness_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda *args, **kwargs: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"message": type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"content": """This is a test answer 1.
|
||||
This is a test answer 2.
|
||||
This is a test answer 3."""
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
|
||||
claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
|
||||
assert len(claim_statements) == 3
|
||||
assert "This is a test answer 1." in claim_statements
|
||||
|
||||
|
||||
def test_generate_claim_inference_prompt(mock_groundedness_metric, mock_data):
|
||||
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
|
||||
claim_statements = [
|
||||
"This is a test claim 1.",
|
||||
"This is a test claim 2.",
|
||||
]
|
||||
prompt = mock_groundedness_metric._generate_claim_inference_prompt(
|
||||
data=mock_data[0], claim_statements=claim_statements
|
||||
)
|
||||
assert "This is a test context 1." in prompt
|
||||
assert "This is a test claim 1." in prompt
|
||||
|
||||
|
||||
def test_get_claim_verdict_scores(mock_groundedness_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_groundedness_metric.client.chat.completions,
|
||||
"create",
|
||||
lambda *args, **kwargs: type(
|
||||
"obj",
|
||||
(object,),
|
||||
{"choices": [type("obj", (object,), {"message": type("obj", (object,), {"content": "1\n0\n-1"})})]},
|
||||
)(),
|
||||
)
|
||||
prompt = mock_groundedness_metric._generate_answer_claim_prompt(data=mock_data[0])
|
||||
claim_statements = mock_groundedness_metric._get_claim_statements(prompt=prompt)
|
||||
prompt = mock_groundedness_metric._generate_claim_inference_prompt(
|
||||
data=mock_data[0], claim_statements=claim_statements
|
||||
)
|
||||
claim_verdict_scores = mock_groundedness_metric._get_claim_verdict_scores(prompt=prompt)
|
||||
assert len(claim_verdict_scores) == 3
|
||||
assert claim_verdict_scores[0] == 1
|
||||
assert claim_verdict_scores[1] == 0
|
||||
|
||||
|
||||
def test_compute_score(mock_groundedness_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mock_groundedness_metric,
|
||||
"_get_claim_statements",
|
||||
lambda *args, **kwargs: np.array(
|
||||
[
|
||||
"This is a test claim 1.",
|
||||
"This is a test claim 2.",
|
||||
]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(mock_groundedness_metric, "_get_claim_verdict_scores", lambda *args, **kwargs: np.array([1, 0]))
|
||||
score = mock_groundedness_metric._compute_score(data=mock_data[0])
|
||||
assert score == 0.5
|
||||
|
||||
|
||||
def test_evaluate(mock_groundedness_metric, mock_data, monkeypatch):
|
||||
monkeypatch.setattr(mock_groundedness_metric, "_compute_score", lambda *args, **kwargs: 0.5)
|
||||
score = mock_groundedness_metric.evaluate(dataset=mock_data)
|
||||
assert score == 0.5
|
||||
79
embedchain/tests/helper_classes/test_json_serializable.py
Normal file
79
embedchain/tests/helper_classes/test_json_serializable.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import random
|
||||
import unittest
|
||||
from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
|
||||
|
||||
class TestJsonSerializable(unittest.TestCase):
|
||||
"""Test that the datatype detection is working, based on the input."""
|
||||
|
||||
def test_base_function(self):
|
||||
"""Test that the base premise of serialization and deserealization is working"""
|
||||
|
||||
@register_deserializable
|
||||
class TestClass(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.rng = random.random()
|
||||
|
||||
original_class = TestClass()
|
||||
serial = original_class.serialize()
|
||||
|
||||
# Negative test to show that a new class does not have the same random number.
|
||||
negative_test_class = TestClass()
|
||||
self.assertNotEqual(original_class.rng, negative_test_class.rng)
|
||||
|
||||
# Test to show that a deserialized class has the same random number.
|
||||
positive_test_class: TestClass = TestClass().deserialize(serial)
|
||||
self.assertEqual(original_class.rng, positive_test_class.rng)
|
||||
self.assertTrue(isinstance(positive_test_class, TestClass))
|
||||
|
||||
# Test that it works as a static method too.
|
||||
positive_test_class: TestClass = TestClass.deserialize(serial)
|
||||
self.assertEqual(original_class.rng, positive_test_class.rng)
|
||||
|
||||
# TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
|
||||
|
||||
def test_registration_required(self):
|
||||
"""Test that registration is required, and that without registration the default class is returned."""
|
||||
|
||||
class SecondTestClass(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.default = True
|
||||
|
||||
app = SecondTestClass()
|
||||
# Make not default
|
||||
app.default = False
|
||||
# Serialize
|
||||
serial = app.serialize()
|
||||
# Deserialize. Due to the way errors are handled, it will not fail but return a default class.
|
||||
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||
self.assertTrue(app.default)
|
||||
# If we register and try again with the same serial, it should work
|
||||
SecondTestClass._register_class_as_deserializable(SecondTestClass)
|
||||
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||
self.assertFalse(app.default)
|
||||
|
||||
def test_recursive(self):
|
||||
"""Test recursiveness with the real app"""
|
||||
random_id = str(random.random())
|
||||
config = AppConfig(id=random_id, collect_metrics=False)
|
||||
# config class is set under app.config.
|
||||
app = App(config=config)
|
||||
s = app.serialize()
|
||||
new_app: App = App.deserialize(s)
|
||||
# The id of the new app is the same as the first one.
|
||||
self.assertEqual(random_id, new_app.config.id)
|
||||
# We have proven that a nested class (app.config) can be serialized and deserialized just the same.
|
||||
# TODO: test deeper recursion
|
||||
|
||||
def test_special_subclasses(self):
|
||||
"""Test special subclasses that are not serializable by default."""
|
||||
# Template
|
||||
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
|
||||
s = config.serialize()
|
||||
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
|
||||
self.assertEqual(config.prompt.template, new_config.prompt.template)
|
||||
54
embedchain/tests/llm/test_anthrophic.py
Normal file
54
embedchain/tests/llm/test_anthrophic.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.anthropic import AnthropicLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def anthropic_llm():
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(temperature=0.5, model="claude-instant-1", token_usage=False)
|
||||
return AnthropicLlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(anthropic_llm):
|
||||
with patch.object(AnthropicLlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = anthropic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt, anthropic_llm.config)
|
||||
|
||||
|
||||
def test_get_messages(anthropic_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = anthropic_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(anthropic_llm):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model, token_usage=True
|
||||
)
|
||||
anthropic_llm.config = test_config
|
||||
with patch.object(
|
||||
AnthropicLlm, "_get_answer", return_value=("Test Response", {"input_tokens": 1, "output_tokens": 2})
|
||||
) as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response, token_info = anthropic_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 1.265e-05,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
mock_method.assert_called_once_with(prompt, anthropic_llm.config)
|
||||
56
embedchain/tests/llm/test_aws_bedrock.py
Normal file
56
embedchain/tests/llm/test_aws_bedrock.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.aws_bedrock import AWSBedrockLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(monkeypatch):
|
||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="amazon.titan-text-express-v1",
|
||||
model_kwargs={
|
||||
"temperature": 0.5,
|
||||
"topP": 1,
|
||||
"maxTokenCount": 1000,
|
||||
},
|
||||
)
|
||||
yield config
|
||||
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
|
||||
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = AWSBedrockLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = AWSBedrockLlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
|
||||
|
||||
llm = AWSBedrockLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_bedrock_chat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
87
embedchain/tests/llm/test_azure_openai.py
Normal file
87
embedchain/tests/llm/test_azure_openai.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.azure_openai import AzureOpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_openai_llm():
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
model="gpt-3.5-turbo",
|
||||
max_tokens=50,
|
||||
system_prompt="System Prompt",
|
||||
)
|
||||
return AzureOpenAILlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(azure_openai_llm):
|
||||
with patch.object(AzureOpenAILlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = azure_openai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt=prompt, config=azure_openai_llm.config)
|
||||
|
||||
|
||||
def test_get_answer(azure_openai_llm):
|
||||
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.invoke.return_value = MagicMock(content="Test Response")
|
||||
|
||||
prompt = "Test Prompt"
|
||||
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)
|
||||
|
||||
assert response == "Test Response"
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name=azure_openai_llm.config.deployment_name,
|
||||
openai_api_version="2024-02-01",
|
||||
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
|
||||
temperature=azure_openai_llm.config.temperature,
|
||||
max_tokens=azure_openai_llm.config.max_tokens,
|
||||
streaming=azure_openai_llm.config.stream,
|
||||
)
|
||||
|
||||
|
||||
def test_get_messages(azure_openai_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = azure_openai_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
def test_when_no_deployment_name_provided():
|
||||
config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
|
||||
with pytest.raises(ValueError):
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test Prompt")
|
||||
|
||||
|
||||
def test_with_api_version():
|
||||
config = BaseLlmConfig(
|
||||
deployment_name="azure_deployment",
|
||||
temperature=0.7,
|
||||
model="gpt-3.5-turbo",
|
||||
max_tokens=50,
|
||||
system_prompt="System Prompt",
|
||||
api_version="2024-02-01",
|
||||
)
|
||||
|
||||
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
|
||||
llm = AzureOpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test Prompt")
|
||||
|
||||
mock_chat.assert_called_once_with(
|
||||
deployment_name="azure_deployment",
|
||||
openai_api_version="2024-02-01",
|
||||
model_name="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
streaming=False,
|
||||
)
|
||||
61
embedchain/tests/llm/test_base_llm.py
Normal file
61
embedchain/tests/llm/test_base_llm.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from string import Template
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.base import BaseLlm, BaseLlmConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_llm():
|
||||
config = BaseLlmConfig()
|
||||
return BaseLlm(config=config)
|
||||
|
||||
|
||||
def test_is_get_llm_model_answer_not_implemented(base_llm):
|
||||
with pytest.raises(NotImplementedError):
|
||||
base_llm.get_llm_model_answer()
|
||||
|
||||
|
||||
def test_is_stream_bool():
|
||||
with pytest.raises(ValueError):
|
||||
config = BaseLlmConfig(stream="test value")
|
||||
BaseLlm(config=config)
|
||||
|
||||
|
||||
def test_template_string_gets_converted_to_Template_instance():
|
||||
config = BaseLlmConfig(template="test value $query $context")
|
||||
llm = BaseLlm(config=config)
|
||||
assert isinstance(llm.config.prompt, Template)
|
||||
|
||||
|
||||
def test_is_get_llm_model_answer_implemented():
|
||||
class TestLlm(BaseLlm):
|
||||
def get_llm_model_answer(self):
|
||||
return "Implemented"
|
||||
|
||||
config = BaseLlmConfig()
|
||||
llm = TestLlm(config=config)
|
||||
assert llm.get_llm_model_answer() == "Implemented"
|
||||
|
||||
|
||||
def test_stream_response(base_llm):
|
||||
answer = ["Chunk1", "Chunk2", "Chunk3"]
|
||||
result = list(base_llm._stream_response(answer))
|
||||
assert result == answer
|
||||
|
||||
|
||||
def test_append_search_and_context(base_llm):
|
||||
context = "Context"
|
||||
web_search_result = "Web Search Result"
|
||||
result = base_llm._append_search_and_context(context, web_search_result)
|
||||
expected_result = "Context\nWeb Search Result: Web Search Result"
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_access_search_and_get_results(base_llm, mocker):
|
||||
base_llm.access_search_and_get_results = mocker.patch.object(
|
||||
base_llm, "access_search_and_get_results", return_value="Search Results"
|
||||
)
|
||||
input_query = "Test query"
|
||||
result = base_llm.access_search_and_get_results(input_query)
|
||||
assert result == "Search Results"
|
||||
120
embedchain/tests/llm/test_chat.py
Normal file
120
embedchain/tests/llm/test_chat.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
|
||||
memory.
|
||||
The 'chat' method is called twice. The first call initializes the chat history memory.
|
||||
The second call is expected to use the chat history from the first call.
|
||||
|
||||
Key assumptions tested:
|
||||
called with correct arguments, adding the correct chat history.
|
||||
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
|
||||
- During the second call, the 'chat' method uses the chat history from the first call.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and
|
||||
'memory' methods.
|
||||
"""
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
with patch.object(BaseLlm, "add_history") as mock_history:
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer", session_id="default")
|
||||
|
||||
second_answer = app.chat("Test query 2", session_id="test_session")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer", session_id="test_session")
|
||||
|
||||
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_template_replacement(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
Tests that if a default template is used and it doesn't contain history,
|
||||
the default template is swapped in.
|
||||
|
||||
Also tests that a dry run does not change the history
|
||||
"""
|
||||
with patch.object(ChatHistory, "get") as mock_memory:
|
||||
mock_message = ChatMessage()
|
||||
mock_message.add_user_message("Test query 1")
|
||||
mock_message.add_ai_message("Test answer")
|
||||
mock_memory.return_value = [mock_message]
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
app = App(config=config)
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
history = app.llm.history
|
||||
dry_run = app.chat("Test query 2", dry_run=True)
|
||||
self.assertIn("Conversation history:", dry_run)
|
||||
self.assertEqual(history, app.llm.history)
|
||||
self.assertEqual(len(app.llm.history), 1)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
"""
|
||||
Test where filter
|
||||
"""
|
||||
with patch.object(self.app, "_retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.chat("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_retrieve.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_chat_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the '_retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified
|
||||
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(self.app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
llm_config = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test query", llm_config)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
23
embedchain/tests/llm/test_clarifai.py
Normal file
23
embedchain/tests/llm/test_clarifai.py
Normal file
@@ -0,0 +1,23 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.clarifai import ClarifaiLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clarifai_llm_config(monkeypatch):
|
||||
monkeypatch.setenv("CLARIFAI_PAT","test_api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="https://clarifai.com/openai/chat-completion/models/GPT-4",
|
||||
model_kwargs={"temperature": 0.7, "max_tokens": 100},
|
||||
)
|
||||
yield config
|
||||
monkeypatch.delenv("CLARIFAI_PAT")
|
||||
|
||||
def test_clarifai__llm_get_llm_model_answer(clarifai_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.clarifai.ClarifaiLlm._get_answer", return_value="Test answer")
|
||||
llm = ClarifaiLlm(clarifai_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
73
embedchain/tests/llm/test_cohere.py
Normal file
73
embedchain/tests/llm/test_cohere.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.cohere import CohereLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_llm_config():
|
||||
os.environ["COHERE_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(model="command-r", max_tokens=100, temperature=0.7, top_p=0.8, token_usage=False)
|
||||
yield config
|
||||
os.environ.pop("COHERE_API_KEY")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
CohereLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(cohere_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(cohere_llm_config, mocker):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=cohere_llm_config.temperature,
|
||||
max_tokens=cohere_llm_config.max_tokens,
|
||||
top_p=cohere_llm_config.top_p,
|
||||
model=cohere_llm_config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
mocker.patch(
|
||||
"embedchain.llm.cohere.CohereLlm._get_answer",
|
||||
return_value=("Test answer", {"input_tokens": 1, "output_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = CohereLlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 3.5e-06,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
|
||||
mocked_cohere = mocker.patch("embedchain.llm.cohere.ChatCohere")
|
||||
mocked_cohere.return_value.invoke.return_value.content = "Mocked answer"
|
||||
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
70
embedchain/tests/llm/test_generate_prompt.py
Normal file
70
embedchain/tests/llm/test_generate_prompt.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import unittest
|
||||
from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
|
||||
|
||||
class TestGeneratePrompt(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
def test_generate_prompt_with_template(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly formats the prompt using
|
||||
a custom template provided in the BaseLlmConfig instance.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and a custom template, and then calls generate_prompt. It checks that the
|
||||
returned prompt correctly incorporates all the contexts and the query into
|
||||
the format specified by the template.
|
||||
"""
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
|
||||
config = BaseLlmConfig(template=Template(template))
|
||||
self.app.llm.config = config
|
||||
|
||||
# Execute
|
||||
result = self.app.llm.generate_prompt(input_query, contexts)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_prompt_with_contexts_list(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly handles a list of contexts.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and then calls generate_prompt. It checks that the returned prompt
|
||||
correctly includes all the contexts and the query.
|
||||
"""
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
config = BaseLlmConfig()
|
||||
|
||||
# Execute
|
||||
self.app.llm.config = config
|
||||
result = self.app.llm.generate_prompt(input_query, contexts)
|
||||
|
||||
# Assert
|
||||
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_prompt_with_history(self):
|
||||
"""
|
||||
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
|
||||
"""
|
||||
config = BaseLlmConfig()
|
||||
config.prompt = Template("Context: $context | Query: $query | History: $history")
|
||||
self.app.llm.config = config
|
||||
self.app.llm.set_history(["Past context 1", "Past context 2"])
|
||||
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
|
||||
|
||||
expected_prompt = "Context: Test context | Query: Test query | History: Past context 1\nPast context 2"
|
||||
self.assertEqual(prompt, expected_prompt)
|
||||
43
embedchain/tests/llm/test_google.py
Normal file
43
embedchain/tests/llm/test_google.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.google import GoogleLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_llm_config():
|
||||
return BaseLlmConfig(model="gemini-pro", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||
|
||||
|
||||
def test_google_llm_init_missing_api_key(monkeypatch):
|
||||
monkeypatch.delenv("GOOGLE_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Please set the GOOGLE_API_KEY environment variable."):
|
||||
GoogleLlm()
|
||||
|
||||
|
||||
def test_google_llm_init(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr("importlib.import_module", lambda x: None)
|
||||
google_llm = GoogleLlm()
|
||||
assert google_llm is not None
|
||||
|
||||
|
||||
def test_google_llm_get_llm_model_answer_with_system_prompt(monkeypatch):
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
monkeypatch.setattr("importlib.import_module", lambda x: None)
|
||||
google_llm = GoogleLlm(config=BaseLlmConfig(system_prompt="system prompt"))
|
||||
with pytest.raises(ValueError, match="GoogleLlm does not support `system_prompt`"):
|
||||
google_llm.get_llm_model_answer("test prompt")
|
||||
|
||||
|
||||
def test_google_llm_get_llm_model_answer(monkeypatch, google_llm_config):
|
||||
def mock_get_answer(prompt, config):
|
||||
return "Generated Text"
|
||||
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "fake_api_key")
|
||||
monkeypatch.setattr(GoogleLlm, "_get_answer", mock_get_answer)
|
||||
google_llm = GoogleLlm(config=google_llm_config)
|
||||
result = google_llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
60
embedchain/tests/llm/test_gpt4all.py
Normal file
60
embedchain/tests/llm/test_gpt4all.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.gpt4all import GPT4ALLLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="orca-mini-3b-gguf2-q4_0.gguf",
|
||||
)
|
||||
yield config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt4all_with_config(config):
|
||||
return GPT4ALLLlm(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt4all_without_config():
|
||||
return GPT4ALLLlm()
|
||||
|
||||
|
||||
def test_gpt4all_init_with_config(config, gpt4all_with_config):
|
||||
assert gpt4all_with_config.config.temperature == config.temperature
|
||||
assert gpt4all_with_config.config.max_tokens == config.max_tokens
|
||||
assert gpt4all_with_config.config.top_p == config.top_p
|
||||
assert gpt4all_with_config.config.stream == config.stream
|
||||
assert gpt4all_with_config.config.system_prompt == config.system_prompt
|
||||
assert gpt4all_with_config.config.model == config.model
|
||||
|
||||
assert isinstance(gpt4all_with_config.instance, LangchainGPT4All)
|
||||
|
||||
|
||||
def test_gpt4all_init_without_config(gpt4all_without_config):
|
||||
assert gpt4all_without_config.config.model == "orca-mini-3b-gguf2-q4_0.gguf"
|
||||
assert isinstance(gpt4all_without_config.instance, LangchainGPT4All)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(mocker, gpt4all_with_config):
|
||||
test_query = "Test query"
|
||||
test_answer = "Test answer"
|
||||
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.gpt4all.GPT4ALLLlm._get_answer", return_value=test_answer)
|
||||
answer = gpt4all_with_config.get_llm_model_answer(test_query)
|
||||
|
||||
assert answer == test_answer
|
||||
mocked_get_answer.assert_called_once_with(prompt=test_query, config=gpt4all_with_config.config)
|
||||
|
||||
|
||||
def test_gpt4all_model_switching(gpt4all_with_config):
|
||||
with pytest.raises(RuntimeError, match="GPT4ALLLlm does not support switching models at runtime."):
|
||||
gpt4all_with_config._get_answer("Test prompt", BaseLlmConfig(model="new_model"))
|
||||
83
embedchain/tests/llm/test_huggingface.py
Normal file
83
embedchain/tests/llm/test_huggingface.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.huggingface import HuggingFaceLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huggingface_llm_config():
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
yield config
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huggingface_endpoint_config():
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
config = BaseLlmConfig(endpoint="https://api-inference.huggingface.co/models/gpt2", model_kwargs={"device": "cpu"})
|
||||
yield config
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_top_p_value_within_range():
|
||||
config = BaseLlmConfig(top_p=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceLlm._get_answer("test_prompt", config)
|
||||
|
||||
|
||||
def test_dependency_is_imported():
|
||||
importlib_installed = True
|
||||
try:
|
||||
importlib.import_module("huggingface_hub")
|
||||
except ImportError:
|
||||
importlib_installed = False
|
||||
assert importlib_installed
|
||||
|
||||
|
||||
def test_get_llm_model_answer(huggingface_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_hugging_face_mock(huggingface_llm_config, mocker):
|
||||
mock_llm_instance = mocker.Mock(return_value="Test answer")
|
||||
mock_hf_hub = mocker.patch("embedchain.llm.huggingface.HuggingFaceHub")
|
||||
mock_hf_hub.return_value.invoke = mock_llm_instance
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
assert answer == "Test answer"
|
||||
mock_llm_instance.assert_called_once_with("Test query")
|
||||
|
||||
|
||||
def test_custom_endpoint(huggingface_endpoint_config, mocker):
|
||||
mock_llm_instance = mocker.Mock(return_value="Test answer")
|
||||
mock_hf_endpoint = mocker.patch("embedchain.llm.huggingface.HuggingFaceEndpoint")
|
||||
mock_hf_endpoint.return_value.invoke = mock_llm_instance
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_endpoint_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mock_llm_instance.assert_called_once_with("Test query")
|
||||
79
embedchain/tests/llm/test_jina.py
Normal file
79
embedchain/tests/llm/test_jina.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.jina import JinaLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
os.environ["JINACHAT_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
|
||||
yield config
|
||||
os.environ.pop("JINACHAT_API_KEY")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
JinaLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(config, mocker):
|
||||
config.system_prompt = "Custom system prompt"
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once_with(
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
jinachat_api_key=os.environ["JINACHAT_API_KEY"],
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
)
|
||||
40
embedchain/tests/llm/test_llama2.py
Normal file
40
embedchain/tests/llm/test_llama2.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.llama2 import Llama2Llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama2_llm():
|
||||
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
|
||||
llm = Llama2Llm()
|
||||
return llm
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
Llama2Llm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
|
||||
llama2_llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llama2_llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(llama2_llm, mocker):
|
||||
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
|
||||
mocked_replicate_instance = mocker.MagicMock()
|
||||
mocked_replicate.return_value = mocked_replicate_instance
|
||||
mocked_replicate_instance.invoke.return_value = "Test answer"
|
||||
|
||||
llama2_llm.config.model = "test_model"
|
||||
llama2_llm.config.max_tokens = 50
|
||||
llama2_llm.config.temperature = 0.7
|
||||
llama2_llm.config.top_p = 0.8
|
||||
|
||||
answer = llama2_llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
87
embedchain/tests/llm/test_mistralai.py
Normal file
87
embedchain/tests/llm/test_mistralai.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.mistralai import MistralAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistralai_llm_config(monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
|
||||
yield BaseLlmConfig(model="mistral-tiny", max_tokens=100, temperature=0.7, top_p=0.5, stream=False)
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
|
||||
|
||||
def test_mistralai_llm_init_missing_api_key(monkeypatch):
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="Please set the MISTRAL_API_KEY environment variable."):
|
||||
MistralAILlm()
|
||||
|
||||
|
||||
def test_mistralai_llm_init(monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "fake_api_key")
|
||||
llm = MistralAILlm()
|
||||
assert llm is not None
|
||||
|
||||
|
||||
def test_get_llm_model_answer(monkeypatch, mistralai_llm_config):
|
||||
def mock_get_answer(self, prompt, config):
|
||||
return "Generated Text"
|
||||
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", mock_get_answer)
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(monkeypatch, mistralai_llm_config):
|
||||
mistralai_llm_config.system_prompt = "Test system prompt"
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(monkeypatch, mistralai_llm_config):
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(monkeypatch, mistralai_llm_config):
|
||||
mistralai_llm_config.system_prompt = None
|
||||
monkeypatch.setattr(MistralAILlm, "_get_answer", lambda self, prompt, config: "Generated Text")
|
||||
llm = MistralAILlm(config=mistralai_llm_config)
|
||||
result = llm.get_llm_model_answer("test prompt")
|
||||
|
||||
assert result == "Generated Text"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(monkeypatch, mistralai_llm_config):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=mistralai_llm_config.temperature,
|
||||
max_tokens=mistralai_llm_config.max_tokens,
|
||||
top_p=mistralai_llm_config.top_p,
|
||||
model=mistralai_llm_config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
MistralAILlm,
|
||||
"_get_answer",
|
||||
lambda self, prompt, config: ("Generated Text", {"prompt_tokens": 1, "completion_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = MistralAILlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Generated Text"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 7.5e-07,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
52
embedchain/tests/llm/test_ollama.py
Normal file
52
embedchain/tests/llm/test_ollama.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.ollama import OllamaLlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_llm_config():
|
||||
config = BaseLlmConfig(model="llama2", temperature=0.7, top_p=0.8, stream=True, system_prompt=None)
|
||||
yield config
|
||||
|
||||
|
||||
def test_get_llm_model_answer(ollama_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
|
||||
mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OllamaLlm(ollama_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_answer_mocked_ollama(ollama_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
|
||||
mocked_ollama = mocker.patch("embedchain.llm.ollama.Ollama")
|
||||
mock_instance = mocked_ollama.return_value
|
||||
mock_instance.invoke.return_value = "Mocked answer"
|
||||
|
||||
llm = OllamaLlm(ollama_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(ollama_llm_config, mocker):
|
||||
ollama_llm_config.stream = True
|
||||
ollama_llm_config.callbacks = [StreamingStdOutCallbackHandler()]
|
||||
mocker.patch("embedchain.llm.ollama.Client.list", return_value={"models": [{"name": "llama2"}]})
|
||||
mocked_ollama_chat = mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OllamaLlm(ollama_llm_config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_ollama_chat.assert_called_once()
|
||||
call_args = mocked_ollama_chat.call_args
|
||||
config_arg = call_args[1]["config"]
|
||||
callbacks = config_arg.callbacks
|
||||
|
||||
assert len(callbacks) == 1
|
||||
assert isinstance(callbacks[0], StreamingStdOutCallbackHandler)
|
||||
261
embedchain/tests/llm/test_openai.py
Normal file
261
embedchain/tests/llm/test_openai.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def env_config():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1/engines/"
|
||||
yield
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(env_config):
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies=None,
|
||||
http_async_client_proxies=None,
|
||||
)
|
||||
yield config
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(config, mocker):
|
||||
config.system_prompt = "Custom system prompt"
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(config, mocker):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
stream=config.stream,
|
||||
system_prompt=config.system_prompt,
|
||||
model=config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
mocked_get_answer = mocker.patch(
|
||||
"embedchain.llm.openai.OpenAILlm._get_answer",
|
||||
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = OpenAILlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 5.5e-06,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
mocked_get_answer.assert_called_once_with("Test query", test_config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_special_headers(config, mocker):
|
||||
config.default_headers = {"test": "test"}
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
default_headers={"test": "test"},
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_model_kwargs(config, mocker):
|
||||
config.model_kwargs = {"response_format": {"type": "json_object"}}
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mock_return, expected",
|
||||
[
|
||||
([{"test": "test"}], '{"test": "test"}'),
|
||||
([], "Input could not be mapped to the function!"),
|
||||
],
|
||||
)
|
||||
def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
|
||||
mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
|
||||
mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
|
||||
|
||||
llm = OpenAILlm(config, tools={"test": "test"})
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=None,
|
||||
)
|
||||
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
|
||||
mocked_json_output_tools_parser.assert_called_once()
|
||||
|
||||
assert answer == expected
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_client_proxies(env_config, mocker):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mock_http_client = mocker.Mock(spec=httpx.Client)
|
||||
mock_http_client_instance = mocker.Mock(spec=httpx.Client)
|
||||
mock_http_client.return_value = mock_http_client_instance
|
||||
|
||||
mocker.patch("httpx.Client", new=mock_http_client)
|
||||
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_client_proxies="http://testproxy.mem0.net:8000",
|
||||
)
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=mock_http_client_instance,
|
||||
http_async_client=None,
|
||||
)
|
||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_http_async_client_proxies(env_config, mocker):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mock_http_async_client = mocker.Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client_instance = mocker.Mock(spec=httpx.AsyncClient)
|
||||
mock_http_async_client.return_value = mock_http_async_client_instance
|
||||
|
||||
mocker.patch("httpx.AsyncClient", new=mock_http_async_client)
|
||||
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
top_p=0.8,
|
||||
stream=False,
|
||||
system_prompt="System prompt",
|
||||
model="gpt-3.5-turbo",
|
||||
http_async_client_proxies={"http://": "http://testproxy.mem0.net:8000"},
|
||||
)
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
base_url=os.environ["OPENAI_API_BASE"],
|
||||
http_client=None,
|
||||
http_async_client=mock_http_async_client_instance,
|
||||
)
|
||||
mock_http_async_client.assert_called_once_with(proxies={"http://": "http://testproxy.mem0.net:8000"})
|
||||
79
embedchain/tests/llm/test_query.py
Normal file
79
embedchain/tests/llm/test_query.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
return app
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query(app):
|
||||
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = app.query(input_query="Test query")
|
||||
assert answer == "Test answer"
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
_, kwargs = mock_retrieve.call_args
|
||||
input_query_arg = kwargs.get("input_query")
|
||||
assert input_query_arg == "Test query"
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_query_config_app_passing(mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
llm = OpenAILlm(config=chat_config)
|
||||
app = App(config=config, llm=llm)
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(app):
|
||||
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_retrieve.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
assert kwargs.get("where") == {"attribute": "value"}
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(app):
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
llm_config = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = app.query("Test query", llm_config)
|
||||
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_database_query.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
74
embedchain/tests/llm/test_together.py
Normal file
74
embedchain/tests/llm/test_together.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.together import TogetherLlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def together_llm_config():
|
||||
os.environ["TOGETHER_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(model="together-ai-up-to-3b", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
yield config
|
||||
os.environ.pop("TOGETHER_API_KEY")
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
TogetherLlm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(together_llm_config):
|
||||
llm = TogetherLlm(together_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(together_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.together.TogetherLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = TogetherLlm(together_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(together_llm_config, mocker):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=together_llm_config.temperature,
|
||||
max_tokens=together_llm_config.max_tokens,
|
||||
top_p=together_llm_config.top_p,
|
||||
model=together_llm_config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
mocker.patch(
|
||||
"embedchain.llm.together.TogetherLlm._get_answer",
|
||||
return_value=("Test answer", {"prompt_tokens": 1, "completion_tokens": 2}),
|
||||
)
|
||||
|
||||
llm = TogetherLlm(test_config)
|
||||
answer, token_info = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 3e-07,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
def test_get_answer_mocked_together(together_llm_config, mocker):
|
||||
mocked_together = mocker.patch("embedchain.llm.together.ChatTogether")
|
||||
mock_instance = mocked_together.return_value
|
||||
mock_instance.invoke.return_value.content = "Mocked answer"
|
||||
|
||||
llm = TogetherLlm(together_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
76
embedchain/tests/llm/test_vertex_ai.py
Normal file
76
embedchain/tests/llm/test_vertex_ai.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.core.db.database import database_manager
|
||||
from embedchain.llm.vertex_ai import VertexAILlm
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_database():
|
||||
database_manager.setup_engine()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vertexai_llm():
|
||||
config = BaseLlmConfig(temperature=0.6, model="chat-bison")
|
||||
return VertexAILlm(config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer(vertexai_llm):
|
||||
with patch.object(VertexAILlm, "_get_answer", return_value="Test Response") as mock_method:
|
||||
prompt = "Test Prompt"
|
||||
response = vertexai_llm.get_llm_model_answer(prompt)
|
||||
assert response == "Test Response"
|
||||
mock_method.assert_called_once_with(prompt, vertexai_llm.config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_token_usage(vertexai_llm):
|
||||
test_config = BaseLlmConfig(
|
||||
temperature=vertexai_llm.config.temperature,
|
||||
max_tokens=vertexai_llm.config.max_tokens,
|
||||
top_p=vertexai_llm.config.top_p,
|
||||
model=vertexai_llm.config.model,
|
||||
token_usage=True,
|
||||
)
|
||||
vertexai_llm.config = test_config
|
||||
with patch.object(
|
||||
VertexAILlm,
|
||||
"_get_answer",
|
||||
return_value=("Test Response", {"prompt_token_count": 1, "candidates_token_count": 2}),
|
||||
):
|
||||
response, token_info = vertexai_llm.get_llm_model_answer("Test Query")
|
||||
assert response == "Test Response"
|
||||
assert token_info == {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"total_cost": 3.75e-07,
|
||||
"cost_currency": "USD",
|
||||
}
|
||||
|
||||
|
||||
@patch("embedchain.llm.vertex_ai.ChatVertexAI")
|
||||
def test_get_answer(mock_chat_vertexai, vertexai_llm, caplog):
|
||||
mock_chat_vertexai.return_value.invoke.return_value = MagicMock(content="Test Response")
|
||||
|
||||
config = vertexai_llm.config
|
||||
prompt = "Test Prompt"
|
||||
messages = vertexai_llm._get_messages(prompt)
|
||||
response = vertexai_llm._get_answer(prompt, config)
|
||||
mock_chat_vertexai.return_value.invoke.assert_called_once_with(messages)
|
||||
|
||||
assert response == "Test Response" # Assertion corrected
|
||||
assert "Config option `top_p` is not supported by this model." not in caplog.text
|
||||
|
||||
|
||||
def test_get_messages(vertexai_llm):
|
||||
prompt = "Test Prompt"
|
||||
system_prompt = "Test System Prompt"
|
||||
messages = vertexai_llm._get_messages(prompt, system_prompt)
|
||||
assert messages == [
|
||||
SystemMessage(content="Test System Prompt", additional_kwargs={}),
|
||||
HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
|
||||
]
|
||||
100
embedchain/tests/loaders/test_audio.py
Normal file
100
embedchain/tests/loaders/test_audio.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
|
||||
from deepgram import PrerecordedOptions
|
||||
|
||||
from embedchain.loaders.audio import AudioLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_audio_loader(mocker):
|
||||
mock_dropbox = mocker.patch("deepgram.DeepgramClient")
|
||||
mock_dbx = mocker.MagicMock()
|
||||
mock_dropbox.return_value = mock_dbx
|
||||
|
||||
os.environ["DEEPGRAM_API_KEY"] = "test_key"
|
||||
loader = AudioLoader()
|
||||
loader.client = mock_dbx
|
||||
|
||||
yield loader, mock_dbx
|
||||
|
||||
if "DEEPGRAM_API_KEY" in os.environ:
|
||||
del os.environ["DEEPGRAM_API_KEY"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||
) # as `match` statement was introduced in python 3.10
|
||||
def test_initialization(setup_audio_loader):
|
||||
"""Test initialization of AudioLoader."""
|
||||
loader, _ = setup_audio_loader
|
||||
assert loader is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||
) # as `match` statement was introduced in python 3.10
|
||||
def test_load_data_from_url(setup_audio_loader):
|
||||
loader, mock_dbx = setup_audio_loader
|
||||
url = "https://example.com/audio.mp3"
|
||||
expected_content = "This is a test audio transcript."
|
||||
|
||||
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.return_value = mock_response
|
||||
|
||||
result = loader.load_data(url)
|
||||
|
||||
doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
|
||||
expected_result = {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": expected_content,
|
||||
"meta_data": {"url": url},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_url.assert_called_once_with(
|
||||
{"url": url}, PrerecordedOptions(model="nova-2", smart_format=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="Test skipped for Python 3.9 or lower"
|
||||
) # as `match` statement was introduced in python 3.10
|
||||
def test_load_data_from_file(setup_audio_loader):
|
||||
loader, mock_dbx = setup_audio_loader
|
||||
file_path = "local_audio.mp3"
|
||||
expected_content = "This is a test audio transcript."
|
||||
|
||||
mock_response = {"results": {"channels": [{"alternatives": [{"transcript": expected_content}]}]}}
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.return_value = mock_response
|
||||
|
||||
# Mock the file reading functionality
|
||||
with patch("builtins.open", mock_open(read_data=b"some data")) as mock_file:
|
||||
result = loader.load_data(file_path)
|
||||
|
||||
doc_id = hashlib.sha256((expected_content + file_path).encode()).hexdigest()
|
||||
expected_result = {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": expected_content,
|
||||
"meta_data": {"url": file_path},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
mock_dbx.listen.prerecorded.v.assert_called_once_with("1")
|
||||
mock_dbx.listen.prerecorded.v.return_value.transcribe_file.assert_called_once_with(
|
||||
{"buffer": mock_file.return_value}, PrerecordedOptions(model="nova-2", smart_format=True)
|
||||
)
|
||||
113
embedchain/tests/loaders/test_csv.py
Normal file
113
embedchain/tests/loaders/test_csv.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import csv
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.csv import CsvLoader
|
||||
|
||||
|
||||
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
|
||||
def test_load_data(delimiter):
|
||||
"""
|
||||
Test csv loader
|
||||
|
||||
Tests that file is loaded, metadata is correct and content is correct
|
||||
"""
|
||||
# Creating temporary CSV file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
|
||||
writer = csv.writer(tmpfile, delimiter=delimiter)
|
||||
writer.writerow(["Name", "Age", "Occupation"])
|
||||
writer.writerow(["Alice", "28", "Engineer"])
|
||||
writer.writerow(["Bob", "35", "Doctor"])
|
||||
writer.writerow(["Charlie", "22", "Student"])
|
||||
|
||||
tmpfile.seek(0)
|
||||
filename = tmpfile.name
|
||||
|
||||
# Loading CSV using CsvLoader
|
||||
loader = CsvLoader()
|
||||
result = loader.load_data(filename)
|
||||
data = result["data"]
|
||||
|
||||
# Assertions
|
||||
assert len(data) == 3
|
||||
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||
assert data[0]["meta_data"]["url"] == filename
|
||||
assert data[0]["meta_data"]["row"] == 1
|
||||
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||
assert data[1]["meta_data"]["url"] == filename
|
||||
assert data[1]["meta_data"]["row"] == 2
|
||||
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||
assert data[2]["meta_data"]["url"] == filename
|
||||
assert data[2]["meta_data"]["row"] == 3
|
||||
|
||||
# Cleaning up the temporary file
|
||||
os.unlink(filename)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("delimiter", [",", "\t", ";", "|"])
|
||||
def test_load_data_with_file_uri(delimiter):
|
||||
"""
|
||||
Test csv loader with file URI
|
||||
|
||||
Tests that file is loaded, metadata is correct and content is correct
|
||||
"""
|
||||
# Creating temporary CSV file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", newline="", delete=False) as tmpfile:
|
||||
writer = csv.writer(tmpfile, delimiter=delimiter)
|
||||
writer.writerow(["Name", "Age", "Occupation"])
|
||||
writer.writerow(["Alice", "28", "Engineer"])
|
||||
writer.writerow(["Bob", "35", "Doctor"])
|
||||
writer.writerow(["Charlie", "22", "Student"])
|
||||
|
||||
tmpfile.seek(0)
|
||||
filename = pathlib.Path(tmpfile.name).as_uri() # Convert path to file URI
|
||||
|
||||
# Loading CSV using CsvLoader
|
||||
loader = CsvLoader()
|
||||
result = loader.load_data(filename)
|
||||
data = result["data"]
|
||||
|
||||
# Assertions
|
||||
assert len(data) == 3
|
||||
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||
assert data[0]["meta_data"]["url"] == filename
|
||||
assert data[0]["meta_data"]["row"] == 1
|
||||
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||
assert data[1]["meta_data"]["url"] == filename
|
||||
assert data[1]["meta_data"]["row"] == 2
|
||||
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||
assert data[2]["meta_data"]["url"] == filename
|
||||
assert data[2]["meta_data"]["row"] == 3
|
||||
|
||||
# Cleaning up the temporary file
|
||||
os.unlink(tmpfile.name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"])
|
||||
def test_get_file_content(content):
|
||||
with pytest.raises(ValueError):
|
||||
loader = CsvLoader()
|
||||
loader._get_file_content(content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"])
|
||||
def test_get_file_content_http(content):
|
||||
"""
|
||||
Test _get_file_content method of CsvLoader for http and https URLs
|
||||
"""
|
||||
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
loader = CsvLoader()
|
||||
file_content = loader._get_file_content(content)
|
||||
|
||||
mock_get.assert_called_once_with(content)
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
assert file_content.read() == mock_response.text
|
||||
104
embedchain/tests/loaders/test_discourse.py
Normal file
104
embedchain/tests/loaders/test_discourse.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.discourse import DiscourseLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discourse_loader_config():
|
||||
return {
|
||||
"domain": "https://example.com/",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discourse_loader(discourse_loader_config):
|
||||
return DiscourseLoader(config=discourse_loader_config)
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_valid_config():
|
||||
config = {"domain": "https://example.com/"}
|
||||
loader = DiscourseLoader(config=config)
|
||||
assert loader.domain == "https://example.com/"
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_missing_config():
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a config"):
|
||||
DiscourseLoader()
|
||||
|
||||
|
||||
def test_discourse_loader_init_with_missing_domain():
|
||||
config = {"another_key": "value"}
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a domain"):
|
||||
DiscourseLoader(config=config)
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_valid_query(discourse_loader):
|
||||
discourse_loader._check_query("sample query")
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_empty_query(discourse_loader):
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
|
||||
discourse_loader._check_query("")
|
||||
|
||||
|
||||
def test_discourse_loader_check_query_with_invalid_query_type(discourse_loader):
|
||||
with pytest.raises(ValueError, match="DiscourseLoader requires a query"):
|
||||
discourse_loader._check_query(123)
|
||||
|
||||
|
||||
def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def json(self):
|
||||
return {"raw": "Sample post content"}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
post_data = discourse_loader._load_post(123)
|
||||
|
||||
assert post_data["content"] == "Sample post content"
|
||||
assert "meta_data" in post_data
|
||||
|
||||
|
||||
def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
|
||||
def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
def json(self):
|
||||
return {"grouped_search_result": {"post_ids": [123, 456, 789]}}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr(requests, "get", mock_get)
|
||||
|
||||
def mock_load_post(*args, **kwargs):
|
||||
return {
|
||||
"content": "Sample post content",
|
||||
"meta_data": {
|
||||
"url": "https://example.com/posts/123.json",
|
||||
"created_at": "2021-01-01",
|
||||
"username": "test_user",
|
||||
"topic_slug": "test_topic",
|
||||
"score": 10,
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(discourse_loader, "_load_post", mock_load_post)
|
||||
|
||||
data = discourse_loader.load_data("sample query")
|
||||
|
||||
assert len(data["data"]) == 3
|
||||
assert data["data"][0]["content"] == "Sample post content"
|
||||
assert data["data"][0]["meta_data"]["url"] == "https://example.com/posts/123.json"
|
||||
assert data["data"][0]["meta_data"]["created_at"] == "2021-01-01"
|
||||
assert data["data"][0]["meta_data"]["username"] == "test_user"
|
||||
assert data["data"][0]["meta_data"]["topic_slug"] == "test_topic"
|
||||
assert data["data"][0]["meta_data"]["score"] == 10
|
||||
130
embedchain/tests/loaders/test_docs_site.py
Normal file
130
embedchain/tests/loaders/test_docs_site.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import hashlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests_get():
|
||||
with patch("requests.get") as mock_get:
|
||||
yield mock_get
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docs_site_loader():
|
||||
return DocsSiteLoader()
|
||||
|
||||
|
||||
def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = """
|
||||
<html>
|
||||
<a href="/page1">Page 1</a>
|
||||
<a href="/page2">Page 2</a>
|
||||
</html>
|
||||
"""
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
docs_site_loader._get_child_links_recursive("https://example.com")
|
||||
|
||||
assert len(docs_site_loader.visited_links) == 2
|
||||
assert "https://example.com/page1" in docs_site_loader.visited_links
|
||||
assert "https://example.com/page2" in docs_site_loader.visited_links
|
||||
|
||||
|
||||
def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
docs_site_loader._get_child_links_recursive("https://example.com")
|
||||
|
||||
assert len(docs_site_loader.visited_links) == 0
|
||||
|
||||
|
||||
def test_get_all_urls(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = """
|
||||
<html>
|
||||
<a href="/page1">Page 1</a>
|
||||
<a href="/page2">Page 2</a>
|
||||
<a href="https://example.com/external">External</a>
|
||||
</html>
|
||||
"""
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
all_urls = docs_site_loader._get_all_urls("https://example.com")
|
||||
|
||||
assert len(all_urls) == 3
|
||||
assert "https://example.com/page1" in all_urls
|
||||
assert "https://example.com/page2" in all_urls
|
||||
assert "https://example.com/external" in all_urls
|
||||
|
||||
|
||||
def test_load_data_from_url(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = """
|
||||
<html>
|
||||
<nav>
|
||||
<h1>Navigation</h1>
|
||||
</nav>
|
||||
<article class="bd-article">
|
||||
<p>Article Content</p>
|
||||
</article>
|
||||
</html>
|
||||
""".encode()
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
data = docs_site_loader._load_data_from_url("https://example.com/page1")
|
||||
|
||||
assert len(data) == 1
|
||||
assert data[0]["content"] == "Article Content"
|
||||
assert data[0]["meta_data"]["url"] == "https://example.com/page1"
|
||||
|
||||
|
||||
def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
data = docs_site_loader._load_data_from_url("https://example.com/page1")
|
||||
|
||||
assert data == []
|
||||
assert len(data) == 0
|
||||
|
||||
|
||||
def test_load_data(mock_requests_get, docs_site_loader):
|
||||
mock_response = Response()
|
||||
mock_response.status_code = 200
|
||||
mock_response._content = """
|
||||
<html>
|
||||
<a href="/page1">Page 1</a>
|
||||
<a href="/page2">Page 2</a>
|
||||
""".encode()
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
url = "https://example.com"
|
||||
data = docs_site_loader.load_data(url)
|
||||
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
|
||||
|
||||
assert len(data["data"]) == 2
|
||||
assert data["doc_id"] == expected_doc_id
|
||||
|
||||
|
||||
def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
|
||||
mock_response = Response()
|
||||
mock_response.status_code = 404
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
url = "https://example.com"
|
||||
data = docs_site_loader.load_data(url)
|
||||
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
|
||||
|
||||
assert len(data["data"]) == 0
|
||||
assert data["doc_id"] == expected_doc_id
|
||||
218
embedchain/tests/loaders/test_docs_site_loader.py
Normal file
218
embedchain/tests/loaders/test_docs_site_loader.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import pytest
|
||||
import responses
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ignored_tag",
|
||||
[
|
||||
"<nav>This is a navigation bar.</nav>",
|
||||
"<aside>This is an aside.</aside>",
|
||||
"<form>This is a form.</form>",
|
||||
"<header>This is a header.</header>",
|
||||
"<noscript>This is a noscript.</noscript>",
|
||||
"<svg>This is an SVG.</svg>",
|
||||
"<canvas>This is a canvas.</canvas>",
|
||||
"<footer>This is a footer.</footer>",
|
||||
"<script>This is a script.</script>",
|
||||
"<style>This is a style.</style>",
|
||||
],
|
||||
ids=["nav", "aside", "form", "header", "noscript", "svg", "canvas", "footer", "script", "style"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"selectee",
|
||||
[
|
||||
"""
|
||||
<article class="bd-article">
|
||||
<h2>Article Title</h2>
|
||||
<p>Article content goes here.</p>
|
||||
{ignored_tag}
|
||||
</article>
|
||||
""",
|
||||
"""
|
||||
<article role="main">
|
||||
<h2>Main Article Title</h2>
|
||||
<p>Main article content goes here.</p>
|
||||
{ignored_tag}
|
||||
</article>
|
||||
""",
|
||||
"""
|
||||
<div class="md-content">
|
||||
<h2>Markdown Content</h2>
|
||||
<p>Markdown content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<div role="main">
|
||||
<h2>Main Content</h2>
|
||||
<p>Main content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<div class="container">
|
||||
<h2>Container</h2>
|
||||
<p>Container content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<div class="section">
|
||||
<h2>Section</h2>
|
||||
<p>Section content goes here.</p>
|
||||
{ignored_tag}
|
||||
</div>
|
||||
""",
|
||||
"""
|
||||
<article>
|
||||
<h2>Generic Article</h2>
|
||||
<p>Generic article content goes here.</p>
|
||||
{ignored_tag}
|
||||
</article>
|
||||
""",
|
||||
"""
|
||||
<main>
|
||||
<h2>Main Content</h2>
|
||||
<p>Main content goes here.</p>
|
||||
{ignored_tag}
|
||||
</main>
|
||||
""",
|
||||
],
|
||||
ids=[
|
||||
"article.bd-article",
|
||||
'article[role="main"]',
|
||||
"div.md-content",
|
||||
'div[role="main"]',
|
||||
"div.container",
|
||||
"div.section",
|
||||
"article",
|
||||
"main",
|
||||
],
|
||||
)
|
||||
def test_load_data_gets_by_selectors_and_ignored_tags(selectee, ignored_tag, loader, mocked_responses, mocker):
|
||||
child_url = "https://docs.embedchain.ai/quickstart"
|
||||
selectee = selectee.format(ignored_tag=ignored_tag)
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
{selectee}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
html_body = html_body.format(selectee=selectee)
|
||||
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
url = "https://docs.embedchain.ai/"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/quickstart">Quickstart</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data(url)
|
||||
selector_soup = BeautifulSoup(selectee, "html.parser")
|
||||
expected_content = " ".join((selector_soup.select_one("h2").get_text(), selector_soup.select_one("p").get_text()))
|
||||
assert result["doc_id"] == doc_id
|
||||
assert result["data"] == [
|
||||
{
|
||||
"content": expected_content,
|
||||
"meta_data": {"url": "https://docs.embedchain.ai/quickstart"},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_load_data_gets_child_links_recursively(loader, mocked_responses, mocker):
|
||||
child_url = "https://docs.embedchain.ai/quickstart"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/">..</a></li>
|
||||
<li><a href="/quickstart">.</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
child_url = "https://docs.embedchain.ai/introduction"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/">..</a></li>
|
||||
<li><a href="/introduction">.</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(child_url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
url = "https://docs.embedchain.ai/"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/quickstart">Quickstart</a></li>
|
||||
<li><a href="/introduction">Introduction</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data(url)
|
||||
assert result["doc_id"] == doc_id
|
||||
expected_data = [
|
||||
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/quickstart"}},
|
||||
{"content": "..\n.", "meta_data": {"url": "https://docs.embedchain.ai/introduction"}},
|
||||
]
|
||||
assert all(item in expected_data for item in result["data"])
|
||||
|
||||
|
||||
def test_load_data_fails_to_fetch_website(loader, mocked_responses, mocker):
|
||||
child_url = "https://docs.embedchain.ai/introduction"
|
||||
mocked_responses.get(child_url, status=404)
|
||||
|
||||
url = "https://docs.embedchain.ai/"
|
||||
html_body = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<body>
|
||||
<li><a href="/introduction">Introduction</a></li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
mocked_responses.get(url, body=html_body, status=200, content_type="text/html")
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data(url)
|
||||
assert result["doc_id"] is doc_id
|
||||
assert result["data"] == []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
|
||||
return DocsSiteLoader()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_responses():
|
||||
with responses.RequestsMock() as rsps:
|
||||
yield rsps
|
||||
39
embedchain/tests/loaders/test_docx_file.py
Normal file
39
embedchain/tests/loaders/test_docx_file.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.docx_file import DocxFileLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_docx2txt_loader():
|
||||
with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader:
|
||||
yield mock_loader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docx_file_loader():
|
||||
return DocxFileLoader()
|
||||
|
||||
|
||||
def test_load_data(mock_docx2txt_loader, docx_file_loader):
|
||||
mock_url = "mock_docx_file.docx"
|
||||
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})]
|
||||
|
||||
mock_docx2txt_loader.return_value = mock_loader
|
||||
|
||||
result = docx_file_loader.load_data(mock_url)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_content = "Sample Docx Content"
|
||||
assert result["data"][0]["content"] == expected_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == "local"
|
||||
|
||||
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
85
embedchain/tests/loaders/test_dropbox.py
Normal file
85
embedchain/tests/loaders/test_dropbox.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dropbox.files import FileMetadata
|
||||
|
||||
from embedchain.loaders.dropbox import DropboxLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_dropbox_loader(mocker):
|
||||
mock_dropbox = mocker.patch("dropbox.Dropbox")
|
||||
mock_dbx = mocker.MagicMock()
|
||||
mock_dropbox.return_value = mock_dbx
|
||||
|
||||
os.environ["DROPBOX_ACCESS_TOKEN"] = "test_token"
|
||||
loader = DropboxLoader()
|
||||
|
||||
yield loader, mock_dbx
|
||||
|
||||
if "DROPBOX_ACCESS_TOKEN" in os.environ:
|
||||
del os.environ["DROPBOX_ACCESS_TOKEN"]
|
||||
|
||||
|
||||
def test_initialization(setup_dropbox_loader):
|
||||
"""Test initialization of DropboxLoader."""
|
||||
loader, _ = setup_dropbox_loader
|
||||
assert loader is not None
|
||||
|
||||
|
||||
def test_download_folder(setup_dropbox_loader, mocker):
|
||||
"""Test downloading a folder."""
|
||||
loader, mock_dbx = setup_dropbox_loader
|
||||
mocker.patch("os.makedirs")
|
||||
mocker.patch("os.path.join", return_value="mock/path")
|
||||
|
||||
mock_file_metadata = mocker.MagicMock(spec=FileMetadata)
|
||||
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
|
||||
|
||||
entries = loader._download_folder("path/to/folder", "local_root")
|
||||
assert entries is not None
|
||||
|
||||
|
||||
def test_generate_dir_id_from_all_paths(setup_dropbox_loader, mocker):
|
||||
"""Test directory ID generation."""
|
||||
loader, mock_dbx = setup_dropbox_loader
|
||||
mock_file_metadata = mocker.MagicMock(spec=FileMetadata, name="file.txt")
|
||||
mock_dbx.files_list_folder.return_value.entries = [mock_file_metadata]
|
||||
|
||||
dir_id = loader._generate_dir_id_from_all_paths("path/to/folder")
|
||||
assert dir_id is not None
|
||||
assert len(dir_id) == 64
|
||||
|
||||
|
||||
def test_clean_directory(setup_dropbox_loader, mocker):
|
||||
"""Test cleaning up a directory."""
|
||||
loader, _ = setup_dropbox_loader
|
||||
mocker.patch("os.listdir", return_value=["file1", "file2"])
|
||||
mocker.patch("os.remove")
|
||||
mocker.patch("os.rmdir")
|
||||
|
||||
loader._clean_directory("path/to/folder")
|
||||
|
||||
|
||||
def test_load_data(mocker, setup_dropbox_loader, tmp_path):
|
||||
loader = setup_dropbox_loader[0]
|
||||
|
||||
mock_file_metadata = MagicMock(spec=FileMetadata, name="file.txt")
|
||||
mocker.patch.object(loader.dbx, "files_list_folder", return_value=MagicMock(entries=[mock_file_metadata]))
|
||||
mocker.patch.object(loader.dbx, "files_download_to_file")
|
||||
|
||||
# Mock DirectoryLoader
|
||||
mock_data = {"data": "test_data"}
|
||||
mocker.patch("embedchain.loaders.directory_loader.DirectoryLoader.load_data", return_value=mock_data)
|
||||
|
||||
test_dir = tmp_path / "dropbox_test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "file.txt"
|
||||
test_file.write_text("dummy content")
|
||||
mocker.patch.object(loader, "_generate_dir_id_from_all_paths", return_value=str(test_dir))
|
||||
|
||||
result = loader.load_data("path/to/folder")
|
||||
|
||||
assert result == {"doc_id": mocker.ANY, "data": "test_data"}
|
||||
loader.dbx.files_list_folder.assert_called_once_with("path/to/folder")
|
||||
33
embedchain/tests/loaders/test_excel_file.py
Normal file
33
embedchain/tests/loaders/test_excel_file.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import hashlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.excel_file import ExcelFileLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def excel_file_loader():
|
||||
return ExcelFileLoader()
|
||||
|
||||
|
||||
def test_load_data(excel_file_loader):
|
||||
mock_url = "mock_excel_file.xlsx"
|
||||
expected_content = "Sample Excel Content"
|
||||
|
||||
# Mock the load_data method of the excel_file_loader instance
|
||||
with patch.object(
|
||||
excel_file_loader,
|
||||
"load_data",
|
||||
return_value={
|
||||
"doc_id": hashlib.sha256((expected_content + mock_url).encode()).hexdigest(),
|
||||
"data": [{"content": expected_content, "meta_data": {"url": mock_url}}],
|
||||
},
|
||||
):
|
||||
result = excel_file_loader.load_data(mock_url)
|
||||
|
||||
assert result["data"][0]["content"] == expected_content
|
||||
assert result["data"][0]["meta_data"]["url"] == mock_url
|
||||
|
||||
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
33
embedchain/tests/loaders/test_github.py
Normal file
33
embedchain/tests/loaders/test_github.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_loader_config():
|
||||
return {
|
||||
"token": "your_mock_token",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_loader(mocker, mock_github_loader_config):
|
||||
mock_github = mocker.patch("github.Github")
|
||||
_ = mock_github.return_value
|
||||
return GithubLoader(config=mock_github_loader_config)
|
||||
|
||||
|
||||
def test_github_loader_init(mocker, mock_github_loader_config):
|
||||
mock_github = mocker.patch("github.Github")
|
||||
GithubLoader(config=mock_github_loader_config)
|
||||
mock_github.assert_called_once_with("your_mock_token")
|
||||
|
||||
|
||||
def test_github_loader_init_empty_config(mocker):
|
||||
with pytest.raises(ValueError, match="requires a personal access token"):
|
||||
GithubLoader()
|
||||
|
||||
|
||||
def test_github_loader_init_missing_token():
|
||||
with pytest.raises(ValueError, match="requires a personal access token"):
|
||||
GithubLoader(config={})
|
||||
43
embedchain/tests/loaders/test_gmail.py
Normal file
43
embedchain/tests/loaders/test_gmail.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.gmail import GmailLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_beautifulsoup(mocker):
|
||||
return mocker.patch("embedchain.loaders.gmail.BeautifulSoup", return_value=mocker.MagicMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gmail_loader(mock_beautifulsoup):
|
||||
return GmailLoader()
|
||||
|
||||
|
||||
def test_load_data_file_not_found(gmail_loader, mocker):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
with mocker.patch("os.path.isfile", return_value=False):
|
||||
gmail_loader.load_data("your_query")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO: Fix this test. Failing due to some googleapiclient import issue.")
|
||||
def test_load_data(gmail_loader, mocker):
|
||||
mock_gmail_reader_instance = mocker.MagicMock()
|
||||
text = "your_test_email_text"
|
||||
metadata = {
|
||||
"id": "your_test_id",
|
||||
"snippet": "your_test_snippet",
|
||||
}
|
||||
mock_gmail_reader_instance.load_data.return_value = [
|
||||
{
|
||||
"text": text,
|
||||
"extra_info": metadata,
|
||||
}
|
||||
]
|
||||
|
||||
with mocker.patch("os.path.isfile", return_value=True):
|
||||
response_data = gmail_loader.load_data("your_query")
|
||||
|
||||
assert "doc_id" in response_data
|
||||
assert "data" in response_data
|
||||
assert isinstance(response_data["doc_id"], str)
|
||||
assert isinstance(response_data["data"], list)
|
||||
37
embedchain/tests/loaders/test_google_drive.py
Normal file
37
embedchain/tests/loaders/test_google_drive.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.google_drive import GoogleDriveLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_drive_folder_loader():
|
||||
return GoogleDriveLoader()
|
||||
|
||||
|
||||
def test_load_data_invalid_drive_url(google_drive_folder_loader):
|
||||
mock_invalid_drive_url = "https://example.com"
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The url provided https://example.com does not match a google drive folder url. Example "
|
||||
"drive url: https://drive.google.com/drive/u/0/folders/xxxx",
|
||||
):
|
||||
google_drive_folder_loader.load_data(mock_invalid_drive_url)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
|
||||
def test_load_data_incorrect_drive_url(google_drive_folder_loader):
|
||||
mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
|
||||
with pytest.raises(
|
||||
FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
|
||||
):
|
||||
google_drive_folder_loader.load_data(mock_invalid_drive_url)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
|
||||
def test_load_data(google_drive_folder_loader):
|
||||
mock_valid_url = "YOUR_VALID_URL"
|
||||
result = google_drive_folder_loader.load_data(mock_valid_url)
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert "content" in result["data"][0]
|
||||
assert "meta_data" in result["data"][0]
|
||||
131
embedchain/tests/loaders/test_json.py
Normal file
131
embedchain/tests/loaders/test_json.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
|
||||
|
||||
def test_load_data(mocker):
|
||||
content = "temp.json"
|
||||
|
||||
mock_document = {
|
||||
"doc_id": hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest(),
|
||||
"data": [
|
||||
{"content": "content1", "meta_data": {"url": content}},
|
||||
{"content": "content2", "meta_data": {"url": content}},
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch("embedchain.loaders.json.JSONLoader.load_data", return_value=mock_document)
|
||||
|
||||
json_loader = JSONLoader()
|
||||
|
||||
result = json_loader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_data = [
|
||||
{"content": "content1", "meta_data": {"url": content}},
|
||||
{"content": "content2", "meta_data": {"url": content}},
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
|
||||
def test_load_data_url(mocker):
|
||||
content = "https://example.com/posts.json"
|
||||
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
mocker.patch(
|
||||
"embedchain.loaders.json.JSONReader.load_data",
|
||||
return_value=[
|
||||
{
|
||||
"text": "content1",
|
||||
},
|
||||
{
|
||||
"text": "content2",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"document1": "content1", "document2": "content2"}
|
||||
|
||||
mocker.patch("requests.get", return_value=mock_response)
|
||||
|
||||
result = JSONLoader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_data = [
|
||||
{"content": "content1", "meta_data": {"url": content}},
|
||||
{"content": "content2", "meta_data": {"url": content}},
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
expected_doc_id = hashlib.sha256((content + ", ".join(["content1", "content2"])).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
|
||||
def test_load_data_invalid_string_content(mocker):
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
mocker.patch("requests.get")
|
||||
|
||||
content = "123: 345}"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid content to load json data from"):
|
||||
JSONLoader.load_data(content)
|
||||
|
||||
|
||||
def test_load_data_invalid_url(mocker):
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
|
||||
mock_response = mocker.Mock()
|
||||
mock_response.status_code = 404
|
||||
mocker.patch("requests.get", return_value=mock_response)
|
||||
|
||||
content = "http://invalid-url.com/"
|
||||
|
||||
with pytest.raises(ValueError, match=f"Invalid content to load json data from: {content}"):
|
||||
JSONLoader.load_data(content)
|
||||
|
||||
|
||||
def test_load_data_from_json_string(mocker):
|
||||
content = '{"foo": "bar"}'
|
||||
|
||||
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
|
||||
|
||||
mocker.patch("os.path.isfile", return_value=False)
|
||||
mocker.patch(
|
||||
"embedchain.loaders.json.JSONReader.load_data",
|
||||
return_value=[
|
||||
{
|
||||
"text": "content1",
|
||||
},
|
||||
{
|
||||
"text": "content2",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result = JSONLoader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
expected_data = [
|
||||
{"content": "content1", "meta_data": {"url": content_url_str}},
|
||||
{"content": "content2", "meta_data": {"url": content_url_str}},
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
expected_doc_id = hashlib.sha256((content_url_str + ", ".join(["content1", "content2"])).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
32
embedchain/tests/loaders/test_local_qna_pair.py
Normal file
32
embedchain/tests/loaders/test_local_qna_pair.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qna_pair_loader():
|
||||
return LocalQnaPairLoader()
|
||||
|
||||
|
||||
def test_load_data(qna_pair_loader):
|
||||
question = "What is the capital of France?"
|
||||
answer = "The capital of France is Paris."
|
||||
|
||||
content = (question, answer)
|
||||
result = qna_pair_loader.load_data(content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
url = "local"
|
||||
|
||||
expected_content = f"Q: {question}\nA: {answer}"
|
||||
assert result["data"][0]["content"] == expected_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == url
|
||||
|
||||
assert result["data"][0]["meta_data"]["question"] == question
|
||||
|
||||
expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
27
embedchain/tests/loaders/test_local_text.py
Normal file
27
embedchain/tests/loaders/test_local_text.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.local_text import LocalTextLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_loader():
|
||||
return LocalTextLoader()
|
||||
|
||||
|
||||
def test_load_data(text_loader):
|
||||
mock_content = "This is a sample text content."
|
||||
|
||||
result = text_loader.load_data(mock_content)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
url = "local"
|
||||
assert result["data"][0]["content"] == mock_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == url
|
||||
|
||||
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
30
embedchain/tests/loaders/test_mdx.py
Normal file
30
embedchain/tests/loaders/test_mdx.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import hashlib
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mdx_loader():
|
||||
return MdxLoader()
|
||||
|
||||
|
||||
def test_load_data(mdx_loader):
|
||||
mock_content = "Sample MDX Content"
|
||||
|
||||
# Mock open function to simulate file reading
|
||||
with patch("builtins.open", mock_open(read_data=mock_content)):
|
||||
url = "mock_file.mdx"
|
||||
result = mdx_loader.load_data(url)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
|
||||
assert result["data"][0]["content"] == mock_content
|
||||
|
||||
assert result["data"][0]["meta_data"]["url"] == url
|
||||
|
||||
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
77
embedchain/tests/loaders/test_mysql.py
Normal file
77
embedchain/tests/loaders/test_mysql.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.mysql import MySQLLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mysql_loader(mocker):
|
||||
with mocker.patch("mysql.connector.connection.MySQLConnection"):
|
||||
config = {
|
||||
"host": "localhost",
|
||||
"port": "3306",
|
||||
"user": "your_username",
|
||||
"password": "your_password",
|
||||
"database": "your_database",
|
||||
}
|
||||
loader = MySQLLoader(config=config)
|
||||
yield loader
|
||||
|
||||
|
||||
def test_mysql_loader_initialization(mysql_loader):
|
||||
assert mysql_loader.config is not None
|
||||
assert mysql_loader.connection is not None
|
||||
assert mysql_loader.cursor is not None
|
||||
|
||||
|
||||
def test_mysql_loader_invalid_config():
|
||||
with pytest.raises(ValueError, match="Invalid sql config: None"):
|
||||
MySQLLoader(config=None)
|
||||
|
||||
|
||||
def test_mysql_loader_setup_loader_successful(mysql_loader):
|
||||
assert mysql_loader.connection is not None
|
||||
assert mysql_loader.cursor is not None
|
||||
|
||||
|
||||
def test_mysql_loader_setup_loader_connection_error(mysql_loader, mocker):
|
||||
mocker.patch("mysql.connector.connection.MySQLConnection", side_effect=IOError("Mocked connection error"))
|
||||
with pytest.raises(ValueError, match="Unable to connect with the given config:"):
|
||||
mysql_loader._setup_loader(config={})
|
||||
|
||||
|
||||
def test_mysql_loader_check_query_successful(mysql_loader):
|
||||
query = "SELECT * FROM table"
|
||||
mysql_loader._check_query(query=query)
|
||||
|
||||
|
||||
def test_mysql_loader_check_query_invalid(mysql_loader):
|
||||
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
|
||||
mysql_loader._check_query(query=123)
|
||||
|
||||
|
||||
def test_mysql_loader_load_data_successful(mysql_loader, mocker):
|
||||
mock_cursor = MagicMock()
|
||||
mocker.patch.object(mysql_loader, "cursor", mock_cursor)
|
||||
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
|
||||
|
||||
query = "SELECT * FROM table"
|
||||
result = mysql_loader.load_data(query)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
assert result["data"][0]["meta_data"]["url"] == query
|
||||
assert result["data"][1]["meta_data"]["url"] == query
|
||||
|
||||
doc_id = hashlib.sha256((query + ", ".join([d["content"] for d in result["data"]])).encode()).hexdigest()
|
||||
|
||||
assert result["doc_id"] == doc_id
|
||||
assert mock_cursor.execute.called_with(query)
|
||||
|
||||
|
||||
def test_mysql_loader_load_data_invalid_query(mysql_loader):
|
||||
with pytest.raises(ValueError, match="Invalid mysql query: 123"):
|
||||
mysql_loader.load_data(query=123)
|
||||
36
embedchain/tests/loaders/test_notion.py
Normal file
36
embedchain/tests/loaders/test_notion.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import hashlib
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.notion import NotionLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_loader():
|
||||
with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}):
|
||||
yield NotionLoader()
|
||||
|
||||
|
||||
def test_load_data(notion_loader):
|
||||
source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef"
|
||||
mock_text = "This is a test page."
|
||||
expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest()
|
||||
expected_data = [
|
||||
{
|
||||
"content": mock_text,
|
||||
"meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"}, # formatted_id
|
||||
}
|
||||
]
|
||||
|
||||
mock_page = Mock()
|
||||
mock_page.text = mock_text
|
||||
mock_documents = [mock_page]
|
||||
|
||||
with patch("embedchain.loaders.notion.NotionPageLoader") as mock_reader:
|
||||
mock_reader.return_value.load_data.return_value = mock_documents
|
||||
result = notion_loader.load_data(source)
|
||||
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
assert result["data"] == expected_data
|
||||
26
embedchain/tests/loaders/test_openapi.py
Normal file
26
embedchain/tests/loaders/test_openapi.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.openapi import OpenAPILoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_loader():
|
||||
return OpenAPILoader()
|
||||
|
||||
|
||||
def test_load_data(openapi_loader, mocker):
|
||||
mocker.patch("builtins.open", mocker.mock_open(read_data="key1: value1\nkey2: value2"))
|
||||
|
||||
mocker.patch("hashlib.sha256", return_value=mocker.Mock(hexdigest=lambda: "mock_hash"))
|
||||
|
||||
file_path = "configs/openai_openapi.yaml"
|
||||
result = openapi_loader.load_data(file_path)
|
||||
|
||||
expected_doc_id = "mock_hash"
|
||||
expected_data = [
|
||||
{"content": "key1: value1", "meta_data": {"url": file_path, "row": 1}},
|
||||
{"content": "key2: value2", "meta_data": {"url": file_path, "row": 2}},
|
||||
]
|
||||
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
assert result["data"] == expected_data
|
||||
36
embedchain/tests/loaders/test_pdf_file.py
Normal file
36
embedchain/tests/loaders/test_pdf_file.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_load_data(loader, mocker):
|
||||
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
|
||||
mocked_pypdfloader.return_value.load_and_split.return_value = [
|
||||
Document(page_content="Page 0 Content", metadata={"source": "example.pdf", "page": 0}),
|
||||
Document(page_content="Page 1 Content", metadata={"source": "example.pdf", "page": 1}),
|
||||
]
|
||||
|
||||
mock_sha256 = mocker.patch("embedchain.loaders.docs_site_loader.hashlib.sha256")
|
||||
doc_id = "mocked_hash"
|
||||
mock_sha256.return_value.hexdigest.return_value = doc_id
|
||||
|
||||
result = loader.load_data("dummy_url")
|
||||
assert result["doc_id"] is doc_id
|
||||
assert result["data"] == [
|
||||
{"content": "Page 0 Content", "meta_data": {"source": "example.pdf", "page": 0, "url": "dummy_url"}},
|
||||
{"content": "Page 1 Content", "meta_data": {"source": "example.pdf", "page": 1, "url": "dummy_url"}},
|
||||
]
|
||||
|
||||
|
||||
def test_load_data_fails_to_find_data(loader, mocker):
|
||||
mocked_pypdfloader = mocker.patch("embedchain.loaders.pdf_file.PyPDFLoader")
|
||||
mocked_pypdfloader.return_value.load_and_split.return_value = []
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
loader.load_data("dummy_url")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
|
||||
return PdfFileLoader()
|
||||
60
embedchain/tests/loaders/test_postgres.py
Normal file
60
embedchain/tests/loaders/test_postgres.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import psycopg
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def postgres_loader(mocker):
|
||||
with mocker.patch.object(psycopg, "connect"):
|
||||
config = {"url": "postgres://user:password@localhost:5432/database"}
|
||||
loader = PostgresLoader(config=config)
|
||||
yield loader
|
||||
|
||||
|
||||
def test_postgres_loader_initialization(postgres_loader):
|
||||
assert postgres_loader.connection is not None
|
||||
assert postgres_loader.cursor is not None
|
||||
|
||||
|
||||
def test_postgres_loader_invalid_config():
|
||||
with pytest.raises(ValueError, match="Must provide the valid config. Received: None"):
|
||||
PostgresLoader(config=None)
|
||||
|
||||
|
||||
def test_load_data(postgres_loader, monkeypatch):
|
||||
mock_cursor = MagicMock()
|
||||
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
|
||||
|
||||
query = "SELECT * FROM table"
|
||||
mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
|
||||
|
||||
result = postgres_loader.load_data(query)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
assert result["data"][0]["meta_data"]["url"] == query
|
||||
assert result["data"][1]["meta_data"]["url"] == query
|
||||
assert mock_cursor.execute.called_with(query)
|
||||
|
||||
|
||||
def test_load_data_exception(postgres_loader, monkeypatch):
|
||||
mock_cursor = MagicMock()
|
||||
monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
|
||||
|
||||
_ = "SELECT * FROM table"
|
||||
mock_cursor.execute.side_effect = Exception("Mocked exception")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception"
|
||||
):
|
||||
postgres_loader.load_data("SELECT * FROM table")
|
||||
|
||||
|
||||
def test_close_connection(postgres_loader):
|
||||
postgres_loader.close_connection()
|
||||
assert postgres_loader.cursor is None
|
||||
assert postgres_loader.connection is None
|
||||
47
embedchain/tests/loaders/test_slack.py
Normal file
47
embedchain/tests/loaders/test_slack.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.slack import SlackLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def slack_loader(mocker, monkeypatch):
|
||||
# Mocking necessary dependencies
|
||||
mocker.patch("slack_sdk.WebClient")
|
||||
mocker.patch("ssl.create_default_context")
|
||||
mocker.patch("certifi.where")
|
||||
|
||||
monkeypatch.setenv("SLACK_USER_TOKEN", "slack_user_token")
|
||||
|
||||
return SlackLoader()
|
||||
|
||||
|
||||
def test_slack_loader_initialization(slack_loader):
|
||||
assert slack_loader.client is not None
|
||||
assert slack_loader.config == {"base_url": "https://www.slack.com/api/"}
|
||||
|
||||
|
||||
def test_slack_loader_setup_loader(slack_loader):
|
||||
slack_loader._setup_loader({"base_url": "https://custom.slack.api/"})
|
||||
|
||||
assert slack_loader.client is not None
|
||||
|
||||
|
||||
def test_slack_loader_check_query(slack_loader):
|
||||
valid_json_query = "test_query"
|
||||
invalid_query = 123
|
||||
|
||||
slack_loader._check_query(valid_json_query)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
slack_loader._check_query(invalid_query)
|
||||
|
||||
|
||||
def test_slack_loader_load_data(slack_loader, mocker):
|
||||
valid_json_query = "in:random"
|
||||
|
||||
mocker.patch.object(slack_loader.client, "search_messages", return_value={"messages": {}})
|
||||
|
||||
result = slack_loader.load_data(valid_json_query)
|
||||
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
117
embedchain/tests/loaders/test_web_page.py
Normal file
117
embedchain/tests/loaders/test_web_page.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import hashlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_page_loader():
|
||||
return WebPageLoader()
|
||||
|
||||
|
||||
def test_load_data(web_page_loader):
|
||||
page_url = "https://example.com/page"
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Test Page</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="content">
|
||||
<p>This is some test content.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
with patch("embedchain.loaders.web_page.WebPageLoader._session.get", return_value=mock_response):
|
||||
result = web_page_loader.load_data(page_url)
|
||||
|
||||
content = web_page_loader._get_clean_content(mock_response.content, page_url)
|
||||
expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest()
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
expected_data = [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": {
|
||||
"url": page_url,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
|
||||
def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
|
||||
mock_html = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Sample HTML</title>
|
||||
<style>
|
||||
/* Stylesheet to be excluded */
|
||||
.elementor-location-header {
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header id="header">Header Content</header>
|
||||
<nav class="nav">Nav Content</nav>
|
||||
<aside>Aside Content</aside>
|
||||
<form>Form Content</form>
|
||||
<main>Main Content</main>
|
||||
<footer class="footer">Footer Content</footer>
|
||||
<script>Some Script</script>
|
||||
<noscript>NoScript Content</noscript>
|
||||
<svg>SVG Content</svg>
|
||||
<canvas>Canvas Content</canvas>
|
||||
|
||||
<div id="sidebar">Sidebar Content</div>
|
||||
<div id="main-navigation">Main Navigation Content</div>
|
||||
<div id="menu-main-menu">Menu Main Menu Content</div>
|
||||
|
||||
<div class="header-sidebar-wrapper">Header Sidebar Wrapper Content</div>
|
||||
<div class="blog-sidebar-wrapper">Blog Sidebar Wrapper Content</div>
|
||||
<div class="related-posts">Related Posts Content</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
tags_to_exclude = [
|
||||
"nav",
|
||||
"aside",
|
||||
"form",
|
||||
"header",
|
||||
"noscript",
|
||||
"svg",
|
||||
"canvas",
|
||||
"footer",
|
||||
"script",
|
||||
"style",
|
||||
]
|
||||
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
||||
classes_to_exclude = [
|
||||
"elementor-location-header",
|
||||
"navbar-header",
|
||||
"nav",
|
||||
"header-sidebar-wrapper",
|
||||
"blog-sidebar-wrapper",
|
||||
"related-posts",
|
||||
]
|
||||
|
||||
content = web_page_loader._get_clean_content(mock_html, "https://example.com/page")
|
||||
|
||||
for tag in tags_to_exclude:
|
||||
assert tag not in content
|
||||
|
||||
for id in ids_to_exclude:
|
||||
assert id not in content
|
||||
|
||||
for class_name in classes_to_exclude:
|
||||
assert class_name not in content
|
||||
|
||||
assert len(content) > 0
|
||||
62
embedchain/tests/loaders/test_xml.py
Normal file
62
embedchain/tests/loaders/test_xml.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.xml import XmlLoader
|
||||
|
||||
# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
|
||||
SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<factbook>
|
||||
<country>
|
||||
<name>United States</name>
|
||||
<capital>Washington, DC</capital>
|
||||
<leader>Joe Biden</leader>
|
||||
<sport>Baseball</sport>
|
||||
</country>
|
||||
<country>
|
||||
<name>Canada</name>
|
||||
<capital>Ottawa</capital>
|
||||
<leader>Justin Trudeau</leader>
|
||||
<sport>Hockey</sport>
|
||||
</country>
|
||||
<country>
|
||||
<name>France</name>
|
||||
<capital>Paris</capital>
|
||||
<leader>Emmanuel Macron</leader>
|
||||
<sport>Soccer</sport>
|
||||
</country>
|
||||
<country>
|
||||
<name>Trinidad & Tobado</name>
|
||||
<capital>Port of Spain</capital>
|
||||
<leader>Keith Rowley</leader>
|
||||
<sport>Track & Field</sport>
|
||||
</country>
|
||||
</factbook>"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("xml", [SAMPLE_XML])
|
||||
def test_load_data(xml: str):
|
||||
"""
|
||||
Test XML loader
|
||||
|
||||
Tests that XML file is loaded, metadata is correct and content is correct
|
||||
"""
|
||||
# Creating temporary XML file
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
|
||||
tmpfile.write(xml)
|
||||
|
||||
tmpfile.seek(0)
|
||||
filename = tmpfile.name
|
||||
|
||||
# Loading CSV using XmlLoader
|
||||
loader = XmlLoader()
|
||||
result = loader.load_data(filename)
|
||||
data = result["data"]
|
||||
|
||||
# Assertions
|
||||
assert len(data) == 1
|
||||
assert "United States Washington, DC Joe Biden" in data[0]["content"]
|
||||
assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
|
||||
assert "France Paris Emmanuel Macron" in data[0]["content"]
|
||||
assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
|
||||
assert data[0]["meta_data"]["url"] == filename
|
||||
53
embedchain/tests/loaders/test_youtube_video.py
Normal file
53
embedchain/tests/loaders/test_youtube_video.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def youtube_video_loader():
|
||||
return YoutubeVideoLoader()
|
||||
|
||||
|
||||
def test_load_data(youtube_video_loader):
|
||||
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
|
||||
mock_loader = Mock()
|
||||
mock_page_content = "This is a YouTube video content."
|
||||
mock_loader.load.return_value = [
|
||||
MagicMock(
|
||||
page_content=mock_page_content,
|
||||
metadata={"url": video_url, "title": "Test Video"},
|
||||
)
|
||||
]
|
||||
|
||||
mock_transcript = [{"text": "sample text", "start": 0.0, "duration": 5.0}]
|
||||
|
||||
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader), patch(
|
||||
"embedchain.loaders.youtube_video.YouTubeTranscriptApi.get_transcript", return_value=mock_transcript
|
||||
):
|
||||
result = youtube_video_loader.load_data(video_url)
|
||||
|
||||
expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest()
|
||||
|
||||
assert result["doc_id"] == expected_doc_id
|
||||
|
||||
expected_data = [
|
||||
{
|
||||
"content": "This is a YouTube video content.",
|
||||
"meta_data": {"url": video_url, "title": "Test Video", "transcript": "Unavailable"},
|
||||
}
|
||||
]
|
||||
|
||||
assert result["data"] == expected_data
|
||||
|
||||
|
||||
def test_load_data_with_empty_doc(youtube_video_loader):
|
||||
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
|
||||
mock_loader = Mock()
|
||||
mock_loader.load.return_value = []
|
||||
|
||||
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
|
||||
with pytest.raises(ValueError):
|
||||
youtube_video_loader.load_data(video_url)
|
||||
91
embedchain/tests/memory/test_chat_memory.py
Normal file
91
embedchain/tests/memory/test_chat_memory.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
# Fixture for creating an instance of ChatHistory
|
||||
@pytest.fixture
|
||||
def chat_memory_instance():
|
||||
return ChatHistory()
|
||||
|
||||
|
||||
def test_add_chat_memory(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
session_id = "test_session"
|
||||
human_message = "Hello, how are you?"
|
||||
ai_message = "I'm fine, thank you!"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, session_id, chat_message)
|
||||
|
||||
assert chat_memory_instance.count(app_id, session_id) == 1
|
||||
chat_memory_instance.delete(app_id, session_id)
|
||||
|
||||
|
||||
def test_get(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
session_id = "test_session"
|
||||
|
||||
for i in range(1, 7):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, session_id, chat_message)
|
||||
|
||||
recent_memories = chat_memory_instance.get(app_id, session_id, num_rounds=5)
|
||||
|
||||
assert len(recent_memories) == 5
|
||||
|
||||
all_memories = chat_memory_instance.get(app_id, fetch_all=True)
|
||||
|
||||
assert len(all_memories) == 6
|
||||
|
||||
|
||||
def test_delete_chat_history(chat_memory_instance):
|
||||
app_id = "test_app"
|
||||
session_id = "test_session"
|
||||
|
||||
for i in range(1, 6):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, session_id, chat_message)
|
||||
|
||||
session_id_2 = "test_session_2"
|
||||
|
||||
for i in range(1, 6):
|
||||
human_message = f"Question {i}"
|
||||
ai_message = f"Answer {i}"
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message)
|
||||
chat_message.add_ai_message(ai_message)
|
||||
|
||||
chat_memory_instance.add(app_id, session_id_2, chat_message)
|
||||
|
||||
chat_memory_instance.delete(app_id, session_id)
|
||||
|
||||
assert chat_memory_instance.count(app_id, session_id) == 0
|
||||
assert chat_memory_instance.count(app_id) == 5
|
||||
|
||||
chat_memory_instance.delete(app_id)
|
||||
|
||||
assert chat_memory_instance.count(app_id) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def close_connection(chat_memory_instance):
|
||||
yield
|
||||
chat_memory_instance.close_connection()
|
||||
37
embedchain/tests/memory/test_memory_messages.py
Normal file
37
embedchain/tests/memory/test_memory_messages.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from embedchain.memory.message import BaseMessage, ChatMessage
|
||||
|
||||
|
||||
def test_ec_base_message():
|
||||
content = "Hello, how are you?"
|
||||
created_by = "human"
|
||||
metadata = {"key": "value"}
|
||||
|
||||
message = BaseMessage(content=content, created_by=created_by, metadata=metadata)
|
||||
|
||||
assert message.content == content
|
||||
assert message.created_by == created_by
|
||||
assert message.metadata == metadata
|
||||
assert message.type is None
|
||||
assert message.is_lc_serializable() is True
|
||||
assert str(message) == f"{created_by}: {content}"
|
||||
|
||||
|
||||
def test_ec_base_chat_message():
|
||||
human_message_content = "Hello, how are you?"
|
||||
ai_message_content = "I'm fine, thank you!"
|
||||
human_metadata = {"user": "John"}
|
||||
ai_metadata = {"response_time": 0.5}
|
||||
|
||||
chat_message = ChatMessage()
|
||||
chat_message.add_user_message(human_message_content, metadata=human_metadata)
|
||||
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
|
||||
|
||||
assert chat_message.human_message.content == human_message_content
|
||||
assert chat_message.human_message.created_by == "human"
|
||||
assert chat_message.human_message.metadata == human_metadata
|
||||
|
||||
assert chat_message.ai_message.content == ai_message_content
|
||||
assert chat_message.ai_message.created_by == "ai"
|
||||
assert chat_message.ai_message.metadata == ai_metadata
|
||||
|
||||
assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"
|
||||
30
embedchain/tests/models/test_data_type.py
Normal file
30
embedchain/tests/models/test_data_type.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
|
||||
|
||||
def test_subclass_types_in_data_type():
|
||||
"""Test that all data type category subclasses are contained in the composite data type"""
|
||||
# Check if DirectDataType values are in DataType
|
||||
for data_type in DirectDataType:
|
||||
assert data_type.value in DataType._value2member_map_
|
||||
|
||||
# Check if IndirectDataType values are in DataType
|
||||
for data_type in IndirectDataType:
|
||||
assert data_type.value in DataType._value2member_map_
|
||||
|
||||
# Check if SpecialDataType values are in DataType
|
||||
for data_type in SpecialDataType:
|
||||
assert data_type.value in DataType._value2member_map_
|
||||
|
||||
|
||||
def test_data_type_in_subclasses():
|
||||
"""Test that all data types in the composite data type are categorized in a subclass"""
|
||||
for data_type in DataType:
|
||||
if data_type.value in DirectDataType._value2member_map_:
|
||||
assert data_type.value in DirectDataType._value2member_map_
|
||||
elif data_type.value in IndirectDataType._value2member_map_:
|
||||
assert data_type.value in IndirectDataType._value2member_map_
|
||||
elif data_type.value in SpecialDataType._value2member_map_:
|
||||
assert data_type.value in SpecialDataType._value2member_map_
|
||||
else:
|
||||
assert False, f"{data_type.value} not found in any subclass enums"
|
||||
65
embedchain/tests/telemetry/test_posthog.py
Normal file
65
embedchain/tests/telemetry/test_posthog.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
|
||||
|
||||
class TestAnonymousTelemetry:
|
||||
def test_init(self, mocker):
|
||||
# Enable telemetry specifically for this test
|
||||
os.environ["EC_TELEMETRY"] = "true"
|
||||
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
|
||||
telemetry = AnonymousTelemetry()
|
||||
assert telemetry.project_api_key == "phc_PHQDA5KwztijnSojsxJ2c1DuJd52QCzJzT2xnSGvjN2"
|
||||
assert telemetry.host == "https://app.posthog.com"
|
||||
assert telemetry.enabled is True
|
||||
assert telemetry.user_id
|
||||
mock_posthog.assert_called_once_with(project_api_key=telemetry.project_api_key, host=telemetry.host)
|
||||
|
||||
def test_init_with_disabled_telemetry(self, mocker):
|
||||
mocker.patch("embedchain.telemetry.posthog.Posthog")
|
||||
telemetry = AnonymousTelemetry()
|
||||
assert telemetry.enabled is False
|
||||
assert telemetry.posthog.disabled is True
|
||||
|
||||
def test_get_user_id(self, mocker, tmpdir):
|
||||
mock_uuid = mocker.patch("embedchain.telemetry.posthog.uuid.uuid4")
|
||||
mock_uuid.return_value = "unique_user_id"
|
||||
config_file = tmpdir.join("config.json")
|
||||
mocker.patch("embedchain.telemetry.posthog.CONFIG_FILE", str(config_file))
|
||||
telemetry = AnonymousTelemetry()
|
||||
|
||||
user_id = telemetry._get_user_id()
|
||||
assert user_id == "unique_user_id"
|
||||
assert config_file.read() == '{"user_id": "unique_user_id"}'
|
||||
|
||||
def test_capture(self, mocker):
|
||||
# Enable telemetry specifically for this test
|
||||
os.environ["EC_TELEMETRY"] = "true"
|
||||
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
|
||||
telemetry = AnonymousTelemetry()
|
||||
event_name = "test_event"
|
||||
properties = {"key": "value"}
|
||||
telemetry.capture(event_name, properties)
|
||||
|
||||
mock_posthog.assert_called_once_with(
|
||||
project_api_key=telemetry.project_api_key,
|
||||
host=telemetry.host,
|
||||
)
|
||||
mock_posthog.return_value.capture.assert_called_once_with(
|
||||
telemetry.user_id,
|
||||
event_name,
|
||||
properties,
|
||||
)
|
||||
|
||||
def test_capture_with_exception(self, mocker, caplog):
|
||||
os.environ["EC_TELEMETRY"] = "true"
|
||||
mock_posthog = mocker.patch("embedchain.telemetry.posthog.Posthog")
|
||||
mock_posthog.return_value.capture.side_effect = Exception("Test Exception")
|
||||
telemetry = AnonymousTelemetry()
|
||||
event_name = "test_event"
|
||||
properties = {"key": "value"}
|
||||
with caplog.at_level(logging.ERROR):
|
||||
telemetry.capture(event_name, properties)
|
||||
assert "Failed to send telemetry event" in caplog.text
|
||||
caplog.clear()
|
||||
111
embedchain/tests/test_app.py
Normal file
111
embedchain/tests/test_app.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_BASE"] = "test_api_base"
|
||||
return App()
|
||||
|
||||
|
||||
def test_app(app):
|
||||
assert isinstance(app.llm, BaseLlm)
|
||||
assert isinstance(app.db, BaseVectorDB)
|
||||
assert isinstance(app.embedding_model, BaseEmbedder)
|
||||
|
||||
|
||||
class TestConfigForAppComponents:
|
||||
def test_constructor_config(self):
|
||||
collection_name = "my-test-collection"
|
||||
db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
|
||||
app = App(db=db)
|
||||
assert app.db.config.collection_name == collection_name
|
||||
|
||||
def test_component_config(self):
|
||||
collection_name = "my-test-collection"
|
||||
database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
|
||||
app = App(db=database)
|
||||
assert app.db.config.collection_name == collection_name
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
with open(yaml_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
def test_from_chroma_config(self, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
|
||||
yaml_path = "configs/chroma.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(config_path=yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
# Even though not present in the config, the default value is used
|
||||
assert app.config.collect_metrics is True
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedding_model.config.model == embedder_config["model"]
|
||||
assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name")
|
||||
|
||||
def test_from_opensource_config(self, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
|
||||
yaml_path = "configs/opensource.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.model == llm_config["model"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"]
|
||||
53
embedchain/tests/test_client.py
Normal file
53
embedchain/tests/test_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
|
||||
from embedchain import Client
|
||||
|
||||
|
||||
class TestClient:
|
||||
@pytest.fixture
|
||||
def mock_requests_post(self, mocker):
|
||||
return mocker.patch("embedchain.client.requests.post")
|
||||
|
||||
def test_valid_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
client = Client(api_key="valid_api_key")
|
||||
assert client.check("valid_api_key") is True
|
||||
|
||||
def test_invalid_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 401
|
||||
with pytest.raises(ValueError):
|
||||
Client(api_key="invalid_api_key")
|
||||
|
||||
def test_update_valid_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
client = Client(api_key="valid_api_key")
|
||||
client.update("new_valid_api_key")
|
||||
assert client.get() == "new_valid_api_key"
|
||||
|
||||
def test_clear_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
client = Client(api_key="valid_api_key")
|
||||
client.clear()
|
||||
assert client.get() is None
|
||||
|
||||
def test_save_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
api_key_to_save = "valid_api_key"
|
||||
client = Client(api_key=api_key_to_save)
|
||||
client.save()
|
||||
assert client.get() == api_key_to_save
|
||||
|
||||
def test_load_api_key_from_config(self, mocker):
|
||||
mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"})
|
||||
client = Client()
|
||||
assert client.get() == "test_api_key"
|
||||
|
||||
def test_load_invalid_api_key_from_config(self, mocker):
|
||||
mocker.patch("embedchain.Client.load_config", return_value={})
|
||||
with pytest.raises(ValueError):
|
||||
Client()
|
||||
|
||||
def test_load_missing_api_key_from_config(self, mocker):
|
||||
mocker.patch("embedchain.Client.load_config", return_value={})
|
||||
with pytest.raises(ValueError):
|
||||
Client()
|
||||
66
embedchain/tests/test_factory.py
Normal file
66
embedchain/tests/test_factory.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import embedchain
|
||||
import embedchain.embedder.gpt4all
|
||||
import embedchain.embedder.huggingface
|
||||
import embedchain.embedder.openai
|
||||
import embedchain.embedder.vertexai
|
||||
import embedchain.llm.anthropic
|
||||
import embedchain.llm.openai
|
||||
import embedchain.vectordb.chroma
|
||||
import embedchain.vectordb.elasticsearch
|
||||
import embedchain.vectordb.opensearch
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
|
||||
|
||||
class TestFactories:
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("openai", {}, embedchain.llm.openai.OpenAILlm),
|
||||
("anthropic", {}, embedchain.llm.anthropic.AnthropicLlm),
|
||||
],
|
||||
)
|
||||
def test_llm_factory_create(self, provider_name, config_data, expected_class):
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_BASE"] = "test_api_base"
|
||||
llm_instance = LlmFactory.create(provider_name, config_data)
|
||||
assert isinstance(llm_instance, expected_class)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("gpt4all", {}, embedchain.embedder.gpt4all.GPT4AllEmbedder),
|
||||
(
|
||||
"huggingface",
|
||||
{"model": "sentence-transformers/all-mpnet-base-v2", "vector_dimension": 768},
|
||||
embedchain.embedder.huggingface.HuggingFaceEmbedder,
|
||||
),
|
||||
("vertexai", {"model": "textembedding-gecko"}, embedchain.embedder.vertexai.VertexAIEmbedder),
|
||||
("openai", {}, embedchain.embedder.openai.OpenAIEmbedder),
|
||||
],
|
||||
)
|
||||
def test_embedder_factory_create(self, mocker, provider_name, config_data, expected_class):
|
||||
mocker.patch("embedchain.embedder.vertexai.VertexAIEmbedder", autospec=True)
|
||||
embedder_instance = EmbedderFactory.create(provider_name, config_data)
|
||||
assert isinstance(embedder_instance, expected_class)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_name, config_data, expected_class",
|
||||
[
|
||||
("chroma", {}, embedchain.vectordb.chroma.ChromaDB),
|
||||
(
|
||||
"opensearch",
|
||||
{"opensearch_url": "http://localhost:9200", "http_auth": ("admin", "admin")},
|
||||
embedchain.vectordb.opensearch.OpenSearchDB,
|
||||
),
|
||||
("elasticsearch", {"es_url": "http://localhost:9200"}, embedchain.vectordb.elasticsearch.ElasticsearchDB),
|
||||
],
|
||||
)
|
||||
def test_vectordb_factory_create(self, mocker, provider_name, config_data, expected_class):
|
||||
mocker.patch("embedchain.vectordb.opensearch.OpenSearchDB", autospec=True)
|
||||
vectordb_instance = VectorDBFactory.create(provider_name, config_data)
|
||||
assert isinstance(vectordb_instance, expected_class)
|
||||
38
embedchain/tests/test_utils.py
Normal file
38
embedchain/tests/test_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import yaml
|
||||
|
||||
from embedchain.utils.misc import validate_config
|
||||
|
||||
CONFIG_YAMLS = [
|
||||
"configs/anthropic.yaml",
|
||||
"configs/azure_openai.yaml",
|
||||
"configs/chroma.yaml",
|
||||
"configs/chunker.yaml",
|
||||
"configs/cohere.yaml",
|
||||
"configs/together.yaml",
|
||||
"configs/ollama.yaml",
|
||||
"configs/full-stack.yaml",
|
||||
"configs/gpt4.yaml",
|
||||
"configs/gpt4all.yaml",
|
||||
"configs/huggingface.yaml",
|
||||
"configs/jina.yaml",
|
||||
"configs/llama2.yaml",
|
||||
"configs/opensearch.yaml",
|
||||
"configs/opensource.yaml",
|
||||
"configs/pinecone.yaml",
|
||||
"configs/vertexai.yaml",
|
||||
"configs/weaviate.yaml",
|
||||
]
|
||||
|
||||
|
||||
def test_all_config_yamls():
|
||||
"""Test that all config yamls are valid."""
|
||||
for config_yaml in CONFIG_YAMLS:
|
||||
with open(config_yaml, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
assert config is not None
|
||||
|
||||
try:
|
||||
validate_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error in {config_yaml}: {e}")
|
||||
raise e
|
||||
253
embedchain/tests/vectordb/test_chroma_db.py
Normal file
253
embedchain/tests/vectordb/test_chroma_db.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_db():
|
||||
return ChromaDB(config=ChromaDbConfig(host="test-host", port="1234"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_settings():
|
||||
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
chroma_db = ChromaDB(config=chroma_config)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
return App(config=app_config, db=chroma_db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def cleanup_db():
|
||||
yield
|
||||
try:
|
||||
shutil.rmtree("test-db")
|
||||
except OSError as e:
|
||||
print("Error: %s - %s." % (e.filename, e.strerror))
|
||||
|
||||
|
||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
def test_chroma_db_init_with_host_and_port(mock_client):
|
||||
chroma_db = ChromaDB(config=ChromaDbConfig(host="test-host", port="1234")) # noqa
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host == "test-host"
|
||||
assert called_settings.chroma_server_http_port == "1234"
|
||||
|
||||
|
||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
def test_chroma_db_init_with_basic_auth(mock_client):
|
||||
chroma_config = {
|
||||
"host": "test-host",
|
||||
"port": "1234",
|
||||
"chroma_settings": {
|
||||
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
|
||||
"chroma_client_auth_credentials": "admin:admin",
|
||||
},
|
||||
}
|
||||
|
||||
ChromaDB(config=ChromaDbConfig(**chroma_config))
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host == "test-host"
|
||||
assert called_settings.chroma_server_http_port == "1234"
|
||||
assert (
|
||||
called_settings.chroma_client_auth_provider == chroma_config["chroma_settings"]["chroma_client_auth_provider"]
|
||||
)
|
||||
assert (
|
||||
called_settings.chroma_client_auth_credentials
|
||||
== chroma_config["chroma_settings"]["chroma_client_auth_credentials"]
|
||||
)
|
||||
|
||||
|
||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
def test_app_init_with_host_and_port(mock_client):
|
||||
host = "test-host"
|
||||
port = "1234"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
db_config = ChromaDbConfig(host=host, port=port)
|
||||
db = ChromaDB(config=db_config)
|
||||
_app = App(config=config, db=db)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host == host
|
||||
assert called_settings.chroma_server_http_port == port
|
||||
|
||||
|
||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
def test_app_init_with_host_and_port_none(mock_client):
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
_app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host is None
|
||||
assert called_settings.chroma_server_http_port is None
|
||||
|
||||
|
||||
def test_chroma_db_duplicates_throw_warning(caplog):
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" in caplog.text
|
||||
assert "Add of existing embedding ID: 0" in caplog.text
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_duplicates_collections_no_warning(caplog):
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.set_collection_name("test_collection_2")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" not in caplog.text
|
||||
assert "Add of existing embedding ID: 0" not in caplog.text
|
||||
app.db.reset()
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_init_with_default_collection():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
assert app.db.collection.name == "embedchain_store"
|
||||
|
||||
|
||||
def test_chroma_db_collection_init_with_custom_collection():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name(name="test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_chroma_db_collection_set_collection_name():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_chroma_db_collection_changes_encapsulated():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 0
|
||||
|
||||
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
assert app.db.count() == 1
|
||||
|
||||
app.set_collection_name("test_collection_2")
|
||||
assert app.db.count() == 0
|
||||
|
||||
app.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
app.db.reset()
|
||||
app.set_collection_name("test_collection_2")
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_collections_are_persistent():
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
del app
|
||||
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_parallel_collections():
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
|
||||
app1 = App(
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db1,
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
|
||||
app2 = App(
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db2,
|
||||
)
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
assert app1.db.count() == 1
|
||||
assert app2.db.count() == 0
|
||||
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["0"])
|
||||
|
||||
app1.set_collection_name("test_collection_2")
|
||||
assert app1.db.count() == 1
|
||||
app2.set_collection_name("test_collection_1")
|
||||
assert app2.db.count() == 3
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_ids_share_collections():
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
|
||||
assert app1.db.count() == 3
|
||||
assert app2.db.count() == 3
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_reset():
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("two_collection")
|
||||
db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
|
||||
app3.set_collection_name("three_collection")
|
||||
db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
|
||||
app4.set_collection_name("four_collection")
|
||||
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
app2.db.collection.add(embeddings=[0, 0, 0], ids=["2"])
|
||||
app3.db.collection.add(embeddings=[0, 0, 0], ids=["3"])
|
||||
app4.db.collection.add(embeddings=[0, 0, 0], ids=["4"])
|
||||
|
||||
app1.db.reset()
|
||||
|
||||
assert app1.db.count() == 0
|
||||
assert app2.db.count() == 1
|
||||
assert app3.db.count() == 1
|
||||
assert app4.db.count() == 1
|
||||
|
||||
# cleanup
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
86
embedchain/tests/vectordb/test_elasticsearch_db.py
Normal file
86
embedchain/tests/vectordb/test_elasticsearch_db.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ElasticsearchDBConfig
|
||||
from embedchain.embedder.gpt4all import GPT4AllEmbedder
|
||||
from embedchain.vectordb.elasticsearch import ElasticsearchDB
|
||||
|
||||
|
||||
class TestEsDB(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_setUp(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
self.vector_dim = 384
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
|
||||
# Create some dummy data
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
self.db.add(documents, metadatas, ids)
|
||||
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"text": "This is another document.",
|
||||
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# Configure the mock client to return the mocked response.
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = "This is a document"
|
||||
results_without_citations = self.db.query(query, n_results=2, where={})
|
||||
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||
|
||||
results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
|
||||
expected_results_with_citations = [
|
||||
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
|
||||
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
|
||||
]
|
||||
self.assertEqual(results_with_citations, expected_results_with_citations)
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
try:
|
||||
del os.environ["ELASTICSEARCH_URL"]
|
||||
except KeyError:
|
||||
pass
|
||||
# Test if an exception is raised when an invalid es_config is provided
|
||||
with self.assertRaises(AttributeError):
|
||||
ElasticsearchDB()
|
||||
|
||||
def test_init_with_invalid_es_config(self):
|
||||
# Test if an exception is raised when an invalid es_config is provided
|
||||
with self.assertRaises(TypeError):
|
||||
ElasticsearchDB(es_config={"ES_URL": "some_url", "valid es_config": False})
|
||||
215
embedchain/tests/vectordb/test_lancedb.py
Normal file
215
embedchain/tests/vectordb/test_lancedb.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vector_db.lancedb import LanceDBConfig
|
||||
from embedchain.vectordb.lancedb import LanceDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lancedb():
|
||||
return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_settings():
|
||||
lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
|
||||
lancedb = LanceDB(config=lancedb_config)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
return App(config=app_config, db=lancedb)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def cleanup_db():
|
||||
yield
|
||||
try:
|
||||
shutil.rmtree("test-db.lance")
|
||||
shutil.rmtree("test-db-reset.lance")
|
||||
except OSError as e:
|
||||
print("Error: %s - %s." % (e.filename, e.strerror))
|
||||
|
||||
|
||||
def test_lancedb_duplicates_throw_warning(caplog):
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
assert "Insert of existing doc ID: 0" not in caplog.text
|
||||
assert "Add of existing doc ID: 0" not in caplog.text
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_duplicates_collections_no_warning(caplog):
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
app.set_collection_name("test_collection_2")
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
assert "Insert of existing doc ID: 0" not in caplog.text
|
||||
assert "Add of existing doc ID: 0" not in caplog.text
|
||||
app.db.reset()
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_init_with_default_collection():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
assert app.db.collection.name == "embedchain_store"
|
||||
|
||||
|
||||
def test_lancedb_collection_init_with_custom_collection():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name(name="test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_lancedb_collection_set_collection_name():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_lancedb_collection_changes_encapsulated():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 0
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
assert app.db.count() == 1
|
||||
|
||||
app.set_collection_name("test_collection_2")
|
||||
assert app.db.count() == 0
|
||||
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
app.db.reset()
|
||||
app.set_collection_name("test_collection_2")
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_collections_are_persistent():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
del app
|
||||
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_parallel_collections():
|
||||
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
|
||||
app1 = App(
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db1,
|
||||
)
|
||||
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
|
||||
app2 = App(
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db2,
|
||||
)
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
|
||||
assert app1.db.count() == 1
|
||||
assert app2.db.count() == 0
|
||||
|
||||
app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
|
||||
app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
|
||||
app1.set_collection_name("test_collection_2")
|
||||
assert app1.db.count() == 1
|
||||
app2.set_collection_name("test_collection_1")
|
||||
assert app2.db.count() == 3
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_ids_share_collections():
|
||||
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
|
||||
app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
|
||||
|
||||
assert app1.db.count() == 2
|
||||
assert app2.db.count() == 3
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_reset():
|
||||
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("two_collection")
|
||||
db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
|
||||
app3.set_collection_name("three_collection")
|
||||
db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
|
||||
app4.set_collection_name("four_collection")
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
|
||||
app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
|
||||
app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
|
||||
app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
|
||||
|
||||
app1.db.reset()
|
||||
|
||||
assert app1.db.count() == 0
|
||||
assert app2.db.count() == 1
|
||||
assert app3.db.count() == 1
|
||||
assert app4.db.count() == 1
|
||||
|
||||
# cleanup
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
|
||||
def generate_embeddings(dummy_embed, embed_size):
|
||||
generated_embedding = []
|
||||
for i in range(embed_size):
|
||||
generated_embedding.append(dummy_embed)
|
||||
|
||||
return generated_embedding
|
||||
225
embedchain/tests/vectordb/test_pinecone.py
Normal file
225
embedchain/tests/vectordb/test_pinecone.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.vectordb.pinecone import PineconeDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pinecone_pod_config():
|
||||
return PineconeDBConfig(
|
||||
index_name="test_collection",
|
||||
api_key="test_api_key",
|
||||
vector_dimension=3,
|
||||
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pinecone_serverless_config():
|
||||
return PineconeDBConfig(
|
||||
index_name="test_collection",
|
||||
api_key="test_api_key",
|
||||
vector_dimension=3,
|
||||
serverless_config={
|
||||
"cloud": "test_cloud",
|
||||
"region": "test_region",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_pinecone_init_without_config(monkeypatch):
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
pinecone_db = PineconeDB()
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
|
||||
monkeypatch.delenv("PINECONE_API_KEY")
|
||||
|
||||
|
||||
def test_pinecone_init_with_config(pinecone_pod_config, monkeypatch):
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
|
||||
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
|
||||
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
|
||||
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
|
||||
|
||||
|
||||
class MockListIndexes:
|
||||
def names(self):
|
||||
return ["test_collection"]
|
||||
|
||||
|
||||
class MockPineconeIndex:
|
||||
db = []
|
||||
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def upsert(self, chunk, **kwargs):
|
||||
self.db.extend([c for c in chunk])
|
||||
return
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
return {
|
||||
"matches": [
|
||||
{
|
||||
"metadata": {
|
||||
"key": "value",
|
||||
"text": "text_1",
|
||||
},
|
||||
"score": 0.1,
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"key": "value",
|
||||
"text": "text_2",
|
||||
},
|
||||
"score": 0.2,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def fetch(self, *args, **kwargs):
|
||||
return {
|
||||
"vectors": {
|
||||
"key_1": {
|
||||
"metadata": {
|
||||
"source": "1",
|
||||
}
|
||||
},
|
||||
"key_2": {
|
||||
"metadata": {
|
||||
"source": "2",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def describe_index_stats(self, *args, **kwargs):
|
||||
return {"total_vector_count": len(self.db)}
|
||||
|
||||
|
||||
class MockPineconeClient:
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def list_indexes(self):
|
||||
return MockListIndexes()
|
||||
|
||||
def create_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def Index(self, *args, **kwargs):
|
||||
return MockPineconeIndex()
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockPinecone:
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def Pinecone(*args, **kwargs):
|
||||
return MockPineconeClient()
|
||||
|
||||
def PodSpec(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def ServerlessSpec(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockEmbedder:
|
||||
def embedding_fn(self, documents):
|
||||
return [[1, 1, 1] for d in documents]
|
||||
|
||||
|
||||
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
pinecone_db._setup_pinecone_index()
|
||||
|
||||
assert pinecone_db.client is not None
|
||||
assert pinecone_db.config.index_name == "test_collection"
|
||||
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
||||
assert pinecone_db.pinecone_index is not None
|
||||
|
||||
pinecone_db = PineconeDB(config=pinecone_serverless_config)
|
||||
pinecone_db._setup_pinecone_index()
|
||||
|
||||
assert pinecone_db.client is not None
|
||||
assert pinecone_db.config.index_name == "test_collection"
|
||||
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
||||
assert pinecone_db.pinecone_index is not None
|
||||
|
||||
|
||||
def test_get(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
return db
|
||||
|
||||
pinecone_db = mock_pinecone_db()
|
||||
ids = pinecone_db.get(["key_1", "key_2"])
|
||||
assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
|
||||
|
||||
|
||||
def test_add(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
db._set_embedder(MockEmbedder())
|
||||
return db
|
||||
|
||||
pinecone_db = mock_pinecone_db()
|
||||
pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
|
||||
assert pinecone_db.count() == 2
|
||||
|
||||
pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
|
||||
assert pinecone_db.count() == 4
|
||||
|
||||
|
||||
def test_query(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
db._set_embedder(MockEmbedder())
|
||||
return db
|
||||
|
||||
pinecone_db = mock_pinecone_db()
|
||||
# without citations
|
||||
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
|
||||
assert results == ["text_1", "text_2"]
|
||||
# with citations
|
||||
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
|
||||
assert results == [
|
||||
("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
|
||||
("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
|
||||
]
|
||||
167
embedchain/tests/vectordb/test_qdrant.py
Normal file
167
embedchain/tests/vectordb/test_qdrant.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from mock import patch
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.models import Batch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.qdrant import QdrantDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestQdrantDB(unittest.TestCase):
|
||||
TEST_UUIDS = ["abc", "def", "ghi"]
|
||||
|
||||
def test_incorrect_config_throws_error(self):
|
||||
"""Test the init method of the Qdrant class throws error for incorrect config"""
|
||||
with self.assertRaises(TypeError):
|
||||
QdrantDB(config=PineconeDBConfig())
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_initialize(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
self.assertEqual(db.collection_name, "embedchain-store-1536")
|
||||
self.assertEqual(db.client, qdrant_client_mock.return_value)
|
||||
qdrant_client_mock.return_value.get_collections.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_get(self, qdrant_client_mock):
|
||||
qdrant_client_mock.return_value.scroll.return_value = ([], None)
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
resp = db.get(ids=[], where={})
|
||||
self.assertEqual(resp, {"ids": [], "metadatas": []})
|
||||
resp2 = db.get(ids=["123", "456"], where={"url": "https://ai.ai"})
|
||||
self.assertEqual(resp2, {"ids": [], "metadatas": []})
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
@patch.object(uuid, "uuid4", side_effect=TEST_UUIDS)
|
||||
def test_add(self, uuid_mock, qdrant_client_mock):
|
||||
qdrant_client_mock.return_value.scroll.return_value = ([], None)
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["123", "456"]
|
||||
db.add(documents, metadatas, ids)
|
||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||
collection_name="embedchain-store-1536",
|
||||
points=Batch(
|
||||
ids=["123", "456"],
|
||||
payloads=[
|
||||
{
|
||||
"identifier": "123",
|
||||
"text": "This is a test document.",
|
||||
"metadata": {"text": "This is a test document."},
|
||||
},
|
||||
{
|
||||
"identifier": "456",
|
||||
"text": "This is another test document.",
|
||||
"metadata": {"text": "This is another test document."},
|
||||
},
|
||||
],
|
||||
vectors=[[1, 2, 3], [4, 5, 6]],
|
||||
),
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_query(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
|
||||
|
||||
qdrant_client_mock.return_value.search.assert_called_once_with(
|
||||
collection_name="embedchain-store-1536",
|
||||
query_filter=models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(
|
||||
value="123",
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
query_vector=[1, 2, 3],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_count(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
db.count()
|
||||
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
|
||||
|
||||
@patch("embedchain.vectordb.qdrant.QdrantClient")
|
||||
def test_reset(self, qdrant_client_mock):
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
db.reset()
|
||||
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
|
||||
collection_name="embedchain-store-1536"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
237
embedchain/tests/vectordb/test_weaviate.py
Normal file
237
embedchain/tests/vectordb/test_weaviate.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vector_db.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.weaviate import WeaviateDB
|
||||
|
||||
|
||||
def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
|
||||
"""A mock embedding function."""
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
class TestWeaviateDb(unittest.TestCase):
|
||||
def test_incorrect_config_throws_error(self):
|
||||
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
|
||||
with self.assertRaises(TypeError):
|
||||
WeaviateDB(config=PineconeDBConfig())
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_initialize(self, weaviate_mock):
|
||||
"""Test the init method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
weaviate_client_schema_mock = weaviate_client_mock.schema
|
||||
|
||||
# Mock that schema doesn't already exist so that a new schema is created
|
||||
weaviate_client_schema_mock.exists.return_value = False
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
expected_class_obj = {
|
||||
"classes": [
|
||||
{
|
||||
"class": "Embedchain_store_1536",
|
||||
"vectorizer": "none",
|
||||
"properties": [
|
||||
{
|
||||
"name": "identifier",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "text",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "metadata",
|
||||
"dataType": ["Embedchain_store_1536_metadata"],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"class": "Embedchain_store_1536_metadata",
|
||||
"vectorizer": "none",
|
||||
"properties": [
|
||||
{
|
||||
"name": "data_type",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "doc_id",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "url",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "hash",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "app_id",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# Assert that the Weaviate client was initialized
|
||||
weaviate_mock.Client.assert_called_once()
|
||||
self.assertEqual(db.index_name, "Embedchain_store_1536")
|
||||
weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_get_or_create_db(self, weaviate_mock):
|
||||
"""Test the _get_or_create_db method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
expected_client = db._get_or_create_db()
|
||||
self.assertEqual(expected_client, weaviate_client_mock)
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_add(self, weaviate_mock):
|
||||
"""Test the add method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
weaviate_client_batch_mock = weaviate_client_mock.batch
|
||||
weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
documents = ["This is test document"]
|
||||
metadatas = [None]
|
||||
ids = ["id_1"]
|
||||
db.add(documents, metadatas, ids)
|
||||
|
||||
# Check if the document was added to the database.
|
||||
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
|
||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
||||
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
|
||||
)
|
||||
|
||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
||||
data_object={"text": documents[0]},
|
||||
class_name="Embedchain_store_1536_metadata",
|
||||
vector=[1, 2, 3],
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_query_without_where(self, weaviate_mock):
|
||||
"""Test the query method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
weaviate_client_query_mock = weaviate_client_mock.query
|
||||
weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query="This is a test document.", n_results=1, where={})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_query_with_where(self, weaviate_mock):
|
||||
"""Test the query method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
weaviate_client_query_mock = weaviate_client_mock.query
|
||||
weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
|
||||
weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1536_metadata", "doc_id"], "valueText": "123"}
|
||||
)
|
||||
weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_reset(self, weaviate_mock):
|
||||
"""Test the reset method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
weaviate_client_batch_mock = weaviate_client_mock.batch
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Reset the database.
|
||||
db.reset()
|
||||
|
||||
weaviate_client_batch_mock.delete_objects.assert_called_once_with(
|
||||
"Embedchain_store_1536", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.weaviate.weaviate")
|
||||
def test_count(self, weaviate_mock):
|
||||
"""Test the reset method of the WeaviateDb class."""
|
||||
weaviate_client_mock = weaviate_mock.Client.return_value
|
||||
weaviate_client_query = weaviate_client_mock.query
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
embedder.set_vector_dimension(1536)
|
||||
embedder.set_embedding_fn(mock_embedding_fn)
|
||||
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Reset the database.
|
||||
db.count()
|
||||
|
||||
weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1536")
|
||||
168
embedchain/tests/vectordb/test_zilliz_db.py
Normal file
168
embedchain/tests/vectordb/test_zilliz_db.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.config import ZillizDBConfig
|
||||
from embedchain.vectordb.zilliz import ZillizVectorDB
|
||||
|
||||
|
||||
# to run tests, provide the URI and TOKEN in .env file
|
||||
class TestZillizVectorDBConfig:
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def test_init_with_uri_and_token(self):
|
||||
"""
|
||||
Test if the `ZillizVectorDBConfig` instance is initialized with the correct uri and token values.
|
||||
"""
|
||||
# Create a ZillizDBConfig instance with mocked values
|
||||
expected_uri = "mocked_uri"
|
||||
expected_token = "mocked_token"
|
||||
db_config = ZillizDBConfig()
|
||||
|
||||
# Assert that the values in the ZillizVectorDB instance match the mocked values
|
||||
assert db_config.uri == expected_uri
|
||||
assert db_config.token == expected_token
|
||||
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def test_init_without_uri(self):
|
||||
"""
|
||||
Test if the `ZillizVectorDBConfig` instance throws an error when no URI found.
|
||||
"""
|
||||
try:
|
||||
del os.environ["ZILLIZ_CLOUD_URI"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
ZillizDBConfig()
|
||||
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def test_init_without_token(self):
|
||||
"""
|
||||
Test if the `ZillizVectorDBConfig` instance throws an error when no Token found.
|
||||
"""
|
||||
try:
|
||||
del os.environ["ZILLIZ_CLOUD_TOKEN"]
|
||||
except KeyError:
|
||||
pass
|
||||
# Test if an exception is raised when ZILLIZ_CLOUD_TOKEN is missing
|
||||
with pytest.raises(AttributeError):
|
||||
ZillizDBConfig()
|
||||
|
||||
|
||||
class TestZillizVectorDB:
|
||||
@pytest.fixture
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def mock_config(self, mocker):
|
||||
return mocker.Mock(spec=ZillizDBConfig())
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections.connect", autospec=True)
|
||||
def test_zilliz_vector_db_setup(self, mock_connect, mock_client, mock_config):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is initialized with the correct uri and token values.
|
||||
"""
|
||||
# Create an instance of ZillizVectorDB with the mock config
|
||||
# zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
ZillizVectorDB(config=mock_config)
|
||||
|
||||
# Assert that the MilvusClient and connections.connect were called
|
||||
mock_client.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
|
||||
mock_connect.assert_called_once_with(uri=mock_config.uri, token=mock_config.token)
|
||||
|
||||
|
||||
class TestZillizDBCollection:
|
||||
@pytest.fixture
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def mock_config(self, mocker):
|
||||
return mocker.Mock(spec=ZillizDBConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedder(self, mocker):
|
||||
return mocker.Mock()
|
||||
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def test_init_with_default_collection(self):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is initialized with the correct default collection name.
|
||||
"""
|
||||
# Create a ZillizDBConfig instance
|
||||
db_config = ZillizDBConfig()
|
||||
|
||||
assert db_config.collection_name == "embedchain_store"
|
||||
|
||||
@mock.patch.dict(os.environ, {"ZILLIZ_CLOUD_URI": "mocked_uri", "ZILLIZ_CLOUD_TOKEN": "mocked_token"})
|
||||
def test_init_with_custom_collection(self):
|
||||
"""
|
||||
Test if the `ZillizVectorDB` instance is initialized with the correct custom collection name.
|
||||
"""
|
||||
# Create a ZillizDBConfig instance with mocked values
|
||||
|
||||
expected_collection = "test_collection"
|
||||
db_config = ZillizDBConfig(collection_name="test_collection")
|
||||
|
||||
assert db_config.collection_name == expected_collection
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
def test_query(self, mock_connect, mock_client, mock_embedder, mock_config):
|
||||
# Create an instance of ZillizVectorDB with mock config
|
||||
zilliz_db = ZillizVectorDB(config=mock_config)
|
||||
|
||||
# Add a 'embedder' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.embedder = mock_embedder # Mock the 'collection' object
|
||||
|
||||
# Add a 'collection' attribute to the ZillizVectorDB instance for testing
|
||||
zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
|
||||
|
||||
assert zilliz_db.client == mock_client()
|
||||
|
||||
# Mock the MilvusClient search method
|
||||
with patch.object(zilliz_db.client, "search") as mock_search:
|
||||
# Mock the embedding function
|
||||
mock_embedder.embedding_fn.return_value = ["query_vector"]
|
||||
|
||||
# Mock the search result
|
||||
mock_search.return_value = [
|
||||
[
|
||||
{
|
||||
"distance": 0.0,
|
||||
"entity": {
|
||||
"text": "result_doc",
|
||||
"embeddings": [1, 2, 3],
|
||||
"metadata": {"url": "url_1", "doc_id": "doc_id_1"},
|
||||
},
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
query_result = zilliz_db.query(input_query="query_text", n_results=1, where={})
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_vector"],
|
||||
filter="",
|
||||
limit=1,
|
||||
output_fields=["*"],
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query="query_text", n_results=1, where={}, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_vector"],
|
||||
filter="",
|
||||
limit=1,
|
||||
output_fields=["*"],
|
||||
)
|
||||
|
||||
assert query_result_with_citations == [("result_doc", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.0})]
|
||||
Reference in New Issue
Block a user