from __future__ import annotations from pathlib import Path import sqlite3 import sys import tempfile import unittest from unittest.mock import patch import click SRC_ROOT = Path(__file__).resolve().parents[2] / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from ffx.constants import DATABASE_VERSION # noqa: E402 from ffx.database import DATABASE_VERSION_KEY, databaseContext, getDatabaseVersion # noqa: E402 from ffx.model.shifted_season import ShiftedSeason # noqa: E402 from ffx.model.property import Property # noqa: E402 from ffx.model.show import Show # noqa: E402 from ffx.model.show import Base # noqa: E402 from ffx.show_controller import ShowController # noqa: E402 from ffx.show_descriptor import ShowDescriptor # noqa: E402 from ffx.shifted_season_controller import ShiftedSeasonController # noqa: E402 class StaticConfig: def getData(self): return {} class DatabaseContextTests(unittest.TestCase): def setUp(self): self.tempdir = tempfile.TemporaryDirectory() self.database_path = Path(self.tempdir.name) / "ffx-test.db" def tearDown(self): self.tempdir.cleanup() def create_demo_show_with_shift(self): database_context = databaseContext(str(self.database_path)) context = { "database": database_context, "config": StaticConfig(), "logger": object(), } try: ShowController(context).updateShow( ShowDescriptor(id=1, name="Demo", year=2000) ) shifted_season_id = ShiftedSeasonController(context).addShiftedSeason( showId=1, shiftedSeasonObj={ "original_season": 1, "first_episode": 1, "last_episode": 10, "season_offset": 1, "episode_offset": -10, }, ) finally: database_context["engine"].dispose() return shifted_season_id def rewrite_shows_table_without_show_fields(self, cursor): cursor.execute("ALTER TABLE shows RENAME TO shows_current") cursor.execute( """ CREATE TABLE shows ( id INTEGER PRIMARY KEY, name VARCHAR, year INTEGER, index_season_digits INTEGER, index_episode_digits INTEGER, indicator_season_digits INTEGER, indicator_episode_digits INTEGER ) """ ) cursor.execute( """ INSERT INTO shows ( id, name, year, index_season_digits, index_episode_digits, indicator_season_digits, indicator_episode_digits ) SELECT id, name, year, index_season_digits, index_episode_digits, indicator_season_digits, indicator_episode_digits FROM shows_current """ ) cursor.execute("DROP TABLE shows_current") def rewrite_shifted_seasons_table_without_pattern_owner(self, cursor): cursor.execute("DROP INDEX IF EXISTS ix_shifted_seasons_show_id") cursor.execute("DROP INDEX IF EXISTS ix_shifted_seasons_pattern_id") cursor.execute( "ALTER TABLE shifted_seasons RENAME TO shifted_seasons_current" ) cursor.execute( """ CREATE TABLE shifted_seasons ( id INTEGER PRIMARY KEY, show_id INTEGER, original_season INTEGER, first_episode INTEGER DEFAULT -1, last_episode INTEGER DEFAULT -1, season_offset INTEGER DEFAULT 0, episode_offset INTEGER DEFAULT 0, FOREIGN KEY(show_id) REFERENCES shows(id) ON DELETE CASCADE ) """ ) cursor.execute( """ INSERT INTO shifted_seasons ( id, show_id, original_season, first_episode, last_episode, season_offset, episode_offset ) SELECT id, show_id, original_season, first_episode, last_episode, season_offset, episode_offset FROM shifted_seasons_current """ ) cursor.execute("DROP TABLE shifted_seasons_current") def test_database_context_bootstraps_new_database_with_current_version(self): with patch("ffx.database.Base.metadata.create_all", wraps=Base.metadata.create_all) as mocked_create_all: context = databaseContext(str(self.database_path)) try: self.assertTrue(self.database_path.exists()) self.assertEqual(DATABASE_VERSION, getDatabaseVersion(context)) finally: context["engine"].dispose() mocked_create_all.assert_called_once() def test_database_context_skips_create_all_when_schema_is_already_present(self): initial_context = databaseContext(str(self.database_path)) initial_context["engine"].dispose() with patch("ffx.database.Base.metadata.create_all") as mocked_create_all: context = databaseContext(str(self.database_path)) try: self.assertEqual(DATABASE_VERSION, getDatabaseVersion(context)) finally: context["engine"].dispose() mocked_create_all.assert_not_called() def test_database_context_restores_missing_version_property_without_schema_bootstrap(self): context = databaseContext(str(self.database_path)) Session = context["session"] try: session = Session() try: version_row = ( session.query(Property) .filter(Property.key == DATABASE_VERSION_KEY) .first() ) session.delete(version_row) session.commit() finally: session.close() finally: context["engine"].dispose() with patch("ffx.database.Base.metadata.create_all") as mocked_create_all: reopened_context = databaseContext(str(self.database_path)) try: self.assertEqual(DATABASE_VERSION, getDatabaseVersion(reopened_context)) finally: reopened_context["engine"].dispose() mocked_create_all.assert_not_called() def test_database_context_migrates_v2_shifted_seasons_schema_to_v3(self): shifted_season_id = self.create_demo_show_with_shift() connection = sqlite3.connect(self.database_path) try: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys=OFF") self.rewrite_shifted_seasons_table_without_pattern_owner(cursor) self.rewrite_shows_table_without_show_fields(cursor) cursor.execute( "UPDATE properties SET value = '2' WHERE key = ?", (DATABASE_VERSION_KEY,), ) connection.commit() finally: connection.close() with patch("ffx.database.click.confirm", return_value=True) as mocked_confirm, patch( "ffx.database.click.echo" ) as mocked_echo: reopened_context = databaseContext(str(self.database_path)) try: self.assertEqual(DATABASE_VERSION, getDatabaseVersion(reopened_context)) mocked_confirm.assert_called_once() backup_path = Path(f"{self.database_path}.v2-to-v{DATABASE_VERSION}.bak") self.assertTrue(backup_path.exists()) Session = reopened_context["session"] session = Session() try: migrated_shifted_season = ( session.query(ShiftedSeason) .filter(ShiftedSeason.id == shifted_season_id) .first() ) self.assertIsNotNone(migrated_shifted_season) self.assertEqual(1, migrated_shifted_season.getShowId()) self.assertIsNone(migrated_shifted_season.getPatternId()) self.assertEqual(1, migrated_shifted_season.getOriginalSeason()) self.assertEqual(1, migrated_shifted_season.getFirstEpisode()) self.assertEqual(10, migrated_shifted_season.getLastEpisode()) migrated_show = session.query(Show).filter(Show.id == 1).first() self.assertIsNotNone(migrated_show) self.assertEqual(0, int(migrated_show.quality or 0)) self.assertEqual('', str(migrated_show.notes or '')) finally: session.close() finally: reopened_context["engine"].dispose() echoedLines = [call.args[0] for call in mocked_echo.call_args_list] self.assertIn("Database migration required.", echoedLines) self.assertIn("Current version: 2", echoedLines) self.assertIn(f"Target version: {DATABASE_VERSION}", echoedLines) self.assertIn( " 2 -> 3: ffx.model.migration.step_2_3 [present]", echoedLines, ) def test_database_context_aborts_migration_when_confirmation_is_declined(self): context = databaseContext(str(self.database_path)) try: Session = context["session"] session = Session() try: version_row = ( session.query(Property) .filter(Property.key == DATABASE_VERSION_KEY) .first() ) version_row.value = "2" session.commit() finally: session.close() finally: context["engine"].dispose() with patch("ffx.database.click.confirm", return_value=False), patch( "ffx.database.click.echo" ): with self.assertRaises(click.ClickException) as raisedContext: databaseContext(str(self.database_path)) self.assertEqual("Database migration aborted by user.", str(raisedContext.exception)) self.assertFalse(Path(f"{self.database_path}.v2-to-v{DATABASE_VERSION}.bak").exists()) def test_database_context_repairs_current_show_schema_without_version_bump(self): self.create_demo_show_with_shift() connection = sqlite3.connect(self.database_path) try: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys=OFF") self.rewrite_shows_table_without_show_fields(cursor) connection.commit() finally: connection.close() with patch("ffx.database.click.confirm") as mocked_confirm, patch( "ffx.database.click.echo" ) as mocked_echo: reopened_context = databaseContext(str(self.database_path)) try: self.assertEqual(DATABASE_VERSION, getDatabaseVersion(reopened_context)) Session = reopened_context["session"] session = Session() try: repaired_show = session.query(Show).filter(Show.id == 1).first() self.assertIsNotNone(repaired_show) self.assertEqual(0, int(repaired_show.quality or 0)) self.assertEqual('', str(repaired_show.notes or '')) finally: session.close() finally: reopened_context["engine"].dispose() mocked_confirm.assert_not_called() mocked_echo.assert_not_called() if __name__ == "__main__": unittest.main()