102 lines
3.0 KiB
Python
102 lines
3.0 KiB
Python
import os, click
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from ffx.model.show import Base
|
|
|
|
from ffx.model.property import Property
|
|
|
|
from ffx.constants import DATABASE_VERSION
|
|
|
|
|
|
DATABASE_VERSION_KEY = 'database_version'
|
|
|
|
class DatabaseVersionException(Exception):
|
|
def __init__(self, errorMessage):
|
|
super().__init__(errorMessage)
|
|
|
|
def databaseContext(databasePath: str = ''):
|
|
|
|
databaseContext = {}
|
|
|
|
if databasePath is None:
|
|
# sqlite:///:memory:
|
|
databasePath = ':memory:'
|
|
elif not databasePath:
|
|
homeDir = os.path.expanduser("~")
|
|
ffxVarDir = os.path.join(homeDir, '.local', 'var', 'ffx')
|
|
if not os.path.exists(ffxVarDir):
|
|
os.makedirs(ffxVarDir)
|
|
databasePath = os.path.join(ffxVarDir, 'ffx.db')
|
|
|
|
databaseContext['url'] = f"sqlite:///{databasePath}"
|
|
databaseContext['engine'] = create_engine(databaseContext['url'])
|
|
databaseContext['session'] = sessionmaker(bind=databaseContext['engine'])
|
|
|
|
Base.metadata.create_all(databaseContext['engine'])
|
|
|
|
# isSyncronuous = False
|
|
# while not isSyncronuous:
|
|
# while True:
|
|
# try:
|
|
# with databaseContext['database_engine'].connect() as connection:
|
|
# connection.execute(sqlalchemy.text('PRAGMA foreign_keys=ON;'))
|
|
# #isSyncronuous = True
|
|
# break
|
|
# except sqlite3.OperationalError:
|
|
# time.sleep(0.1)
|
|
|
|
ensureDatabaseVersion(databaseContext)
|
|
|
|
return databaseContext
|
|
|
|
def ensureDatabaseVersion(databaseContext):
|
|
|
|
currentDatabaseVersion = getDatabaseVersion(databaseContext)
|
|
if currentDatabaseVersion:
|
|
if currentDatabaseVersion != DATABASE_VERSION:
|
|
raise DatabaseVersionException(f"Current database version ({currentDatabaseVersion}) does not match required ({DATABASE_VERSION})")
|
|
else:
|
|
setDatabaseVersion(databaseContext, DATABASE_VERSION)
|
|
|
|
|
|
def getDatabaseVersion(databaseContext):
|
|
|
|
try:
|
|
|
|
Session = databaseContext['session']
|
|
s = Session()
|
|
q = s.query(Property).filter(Property.key == DATABASE_VERSION_KEY)
|
|
|
|
return int(q.first().value) if q.count() else 0
|
|
|
|
except Exception as ex:
|
|
raise click.ClickException(f"getDatabaseVersion(): {repr(ex)}")
|
|
finally:
|
|
s.close()
|
|
|
|
|
|
def setDatabaseVersion(databaseContext, databaseVersion: int):
|
|
|
|
try:
|
|
Session = databaseContext['session']
|
|
s = Session()
|
|
|
|
q = s.query(Property).filter(Property.key == DATABASE_VERSION_KEY)
|
|
|
|
dbVersion = int(databaseVersion)
|
|
|
|
versionProperty = q.first()
|
|
if versionProperty:
|
|
versionProperty.value = str(dbVersion)
|
|
else:
|
|
versionProperty = Property(key = DATABASE_VERSION_KEY,
|
|
value = str(dbVersion))
|
|
s.add(versionProperty)
|
|
s.commit()
|
|
|
|
except Exception as ex:
|
|
raise click.ClickException(f"setDatabaseVersion(): {repr(ex)}")
|
|
finally:
|
|
s.close() |