226 lines
7.3 KiB
Python
226 lines
7.3 KiB
Python
import os, shutil, click
|
|
|
|
from sqlalchemy import create_engine, inspect, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
# Import the full model package so SQLAlchemy registers every mapped class
|
|
# before metadata creation and the first ORM query.
|
|
import ffx.model
|
|
from ffx.model.show import Base
|
|
|
|
from ffx.model.property import Property
|
|
from ffx.model.migration import (
|
|
DatabaseVersionException,
|
|
getMigrationPlan,
|
|
migrateDatabase,
|
|
)
|
|
|
|
from ffx.constants import DATABASE_VERSION
|
|
|
|
|
|
DATABASE_VERSION_KEY = 'database_version'
|
|
EXPECTED_TABLE_NAMES = set(Base.metadata.tables.keys())
|
|
|
|
def databaseContext(databasePath: str = ''):
|
|
|
|
databaseContext = {}
|
|
|
|
if databasePath is None:
|
|
# sqlite:///:memory:
|
|
databasePath = ':memory:'
|
|
elif not databasePath:
|
|
homeDir = os.path.expanduser("~")
|
|
ffxVarDir = os.path.join(homeDir, '.local', 'var', 'ffx')
|
|
if not os.path.exists(ffxVarDir):
|
|
os.makedirs(ffxVarDir)
|
|
databasePath = os.path.join(ffxVarDir, 'ffx.db')
|
|
else:
|
|
databasePath = os.path.expanduser(databasePath)
|
|
|
|
if databasePath != ':memory:':
|
|
databasePath = os.path.abspath(databasePath)
|
|
|
|
databaseContext['path'] = databasePath
|
|
databaseContext['url'] = f"sqlite:///{databasePath}"
|
|
databaseContext['engine'] = create_engine(databaseContext['url'])
|
|
databaseContext['session'] = sessionmaker(bind=databaseContext['engine'])
|
|
|
|
bootstrapDatabaseIfNeeded(databaseContext)
|
|
|
|
# isSyncronuous = False
|
|
# while not isSyncronuous:
|
|
# while True:
|
|
# try:
|
|
# with databaseContext['database_engine'].connect() as connection:
|
|
# connection.execute(sqlalchemy.text('PRAGMA foreign_keys=ON;'))
|
|
# #isSyncronuous = True
|
|
# break
|
|
# except sqlite3.OperationalError:
|
|
# time.sleep(0.1)
|
|
|
|
ensureDatabaseVersion(databaseContext)
|
|
|
|
return databaseContext
|
|
|
|
|
|
def databaseNeedsBootstrap(databaseContext) -> bool:
|
|
inspector = inspect(databaseContext['engine'])
|
|
existingTableNames = set(inspector.get_table_names())
|
|
return not EXPECTED_TABLE_NAMES.issubset(existingTableNames)
|
|
|
|
|
|
def bootstrapDatabaseIfNeeded(databaseContext):
|
|
if not databaseNeedsBootstrap(databaseContext):
|
|
return
|
|
|
|
Base.metadata.create_all(databaseContext['engine'])
|
|
|
|
|
|
def ensureDatabaseVersion(databaseContext):
|
|
|
|
currentDatabaseVersion = getDatabaseVersion(databaseContext)
|
|
if not currentDatabaseVersion:
|
|
setDatabaseVersion(databaseContext, DATABASE_VERSION)
|
|
return
|
|
|
|
if currentDatabaseVersion > DATABASE_VERSION:
|
|
raise DatabaseVersionException(
|
|
f"Current database version ({currentDatabaseVersion}) does not match required ({DATABASE_VERSION})"
|
|
)
|
|
|
|
if currentDatabaseVersion < DATABASE_VERSION:
|
|
promptForDatabaseMigration(databaseContext, currentDatabaseVersion, DATABASE_VERSION)
|
|
migrateDatabase(databaseContext, currentDatabaseVersion, DATABASE_VERSION, setDatabaseVersion)
|
|
currentDatabaseVersion = getDatabaseVersion(databaseContext)
|
|
|
|
if currentDatabaseVersion != DATABASE_VERSION:
|
|
raise DatabaseVersionException(
|
|
f"Current database version ({currentDatabaseVersion}) does not match required ({DATABASE_VERSION})"
|
|
)
|
|
|
|
ensureCurrentSchemaCompatibility(databaseContext)
|
|
|
|
|
|
def ensureCurrentSchemaCompatibility(databaseContext):
|
|
engine = databaseContext['engine']
|
|
inspector = inspect(engine)
|
|
showColumns = {
|
|
column['name']
|
|
for column in inspector.get_columns('shows')
|
|
}
|
|
|
|
alterStatements = []
|
|
if 'quality' not in showColumns:
|
|
alterStatements.append("ALTER TABLE shows ADD COLUMN quality INTEGER DEFAULT 0")
|
|
if 'notes' not in showColumns:
|
|
alterStatements.append("ALTER TABLE shows ADD COLUMN notes TEXT DEFAULT ''")
|
|
|
|
if not alterStatements:
|
|
return
|
|
|
|
with engine.begin() as connection:
|
|
for alterStatement in alterStatements:
|
|
connection.execute(text(alterStatement))
|
|
|
|
|
|
def promptForDatabaseMigration(databaseContext, currentDatabaseVersion: int, targetDatabaseVersion: int):
|
|
migrationPlan = getMigrationPlan(currentDatabaseVersion, targetDatabaseVersion)
|
|
|
|
click.echo("Database migration required.")
|
|
click.echo(f"Current version: {currentDatabaseVersion}")
|
|
click.echo(f"Target version: {targetDatabaseVersion}")
|
|
click.echo("Steps required:")
|
|
|
|
missingSteps = []
|
|
for migrationStep in migrationPlan:
|
|
moduleStatus = "present" if migrationStep.modulePresent else "missing"
|
|
click.echo(
|
|
f" {migrationStep.versionFrom} -> {migrationStep.versionTo}: "
|
|
+ f"{migrationStep.moduleName} [{moduleStatus}]"
|
|
)
|
|
if not migrationStep.modulePresent:
|
|
missingSteps.append(migrationStep)
|
|
|
|
if missingSteps:
|
|
firstMissingStep = missingSteps[0]
|
|
raise DatabaseVersionException(
|
|
f"No migration path from database version "
|
|
+ f"{firstMissingStep.versionFrom} to {firstMissingStep.versionTo}"
|
|
)
|
|
|
|
if not click.confirm(
|
|
"Create a backup and continue with database migration?",
|
|
default=True,
|
|
):
|
|
raise click.ClickException("Database migration aborted by user.")
|
|
|
|
backupPath = backupDatabaseBeforeMigration(
|
|
databaseContext,
|
|
currentDatabaseVersion,
|
|
targetDatabaseVersion,
|
|
)
|
|
click.echo(f"Database backup created: {backupPath}")
|
|
|
|
|
|
def backupDatabaseBeforeMigration(databaseContext, currentDatabaseVersion: int, targetDatabaseVersion: int) -> str:
|
|
databasePath = databaseContext.get('path', '')
|
|
if not databasePath or databasePath == ':memory:':
|
|
raise click.ClickException("Database migration backup requires a file-backed SQLite database.")
|
|
|
|
if not os.path.isfile(databasePath):
|
|
raise click.ClickException(f"Database file not found for backup: {databasePath}")
|
|
|
|
backupPath = f"{databasePath}.v{currentDatabaseVersion}-to-v{targetDatabaseVersion}.bak"
|
|
backupIndex = 1
|
|
while os.path.exists(backupPath):
|
|
backupPath = (
|
|
f"{databasePath}.v{currentDatabaseVersion}-to-v{targetDatabaseVersion}.{backupIndex}.bak"
|
|
)
|
|
backupIndex += 1
|
|
|
|
databaseContext['engine'].dispose()
|
|
shutil.copy2(databasePath, backupPath)
|
|
|
|
return backupPath
|
|
|
|
|
|
def getDatabaseVersion(databaseContext):
|
|
|
|
try:
|
|
|
|
Session = databaseContext['session']
|
|
s = Session()
|
|
versionProperty = s.query(Property).filter(Property.key == DATABASE_VERSION_KEY).first()
|
|
|
|
return int(versionProperty.value) if versionProperty is not None else 0
|
|
|
|
except Exception as ex:
|
|
raise click.ClickException(f"getDatabaseVersion(): {repr(ex)}")
|
|
finally:
|
|
s.close()
|
|
|
|
|
|
def setDatabaseVersion(databaseContext, databaseVersion: int):
|
|
|
|
try:
|
|
Session = databaseContext['session']
|
|
s = Session()
|
|
|
|
q = s.query(Property).filter(Property.key == DATABASE_VERSION_KEY)
|
|
|
|
dbVersion = int(databaseVersion)
|
|
|
|
versionProperty = q.first()
|
|
if versionProperty:
|
|
versionProperty.value = str(dbVersion)
|
|
else:
|
|
versionProperty = Property(key = DATABASE_VERSION_KEY,
|
|
value = str(dbVersion))
|
|
s.add(versionProperty)
|
|
s.commit()
|
|
|
|
except Exception as ex:
|
|
raise click.ClickException(f"setDatabaseVersion(): {repr(ex)}")
|
|
finally:
|
|
s.close()
|