Files
t6_mem0/embedchain/core/db/database.py

89 lines
3.0 KiB
Python

import os
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session as SQLAlchemySession
from sqlalchemy.orm import scoped_session, sessionmaker
from .models import Base
class DatabaseManager:
def __init__(self, echo: bool = False):
self.database_uri = os.environ.get("EMBEDCHAIN_DB_URI")
self.echo = echo
self.engine: Engine = None
self._session_factory = None
def setup_engine(self) -> None:
"""Initializes the database engine and session factory."""
if not self.database_uri:
raise RuntimeError("Database URI is not set. Set the EMBEDCHAIN_DB_URI environment variable.")
connect_args = {}
if self.database_uri.startswith("sqlite"):
connect_args["check_same_thread"] = False
self.engine = create_engine(self.database_uri, echo=self.echo, connect_args=connect_args)
self._session_factory = scoped_session(sessionmaker(bind=self.engine))
Base.metadata.bind = self.engine
def init_db(self) -> None:
"""Creates all tables defined in the Base metadata."""
if not self.engine:
raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
Base.metadata.create_all(self.engine)
def get_session(self) -> SQLAlchemySession:
"""Provides a session for database operations."""
if not self._session_factory:
raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
return self._session_factory()
def close_session(self) -> None:
"""Closes the current session."""
if self._session_factory:
self._session_factory.remove()
def execute_transaction(self, transaction_block):
"""Executes a block of code within a database transaction."""
session = self.get_session()
try:
transaction_block(session)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
self.close_session()
# Singleton pattern to use throughout the application
database_manager = DatabaseManager()
# Convenience functions for backward compatibility and ease of use
def setup_engine(database_uri: str, echo: bool = False) -> None:
database_manager.database_uri = database_uri
database_manager.echo = echo
database_manager.setup_engine()
def alembic_upgrade() -> None:
"""Upgrades the database to the latest version."""
alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
alembic_cfg = Config(alembic_config_path)
command.upgrade(alembic_cfg, "head")
def init_db() -> None:
alembic_upgrade()
def get_session() -> SQLAlchemySession:
return database_manager.get_session()
def execute_transaction(transaction_block):
database_manager.execute_transaction(transaction_block)