[Feature] Add support for deploying local pipelines to Embedchain platform (#847)
This commit is contained in:
@@ -3,6 +3,7 @@ import importlib.metadata
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -32,6 +33,7 @@ ABS_PATH = os.getcwd()
|
||||
HOME_DIR = str(Path.home())
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
@@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable):
|
||||
# Send anonymous telemetry
|
||||
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
|
||||
self.u_id = self._load_or_generate_user_id()
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
|
||||
# if (self.config.collect_metrics):
|
||||
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
|
||||
@@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable):
|
||||
:raises ValueError: Invalid data type
|
||||
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
|
||||
deafaults to False
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
if config is None:
|
||||
@@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable):
|
||||
if not data_type:
|
||||
data_type = detect_datatype(source)
|
||||
|
||||
# `source_id` is the hash of the source argument
|
||||
# `source_hash` is the md5 hash of the source argument
|
||||
hash_object = hashlib.md5(str(source).encode("utf-8"))
|
||||
source_id = hash_object.hexdigest()
|
||||
source_hash = hash_object.hexdigest()
|
||||
|
||||
# Check if the data hash already exists, if so, skip the addition
|
||||
self.cursor.execute(
|
||||
"SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id)
|
||||
)
|
||||
existing_data = self.cursor.fetchone()
|
||||
|
||||
if existing_data:
|
||||
print(f"Data with hash {source_hash} already exists. Skipping addition.")
|
||||
return source_hash
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||
)
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
|
||||
# Insert the data into the 'data' table
|
||||
self.cursor.execute(
|
||||
"""
|
||||
INSERT INTO data_sources (hash, pipeline_id, type, value, metadata)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
self.connection.commit()
|
||||
|
||||
if dry_run:
|
||||
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
||||
logging.debug(f"Dry run info : {data_chunks_info}")
|
||||
@@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable):
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
||||
thread_telemetry.start()
|
||||
|
||||
return source_id
|
||||
return source_hash
|
||||
|
||||
def add_local(
|
||||
self,
|
||||
@@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable):
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
logging.warning(
|
||||
@@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable):
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
source_hash: Optional[str] = None,
|
||||
dry_run=False,
|
||||
):
|
||||
"""
|
||||
@@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable):
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param source_id: Hexadecimal hash of the source.
|
||||
:param source_hash: Hexadecimal hash of the source.
|
||||
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
|
||||
:type dry_run: bool, defaults to False
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
@@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable):
|
||||
m["app_id"] = self.config.id
|
||||
|
||||
# Add hashed source
|
||||
m["hash"] = source_id
|
||||
m["hash"] = source_hash
|
||||
|
||||
# Note: Metadata is the function argument
|
||||
if metadata:
|
||||
@@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
|
||||
DEPRECATED IN FAVOR OF `db.reset()`
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
||||
thread_telemetry.start()
|
||||
|
||||
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
|
||||
self.db.reset()
|
||||
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||
self.connection.commit()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
||||
|
||||
Reference in New Issue
Block a user