[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)

This commit is contained in:
Deshraj Yadav
2024-02-19 13:04:18 -08:00
committed by GitHub
parent 6c12bc9044
commit 5e2e7fb639
20 changed files with 601 additions and 202 deletions

View File

@@ -1,7 +1,6 @@
import hashlib
import json
import logging
import sqlite3
from typing import Any, Optional, Union
from dotenv import load_dotenv
@@ -13,7 +12,7 @@ from embedchain.cache import (adapt, get_gptcache_session,
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
from embedchain.constants import SQLITE_PATH
from embedchain.core.db.models import DataSource
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
@@ -21,7 +20,6 @@ from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -85,30 +83,6 @@ class EmbedChain(JSONSerializable):
self.user_asks = []
self.chunker: Optional[ChunkerConfig] = None
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# Send anonymous telemetry
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
@property
def collect_metrics(self):
@@ -204,17 +178,21 @@ class EmbedChain(JSONSerializable):
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
# Insert the data into the 'data' table
self.cursor.execute(
"""
INSERT OR REPLACE INTO data_sources (hash, pipeline_id, type, value, metadata)
VALUES (?, ?, ?, ?, ?)
""",
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
# Insert the data into the 'ec_data_sources' table
self.db_session.add(
DataSource(
hash=source_hash,
app_id=self.config.id,
type=data_type.value,
value=source,
metadata=json.dumps(metadata),
)
)
# Commit the transaction
self.connection.commit()
try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error adding data source: {e}")
self.db_session.rollback()
if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
@@ -666,9 +644,14 @@ class EmbedChain(JSONSerializable):
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
"""
try:
self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete()
self.db_session.commit()
except Exception as e:
logging.error(f"Error deleting chat history: {e}")
self.db_session.rollback()
return None
self.db.reset()
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
self.connection.commit()
self.delete_all_chat_history(app_id=self.config.id)
# Send anonymous telemetry
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)