Lancedb Integration (#1411)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from embedchain.chunkers.audio import AudioChunker
|
||||
from embedchain.chunkers.common_chunker import CommonChunker
|
||||
from embedchain.chunkers.discourse import DiscourseChunker
|
||||
from embedchain.chunkers.docs_site import DocsSiteChunker
|
||||
@@ -19,7 +20,6 @@ from embedchain.chunkers.text import TextChunker
|
||||
from embedchain.chunkers.web_page import WebPageChunker
|
||||
from embedchain.chunkers.xml import XmlChunker
|
||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||
from embedchain.chunkers.audio import AudioChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import pytest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
|
||||
from deepgram import PrerecordedOptions
|
||||
|
||||
from embedchain.loaders.audio import AudioLoader
|
||||
|
||||
|
||||
|
||||
215
tests/vectordb/test_lancedb.py
Normal file
215
tests/vectordb/test_lancedb.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.lancedb import LanceDBConfig
|
||||
from embedchain.vectordb.lancedb import LanceDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lancedb():
|
||||
return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_settings():
|
||||
lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
|
||||
lancedb = LanceDB(config=lancedb_config)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
return App(config=app_config, db=lancedb)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def cleanup_db():
|
||||
yield
|
||||
try:
|
||||
shutil.rmtree("test-db.lance")
|
||||
shutil.rmtree("test-db-reset.lance")
|
||||
except OSError as e:
|
||||
print("Error: %s - %s." % (e.filename, e.strerror))
|
||||
|
||||
|
||||
def test_lancedb_duplicates_throw_warning(caplog):
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
assert "Insert of existing doc ID: 0" not in caplog.text
|
||||
assert "Add of existing doc ID: 0" not in caplog.text
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_duplicates_collections_no_warning(caplog):
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
app.set_collection_name("test_collection_2")
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
assert "Insert of existing doc ID: 0" not in caplog.text
|
||||
assert "Add of existing doc ID: 0" not in caplog.text
|
||||
app.db.reset()
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_init_with_default_collection():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
assert app.db.collection.name == "embedchain_store"
|
||||
|
||||
|
||||
def test_lancedb_collection_init_with_custom_collection():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name(name="test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_lancedb_collection_set_collection_name():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_lancedb_collection_changes_encapsulated():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 0
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
assert app.db.count() == 1
|
||||
|
||||
app.set_collection_name("test_collection_2")
|
||||
assert app.db.count() == 0
|
||||
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
app.db.reset()
|
||||
app.set_collection_name("test_collection_2")
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_collections_are_persistent():
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
del app
|
||||
|
||||
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
|
||||
app.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_parallel_collections():
|
||||
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
|
||||
app1 = App(
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db1,
|
||||
)
|
||||
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
|
||||
app2 = App(
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db2,
|
||||
)
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
|
||||
assert app1.db.count() == 1
|
||||
assert app2.db.count() == 0
|
||||
|
||||
app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
|
||||
app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
|
||||
|
||||
app1.set_collection_name("test_collection_2")
|
||||
assert app1.db.count() == 1
|
||||
app2.set_collection_name("test_collection_1")
|
||||
assert app2.db.count() == 3
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_ids_share_collections():
|
||||
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
|
||||
app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
|
||||
|
||||
assert app1.db.count() == 2
|
||||
assert app2.db.count() == 3
|
||||
|
||||
# cleanup
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
|
||||
|
||||
def test_lancedb_collection_reset():
|
||||
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("two_collection")
|
||||
db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
|
||||
app3.set_collection_name("three_collection")
|
||||
db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
|
||||
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
|
||||
app4.set_collection_name("four_collection")
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
app1.db.reset()
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
|
||||
app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
|
||||
app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
|
||||
app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
|
||||
|
||||
app1.db.reset()
|
||||
|
||||
assert app1.db.count() == 0
|
||||
assert app2.db.count() == 1
|
||||
assert app3.db.count() == 1
|
||||
assert app4.db.count() == 1
|
||||
|
||||
# cleanup
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
|
||||
def generate_embeddings(dummy_embed, embed_size):
|
||||
generated_embedding = []
|
||||
for i in range(embed_size):
|
||||
generated_embedding.append(dummy_embed)
|
||||
|
||||
return generated_embedding
|
||||
Reference in New Issue
Block a user