Lancedb Integration (#1411)

This commit is contained in:
Prashant Dixit
2024-06-21 21:29:22 +05:30
committed by GitHub
parent f6ddd5ffc5
commit 48b24f6f12
16 changed files with 963 additions and 12 deletions

View File

@@ -0,0 +1,48 @@
---
title: LanceDB
---
## Install Embedchain with LanceDB
Install Embedchain, LanceDB and related dependencies using the following command:
```bash
pip install "embedchain[lancedb]"
```
LanceDB is a developer-friendly, open source database for AI. From hyper scalable vector search and advanced retrieval for RAG, to streaming training data and interactive exploration of large scale AI datasets.
In order to use LanceDB as vector database, not need to set any key for local use.
<CodeGroup>
```python main.py
import os
from embedchain import App
# set OPENAI_API_KEY as env variable
os.environ["OPENAI_API_KEY"] = "sk-xxx"
# Create Embedchain App and set config
app = App.from_config(config={
"vectordb": {
"provider": "lancedb",
"config": {
"collection_name": "lancedb-index"
}
}
}
)
# Add data source and start queryin
app.add("https://www.forbes.com/profile/elon-musk")
# query continuously
while(True):
question = input("Enter question: ")
if question in ['q', 'exit', 'quit']:
break
answer = app.query(question)
print(answer)
```
</CodeGroup>
<Snippet file="missing-vector-db-tip.mdx" />

View File

@@ -0,0 +1,33 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class LanceDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
allow_reset=True,
):
"""
Initializes a configuration class instance for LanceDB.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
"""
self.allow_reset = allow_reset
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -6,7 +6,9 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -16,7 +18,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -91,6 +91,7 @@ class VectorDBFactory:
"chroma": "embedchain.vectordb.chroma.ChromaDB",
"elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
"opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
"lancedb": "embedchain.vectordb.lancedb.LanceDB",
"pinecone": "embedchain.vectordb.pinecone.PineconeDB",
"qdrant": "embedchain.vectordb.qdrant.QdrantDB",
"weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
@@ -100,6 +101,7 @@ class VectorDBFactory:
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",

View File

@@ -5,7 +5,9 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage

View File

@@ -35,7 +35,8 @@ class JinaLlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:

View File

@@ -1,6 +1,8 @@
import os
import hashlib
import os
import validators
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader

View File

@@ -11,7 +11,8 @@ class UnstructuredLoader(BaseLoader):
"""Load data from an Unstructured file."""
try:
import unstructured # noqa: F401
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.document_loaders import \
UnstructuredFileLoader
except ImportError:
raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501

View File

@@ -446,7 +446,7 @@ def validate_config(config_data):
},
Optional("vectordb"): {
Optional("provider"): Or(
"chroma", "elasticsearch", "opensearch", "pinecone", "qdrant", "weaviate", "zilliz"
"chroma", "elasticsearch", "opensearch", "lancedb", "pinecone", "qdrant", "weaviate", "zilliz"
),
Optional("config"): object, # TODO: add particular config schema for each provider
},

View File

@@ -0,0 +1,307 @@
from typing import Any, Dict, List, Optional, Union
import pyarrow as pa
try:
import lancedb
except ImportError:
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
from embedchain.config.vectordb.lancedb import LanceDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
@register_deserializable
class LanceDB(BaseVectorDB):
"""
LanceDB as vector database
"""
BATCH_SIZE = 100
def __init__(
self,
config: Optional[LanceDBConfig] = None,
):
"""LanceDB as vector database.
:param config: LanceDB database config, defaults to None
:type config: LanceDBConfig, optional
"""
if config:
self.config = config
else:
self.config = LanceDBConfig()
self.client = lancedb.connect(self.config.dir or "~/.lancedb")
self.embedder_check = True
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
if not self.embedder:
raise ValueError(
"Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
)
else:
# check embedder function is working or not
try:
self.embedder.embedding_fn("Hello LanceDB")
except Exception:
self.embedder_check = False
self._get_or_create_collection(self.config.collection_name)
def _get_or_create_db(self):
"""
Called during initialization
"""
return self.client
def _generate_where_clause(self, where: Dict[str, any]) -> str:
"""
This method generate where clause using dictionary containing attributes and their values
"""
where_filters = ""
if len(list(where.keys())) == 1:
where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}"
return where_filters
where_items = list(where.items())
where_count = len(where_items)
for i, (key, value) in enumerate(where_items, start=1):
condition = f"{key} = {value} AND "
where_filters += condition
if i == where_count:
condition = f"{key} = {value}"
where_filters += condition
return where_filters
def _get_or_create_collection(self, table_name: str, reset=False):
"""
Get or create a named collection.
:param name: Name of the collection
:type name: str
:return: Created collection
:rtype: Collection
"""
if not self.embedder_check:
schema = pa.schema(
[
pa.field("doc", pa.string()),
pa.field("metadata", pa.string()),
pa.field("id", pa.string()),
]
)
else:
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)),
pa.field("doc", pa.string()),
pa.field("metadata", pa.string()),
pa.field("id", pa.string()),
]
)
if not reset:
if table_name not in self.client.table_names():
self.collection = self.client.create_table(table_name, schema=schema)
else:
self.client.drop_table(table_name)
self.collection = self.client.create_table(table_name, schema=schema)
self.collection = self.client[table_name]
return self.collection
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence
:type ids: List[str]
:param where: Optional. to filter data
:type where: Dict[str, Any]
:param limit: Optional. maximum number of documents
:type limit: Optional[int]
:return: Existing documents.
:rtype: List[str]
"""
if limit is not None:
max_limit = limit
else:
max_limit = 3
results = {"ids": [], "metadatas": []}
where_clause = {}
if where:
where_clause = self._generate_where_clause(where)
if ids is not None:
records = (
self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict()
)
for id in records["id"]:
if where is not None:
result = (
self.collection.search(query=id, vector_column_name="id")
.where(where_clause)
.limit(max_limit)
.to_list()
)
else:
result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list()
results["ids"] = [r["id"] for r in result]
results["metadatas"] = [r["metadata"] for r in result]
return results
def add(
self,
documents: List[str],
metadatas: List[object],
ids: List[str],
) -> Any:
"""
Add vectors to lancedb database
:param documents: Documents
:type documents: List[str]
:param metadatas: Metadatas
:type metadatas: List[object]
:param ids: ids
:type ids: List[str]
"""
data = []
to_ingest = list(zip(documents, metadatas, ids))
if not self.embedder_check:
for doc, meta, id in to_ingest:
temp = {}
temp["doc"] = doc
temp["metadata"] = str(meta)
temp["id"] = id
data.append(temp)
else:
for doc, meta, id in to_ingest:
temp = {}
temp["doc"] = doc
temp["vector"] = self.embedder.embedding_fn([doc])[0]
temp["metadata"] = str(meta)
temp["id"] = id
data.append(temp)
self.collection.add(data=data)
def _format_result(self, results) -> list:
"""
Format LanceDB results
:param results: LanceDB query results to format.
:type results: QueryResult
:return: Formatted results
:rtype: list[tuple[Document, float]]
"""
return results.tolist()
def query(
self,
input_query: str,
n_results: int = 3,
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
Query contents from vector database based on vector similarity
:param input_query: query string
:type input_query: str
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: to filter data
:type where: dict[str, Any]
:param raw_filter: Raw filter to apply
:type raw_filter: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
if where and raw_filter:
raise ValueError("Both `where` and `raw_filter` cannot be used together.")
try:
query_embedding = self.embedder.embedding_fn(input_query)[0]
result = self.collection.search(query_embedding).limit(n_results).to_list()
except Exception as e:
e.message()
results_formatted = result
contexts = []
for result in results_formatted:
if citations:
metadata = result["metadata"]
contexts.append((result["doc"], metadata))
else:
contexts.append(result["doc"])
return contexts
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
self._get_or_create_collection(self.config.collection_name)
def count(self) -> int:
"""
Count number of documents/chunks embedded in the database.
:return: number of documents
:rtype: int
"""
return self.collection.count_rows()
def delete(self, where):
return self.collection.delete(where=where)
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the collection and recreate collection
if self.config.allow_reset:
try:
self._get_or_create_collection(self.config.collection_name, reset=True)
except ValueError:
raise ValueError(
"For safety reasons, resetting is disabled. "
"Please enable it by setting `allow_reset=True` in your LanceDbConfig"
) from None
# Recreate
else:
print(
"For safety reasons, resetting is disabled. "
"Please enable it by setting `allow_reset=True` in your LanceDbConfig"
)

146
notebooks/lancedb.ipynb Normal file
View File

@@ -0,0 +1,146 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "b02n_zJ_hl3d"
},
"source": [
"## Cookbook for using LanceDB with Embedchain"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gyJ6ui2vhtMY"
},
"source": [
"### Step-1: Install embedchain package"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-NbXjAdlh0vJ"
},
"outputs": [],
"source": [
"! pip install embedchain lancedb"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nGnpSYAAh2bQ"
},
"source": [
"### Step-2: Set environment variables needed for LanceDB\n",
"\n",
"You can find this env variable on your [OpenAI](https://platform.openai.com/account/api-keys)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0fBdQ9GAiRvK"
},
"outputs": [],
"source": [
"import os\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PGt6uPLIi1CS"
},
"source": [
"### Step-3 Create embedchain app and define your config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Amzxk3m-i3tD"
},
"outputs": [],
"source": [
"app = App.from_config(config={\n",
" \"vectordb\": {\n",
" \"provider\": \"lancedb\",\n",
" \"config\": {\n",
" \"collection_name\": \"lancedb-index\"\n",
" }\n",
" }\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XNXv4yZwi7ef"
},
"source": [
"### Step-4: Add data sources to your app"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sn_0rx9QjIY9"
},
"outputs": [],
"source": [
"app.add(\"https://www.forbes.com/profile/elon-musk\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_7W6fDeAjMAP"
},
"source": [
"### Step-5: All set. Now start asking questions related to your data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cvIK7dWRjN_f"
},
"outputs": [],
"source": [
"while(True):\n",
" question = input(\"Enter question: \")\n",
" if question in ['q', 'exit', 'quit']:\n",
" break\n",
" answer = app.query(question)\n",
" print(answer)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

191
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
[[package]]
name = "aiohttp"
@@ -1142,6 +1142,17 @@ files = [
marshmallow = ">=3.18.0,<4.0.0"
typing-inspect = ">=0.4.0,<1"
[[package]]
name = "decorator"
version = "5.1.1"
description = "Decorators for Humans"
optional = true
python-versions = ">=3.5"
files = [
{file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
]
[[package]]
name = "deprecated"
version = "1.2.14"
@@ -1159,6 +1170,20 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "deprecation"
version = "2.1.0"
description = "A library to handle automated deprecations"
optional = true
python-versions = "*"
files = [
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
]
[package.dependencies]
packaging = "*"
[[package]]
name = "discord"
version = "2.3.2"
@@ -2785,6 +2810,42 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0"
[package.extras]
adal = ["adal (>=1.0.2)"]
[[package]]
name = "lancedb"
version = "0.6.13"
description = "lancedb"
optional = true
python-versions = ">=3.8"
files = [
{file = "lancedb-0.6.13-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:4667353ca7fa187e94cb0ca4c5f9577d65eb5160f6f3fe9e57902d86312c3869"},
{file = "lancedb-0.6.13-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:2e22533fe6f6b2d7037dcdbbb4019a62402bbad4ce18395be68f4aa007bf8bc0"},
{file = "lancedb-0.6.13-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:837eaceafb87e3ae4c261eef45c4f73715f892a36165572c3da621dbdb45afcf"},
{file = "lancedb-0.6.13-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:61af2d72b2a2f0ea419874c3f32760fe5e51530da3be2d65251a0e6ded74419b"},
{file = "lancedb-0.6.13-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:31b24e57ee313f4ce6255e45d42e8bee19b90ddcd13a9e07030ac04f76e7dfde"},
{file = "lancedb-0.6.13-cp38-abi3-win_amd64.whl", hash = "sha256:b851182d8492b1e5b57a441af64c95da65ca30b045d6618dc7d203c6d60d70fa"},
]
[package.dependencies]
attrs = ">=21.3.0"
cachetools = "*"
deprecation = "*"
overrides = ">=0.7"
pydantic = ">=1.10"
pylance = "0.10.12"
ratelimiter = ">=1.0,<2.0"
requests = ">=2.31.0"
retry = ">=0.9.2"
semver = "*"
tqdm = ">=4.27.0"
[package.extras]
azure = ["adlfs (>=2024.2.0)"]
clip = ["open-clip", "pillow", "torch"]
dev = ["pre-commit", "ruff"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "instructorembedding", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"]
tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"]
[[package]]
name = "langchain"
version = "0.1.20"
@@ -4781,6 +4842,65 @@ all = ["apache-bookkeeper-client (>=4.16.1)", "fastavro (>=1.9.2)", "grpcio (>=1
avro = ["fastavro (>=1.9.2)"]
functions = ["apache-bookkeeper-client (>=4.16.1)", "grpcio (>=1.60.0)", "prometheus-client", "protobuf (>=3.6.1,<=3.20.3)", "ratelimit"]
[[package]]
name = "py"
version = "1.11.0"
description = "library with cross-python path, ini-parsing, io, code, log facilities"
optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
]
[[package]]
name = "pyarrow"
version = "15.0.0"
description = "Python library for Apache Arrow"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"},
{file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"},
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"},
{file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"},
{file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"},
{file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"},
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"},
{file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"},
{file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"},
{file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"},
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"},
{file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"},
{file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"},
{file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"},
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"},
{file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"},
{file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"},
{file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"},
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"},
{file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"},
{file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"},
]
[package.dependencies]
numpy = ">=1.16.6,<2"
[[package]]
name = "pyasn1"
version = "0.6.0"
@@ -5019,6 +5139,32 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]]
name = "pylance"
version = "0.10.12"
description = "python wrapper for Lance columnar format"
optional = true
python-versions = ">=3.8"
files = [
{file = "pylance-0.10.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:30cbcca078edeb37e11ae86cf9287d81ce6c0c07ba77239284b369a4b361497b"},
{file = "pylance-0.10.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e558163ff6035d518706cc66848497219ccc755e2972b8f3b1706a3e1fd800fd"},
{file = "pylance-0.10.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75afb39f71d7f12429f9b4d380eb6cf6aed179ae5a1c5d16cc768373a1521f87"},
{file = "pylance-0.10.12-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:3de391dfc3a99bdb245fd1e27ef242be769a94853f802ef57f246e9a21358d32"},
{file = "pylance-0.10.12-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:34a5278b90f4cbcf21261353976127aa2ffbbd7d068810f0a2b0c1aa0334022a"},
{file = "pylance-0.10.12-cp38-abi3-win_amd64.whl", hash = "sha256:6cef5975d513097fd2c22692296c9a5a138928f38d02cd34ab63a7369abc1463"},
]
[package.dependencies]
numpy = ">=1.22"
pyarrow = ">=12,<15.0.1"
[package.extras]
benchmarks = ["pytest-benchmark"]
dev = ["ruff (==0.2.2)"]
ray = ["ray[data]"]
tests = ["boto3", "datasets", "duckdb", "h5py (<3.11)", "ml-dtypes", "pandas", "pillow", "polars[pandas,pyarrow]", "pytest", "tensorflow", "tqdm"]
torch = ["torch"]
[[package]]
name = "pymilvus"
version = "2.4.3"
@@ -5540,6 +5686,20 @@ urllib3 = ">=1.26.14,<3"
[package.extras]
fastembed = ["fastembed (==0.2.6)"]
[[package]]
name = "ratelimiter"
version = "1.2.0.post0"
description = "Simple python rate limiting object"
optional = true
python-versions = "*"
files = [
{file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"},
{file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"},
]
[package.extras]
test = ["pytest (>=3.0)", "pytest-asyncio"]
[[package]]
name = "regex"
version = "2024.5.15"
@@ -5720,6 +5880,21 @@ urllib3 = ">=1.25.10,<3.0"
[package.extras]
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-requests"]
[[package]]
name = "retry"
version = "0.9.2"
description = "Easy to use retry decorator."
optional = true
python-versions = "*"
files = [
{file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
{file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
]
[package.dependencies]
decorator = ">=3.4.2"
py = ">=1.4.26,<2.0.0"
[[package]]
name = "rich"
version = "13.7.1"
@@ -6018,6 +6193,17 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pyde
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
[[package]]
name = "semver"
version = "3.0.2"
description = "Python helper for Semantic Versioning (https://semver.org)"
optional = true
python-versions = ">=3.7"
files = [
{file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"},
{file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"},
]
[[package]]
name = "sentence-transformers"
version = "2.7.0"
@@ -7716,6 +7902,7 @@ gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-a
google = ["google-generativeai"]
googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
huggingface-hub = ["huggingface_hub"]
lancedb = ["lancedb"]
llama2 = ["replicate"]
milvus = ["pymilvus"]
mistralai = ["langchain-mistralai"]
@@ -7738,4 +7925,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<=3.13"
content-hash = "cb7d2794bc2f54e05b2f870843eccc4342f3f2a6531eaa50a3d0d77b358ac4d5"
content-hash = "f9e6357bd1b5f407368d3d52c3f728e12f41fe7d4836e321ec2e10413f58e8a1"

View File

@@ -122,6 +122,7 @@ slack-sdk = { version = "3.21.3", optional = true }
clarifai = { version = "^10.0.1", optional = true }
cohere = { version = "^5.3", optional = true }
together = { version = "^0.2.8", optional = true }
lancedb = { version = "^0.6.2", optional = true }
weaviate-client = { version = "^3.24.1", optional = true }
docx2txt = { version = "^0.8", optional = true }
qdrant-client = { version = "^1.6.3", optional = true }
@@ -173,6 +174,7 @@ pytest-asyncio = "^0.21.1"
[tool.poetry.extras]
streamlit = ["streamlit"]
opensource = ["sentence-transformers", "torch", "gpt4all"]
lancedb = ["lancedb"]
elasticsearch = ["elasticsearch"]
opensearch = ["opensearch-py"]
poe = ["fastapi-poe"]

View File

@@ -1,3 +1,4 @@
from embedchain.chunkers.audio import AudioChunker
from embedchain.chunkers.common_chunker import CommonChunker
from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
@@ -19,7 +20,6 @@ from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.chunkers.audio import AudioChunker
from embedchain.config.add_config import ChunkerConfig
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)

View File

@@ -1,11 +1,13 @@
import hashlib
import os
import sys
import hashlib
import pytest
from unittest.mock import mock_open, patch
import pytest
if sys.version_info > (3, 10): # as `match` statement was introduced in python 3.10
from deepgram import PrerecordedOptions
from embedchain.loaders.audio import AudioLoader

View File

@@ -0,0 +1,215 @@
import os
import shutil
import pytest
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.lancedb import LanceDBConfig
from embedchain.vectordb.lancedb import LanceDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
@pytest.fixture
def lancedb():
return LanceDB(config=LanceDBConfig(dir="test-db", collection_name="test-coll"))
@pytest.fixture
def app_with_settings():
lancedb_config = LanceDBConfig(allow_reset=True, dir="test-db-reset")
lancedb = LanceDB(config=lancedb_config)
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db=lancedb)
@pytest.fixture(scope="session", autouse=True)
def cleanup_db():
yield
try:
shutil.rmtree("test-db.lance")
shutil.rmtree("test-db-reset.lance")
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
def test_lancedb_duplicates_throw_warning(caplog):
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert "Insert of existing doc ID: 0" not in caplog.text
assert "Add of existing doc ID: 0" not in caplog.text
app.db.reset()
def test_lancedb_duplicates_collections_no_warning(caplog):
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app.set_collection_name("test_collection_2")
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert "Insert of existing doc ID: 0" not in caplog.text
assert "Add of existing doc ID: 0" not in caplog.text
app.db.reset()
app.set_collection_name("test_collection_1")
app.db.reset()
def test_lancedb_collection_init_with_default_collection():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
assert app.db.collection.name == "embedchain_store"
def test_lancedb_collection_init_with_custom_collection():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_lancedb_collection_set_collection_name():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_lancedb_collection_changes_encapsulated():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert app.db.count() == 1
app.set_collection_name("test_collection_2")
assert app.db.count() == 0
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
app.set_collection_name("test_collection_2")
app.db.reset()
def test_lancedb_collection_collections_are_persistent():
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
del app
db = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
app.db.reset()
def test_lancedb_collection_parallel_collections():
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
app1 = App(
config=AppConfig(collect_metrics=False),
db=db1,
)
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
app2 = App(
config=AppConfig(collect_metrics=False),
db=db2,
)
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
app1.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
assert app1.db.count() == 1
assert app2.db.count() == 0
app1.db.add(ids=["1", "2"], documents=["doc1", "doc2"], metadatas=["test", "test"])
app2.db.add(ids=["0"], documents=["doc1"], metadatas=["test"])
app1.set_collection_name("test_collection_2")
assert app1.db.count() == 1
app2.set_collection_name("test_collection_1")
assert app2.db.count() == 3
# cleanup
app1.db.reset()
app2.db.reset()
def test_lancedb_collection_ids_share_collections():
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("one_collection")
# cleanup
app1.db.reset()
app2.db.reset()
app1.db.add(ids=["0", "1"], documents=["doc1", "doc2"], metadatas=["test", "test"])
app2.db.add(ids=["2"], documents=["doc3"], metadatas=["test"])
assert app1.db.count() == 2
assert app2.db.count() == 3
# cleanup
app1.db.reset()
app2.db.reset()
def test_lancedb_collection_reset():
db1 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
db2 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("two_collection")
db3 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
app3.set_collection_name("three_collection")
db4 = LanceDB(config=LanceDBConfig(allow_reset=True, dir="test-db"))
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
app4.set_collection_name("four_collection")
# cleanup if any previous tests failed or were interrupted
app1.db.reset()
app2.db.reset()
app3.db.reset()
app4.db.reset()
app1.db.add(ids=["1"], documents=["doc1"], metadatas=["test"])
app2.db.add(ids=["2"], documents=["doc2"], metadatas=["test"])
app3.db.add(ids=["3"], documents=["doc3"], metadatas=["test"])
app4.db.add(ids=["4"], documents=["doc4"], metadatas=["test"])
app1.db.reset()
assert app1.db.count() == 0
assert app2.db.count() == 1
assert app3.db.count() == 1
assert app4.db.count() == 1
# cleanup
app2.db.reset()
app3.db.reset()
app4.db.reset()
def generate_embeddings(dummy_embed, embed_size):
generated_embedding = []
for i in range(embed_size):
generated_embedding.append(dummy_embed)
return generated_embedding