[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)
This commit is contained in:
0
embedchain/core/__init__.py
Normal file
0
embedchain/core/__init__.py
Normal file
0
embedchain/core/db/__init__.py
Normal file
0
embedchain/core/db/__init__.py
Normal file
83
embedchain/core/db/database.py
Normal file
83
embedchain/core/db/database.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session as SQLAlchemySession
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
from .models import Base
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self, database_uri: str = "sqlite:///embedchain.db", echo: bool = False):
|
||||
self.database_uri = database_uri
|
||||
self.echo = echo
|
||||
self.engine: Engine = None
|
||||
self._session_factory = None
|
||||
|
||||
def setup_engine(self) -> None:
|
||||
"""Initializes the database engine and session factory."""
|
||||
self.engine = create_engine(self.database_uri, echo=self.echo, connect_args={"check_same_thread": False})
|
||||
self._session_factory = scoped_session(sessionmaker(bind=self.engine))
|
||||
Base.metadata.bind = self.engine
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Creates all tables defined in the Base metadata."""
|
||||
if not self.engine:
|
||||
raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
def get_session(self) -> SQLAlchemySession:
|
||||
"""Provides a session for database operations."""
|
||||
if not self._session_factory:
|
||||
raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
|
||||
return self._session_factory()
|
||||
|
||||
def close_session(self) -> None:
|
||||
"""Closes the current session."""
|
||||
if self._session_factory:
|
||||
self._session_factory.remove()
|
||||
|
||||
def execute_transaction(self, transaction_block):
|
||||
"""Executes a block of code within a database transaction."""
|
||||
session = self.get_session()
|
||||
try:
|
||||
transaction_block(session)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self.close_session()
|
||||
|
||||
|
||||
# Singleton pattern to use throughout the application
|
||||
database_manager = DatabaseManager()
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility and ease of use
|
||||
def setup_engine(database_uri: str = "sqlite:///embedchain.db", echo: bool = False) -> None:
|
||||
database_manager.database_uri = database_uri
|
||||
database_manager.echo = echo
|
||||
database_manager.setup_engine()
|
||||
|
||||
|
||||
def alembic_upgrade() -> None:
|
||||
"""Upgrades the database to the latest version."""
|
||||
alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
|
||||
alembic_cfg = Config(alembic_config_path)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
alembic_upgrade()
|
||||
|
||||
|
||||
def get_session() -> SQLAlchemySession:
|
||||
return database_manager.get_session()
|
||||
|
||||
|
||||
def execute_transaction(transaction_block):
|
||||
database_manager.execute_transaction(transaction_block)
|
||||
31
embedchain/core/db/models.py
Normal file
31
embedchain/core/db/models.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import TIMESTAMP, Column, Integer, String, Text, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
metadata = Base.metadata
|
||||
|
||||
|
||||
class DataSource(Base):
|
||||
__tablename__ = "ec_data_sources"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
app_id = Column(Text, index=True)
|
||||
hash = Column(Text, index=True)
|
||||
type = Column(Text, index=True)
|
||||
value = Column(Text)
|
||||
meta_data = Column(Text, name="metadata")
|
||||
is_uploaded = Column(Integer, default=0)
|
||||
|
||||
|
||||
class ChatHistory(Base):
|
||||
__tablename__ = "ec_chat_history"
|
||||
|
||||
app_id = Column(String, primary_key=True)
|
||||
id = Column(String, primary_key=True)
|
||||
session_id = Column(String, primary_key=True, index=True)
|
||||
question = Column(Text)
|
||||
answer = Column(Text)
|
||||
meta_data = Column(Text, name="metadata")
|
||||
created_at = Column(TIMESTAMP, default=func.current_timestamp(), index=True)
|
||||
Reference in New Issue
Block a user