diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index c4f30509f..8635a7c49 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -1755,7 +1755,15 @@ def predict( logging.getLogger("lightning").setLevel(log_level_lighting) logging.getLogger("pytorch_lightning").setLevel(log_level_pytorch_lightning) - return predict_callback.result + trainer_predict_callback = next( + ( + callback + for callback in reversed(trainer.callbacks) + if isinstance(callback, PredictCallback) + ), + predict_callback, + ) + return trainer_predict_callback.result def predict_dependency( self, diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 8473dc2e7..3029e30d3 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -358,6 +358,19 @@ def test_prediction_with_dataloder(model, dataloaders_with_covariates, kwargs): model.predict(val_dataloader, fast_dev_run=True, **kwargs) +def test_prediction_with_dataloader_twice(model, dataloaders_with_covariates): + val_dataloader = dataloaders_with_covariates["val"] + first_prediction = model.predict( + val_dataloader, fast_dev_run=True, return_index=True + ) + second_prediction = model.predict( + val_dataloader, fast_dev_run=True, return_index=True + ) + + assert len(first_prediction.index) > 0 + assert len(second_prediction.index) > 0 + + def test_prediction_with_dataloder_raw(data_with_covariates, tmp_path): # tests correct concatenation of raw output test_data = data_with_covariates.copy()