from __future__ import annotations from pathlib import Path import sys import tempfile import unittest from unittest.mock import patch SRC_ROOT = Path(__file__).resolve().parents[2] / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from ffx.constants import DATABASE_VERSION # noqa: E402 from ffx.database import DATABASE_VERSION_KEY, databaseContext, getDatabaseVersion # noqa: E402 from ffx.model.property import Property # noqa: E402 from ffx.model.show import Base # noqa: E402 class DatabaseContextTests(unittest.TestCase): def setUp(self): self.tempdir = tempfile.TemporaryDirectory() self.database_path = Path(self.tempdir.name) / "ffx-test.db" def tearDown(self): self.tempdir.cleanup() def test_database_context_bootstraps_new_database_with_current_version(self): with patch("ffx.database.Base.metadata.create_all", wraps=Base.metadata.create_all) as mocked_create_all: context = databaseContext(str(self.database_path)) try: self.assertTrue(self.database_path.exists()) self.assertEqual(DATABASE_VERSION, getDatabaseVersion(context)) finally: context["engine"].dispose() mocked_create_all.assert_called_once() def test_database_context_skips_create_all_when_schema_is_already_present(self): initial_context = databaseContext(str(self.database_path)) initial_context["engine"].dispose() with patch("ffx.database.Base.metadata.create_all") as mocked_create_all: context = databaseContext(str(self.database_path)) try: self.assertEqual(DATABASE_VERSION, getDatabaseVersion(context)) finally: context["engine"].dispose() mocked_create_all.assert_not_called() def test_database_context_restores_missing_version_property_without_schema_bootstrap(self): context = databaseContext(str(self.database_path)) Session = context["session"] try: session = Session() try: version_row = ( session.query(Property) .filter(Property.key == DATABASE_VERSION_KEY) .first() ) session.delete(version_row) session.commit() finally: session.close() finally: context["engine"].dispose() with patch("ffx.database.Base.metadata.create_all") as mocked_create_all: reopened_context = databaseContext(str(self.database_path)) try: self.assertEqual(DATABASE_VERSION, getDatabaseVersion(reopened_context)) finally: reopened_context["engine"].dispose() mocked_create_all.assert_not_called() if __name__ == "__main__": unittest.main()