From 48b24f6f12a9e5df5def4c7ddd7f301bc5881b17 Mon Sep 17 00:00:00 2001
From: Prashant Dixit <54981696+PrashantDixit0@users.noreply.github.com>
Date: Fri, 21 Jun 2024 21:29:22 +0530
Subject: [PATCH] Lancedb Integration (#1411)
---
docs/components/vector-databases/lancedb.mdx | 48 +++
embedchain/config/vectordb/lancedb.py | 33 ++
embedchain/embedchain.py | 7 +-
embedchain/factory.py | 2 +
embedchain/llm/base.py | 4 +-
embedchain/llm/jina.py | 3 +-
embedchain/loaders/audio.py | 4 +-
embedchain/loaders/unstructured_file.py | 3 +-
embedchain/utils/misc.py | 2 +-
embedchain/vectordb/lancedb.py | 307 +++++++++++++++++++
notebooks/lancedb.ipynb | 146 +++++++++
poetry.lock | 191 +++++++++++-
pyproject.toml | 2 +
tests/chunkers/test_chunkers.py | 2 +-
tests/loaders/test_audio.py | 6 +-
tests/vectordb/test_lancedb.py | 215 +++++++++++++
16 files changed, 963 insertions(+), 12 deletions(-)
create mode 100644 docs/components/vector-databases/lancedb.mdx
create mode 100644 embedchain/config/vectordb/lancedb.py
create mode 100644 embedchain/vectordb/lancedb.py
create mode 100644 notebooks/lancedb.ipynb
create mode 100644 tests/vectordb/test_lancedb.py
diff --git a/docs/components/vector-databases/lancedb.mdx b/docs/components/vector-databases/lancedb.mdx
new file mode 100644
index 00000000..f68d6de9
--- /dev/null
+++ b/docs/components/vector-databases/lancedb.mdx
@@ -0,0 +1,48 @@
+---
+title: LanceDB
+---
+
+## Install Embedchain with LanceDB
+
+Install Embedchain, LanceDB and related dependencies using the following command:
+
+```bash
+pip install "embedchain[lancedb]"
+```
+
+LanceDB is a developer-friendly, open source database for AI. From hyper scalable vector search and advanced retrieval for RAG, to streaming training data and interactive exploration of large scale AI datasets.
+In order to use LanceDB as vector database, not need to set any key for local use.
+
+
+```python main.py
+import os
+from embedchain import App
+
+# set OPENAI_API_KEY as env variable
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+# Create Embedchain App and set config
+app = App.from_config(config={
+ "vectordb": {
+ "provider": "lancedb",
+ "config": {
+ "collection_name": "lancedb-index"
+ }
+ }
+ }
+)
+
+# Add data source and start queryin
+app.add("https://www.forbes.com/profile/elon-musk")
+
+# query continuously
+while(True):
+ question = input("Enter question: ")
+ if question in ['q', 'exit', 'quit']:
+ break
+ answer = app.query(question)
+ print(answer)
+```
+
+
+
\ No newline at end of file
diff --git a/embedchain/config/vectordb/lancedb.py b/embedchain/config/vectordb/lancedb.py
new file mode 100644
index 00000000..2e53ccda
--- /dev/null
+++ b/embedchain/config/vectordb/lancedb.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from embedchain.config.vectordb.base import BaseVectorDbConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class LanceDBConfig(BaseVectorDbConfig):
+ def __init__(
+ self,
+ collection_name: Optional[str] = None,
+ dir: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[str] = None,
+ allow_reset=True,
+ ):
+ """
+ Initializes a configuration class instance for LanceDB.
+
+ :param collection_name: Default name for the collection, defaults to None
+ :type collection_name: Optional[str], optional
+ :param dir: Path to the database directory, where the database is stored, defaults to None
+ :type dir: Optional[str], optional
+ :param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
+ :type host: Optional[str], optional
+ :param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
+ :type port: Optional[str], optional
+ :param allow_reset: Resets the database. defaults to False
+ :type allow_reset: bool
+ """
+
+ self.allow_reset = allow_reset
+ super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py
index c1c658f6..ec1a92d8 100644
--- a/embedchain/embedchain.py
+++ b/embedchain/embedchain.py
@@ -6,7 +6,9 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
-from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
+from embedchain.cache import (adapt, get_gptcache_session,
+ gptcache_data_convert,
+ gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -16,7 +18,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
+from embedchain.models.data_type import (DataType, DirectDataType,
+ IndirectDataType, SpecialDataType)
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
diff --git a/embedchain/factory.py b/embedchain/factory.py
index 567bdb55..0ed7452e 100644
--- a/embedchain/factory.py
+++ b/embedchain/factory.py
@@ -91,6 +91,7 @@ class VectorDBFactory:
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
+ "lancedb": "embedchain.vectordb.lancedb.LanceDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
@@ -100,6 +101,7 @@ class VectorDBFactory:
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
+ "lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py
index 339c78c1..1782dc4b 100644
--- a/embedchain/llm/base.py
+++ b/embedchain/llm/base.py
@@ -5,7 +5,9 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
-from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
+from embedchain.config.llm.base import (DEFAULT_PROMPT,
+ DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
+ DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
diff --git a/embedchain/llm/jina.py b/embedchain/llm/jina.py
index 782742cb..33796bc5 100644
--- a/embedchain/llm/jina.py
+++ b/embedchain/llm/jina.py
@@ -35,7 +35,8 @@ class JinaLlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+ from langchain.callbacks.streaming_stdout import \
+ StreamingStdOutCallbackHandler
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:
diff --git a/embedchain/loaders/audio.py b/embedchain/loaders/audio.py
index 44d7e9bd..6b2b69cf 100644
--- a/embedchain/loaders/audio.py
+++ b/embedchain/loaders/audio.py
@@ -1,6 +1,8 @@
-import os
import hashlib
+import os
+
import validators
+
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
diff --git a/embedchain/loaders/unstructured_file.py b/embedchain/loaders/unstructured_file.py
index 856ac888..62c20ae4 100644
--- a/embedchain/loaders/unstructured_file.py
+++ b/embedchain/loaders/unstructured_file.py
@@ -11,7 +11,8 @@ class UnstructuredLoader(BaseLoader):
"""Load data from an Unstructured file."""
try:
import unstructured # noqa: F401
- from langchain_community.document_loaders import UnstructuredFileLoader
+ from langchain_community.document_loaders import \
+ UnstructuredFileLoader
except ImportError:
raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py
index 1e966328..8213a601 100644
--- a/embedchain/utils/misc.py
+++ b/embedchain/utils/misc.py
@@ -446,7 +446,7 @@ def validate_config(config_data):
},
Optional("vectordb"): {
Optional("provider"): Or(
- "chroma", "elasticsearch", "opensearch", "pinecone", "qdrant", "weaviate", "zilliz"
+ "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz"
),
Optional("config"): object, # TODO: add particular config schema for each provider
},
diff --git a/embedchain/vectordb/lancedb.py b/embedchain/vectordb/lancedb.py
new file mode 100644
index 00000000..a502db65
--- /dev/null
+++ b/embedchain/vectordb/lancedb.py
@@ -0,0 +1,307 @@
+from typing import Any, Dict, List, Optional, Union
+
+import pyarrow as pa
+
+try:
+ import lancedb
+except ImportError:
+ raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
+
+from embedchain.config.vectordb.lancedb import LanceDBConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.vectordb.base import BaseVectorDB
+
+
+@register_deserializable
+class LanceDB(BaseVectorDB):
+ """
+ LanceDB as vector database
+ """
+
+ BATCH_SIZE = 100
+
+ def __init__(
+ self,
+ config: Optional[LanceDBConfig] = None,
+ ):
+ """LanceDB as vector database.
+
+ :param config: LanceDB database config, defaults to None
+ :type config: LanceDBConfig, optional
+ """
+ if config:
+ self.config = config
+ else:
+ self.config = LanceDBConfig()
+
+ self.client = lancedb.connect(self.config.dir or "~/.lancedb")
+ self.embedder_check = True
+
+ super().__init__(config=self.config)
+
+ def _initialize(self):
+ """
+ This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
+ """
+ if not self.embedder:
+ raise ValueError(
+ "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
+ )
+ else:
+ # check embedder function is working or not
+ try:
+ self.embedder.embedding_fn("Hello LanceDB")
+ except Exception:
+ self.embedder_check = False
+
+ self._get_or_create_collection(self.config.collection_name)
+
+ def _get_or_create_db(self):
+ """
+ Called during initialization
+ """
+ return self.client
+
+ def _generate_where_clause(self, where: Dict[str, any]) -> str:
+ """
+ This method generate where clause using dictionary containing attributes and their values
+ """
+
+ where_filters = ""
+
+ if len(list(where.keys())) == 1:
+ where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}"
+ return where_filters
+
+ where_items = list(where.items())
+ where_count = len(where_items)
+
+ for i, (key, value) in enumerate(where_items, start=1):
+ condition = f"{key} = {value} AND "
+ where_filters += condition
+
+ if i == where_count:
+ condition = f"{key} = {value}"
+ where_filters += condition
+
+ return where_filters
+
+ def _get_or_create_collection(self, table_name: str, reset=False):
+ """
+ Get or create a named collection.
+
+ :param name: Name of the collection
+ :type name: str
+ :return: Created collection
+ :rtype: Collection
+ """
+ if not self.embedder_check:
+ schema = pa.schema(
+ [
+ pa.field("doc", pa.string()),
+ pa.field("metadata", pa.string()),
+ pa.field("id", pa.string()),
+ ]
+ )
+
+ else:
+ schema = pa.schema(
+ [
+ pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)),
+ pa.field("doc", pa.string()),
+ pa.field("metadata", pa.string()),
+ pa.field("id", pa.string()),
+ ]
+ )
+
+ if not reset:
+ if table_name not in self.client.table_names():
+ self.collection = self.client.create_table(table_name, schema=schema)
+
+ else:
+ self.client.drop_table(table_name)
+ self.collection = self.client.create_table(table_name, schema=schema)
+
+ self.collection = self.client[table_name]
+
+ return self.collection
+
+ def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+ """
+ Get existing doc ids present in vector database
+
+ :param ids: list of doc ids to check for existence
+ :type ids: List[str]
+ :param where: Optional. to filter data
+ :type where: Dict[str, Any]
+ :param limit: Optional. maximum number of documents
+ :type limit: Optional[int]
+ :return: Existing documents.
+ :rtype: List[str]
+ """
+ if limit is not None:
+ max_limit = limit
+ else:
+ max_limit = 3
+ results = {"ids": [], "metadatas": []}
+
+ where_clause = {}
+ if where:
+ where_clause = self._generate_where_clause(where)
+
+ if ids is not None:
+ records = (
+ self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict()
+ )
+ for id in records["id"]:
+ if where is not None:
+ result = (
+ self.collection.search(query=id, vector_column_name="id")
+ .where(where_clause)
+ .limit(max_limit)
+ .to_list()
+ )
+ else:
+ result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list()
+ results["ids"] = [r["id"] for r in result]
+ results["metadatas"] = [r["metadata"] for r in result]
+
+ return results
+
+ def add(
+ self,
+ documents: List[str],
+ metadatas: List[object],
+ ids: List[str],
+ ) -> Any:
+ """
+ Add vectors to lancedb database
+
+ :param documents: Documents
+ :type documents: List[str]
+ :param metadatas: Metadatas
+ :type metadatas: List[object]
+ :param ids: ids
+ :type ids: List[str]
+ """
+ data = []
+ to_ingest = list(zip(documents, metadatas, ids))
+
+ if not self.embedder_check:
+ for doc, meta, id in to_ingest:
+ temp = {}
+ temp["doc"] = doc
+ temp["metadata"] = str(meta)
+ temp["id"] = id
+ data.append(temp)
+ else:
+ for doc, meta, id in to_ingest:
+ temp = {}
+ temp["doc"] = doc
+ temp["vector"] = self.embedder.embedding_fn([doc])[0]
+ temp["metadata"] = str(meta)
+ temp["id"] = id
+ data.append(temp)
+
+ self.collection.add(data=data)
+
+ def _format_result(self, results) -> list:
+ """
+ Format LanceDB results
+
+ :param results: LanceDB query results to format.
+ :type results: QueryResult
+ :return: Formatted results
+ :rtype: list[tuple[Document, float]]
+ """
+ return results.tolist()
+
+ def query(
+ self,
+ input_query: str,
+ n_results: int = 3,
+ where: Optional[dict[str, any]] = None,
+ raw_filter: Optional[dict[str, any]] = None,
+ citations: bool = False,
+ **kwargs: Optional[dict[str, any]],
+ ) -> Union[list[tuple[str, dict]], list[str]]:
+ """
+ Query contents from vector database based on vector similarity
+
+ :param input_query: query string
+ :type input_query: str
+ :param n_results: no of similar documents to fetch from database
+ :type n_results: int
+ :param where: to filter data
+ :type where: dict[str, Any]
+ :param raw_filter: Raw filter to apply
+ :type raw_filter: dict[str, Any]
+ :param citations: we use citations boolean param to return context along with the answer.
+ :type citations: bool, default is False.
+ :raises InvalidDimensionException: Dimensions do not match.
+ :return: The content of the document that matched your query,
+ along with url of the source and doc_id (if citations flag is true)
+ :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
+ """
+ if where and raw_filter:
+ raise ValueError("Both `where` and `raw_filter` cannot be used together.")
+ try:
+ query_embedding = self.embedder.embedding_fn(input_query)[0]
+ result = self.collection.search(query_embedding).limit(n_results).to_list()
+ except Exception as e:
+ e.message()
+
+ results_formatted = result
+
+ contexts = []
+ for result in results_formatted:
+ if citations:
+ metadata = result["metadata"]
+ contexts.append((result["doc"], metadata))
+ else:
+ contexts.append(result["doc"])
+ return contexts
+
+ def set_collection_name(self, name: str):
+ """
+ Set the name of the collection. A collection is an isolated space for vectors.
+
+ :param name: Name of the collection.
+ :type name: str
+ """
+ if not isinstance(name, str):
+ raise TypeError("Collection name must be a string")
+ self.config.collection_name = name
+ self._get_or_create_collection(self.config.collection_name)
+
+ def count(self) -> int:
+ """
+ Count number of documents/chunks embedded in the database.
+
+ :return: number of documents
+ :rtype: int
+ """
+ return self.collection.count_rows()
+
+ def delete(self, where):
+ return self.collection.delete(where=where)
+
+ def reset(self):
+ """
+ Resets the database. Deletes all embeddings irreversibly.
+ """
+ # Delete all data from the collection and recreate collection
+ if self.config.allow_reset:
+ try:
+ self._get_or_create_collection(self.config.collection_name, reset=True)
+ except ValueError:
+ raise ValueError(
+ "For safety reasons, resetting is disabled. "
+ "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
+ ) from None
+ # Recreate
+ else:
+ print(
+ "For safety reasons, resetting is disabled. "
+ "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
+ )
diff --git a/notebooks/lancedb.ipynb b/notebooks/lancedb.ipynb
new file mode 100644
index 00000000..08d99621
--- /dev/null
+++ b/notebooks/lancedb.ipynb
@@ -0,0 +1,146 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b02n_zJ_hl3d"
+ },
+ "source": [
+ "## Cookbook for using LanceDB with Embedchain"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gyJ6ui2vhtMY"
+ },
+ "source": [
+ "### Step-1: Install embedchain package"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-NbXjAdlh0vJ"
+ },
+ "outputs": [],
+ "source": [
+ "! pip install embedchain lancedb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nGnpSYAAh2bQ"
+ },
+ "source": [
+ "### Step-2: Set environment variables needed for LanceDB\n",
+ "\n",
+ "You can find this env variable on your [OpenAI](https://platform.openai.com/account/api-keys)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0fBdQ9GAiRvK"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from embedchain import App\n",
+ "\n",
+ "os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PGt6uPLIi1CS"
+ },
+ "source": [
+ "### Step-3 Create embedchain app and define your config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Amzxk3m-i3tD"
+ },
+ "outputs": [],
+ "source": [
+ "app = App.from_config(config={\n",
+ " \"vectordb\": {\n",
+ " \"provider\": \"lancedb\",\n",
+ " \"config\": {\n",
+ " \"collection_name\": \"lancedb-index\"\n",
+ " }\n",
+ " }\n",
+ " }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XNXv4yZwi7ef"
+ },
+ "source": [
+ "### Step-4: Add data sources to your app"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Sn_0rx9QjIY9"
+ },
+ "outputs": [],
+ "source": [
+ "app.add(\"https://www.forbes.com/profile/elon-musk\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_7W6fDeAjMAP"
+ },
+ "source": [
+ "### Step-5: All set. Now start asking questions related to your data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cvIK7dWRjN_f"
+ },
+ "outputs": [],
+ "source": [
+ "while(True):\n",
+ " question = input(\"Enter question: \")\n",
+ " if question in ['q', 'exit', 'quit']:\n",
+ " break\n",
+ " answer = app.query(question)\n",
+ " print(answer)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/poetry.lock b/poetry.lock
index 85ccb80e..32a84a77 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
[[package]]
name = "aiohttp"
@@ -1142,6 +1142,17 @@ files = [
marshmallow = ">=3.18.0,<4.0.0"
typing-inspect = ">=0.4.0,<1"
+[[package]]
+name = "decorator"
+version = "5.1.1"
+description = "Decorators for Humans"
+optional = true
+python-versions = ">=3.5"
+files = [
+ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
+ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
+]
+
[[package]]
name = "deprecated"
version = "1.2.14"
@@ -1159,6 +1170,20 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
+[[package]]
+name = "deprecation"
+version = "2.1.0"
+description = "A library to handle automated deprecations"
+optional = true
+python-versions = "*"
+files = [
+ {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
+ {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
+]
+
+[package.dependencies]
+packaging = "*"
+
[[package]]
name = "discord"
version = "2.3.2"
@@ -2785,6 +2810,42 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0"
[package.extras]
adal = ["adal (>=1.0.2)"]
+[[package]]
+name = "lancedb"
+version = "0.6.13"
+description = "lancedb"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "lancedb-0.6.13-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:4667353ca7fa187e94cb0ca4c5f9577d65eb5160f6f3fe9e57902d86312c3869"},
+ {file = "lancedb-0.6.13-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e22533fe6f6b2d7037dcdbbb4019a62402bbad4ce18395be68f4aa007bf8bc0"},
+ {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837eaceafb87e3ae4c261eef45c4f73715f892a36165572c3da621dbdb45afcf"},
+ {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:61af2d72b2a2f0ea419874c3f32760fe5e51530da3be2d65251a0e6ded74419b"},
+ {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:31b24e57ee313f4ce6255e45d42e8bee19b90ddcd13a9e07030ac04f76e7dfde"},
+ {file = "lancedb-0.6.13-cp38-abi3-win_amd64.whl", hash = "sha256:b851182d8492b1e5b57a441af64c95da65ca30b045d6618dc7d203c6d60d70fa"},
+]
+
+[package.dependencies]
+attrs = ">=21.3.0"
+cachetools = "*"
+deprecation = "*"
+overrides = ">=0.7"
+pydantic = ">=1.10"
+pylance = "0.10.12"
+ratelimiter = ">=1.0,<2.0"
+requests = ">=2.31.0"
+retry = ">=0.9.2"
+semver = "*"
+tqdm = ">=4.27.0"
+
+[package.extras]
+azure = ["adlfs (>=2024.2.0)"]
+clip = ["open-clip", "pillow", "torch"]
+dev = ["pre-commit", "ruff"]
+docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
+embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"]
+tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"]
+
[[package]]
name = "langchain"
version = "0.1.20"
@@ -4781,6 +4842,65 @@ all = ["apache-bookkeeper-client (>=4.16.1)", "fastavro (>=1.9.2)", "grpcio (>=1
avro = ["fastavro (>=1.9.2)"]
functions = ["apache-bookkeeper-client (>=4.16.1)", "grpcio (>=1.60.0)", "prometheus-client", "protobuf (>=3.6.1,<=3.20.3)", "ratelimit"]
+[[package]]
+name = "py"
+version = "1.11.0"
+description = "library with cross-python path, ini-parsing, io, code, log facilities"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+files = [
+ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
+ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+]
+
+[[package]]
+name = "pyarrow"
+version = "15.0.0"
+description = "Python library for Apache Arrow"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"},
+ {file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"},
+ {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"},
+ {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"},
+ {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"},
+ {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"},
+ {file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"},
+ {file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"},
+ {file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"},
+ {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"},
+ {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"},
+ {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"},
+ {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"},
+ {file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"},
+ {file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"},
+ {file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"},
+ {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"},
+ {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"},
+ {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"},
+ {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"},
+ {file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"},
+ {file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"},
+ {file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"},
+ {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"},
+ {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"},
+ {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"},
+ {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"},
+ {file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"},
+ {file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"},
+ {file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"},
+ {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"},
+ {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"},
+ {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"},
+ {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"},
+ {file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"},
+ {file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"},
+]
+
+[package.dependencies]
+numpy = ">=1.16.6,<2"
+
[[package]]
name = "pyasn1"
version = "0.6.0"
@@ -5019,6 +5139,32 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
+[[package]]
+name = "pylance"
+version = "0.10.12"
+description = "python wrapper for Lance columnar format"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pylance-0.10.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:30cbcca078edeb37e11ae86cf9287d81ce6c0c07ba77239284b369a4b361497b"},
+ {file = "pylance-0.10.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e558163ff6035d518706cc66848497219ccc755e2972b8f3b1706a3e1fd800fd"},
+ {file = "pylance-0.10.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75afb39f71d7f12429f9b4d380eb6cf6aed179ae5a1c5d16cc768373a1521f87"},
+ {file = "pylance-0.10.12-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:3de391dfc3a99bdb245fd1e27ef242be769a94853f802ef57f246e9a21358d32"},
+ {file = "pylance-0.10.12-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:34a5278b90f4cbcf21261353976127aa2ffbbd7d068810f0a2b0c1aa0334022a"},
+ {file = "pylance-0.10.12-cp38-abi3-win_amd64.whl", hash = "sha256:6cef5975d513097fd2c22692296c9a5a138928f38d02cd34ab63a7369abc1463"},
+]
+
+[package.dependencies]
+numpy = ">=1.22"
+pyarrow = ">=12,<15.0.1"
+
+[package.extras]
+benchmarks = ["pytest-benchmark"]
+dev = ["ruff (==0.2.2)"]
+ray = ["ray[data]"]
+tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"]
+torch = ["torch"]
+
[[package]]
name = "pymilvus"
version = "2.4.3"
@@ -5540,6 +5686,20 @@ urllib3 = ">=1.26.14,<3"
[package.extras]
fastembed = ["fastembed (==0.2.6)"]
+[[package]]
+name = "ratelimiter"
+version = "1.2.0.post0"
+description = "Simple python rate limiting object"
+optional = true
+python-versions = "*"
+files = [
+ {file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"},
+ {file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"},
+]
+
+[package.extras]
+test = ["pytest (>=3.0)", "pytest-asyncio"]
+
[[package]]
name = "regex"
version = "2024.5.15"
@@ -5720,6 +5880,21 @@ urllib3 = ">=1.25.10,<3.0"
[package.extras]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-requests"]
+[[package]]
+name = "retry"
+version = "0.9.2"
+description = "Easy to use retry decorator."
+optional = true
+python-versions = "*"
+files = [
+ {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
+ {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
+]
+
+[package.dependencies]
+decorator = ">=3.4.2"
+py = ">=1.4.26,<2.0.0"
+
[[package]]
name = "rich"
version = "13.7.1"
@@ -6018,6 +6193,17 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pyde
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+[[package]]
+name = "semver"
+version = "3.0.2"
+description = "Python helper for Semantic Versioning (https://semver.org)"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"},
+ {file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"},
+]
+
[[package]]
name = "sentence-transformers"
version = "2.7.0"
@@ -7716,6 +7902,7 @@ gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-a
google = ["google-generativeai"]
googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
huggingface-hub = ["huggingface_hub"]
+lancedb = ["lancedb"]
llama2 = ["replicate"]
milvus = ["pymilvus"]
mistralai = ["langchain-mistralai"]
@@ -7738,4 +7925,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<=3.13"
-content-hash = "cb7d2794bc2f54e05b2f870843eccc4342f3f2a6531eaa50a3d0d77b358ac4d5"
+content-hash = "f9e6357bd1b5f407368d3d52c3f728e12f41fe7d4836e321ec2e10413f58e8a1"
diff --git a/pyproject.toml b/pyproject.toml
index 8a86a063..714463f9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -122,6 +122,7 @@ slack-sdk = { version = "3.21.3", optional = true }
clarifai = { version = "^10.0.1", optional = true }
cohere = { version = "^5.3", optional = true }
together = { version = "^0.2.8", optional = true }
+lancedb = { version = "^0.6.2", optional = true }
weaviate-client = { version = "^3.24.1", optional = true }
docx2txt = { version = "^0.8", optional = true }
qdrant-client = { version = "^1.6.3", optional = true }
@@ -173,6 +174,7 @@ pytest-asyncio = "^0.21.1"
[tool.poetry.extras]
streamlit = ["streamlit"]
opensource = ["sentence-transformers", "torch", "gpt4all"]
+lancedb = ["lancedb"]
elasticsearch = ["elasticsearch"]
opensearch = ["opensearch-py"]
poe = ["fastapi-poe"]
diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py
index 067fa869..8258e776 100644
--- a/tests/chunkers/test_chunkers.py
+++ b/tests/chunkers/test_chunkers.py
@@ -1,3 +1,4 @@
+from embedchain.chunkers.audio import AudioChunker
from embedchain.chunkers.common_chunker import CommonChunker
from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
@@ -19,7 +20,6 @@ from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
-from embedchain.chunkers.audio import AudioChunker
from embedchain.config.add_config import ChunkerConfig
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
diff --git a/tests/loaders/test_audio.py b/tests/loaders/test_audio.py
index c365893b..c62ec139 100644
--- a/tests/loaders/test_audio.py
+++ b/tests/loaders/test_audio.py
@@ -1,11 +1,13 @@
+import hashlib
import os
import sys
-import hashlib
-import pytest
from unittest.mock import mock_open, patch
+import pytest
+
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
from deepgram import PrerecordedOptions
+
from embedchain.loaders.audio import AudioLoader
diff --git a/tests/vectordb/test_lancedb.py b/tests/vectordb/test_lancedb.py
new file mode 100644
index 00000000..f50660bb
--- /dev/null
+++ b/tests/vectordb/test_lancedb.py
@@ -0,0 +1,215 @@
+import os
+import shutil
+
+import pytest
+
+from embedchain import App
+from embedchain.config import AppConfig
+from embedchain.config.vectordb.lancedb import LanceDBConfig
+from embedchain.vectordb.lancedb import LanceDB
+
+os.environ["OPENAI_API_KEY"] = "test-api-key"
+
+
+@pytest.fixture
+def lancedb():
+ return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
+
+
+@pytest.fixture
+def app_with_settings():
+ lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
+ lancedb = LanceDB(config=lancedb_config)
+ app_config = AppConfig(collect_metrics=False)
+ return App(config=app_config, db=lancedb)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def cleanup_db():
+ yield
+ try:
+ shutil.rmtree("test-db.lance")
+ shutil.rmtree("test-db-reset.lance")
+ except OSError as e:
+ print("Error: %s - %s." % (e.filename, e.strerror))
+
+
+def test_lancedb_duplicates_throw_warning(caplog):
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ assert "Insert of existing doc ID: 0" not in caplog.text
+ assert "Add of existing doc ID: 0" not in caplog.text
+ app.db.reset()
+
+
+def test_lancedb_duplicates_collections_no_warning(caplog):
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.set_collection_name("test_collection_1")
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ app.set_collection_name("test_collection_2")
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ assert "Insert of existing doc ID: 0" not in caplog.text
+ assert "Add of existing doc ID: 0" not in caplog.text
+ app.db.reset()
+ app.set_collection_name("test_collection_1")
+ app.db.reset()
+
+
+def test_lancedb_collection_init_with_default_collection():
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ assert app.db.collection.name == "embedchain_store"
+
+
+def test_lancedb_collection_init_with_custom_collection():
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.set_collection_name(name="test_collection")
+ assert app.db.collection.name == "test_collection"
+
+
+def test_lancedb_collection_set_collection_name():
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.set_collection_name("test_collection")
+ assert app.db.collection.name == "test_collection"
+
+
+def test_lancedb_collection_changes_encapsulated():
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.set_collection_name("test_collection_1")
+ assert app.db.count() == 0
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ assert app.db.count() == 1
+
+ app.set_collection_name("test_collection_2")
+ assert app.db.count() == 0
+
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ app.set_collection_name("test_collection_1")
+ assert app.db.count() == 1
+ app.db.reset()
+ app.set_collection_name("test_collection_2")
+ app.db.reset()
+
+
+def test_lancedb_collection_collections_are_persistent():
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.set_collection_name("test_collection_1")
+ app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+ del app
+
+ db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app = App(config=AppConfig(collect_metrics=False), db=db)
+ app.set_collection_name("test_collection_1")
+ assert app.db.count() == 1
+
+ app.db.reset()
+
+
+def test_lancedb_collection_parallel_collections():
+ db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
+ app1 = App(
+ config=AppConfig(collect_metrics=False),
+ db=db1,
+ )
+ db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
+ app2 = App(
+ config=AppConfig(collect_metrics=False),
+ db=db2,
+ )
+
+ # cleanup if any previous tests failed or were interrupted
+ app1.db.reset()
+ app2.db.reset()
+
+ app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+
+ assert app1.db.count() == 1
+ assert app2.db.count() == 0
+
+ app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
+ app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
+
+ app1.set_collection_name("test_collection_2")
+ assert app1.db.count() == 1
+ app2.set_collection_name("test_collection_1")
+ assert app2.db.count() == 3
+
+ # cleanup
+ app1.db.reset()
+ app2.db.reset()
+
+
+def test_lancedb_collection_ids_share_collections():
+ db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app1 = App(config=AppConfig(collect_metrics=False), db=db1)
+ app1.set_collection_name("one_collection")
+ db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app2 = App(config=AppConfig(collect_metrics=False), db=db2)
+ app2.set_collection_name("one_collection")
+
+ # cleanup
+ app1.db.reset()
+ app2.db.reset()
+
+ app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
+ app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
+
+ assert app1.db.count() == 2
+ assert app2.db.count() == 3
+
+ # cleanup
+ app1.db.reset()
+ app2.db.reset()
+
+
+def test_lancedb_collection_reset():
+ db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app1 = App(config=AppConfig(collect_metrics=False), db=db1)
+ app1.set_collection_name("one_collection")
+ db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app2 = App(config=AppConfig(collect_metrics=False), db=db2)
+ app2.set_collection_name("two_collection")
+ db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app3 = App(config=AppConfig(collect_metrics=False), db=db3)
+ app3.set_collection_name("three_collection")
+ db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
+ app4 = App(config=AppConfig(collect_metrics=False), db=db4)
+ app4.set_collection_name("four_collection")
+
+ # cleanup if any previous tests failed or were interrupted
+ app1.db.reset()
+ app2.db.reset()
+ app3.db.reset()
+ app4.db.reset()
+
+ app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
+ app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
+ app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
+ app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
+
+ app1.db.reset()
+
+ assert app1.db.count() == 0
+ assert app2.db.count() == 1
+ assert app3.db.count() == 1
+ assert app4.db.count() == 1
+
+ # cleanup
+ app2.db.reset()
+ app3.db.reset()
+ app4.db.reset()
+
+
+def generate_embeddings(dummy_embed, embed_size):
+ generated_embedding = []
+ for i in range(embed_size):
+ generated_embedding.append(dummy_embed)
+
+ return generated_embedding