[Feature] Add support for deploying local pipelines to Embedchain platform (#847)

This commit is contained in:
Deshraj Yadav
2023-10-25 13:36:24 -07:00
committed by GitHub
parent 76f1993e7a
commit 3979480532
9 changed files with 467 additions and 43 deletions

View File

@@ -3,6 +3,7 @@ import importlib.metadata
import json
import logging
import os
import sqlite3
import threading
import uuid
from pathlib import Path
@@ -32,6 +33,7 @@ ABS_PATH = os.getcwd()
HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
class EmbedChain(JSONSerializable):
@@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable):
# Send anonymous telemetry
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id()
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# if (self.config.collect_metrics):
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
@@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable):
:raises ValueError: Invalid data type
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
deafaults to False
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
if config is None:
@@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable):
if not data_type:
data_type = detect_datatype(source)
# `source_id` is the hash of the source argument
# `source_hash` is the md5 hash of the source argument
hash_object = hashlib.md5(str(source).encode("utf-8"))
source_id = hash_object.hexdigest()
source_hash = hash_object.hexdigest()
# Check if the data hash already exists, if so, skip the addition
self.cursor.execute(
"SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id)
)
existing_data = self.cursor.fetchone()
if existing_data:
print(f"Data with hash {source_hash} already exists. Skipping addition.")
return source_hash
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
documents, metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
# Insert the data into the 'data' table
self.cursor.execute(
"""
INSERT INTO data_sources (hash, pipeline_id, type, value, metadata)
VALUES (?, ?, ?, ?, ?)
""",
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
)
# Commit the transaction
self.connection.commit()
if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
logging.debug(f"Dry run info : {data_chunks_info}")
@@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable):
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
thread_telemetry.start()
return source_id
return source_hash
def add_local(
self,
@@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable):
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
logging.warning(
@@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable):
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
source_hash: Optional[str] = None,
dry_run=False,
):
"""
@@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable):
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:param source_hash: Hexadecimal hash of the source.
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
:type dry_run: bool, defaults to False
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
@@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable):
m["app_id"] = self.config.id
# Add hashed source
m["hash"] = source_id
m["hash"] = source_hash
# Note: Metadata is the function argument
if metadata:
@@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable):
"""
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
DEPRECATED IN FAVOR OF `db.reset()`
"""
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
thread_telemetry.start()
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
self.db.reset()
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
self.connection.commit()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):