[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)
This commit is contained in:
@@ -3,7 +3,6 @@ import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -16,7 +15,8 @@ from embedchain.cache import (Config, ExactMatchEvaluation,
|
||||
gptcache_data_manager, gptcache_pre_function)
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.core.db.database import get_session
|
||||
from embedchain.core.db.models import DataSource
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
@@ -33,9 +33,6 @@ from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
# Set up the user directory if it doesn't exist already
|
||||
Client.setup_dir()
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
@@ -120,6 +117,9 @@ class App(EmbedChain):
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# Session for the metadata db
|
||||
self.db_session = get_session()
|
||||
|
||||
# If cache_config is provided, initializing the cache ...
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
@@ -127,27 +127,6 @@ class App(EmbedChain):
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
self.user_asks = []
|
||||
@@ -307,20 +286,14 @@ class App(EmbedChain):
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.cursor.execute(
|
||||
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
|
||||
(data_hash, self.local_id),
|
||||
)
|
||||
self.connection.commit()
|
||||
self.db_session.query(DataSource).filter_by(hash=data_hash, app_id=self.local_id).update({"is_uploaded": 1})
|
||||
|
||||
def get_data_sources(self):
|
||||
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
|
||||
|
||||
data_sources = []
|
||||
for data in db_data:
|
||||
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
|
||||
|
||||
return data_sources
|
||||
data_sources = self.db_session.query(DataSource).filter_by(app_id=self.local_id).all()
|
||||
results = []
|
||||
for row in data_sources:
|
||||
results.append({"data_type": row.data_type, "data_value": row.data_value, "metadata": row.metadata})
|
||||
return results
|
||||
|
||||
def deploy(self):
|
||||
if self.client is None:
|
||||
@@ -329,14 +302,11 @@ class App(EmbedChain):
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.cursor.execute(
|
||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
|
||||
).fetchall()
|
||||
|
||||
results = self.db_session.query(DataSource).filter_by(app_id=self.local_id, is_uploaded=0).all()
|
||||
if len(results) > 0:
|
||||
print("🛠️ Adding data to your pipeline...")
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result[1], result[2], result[3]
|
||||
data_hash, data_type, data_value = result.hash, result.data_type, result.data_value
|
||||
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||
|
||||
# Send anonymous telemetry
|
||||
@@ -423,10 +393,6 @@ class App(EmbedChain):
|
||||
else:
|
||||
cache_config = None
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
|
||||
return cls(
|
||||
config=app_config,
|
||||
llm=llm,
|
||||
|
||||
Reference in New Issue
Block a user