diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 728d8b6b6ee43..be896c4b1f6a0 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -72,8 +72,10 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None: for logger in self.trainer.loggers: logger.finalize("finished") - hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir) - self.trainer.save_checkpoint(hpc_save_path) + if self.trainer._checkpoint_connector._ckpt_path != "last": + log.info(f"Detecting SLURM environment and ckpt_path=='last'. Disabling hpc ckpt") + hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir) + self.trainer.save_checkpoint(hpc_save_path) if self.trainer.is_global_zero: # find job id diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 8825db3727e86..aa7aa280eefa8 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -70,6 +70,20 @@ def training_step(self, batch, batch_idx): signal.signal(signal.SIGTERM, signal.SIG_DFL) +@RunIf(skip_windows=True) +@mock.patch("lightning.fabric.plugins.environments.slurm.SLURMEnvironment.detect", return_value=True) +@mock.patch("lightning.pytorch.trainer.Trainer.save_checkpoint", mock.MagicMock()) +def test_hpc_ckpt_last_not_created(_, save_checkpoint, tmp_path): + """Test that when ckpt_path='last' hpc_ckpt is not created""" + model = BoringModel() + trainer = Trainer(default_root_dir=tmp_path, max_steps=1, logger=False) + trainer.fit(model, ckpt_path="last") + + connector = _SignalConnector(trainer) + connector._slurm_sigusr_handler_fn(None, None) + save_checkpoint.assert_not_called() + + @RunIf(skip_windows=True) @pytest.mark.parametrize("auto_requeue", [True, False]) @pytest.mark.parametrize("requeue_signal", [signal.SIGUSR1, signal.SIGUSR2, signal.SIGHUP] if not _IS_WINDOWS else [])