[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)

This commit is contained in:
Deshraj Yadav
2024-02-19 13:04:18 -08:00
committed by GitHub
parent 6c12bc9044
commit 5e2e7fb639
20 changed files with 601 additions and 202 deletions

View File

@@ -3,7 +3,6 @@ import concurrent.futures
import json
import logging
import os
import sqlite3
import uuid
from typing import Any, Optional, Union
@@ -16,7 +15,8 @@ from embedchain.cache import (Config, ExactMatchEvaluation,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH
from embedchain.core.db.database import get_session
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
@@ -33,9 +33,6 @@ from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
# Set up the user directory if it doesn't exist already
Client.setup_dir()
@register_deserializable
class App(EmbedChain):
@@ -120,6 +117,9 @@ class App(EmbedChain):
self.llm = llm or OpenAILlm()
self._init_db()
# Session for the metadata db
self.db_session = get_session()
# If cache_config is provided, initializing the cache ...
if self.cache_config is not None:
self._init_cache()
@@ -127,27 +127,6 @@ class App(EmbedChain):
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# Send anonymous telemetry
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
self.user_asks = []
@@ -307,20 +286,14 @@ class App(EmbedChain):
return False
def _mark_data_as_uploaded(self, data_hash):
self.cursor.execute(
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
(data_hash, self.local_id),
)
self.connection.commit()
self.db_session.query(DataSource).filter_by(hash=data_hash, app_id=self.local_id).update({"is_uploaded": 1})
def get_data_sources(self):
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
data_sources = []
for data in db_data:
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
return data_sources
data_sources = self.db_session.query(DataSource).filter_by(app_id=self.local_id).all()
results = []
for row in data_sources:
results.append({"data_type": row.data_type, "data_value": row.data_value, "metadata": row.metadata})
return results
def deploy(self):
if self.client is None:
@@ -329,14 +302,11 @@ class App(EmbedChain):
pipeline_data = self._create_pipeline()
self.id = pipeline_data["id"]
results = self.cursor.execute(
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
).fetchall()
results = self.db_session.query(DataSource).filter_by(app_id=self.local_id, is_uploaded=0).all()
if len(results) > 0:
print("🛠️ Adding data to your pipeline...")
for result in results:
data_hash, data_type, data_value = result[1], result[2], result[3]
data_hash, data_type, data_value = result.hash, result.data_type, result.data_value
self._process_and_upload_data(data_hash, data_type, data_value)
# Send anonymous telemetry
@@ -423,10 +393,6 @@ class App(EmbedChain):
else:
cache_config = None
# Send anonymous telemetry
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
return cls(
config=app_config,
llm=llm,