89 lines
3.0 KiB
Python
89 lines
3.0 KiB
Python
import os
|
|
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.engine.base import Engine
|
|
from sqlalchemy.orm import Session as SQLAlchemySession
|
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
|
|
|
from .models import Base
|
|
|
|
|
|
class DatabaseManager:
|
|
def __init__(self, echo: bool = False):
|
|
self.database_uri = os.environ.get("EMBEDCHAIN_DB_URI")
|
|
self.echo = echo
|
|
self.engine: Engine = None
|
|
self._session_factory = None
|
|
|
|
def setup_engine(self) -> None:
|
|
"""Initializes the database engine and session factory."""
|
|
if not self.database_uri:
|
|
raise RuntimeError("Database URI is not set. Set the EMBEDCHAIN_DB_URI environment variable.")
|
|
connect_args = {}
|
|
if self.database_uri.startswith("sqlite"):
|
|
connect_args["check_same_thread"] = False
|
|
self.engine = create_engine(self.database_uri, echo=self.echo, connect_args=connect_args)
|
|
self._session_factory = scoped_session(sessionmaker(bind=self.engine))
|
|
Base.metadata.bind = self.engine
|
|
|
|
def init_db(self) -> None:
|
|
"""Creates all tables defined in the Base metadata."""
|
|
if not self.engine:
|
|
raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
|
|
Base.metadata.create_all(self.engine)
|
|
|
|
def get_session(self) -> SQLAlchemySession:
|
|
"""Provides a session for database operations."""
|
|
if not self._session_factory:
|
|
raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
|
|
return self._session_factory()
|
|
|
|
def close_session(self) -> None:
|
|
"""Closes the current session."""
|
|
if self._session_factory:
|
|
self._session_factory.remove()
|
|
|
|
def execute_transaction(self, transaction_block):
|
|
"""Executes a block of code within a database transaction."""
|
|
session = self.get_session()
|
|
try:
|
|
transaction_block(session)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
finally:
|
|
self.close_session()
|
|
|
|
|
|
# Singleton pattern to use throughout the application
|
|
database_manager = DatabaseManager()
|
|
|
|
|
|
# Convenience functions for backward compatibility and ease of use
|
|
def setup_engine(database_uri: str, echo: bool = False) -> None:
|
|
database_manager.database_uri = database_uri
|
|
database_manager.echo = echo
|
|
database_manager.setup_engine()
|
|
|
|
|
|
def alembic_upgrade() -> None:
|
|
"""Upgrades the database to the latest version."""
|
|
alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
|
|
alembic_cfg = Config(alembic_config_path)
|
|
command.upgrade(alembic_cfg, "head")
|
|
|
|
|
|
def init_db() -> None:
|
|
alembic_upgrade()
|
|
|
|
|
|
def get_session() -> SQLAlchemySession:
|
|
return database_manager.get_session()
|
|
|
|
|
|
def execute_transaction(transaction_block):
|
|
database_manager.execute_transaction(transaction_block)
|