diff --git a/docs/components/vector-databases/lancedb.mdx b/docs/components/vector-databases/lancedb.mdx new file mode 100644 index 00000000..f68d6de9 --- /dev/null +++ b/docs/components/vector-databases/lancedb.mdx @@ -0,0 +1,48 @@ +--- +title: LanceDB +--- + +## Install Embedchain with LanceDB + +Install Embedchain, LanceDB and related dependencies using the following command: + +```bash +pip install "embedchain[lancedb]" +``` + +LanceDB is a developer-friendly, open source database for AI. From hyper scalable vector search and advanced retrieval for RAG, to streaming training data and interactive exploration of large scale AI datasets. +In order to use LanceDB as vector database, not need to set any key for local use. + + +```python main.py +import os +from embedchain import App + +# set OPENAI_API_KEY as env variable +os.environ["OPENAI_API_KEY"] = "sk-xxx" + +# Create Embedchain App and set config +app = App.from_config(config={ + "vectordb": { + "provider": "lancedb", + "config": { + "collection_name": "lancedb-index" + } + } + } +) + +# Add data source and start queryin +app.add("https://www.forbes.com/profile/elon-musk") + +# query continuously +while(True): + question = input("Enter question: ") + if question in ['q', 'exit', 'quit']: + break + answer = app.query(question) + print(answer) +``` + + + \ No newline at end of file diff --git a/embedchain/config/vectordb/lancedb.py b/embedchain/config/vectordb/lancedb.py new file mode 100644 index 00000000..2e53ccda --- /dev/null +++ b/embedchain/config/vectordb/lancedb.py @@ -0,0 +1,33 @@ +from typing import Optional + +from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class LanceDBConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + host: Optional[str] = None, + port: Optional[str] = None, + allow_reset=True, + ): + """ + Initializes a configuration class instance for LanceDB. + + :param collection_name: Default name for the collection, defaults to None + :type collection_name: Optional[str], optional + :param dir: Path to the database directory, where the database is stored, defaults to None + :type dir: Optional[str], optional + :param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None + :type host: Optional[str], optional + :param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None + :type port: Optional[str], optional + :param allow_reset: Resets the database. defaults to False + :type allow_reset: bool + """ + + self.allow_reset = allow_reset + super().__init__(collection_name=collection_name, dir=dir, host=host, port=port) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index c1c658f6..ec1a92d8 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -6,7 +6,9 @@ from typing import Any, Optional, Union from dotenv import load_dotenv from langchain.docstore.document import Document -from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback +from embedchain.cache import (adapt, get_gptcache_session, + gptcache_data_convert, + gptcache_update_cache_callback) from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.base_app_config import BaseAppConfig @@ -16,7 +18,8 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.helpers.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) from embedchain.utils.misc import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/factory.py b/embedchain/factory.py index 567bdb55..0ed7452e 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -91,6 +91,7 @@ class VectorDBFactory: "chroma": "embedchain.vectordb.chroma.ChromaDB", "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", + "lancedb": "embedchain.vectordb.lancedb.LanceDB", "pinecone": "embedchain.vectordb.pinecone.PineconeDB", "qdrant": "embedchain.vectordb.qdrant.QdrantDB", "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", @@ -100,6 +101,7 @@ class VectorDBFactory: "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", + "lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig", "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig", "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 339c78c1..1782dc4b 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -5,7 +5,9 @@ from typing import Any, Optional from langchain.schema import BaseMessage as LCBaseMessage from embedchain.config import BaseLlmConfig -from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE +from embedchain.config.llm.base import (DEFAULT_PROMPT, + DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, + DOCS_SITE_PROMPT_TEMPLATE) from embedchain.helpers.json_serializable import JSONSerializable from embedchain.memory.base import ChatHistory from embedchain.memory.message import ChatMessage diff --git a/embedchain/llm/jina.py b/embedchain/llm/jina.py index 782742cb..33796bc5 100644 --- a/embedchain/llm/jina.py +++ b/embedchain/llm/jina.py @@ -35,7 +35,8 @@ class JinaLlm(BaseLlm): if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import \ + StreamingStdOutCallbackHandler chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) else: diff --git a/embedchain/loaders/audio.py b/embedchain/loaders/audio.py index 44d7e9bd..6b2b69cf 100644 --- a/embedchain/loaders/audio.py +++ b/embedchain/loaders/audio.py @@ -1,6 +1,8 @@ -import os import hashlib +import os + import validators + from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader diff --git a/embedchain/loaders/unstructured_file.py b/embedchain/loaders/unstructured_file.py index 856ac888..62c20ae4 100644 --- a/embedchain/loaders/unstructured_file.py +++ b/embedchain/loaders/unstructured_file.py @@ -11,7 +11,8 @@ class UnstructuredLoader(BaseLoader): """Load data from an Unstructured file.""" try: import unstructured # noqa: F401 - from langchain_community.document_loaders import UnstructuredFileLoader + from langchain_community.document_loaders import \ + UnstructuredFileLoader except ImportError: raise ImportError( 'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501 diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 1e966328..8213a601 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -446,7 +446,7 @@ def validate_config(config_data): }, Optional("vectordb"): { Optional("provider"): Or( - "chroma", "elasticsearch", "opensearch", "pinecone", "qdrant", "weaviate", "zilliz" + "chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz" ), Optional("config"): object, # TODO: add particular config schema for each provider }, diff --git a/embedchain/vectordb/lancedb.py b/embedchain/vectordb/lancedb.py new file mode 100644 index 00000000..a502db65 --- /dev/null +++ b/embedchain/vectordb/lancedb.py @@ -0,0 +1,307 @@ +from typing import Any, Dict, List, Optional, Union + +import pyarrow as pa + +try: + import lancedb +except ImportError: + raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None + +from embedchain.config.vectordb.lancedb import LanceDBConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.vectordb.base import BaseVectorDB + + +@register_deserializable +class LanceDB(BaseVectorDB): + """ + LanceDB as vector database + """ + + BATCH_SIZE = 100 + + def __init__( + self, + config: Optional[LanceDBConfig] = None, + ): + """LanceDB as vector database. + + :param config: LanceDB database config, defaults to None + :type config: LanceDBConfig, optional + """ + if config: + self.config = config + else: + self.config = LanceDBConfig() + + self.client = lancedb.connect(self.config.dir or "~/.lancedb") + self.embedder_check = True + + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + """ + if not self.embedder: + raise ValueError( + "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization." + ) + else: + # check embedder function is working or not + try: + self.embedder.embedding_fn("Hello LanceDB") + except Exception: + self.embedder_check = False + + self._get_or_create_collection(self.config.collection_name) + + def _get_or_create_db(self): + """ + Called during initialization + """ + return self.client + + def _generate_where_clause(self, where: Dict[str, any]) -> str: + """ + This method generate where clause using dictionary containing attributes and their values + """ + + where_filters = "" + + if len(list(where.keys())) == 1: + where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}" + return where_filters + + where_items = list(where.items()) + where_count = len(where_items) + + for i, (key, value) in enumerate(where_items, start=1): + condition = f"{key} = {value} AND " + where_filters += condition + + if i == where_count: + condition = f"{key} = {value}" + where_filters += condition + + return where_filters + + def _get_or_create_collection(self, table_name: str, reset=False): + """ + Get or create a named collection. + + :param name: Name of the collection + :type name: str + :return: Created collection + :rtype: Collection + """ + if not self.embedder_check: + schema = pa.schema( + [ + pa.field("doc", pa.string()), + pa.field("metadata", pa.string()), + pa.field("id", pa.string()), + ] + ) + + else: + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)), + pa.field("doc", pa.string()), + pa.field("metadata", pa.string()), + pa.field("id", pa.string()), + ] + ) + + if not reset: + if table_name not in self.client.table_names(): + self.collection = self.client.create_table(table_name, schema=schema) + + else: + self.client.drop_table(table_name) + self.collection = self.client.create_table(table_name, schema=schema) + + self.collection = self.client[table_name] + + return self.collection + + def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + """ + Get existing doc ids present in vector database + + :param ids: list of doc ids to check for existence + :type ids: List[str] + :param where: Optional. to filter data + :type where: Dict[str, Any] + :param limit: Optional. maximum number of documents + :type limit: Optional[int] + :return: Existing documents. + :rtype: List[str] + """ + if limit is not None: + max_limit = limit + else: + max_limit = 3 + results = {"ids": [], "metadatas": []} + + where_clause = {} + if where: + where_clause = self._generate_where_clause(where) + + if ids is not None: + records = ( + self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict() + ) + for id in records["id"]: + if where is not None: + result = ( + self.collection.search(query=id, vector_column_name="id") + .where(where_clause) + .limit(max_limit) + .to_list() + ) + else: + result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list() + results["ids"] = [r["id"] for r in result] + results["metadatas"] = [r["metadata"] for r in result] + + return results + + def add( + self, + documents: List[str], + metadatas: List[object], + ids: List[str], + ) -> Any: + """ + Add vectors to lancedb database + + :param documents: Documents + :type documents: List[str] + :param metadatas: Metadatas + :type metadatas: List[object] + :param ids: ids + :type ids: List[str] + """ + data = [] + to_ingest = list(zip(documents, metadatas, ids)) + + if not self.embedder_check: + for doc, meta, id in to_ingest: + temp = {} + temp["doc"] = doc + temp["metadata"] = str(meta) + temp["id"] = id + data.append(temp) + else: + for doc, meta, id in to_ingest: + temp = {} + temp["doc"] = doc + temp["vector"] = self.embedder.embedding_fn([doc])[0] + temp["metadata"] = str(meta) + temp["id"] = id + data.append(temp) + + self.collection.add(data=data) + + def _format_result(self, results) -> list: + """ + Format LanceDB results + + :param results: LanceDB query results to format. + :type results: QueryResult + :return: Formatted results + :rtype: list[tuple[Document, float]] + """ + return results.tolist() + + def query( + self, + input_query: str, + n_results: int = 3, + where: Optional[dict[str, any]] = None, + raw_filter: Optional[dict[str, any]] = None, + citations: bool = False, + **kwargs: Optional[dict[str, any]], + ) -> Union[list[tuple[str, dict]], list[str]]: + """ + Query contents from vector database based on vector similarity + + :param input_query: query string + :type input_query: str + :param n_results: no of similar documents to fetch from database + :type n_results: int + :param where: to filter data + :type where: dict[str, Any] + :param raw_filter: Raw filter to apply + :type raw_filter: dict[str, Any] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :raises InvalidDimensionException: Dimensions do not match. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] + """ + if where and raw_filter: + raise ValueError("Both `where` and `raw_filter` cannot be used together.") + try: + query_embedding = self.embedder.embedding_fn(input_query)[0] + result = self.collection.search(query_embedding).limit(n_results).to_list() + except Exception as e: + e.message() + + results_formatted = result + + contexts = [] + for result in results_formatted: + if citations: + metadata = result["metadata"] + contexts.append((result["doc"], metadata)) + else: + contexts.append(result["doc"]) + return contexts + + def set_collection_name(self, name: str): + """ + Set the name of the collection. A collection is an isolated space for vectors. + + :param name: Name of the collection. + :type name: str + """ + if not isinstance(name, str): + raise TypeError("Collection name must be a string") + self.config.collection_name = name + self._get_or_create_collection(self.config.collection_name) + + def count(self) -> int: + """ + Count number of documents/chunks embedded in the database. + + :return: number of documents + :rtype: int + """ + return self.collection.count_rows() + + def delete(self, where): + return self.collection.delete(where=where) + + def reset(self): + """ + Resets the database. Deletes all embeddings irreversibly. + """ + # Delete all data from the collection and recreate collection + if self.config.allow_reset: + try: + self._get_or_create_collection(self.config.collection_name, reset=True) + except ValueError: + raise ValueError( + "For safety reasons, resetting is disabled. " + "Please enable it by setting `allow_reset=True` in your LanceDbConfig" + ) from None + # Recreate + else: + print( + "For safety reasons, resetting is disabled. " + "Please enable it by setting `allow_reset=True` in your LanceDbConfig" + ) diff --git a/notebooks/lancedb.ipynb b/notebooks/lancedb.ipynb new file mode 100644 index 00000000..08d99621 --- /dev/null +++ b/notebooks/lancedb.ipynb @@ -0,0 +1,146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "b02n_zJ_hl3d" + }, + "source": [ + "## Cookbook for using LanceDB with Embedchain" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gyJ6ui2vhtMY" + }, + "source": [ + "### Step-1: Install embedchain package" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-NbXjAdlh0vJ" + }, + "outputs": [], + "source": [ + "! pip install embedchain lancedb" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nGnpSYAAh2bQ" + }, + "source": [ + "### Step-2: Set environment variables needed for LanceDB\n", + "\n", + "You can find this env variable on your [OpenAI](https://platform.openai.com/account/api-keys)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0fBdQ9GAiRvK" + }, + "outputs": [], + "source": [ + "import os\n", + "from embedchain import App\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PGt6uPLIi1CS" + }, + "source": [ + "### Step-3 Create embedchain app and define your config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Amzxk3m-i3tD" + }, + "outputs": [], + "source": [ + "app = App.from_config(config={\n", + " \"vectordb\": {\n", + " \"provider\": \"lancedb\",\n", + " \"config\": {\n", + " \"collection_name\": \"lancedb-index\"\n", + " }\n", + " }\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XNXv4yZwi7ef" + }, + "source": [ + "### Step-4: Add data sources to your app" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Sn_0rx9QjIY9" + }, + "outputs": [], + "source": [ + "app.add(\"https://www.forbes.com/profile/elon-musk\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_7W6fDeAjMAP" + }, + "source": [ + "### Step-5: All set. Now start asking questions related to your data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cvIK7dWRjN_f" + }, + "outputs": [], + "source": [ + "while(True):\n", + " question = input(\"Enter question: \")\n", + " if question in ['q', 'exit', 'quit']:\n", + " break\n", + " answer = app.query(question)\n", + " print(answer)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/poetry.lock b/poetry.lock index 85ccb80e..32a84a77 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1142,6 +1142,17 @@ files = [ marshmallow = ">=3.18.0,<4.0.0" typing-inspect = ">=0.4.0,<1" +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = true +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "deprecated" version = "1.2.14" @@ -1159,6 +1170,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = true +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "discord" version = "2.3.2" @@ -2785,6 +2810,42 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "lancedb" +version = "0.6.13" +description = "lancedb" +optional = true +python-versions = ">=3.8" +files = [ + {file = "lancedb-0.6.13-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:4667353ca7fa187e94cb0ca4c5f9577d65eb5160f6f3fe9e57902d86312c3869"}, + {file = "lancedb-0.6.13-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e22533fe6f6b2d7037dcdbbb4019a62402bbad4ce18395be68f4aa007bf8bc0"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837eaceafb87e3ae4c261eef45c4f73715f892a36165572c3da621dbdb45afcf"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:61af2d72b2a2f0ea419874c3f32760fe5e51530da3be2d65251a0e6ded74419b"}, + {file = "lancedb-0.6.13-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:31b24e57ee313f4ce6255e45d42e8bee19b90ddcd13a9e07030ac04f76e7dfde"}, + {file = "lancedb-0.6.13-cp38-abi3-win_amd64.whl", hash = "sha256:b851182d8492b1e5b57a441af64c95da65ca30b045d6618dc7d203c6d60d70fa"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +cachetools = "*" +deprecation = "*" +overrides = ">=0.7" +pydantic = ">=1.10" +pylance = "0.10.12" +ratelimiter = ">=1.0,<2.0" +requests = ">=2.31.0" +retry = ">=0.9.2" +semver = "*" +tqdm = ">=4.27.0" + +[package.extras] +azure = ["adlfs (>=2024.2.0)"] +clip = ["open-clip", "pillow", "torch"] +dev = ["pre-commit", "ruff"] +docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] +embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] +tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] + [[package]] name = "langchain" version = "0.1.20" @@ -4781,6 +4842,65 @@ all = ["apache-bookkeeper-client (>=4.16.1)", "fastavro (>=1.9.2)", "grpcio (>=1 avro = ["fastavro (>=1.9.2)"] functions = ["apache-bookkeeper-client (>=4.16.1)", "grpcio (>=1.60.0)", "prometheus-client", "protobuf (>=3.6.1,<=3.20.3)", "ratelimit"] +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, +] + +[[package]] +name = "pyarrow" +version = "15.0.0" +description = "Python library for Apache Arrow" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"}, + {file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"}, + {file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"}, + {file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"}, + {file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"}, + {file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"}, + {file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"}, + {file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"}, + {file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"}, + {file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"}, + {file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"}, + {file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"}, + {file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"}, + {file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"}, + {file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"}, + {file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"}, + {file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"}, + {file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"}, + {file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"}, + {file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"}, + {file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2" + [[package]] name = "pyasn1" version = "0.6.0" @@ -5019,6 +5139,32 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pylance" +version = "0.10.12" +description = "python wrapper for Lance columnar format" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pylance-0.10.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:30cbcca078edeb37e11ae86cf9287d81ce6c0c07ba77239284b369a4b361497b"}, + {file = "pylance-0.10.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e558163ff6035d518706cc66848497219ccc755e2972b8f3b1706a3e1fd800fd"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75afb39f71d7f12429f9b4d380eb6cf6aed179ae5a1c5d16cc768373a1521f87"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:3de391dfc3a99bdb245fd1e27ef242be769a94853f802ef57f246e9a21358d32"}, + {file = "pylance-0.10.12-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:34a5278b90f4cbcf21261353976127aa2ffbbd7d068810f0a2b0c1aa0334022a"}, + {file = "pylance-0.10.12-cp38-abi3-win_amd64.whl", hash = "sha256:6cef5975d513097fd2c22692296c9a5a138928f38d02cd34ab63a7369abc1463"}, +] + +[package.dependencies] +numpy = ">=1.22" +pyarrow = ">=12,<15.0.1" + +[package.extras] +benchmarks = ["pytest-benchmark"] +dev = ["ruff (==0.2.2)"] +ray = ["ray[data]"] +tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"] +torch = ["torch"] + [[package]] name = "pymilvus" version = "2.4.3" @@ -5540,6 +5686,20 @@ urllib3 = ">=1.26.14,<3" [package.extras] fastembed = ["fastembed (==0.2.6)"] +[[package]] +name = "ratelimiter" +version = "1.2.0.post0" +description = "Simple python rate limiting object" +optional = true +python-versions = "*" +files = [ + {file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"}, + {file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"}, +] + +[package.extras] +test = ["pytest (>=3.0)", "pytest-asyncio"] + [[package]] name = "regex" version = "2024.5.15" @@ -5720,6 +5880,21 @@ urllib3 = ">=1.25.10,<3.0" [package.extras] tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-requests"] +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = true +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rich" version = "13.7.1" @@ -6018,6 +6193,17 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pyde doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "semver" +version = "3.0.2" +description = "Python helper for Semantic Versioning (https://semver.org)" +optional = true +python-versions = ">=3.7" +files = [ + {file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"}, + {file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"}, +] + [[package]] name = "sentence-transformers" version = "2.7.0" @@ -7716,6 +7902,7 @@ gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-a google = ["google-generativeai"] googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"] huggingface-hub = ["huggingface_hub"] +lancedb = ["lancedb"] llama2 = ["replicate"] milvus = ["pymilvus"] mistralai = ["langchain-mistralai"] @@ -7738,4 +7925,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.13" -content-hash = "cb7d2794bc2f54e05b2f870843eccc4342f3f2a6531eaa50a3d0d77b358ac4d5" +content-hash = "f9e6357bd1b5f407368d3d52c3f728e12f41fe7d4836e321ec2e10413f58e8a1" diff --git a/pyproject.toml b/pyproject.toml index 8a86a063..714463f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ slack-sdk = { version = "3.21.3", optional = true } clarifai = { version = "^10.0.1", optional = true } cohere = { version = "^5.3", optional = true } together = { version = "^0.2.8", optional = true } +lancedb = { version = "^0.6.2", optional = true } weaviate-client = { version = "^3.24.1", optional = true } docx2txt = { version = "^0.8", optional = true } qdrant-client = { version = "^1.6.3", optional = true } @@ -173,6 +174,7 @@ pytest-asyncio = "^0.21.1" [tool.poetry.extras] streamlit = ["streamlit"] opensource = ["sentence-transformers", "torch", "gpt4all"] +lancedb = ["lancedb"] elasticsearch = ["elasticsearch"] opensearch = ["opensearch-py"] poe = ["fastapi-poe"] diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py index 067fa869..8258e776 100644 --- a/tests/chunkers/test_chunkers.py +++ b/tests/chunkers/test_chunkers.py @@ -1,3 +1,4 @@ +from embedchain.chunkers.audio import AudioChunker from embedchain.chunkers.common_chunker import CommonChunker from embedchain.chunkers.discourse import DiscourseChunker from embedchain.chunkers.docs_site import DocsSiteChunker @@ -19,7 +20,6 @@ from embedchain.chunkers.text import TextChunker from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.xml import XmlChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker -from embedchain.chunkers.audio import AudioChunker from embedchain.config.add_config import ChunkerConfig chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len) diff --git a/tests/loaders/test_audio.py b/tests/loaders/test_audio.py index c365893b..c62ec139 100644 --- a/tests/loaders/test_audio.py +++ b/tests/loaders/test_audio.py @@ -1,11 +1,13 @@ +import hashlib import os import sys -import hashlib -import pytest from unittest.mock import mock_open, patch +import pytest + if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10 from deepgram import PrerecordedOptions + from embedchain.loaders.audio import AudioLoader diff --git a/tests/vectordb/test_lancedb.py b/tests/vectordb/test_lancedb.py new file mode 100644 index 00000000..f50660bb --- /dev/null +++ b/tests/vectordb/test_lancedb.py @@ -0,0 +1,215 @@ +import os +import shutil + +import pytest + +from embedchain import App +from embedchain.config import AppConfig +from embedchain.config.vectordb.lancedb import LanceDBConfig +from embedchain.vectordb.lancedb import LanceDB + +os.environ["OPENAI_API_KEY"] = "test-api-key" + + +@pytest.fixture +def lancedb(): + return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll")) + + +@pytest.fixture +def app_with_settings(): + lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset") + lancedb = LanceDB(config=lancedb_config) + app_config = AppConfig(collect_metrics=False) + return App(config=app_config, db=lancedb) + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_db(): + yield + try: + shutil.rmtree("test-db.lance") + shutil.rmtree("test-db-reset.lance") + except OSError as e: + print("Error: %s - %s." % (e.filename, e.strerror)) + + +def test_lancedb_duplicates_throw_warning(caplog): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + assert "Insert of existing doc ID: 0" not in caplog.text + assert "Add of existing doc ID: 0" not in caplog.text + app.db.reset() + + +def test_lancedb_duplicates_collections_no_warning(caplog): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.set_collection_name("test_collection_1") + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + app.set_collection_name("test_collection_2") + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + assert "Insert of existing doc ID: 0" not in caplog.text + assert "Add of existing doc ID: 0" not in caplog.text + app.db.reset() + app.set_collection_name("test_collection_1") + app.db.reset() + + +def test_lancedb_collection_init_with_default_collection(): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + assert app.db.collection.name == "embedchain_store" + + +def test_lancedb_collection_init_with_custom_collection(): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.set_collection_name(name="test_collection") + assert app.db.collection.name == "test_collection" + + +def test_lancedb_collection_set_collection_name(): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.set_collection_name("test_collection") + assert app.db.collection.name == "test_collection" + + +def test_lancedb_collection_changes_encapsulated(): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.set_collection_name("test_collection_1") + assert app.db.count() == 0 + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + assert app.db.count() == 1 + + app.set_collection_name("test_collection_2") + assert app.db.count() == 0 + + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + app.set_collection_name("test_collection_1") + assert app.db.count() == 1 + app.db.reset() + app.set_collection_name("test_collection_2") + app.db.reset() + + +def test_lancedb_collection_collections_are_persistent(): + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.set_collection_name("test_collection_1") + app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + del app + + db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app = App(config=AppConfig(collect_metrics=False), db=db) + app.set_collection_name("test_collection_1") + assert app.db.count() == 1 + + app.db.reset() + + +def test_lancedb_collection_parallel_collections(): + db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1")) + app1 = App( + config=AppConfig(collect_metrics=False), + db=db1, + ) + db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2")) + app2 = App( + config=AppConfig(collect_metrics=False), + db=db2, + ) + + # cleanup if any previous tests failed or were interrupted + app1.db.reset() + app2.db.reset() + + app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + + assert app1.db.count() == 1 + assert app2.db.count() == 0 + + app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"]) + app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"]) + + app1.set_collection_name("test_collection_2") + assert app1.db.count() == 1 + app2.set_collection_name("test_collection_1") + assert app2.db.count() == 3 + + # cleanup + app1.db.reset() + app2.db.reset() + + +def test_lancedb_collection_ids_share_collections(): + db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app1 = App(config=AppConfig(collect_metrics=False), db=db1) + app1.set_collection_name("one_collection") + db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app2 = App(config=AppConfig(collect_metrics=False), db=db2) + app2.set_collection_name("one_collection") + + # cleanup + app1.db.reset() + app2.db.reset() + + app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"]) + app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"]) + + assert app1.db.count() == 2 + assert app2.db.count() == 3 + + # cleanup + app1.db.reset() + app2.db.reset() + + +def test_lancedb_collection_reset(): + db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app1 = App(config=AppConfig(collect_metrics=False), db=db1) + app1.set_collection_name("one_collection") + db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app2 = App(config=AppConfig(collect_metrics=False), db=db2) + app2.set_collection_name("two_collection") + db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app3 = App(config=AppConfig(collect_metrics=False), db=db3) + app3.set_collection_name("three_collection") + db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db")) + app4 = App(config=AppConfig(collect_metrics=False), db=db4) + app4.set_collection_name("four_collection") + + # cleanup if any previous tests failed or were interrupted + app1.db.reset() + app2.db.reset() + app3.db.reset() + app4.db.reset() + + app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"]) + app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"]) + app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"]) + app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"]) + + app1.db.reset() + + assert app1.db.count() == 0 + assert app2.db.count() == 1 + assert app3.db.count() == 1 + assert app4.db.count() == 1 + + # cleanup + app2.db.reset() + app3.db.reset() + app4.db.reset() + + +def generate_embeddings(dummy_embed, embed_size): + generated_embedding = [] + for i in range(embed_size): + generated_embedding.append(dummy_embed) + + return generated_embedding