[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,7 +12,7 @@ from embedchain.cache import (adapt, get_gptcache_session,
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.core.db.models import DataSource
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
@@ -21,7 +20,6 @@ from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
@@ -85,30 +83,6 @@ class EmbedChain(JSONSerializable):
|
||||
self.user_asks = []
|
||||
|
||||
self.chunker: Optional[ChunkerConfig] = None
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
@property
|
||||
def collect_metrics(self):
|
||||
@@ -204,17 +178,21 @@ class EmbedChain(JSONSerializable):
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
|
||||
# Insert the data into the 'data' table
|
||||
self.cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO data_sources (hash, pipeline_id, type, value, metadata)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
|
||||
# Insert the data into the 'ec_data_sources' table
|
||||
self.db_session.add(
|
||||
DataSource(
|
||||
hash=source_hash,
|
||||
app_id=self.config.id,
|
||||
type=data_type.value,
|
||||
value=source,
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
self.connection.commit()
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding data source: {e}")
|
||||
self.db_session.rollback()
|
||||
|
||||
if dry_run:
|
||||
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
||||
@@ -666,9 +644,14 @@ class EmbedChain(JSONSerializable):
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
"""
|
||||
try:
|
||||
self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete()
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting chat history: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
self.db.reset()
|
||||
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||
self.connection.commit()
|
||||
self.delete_all_chat_history(app_id=self.config.id)
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||
|
||||
Reference in New Issue
Block a user