Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pytorch_forecasting/tests/test_all_v2/test_all_estimators_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pathlib import Path
import shutil

import pytest
from skbase.utils.dependencies import _check_soft_dependencies
import torch

from pytorch_forecasting.tests.test_all_estimators import (
Expand All @@ -18,6 +20,14 @@
ONLY_CHANGED_MODULES = False


def _skip_if_missing_dependencies(obj):
"""Helper to skip test if soft dependencies defined in tags are missing."""
deps = obj.get_class_tag("python_dependencies", None)
if deps is not None:
if not _check_soft_dependencies(deps, severity="none"):
pytest.skip(f"Skipping test due to missing soft dependencies: {deps}")


class TestAllPtForecastersV2(EstimatorPackageConfig, EstimatorFixtureGenerator):
"""Generic tests for all objects in the mini package."""

Expand All @@ -35,6 +45,7 @@ def test_integration(
trainer_kwargs,
tmp_path,
):
_skip_if_missing_dependencies(object_pkg)
pkg, test_data, dm_cfg = _setup_pkg_and_data(
object_pkg, trainer_kwargs, tmp_path
)
Expand All @@ -45,6 +56,7 @@ def test_integration(

def test_checkpointing(self, object_pkg, trainer_kwargs, tmp_path):
"""Test that the package can save a checkpoint and reload from it."""
_skip_if_missing_dependencies(object_pkg)
pkg, test_data, _ = _setup_pkg_and_data(object_pkg, trainer_kwargs, tmp_path)

ckpt_dir = Path(tmp_path) / "checkpoints"
Expand Down Expand Up @@ -73,6 +85,7 @@ def test_checkpointing(self, object_pkg, trainer_kwargs, tmp_path):

def test_predict_modes(self, object_pkg, trainer_kwargs, tmp_path):
"""Test different prediction modes and return_info."""
_skip_if_missing_dependencies(object_pkg)
pkg, test_data, _ = _setup_pkg_and_data(object_pkg, trainer_kwargs, tmp_path)

pkg.fit(test_data["train"], save_ckpt=False)
Expand Down Expand Up @@ -116,6 +129,7 @@ def test_predict_modes(self, object_pkg, trainer_kwargs, tmp_path):
def test_pkg_linkage(self, object_pkg, object_class):
"""Test that the package is linked correctly."""
# check name method
_skip_if_missing_dependencies(object_pkg)
msg = (
f"Package {object_pkg}.name() does not match class "
f"name {object_class.__name__}. "
Expand Down
Loading