Files
ffx/src/ffx/database.py
Javanaut 0e51d6337f ff
2026-04-12 18:35:13 +02:00

226 lines
7.3 KiB
Python

import os, shutil, click
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import sessionmaker
# Import the full model package so SQLAlchemy registers every mapped class
# before metadata creation and the first ORM query.
import ffx.model
from ffx.model.show import Base
from ffx.model.property import Property
from ffx.model.migration import (
DatabaseVersionException,
getMigrationPlan,
migrateDatabase,
)
from ffx.constants import DATABASE_VERSION
DATABASE_VERSION_KEY = 'database_version'
EXPECTED_TABLE_NAMES = set(Base.metadata.tables.keys())
def databaseContext(databasePath: str = ''):
databaseContext = {}
if databasePath is None:
# sqlite:///:memory:
databasePath = ':memory:'
elif not databasePath:
homeDir = os.path.expanduser("~")
ffxVarDir = os.path.join(homeDir, '.local', 'var', 'ffx')
if not os.path.exists(ffxVarDir):
os.makedirs(ffxVarDir)
databasePath = os.path.join(ffxVarDir, 'ffx.db')
else:
databasePath = os.path.expanduser(databasePath)
if databasePath != ':memory:':
databasePath = os.path.abspath(databasePath)
databaseContext['path'] = databasePath
databaseContext['url'] = f"sqlite:///{databasePath}"
databaseContext['engine'] = create_engine(databaseContext['url'])
databaseContext['session'] = sessionmaker(bind=databaseContext['engine'])
bootstrapDatabaseIfNeeded(databaseContext)
# isSyncronuous = False
# while not isSyncronuous:
# while True:
# try:
# with databaseContext['database_engine'].connect() as connection:
# connection.execute(sqlalchemy.text('PRAGMA foreign_keys=ON;'))
# #isSyncronuous = True
# break
# except sqlite3.OperationalError:
# time.sleep(0.1)
ensureDatabaseVersion(databaseContext)
return databaseContext
def databaseNeedsBootstrap(databaseContext) -> bool:
inspector = inspect(databaseContext['engine'])
existingTableNames = set(inspector.get_table_names())
return not EXPECTED_TABLE_NAMES.issubset(existingTableNames)
def bootstrapDatabaseIfNeeded(databaseContext):
if not databaseNeedsBootstrap(databaseContext):
return
Base.metadata.create_all(databaseContext['engine'])
def ensureDatabaseVersion(databaseContext):
currentDatabaseVersion = getDatabaseVersion(databaseContext)
if not currentDatabaseVersion:
setDatabaseVersion(databaseContext, DATABASE_VERSION)
return
if currentDatabaseVersion > DATABASE_VERSION:
raise DatabaseVersionException(
f"Current database version ({currentDatabaseVersion}) does not match required ({DATABASE_VERSION})"
)
if currentDatabaseVersion < DATABASE_VERSION:
promptForDatabaseMigration(databaseContext, currentDatabaseVersion, DATABASE_VERSION)
migrateDatabase(databaseContext, currentDatabaseVersion, DATABASE_VERSION, setDatabaseVersion)
currentDatabaseVersion = getDatabaseVersion(databaseContext)
if currentDatabaseVersion != DATABASE_VERSION:
raise DatabaseVersionException(
f"Current database version ({currentDatabaseVersion}) does not match required ({DATABASE_VERSION})"
)
ensureCurrentSchemaCompatibility(databaseContext)
def ensureCurrentSchemaCompatibility(databaseContext):
engine = databaseContext['engine']
inspector = inspect(engine)
showColumns = {
column['name']
for column in inspector.get_columns('shows')
}
alterStatements = []
if 'quality' not in showColumns:
alterStatements.append("ALTER TABLE shows ADD COLUMN quality INTEGER DEFAULT 0")
if 'notes' not in showColumns:
alterStatements.append("ALTER TABLE shows ADD COLUMN notes TEXT DEFAULT ''")
if not alterStatements:
return
with engine.begin() as connection:
for alterStatement in alterStatements:
connection.execute(text(alterStatement))
def promptForDatabaseMigration(databaseContext, currentDatabaseVersion: int, targetDatabaseVersion: int):
migrationPlan = getMigrationPlan(currentDatabaseVersion, targetDatabaseVersion)
click.echo("Database migration required.")
click.echo(f"Current version: {currentDatabaseVersion}")
click.echo(f"Target version: {targetDatabaseVersion}")
click.echo("Steps required:")
missingSteps = []
for migrationStep in migrationPlan:
moduleStatus = "present" if migrationStep.modulePresent else "missing"
click.echo(
f" {migrationStep.versionFrom} -> {migrationStep.versionTo}: "
+ f"{migrationStep.moduleName} [{moduleStatus}]"
)
if not migrationStep.modulePresent:
missingSteps.append(migrationStep)
if missingSteps:
firstMissingStep = missingSteps[0]
raise DatabaseVersionException(
f"No migration path from database version "
+ f"{firstMissingStep.versionFrom} to {firstMissingStep.versionTo}"
)
if not click.confirm(
"Create a backup and continue with database migration?",
default=True,
):
raise click.ClickException("Database migration aborted by user.")
backupPath = backupDatabaseBeforeMigration(
databaseContext,
currentDatabaseVersion,
targetDatabaseVersion,
)
click.echo(f"Database backup created: {backupPath}")
def backupDatabaseBeforeMigration(databaseContext, currentDatabaseVersion: int, targetDatabaseVersion: int) -> str:
databasePath = databaseContext.get('path', '')
if not databasePath or databasePath == ':memory:':
raise click.ClickException("Database migration backup requires a file-backed SQLite database.")
if not os.path.isfile(databasePath):
raise click.ClickException(f"Database file not found for backup: {databasePath}")
backupPath = f"{databasePath}.v{currentDatabaseVersion}-to-v{targetDatabaseVersion}.bak"
backupIndex = 1
while os.path.exists(backupPath):
backupPath = (
f"{databasePath}.v{currentDatabaseVersion}-to-v{targetDatabaseVersion}.{backupIndex}.bak"
)
backupIndex += 1
databaseContext['engine'].dispose()
shutil.copy2(databasePath, backupPath)
return backupPath
def getDatabaseVersion(databaseContext):
try:
Session = databaseContext['session']
s = Session()
versionProperty = s.query(Property).filter(Property.key == DATABASE_VERSION_KEY).first()
return int(versionProperty.value) if versionProperty is not None else 0
except Exception as ex:
raise click.ClickException(f"getDatabaseVersion(): {repr(ex)}")
finally:
s.close()
def setDatabaseVersion(databaseContext, databaseVersion: int):
try:
Session = databaseContext['session']
s = Session()
q = s.query(Property).filter(Property.key == DATABASE_VERSION_KEY)
dbVersion = int(databaseVersion)
versionProperty = q.first()
if versionProperty:
versionProperty.value = str(dbVersion)
else:
versionProperty = Property(key = DATABASE_VERSION_KEY,
value = str(dbVersion))
s.add(versionProperty)
s.commit()
except Exception as ex:
raise click.ClickException(f"setDatabaseVersion(): {repr(ex)}")
finally:
s.close()