From a9f38d22f149539a59b0694532503531520395a0 Mon Sep 17 00:00:00 2001 From: Rustam Zhumagambetov Date: Tue, 14 May 2024 12:17:04 +0200 Subject: [PATCH 1/2] disable hpc checkpoint if ckpt_path == 'last' --- .../pytorch/trainer/connectors/signal_connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 728d8b6b6ee43..be896c4b1f6a0 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -72,8 +72,10 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: for logger in self.trainer.loggers: logger.finalize("finished") - hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir) - self.trainer.save_checkpoint(hpc_save_path) + if self.trainer._checkpoint_connector._ckpt_path != "last": + log.info(f"Detecting SLURM environment and ckpt_path=='last'. Disabling hpc ckpt") + hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir) + self.trainer.save_checkpoint(hpc_save_path) if self.trainer.is_global_zero: # find job id From 42a255853782424123860328d034294a0f661b62 Mon Sep 17 00:00:00 2001 From: Rustam Zhumagambetov Date: Tue, 14 May 2024 12:17:15 +0200 Subject: [PATCH 2/2] test hpc and ckpt_path='last' --- .../trainer/connectors/test_signal_connector.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 8825db3727e86..aa7aa280eefa8 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -70,6 +70,20 @@ def training_step(self, batch, batch_idx): signal.signal(signal.SIGTERM, signal.SIG_DFL) +@RunIf(skip_windows=True) +@mock.patch("lightning.fabric.plugins.environments.slurm.SLURMEnvironment.detect", return_value=True) +@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) +def test_hpc_ckpt_last_not_created(_, save_checkpoint, tmp_path): + """Test that when ckpt_path='last' hpc_ckpt is not created""" + model = BoringModel() + trainer = Trainer(default_root_dir=tmp_path, max_steps=1, logger=False) + trainer.fit(model, ckpt_path="last") + + connector = _SignalConnector(trainer) + connector._slurm_sigusr_handler_fn(None, None) + save_checkpoint.assert_not_called() + + @RunIf(skip_windows=True) @pytest.mark.parametrize("auto_requeue", [True, False]) @pytest.mark.parametrize("requeue_signal", [signal.SIGUSR1, signal.SIGUSR2, signal.SIGHUP] if not _IS_WINDOWS else [])