Skip to content

Commit f4397be

Browse files
authored
[MNT] Isolate cpflow package, towards fixing readthedocs build (#1775)
The readthedocs build is failing due to attempted imports of the `cpflow` package. This PR fixes this by isolating the `cpflow` soft dependency imports. For this, the `_safe_import` utility from `sktime` is copied over - mid-term, we may want to move this to `scikit-base`.
1 parent 0f08ecc commit f4397be

File tree

4 files changed

+106
-1
lines changed

4 files changed

+106
-1
lines changed

pytorch_forecasting/metrics/_mqf2_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import List, Optional, Tuple
44

5-
from cpflows.flows import DeepConvexFlow, SequentialFlow
65
import torch
76
from torch.distributions import (
87
AffineTransform,
@@ -12,6 +11,11 @@
1211
)
1312
import torch.nn.functional as F
1413

14+
from pytorch_forecasting.utils._dependencies import _safe_import
15+
16+
DeepConvexFlow = _safe_import("cpflows.flows.DeepConvexFlow")
17+
SequentialFlow = _safe_import("cpflows.flows.SequentialFlow")
18+
1519

1620
class DeepConvexNet(DeepConvexFlow):
1721
r"""
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Utilities for managing dependencies."""
2+
3+
from pytorch_forecasting.utils._dependencies._dependencies import (
4+
_check_matplotlib,
5+
_get_installed_packages,
6+
)
7+
from pytorch_forecasting.utils._dependencies._safe_import import _safe_import
8+
9+
__all__ = [
10+
"_get_installed_packages",
11+
"_check_matplotlib",
12+
"_safe_import",
13+
]
File renamed without changes.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Import a module/class, return a Mock object if import fails.
2+
3+
Copied from sktime/skbase.
4+
"""
5+
6+
import importlib
7+
from unittest.mock import MagicMock
8+
9+
from pytorch_forecasting.utils._dependencies import _get_installed_packages
10+
11+
12+
def _safe_import(import_path, pkg_name=None):
13+
"""Import a module/class, return a Mock object if import fails.
14+
15+
The function supports importing both top-level modules and nested attributes:
16+
17+
- Top-level module: "torch" -> imports torch
18+
- Nested module: "torch.nn" -> imports torch.nn
19+
- Class/function: "torch.nn.Linear" -> imports Linear class from torch.nn
20+
21+
Parameters
22+
----------
23+
import_path : str
24+
The path to the module/class to import. Can be:
25+
26+
- Single module: "torch"
27+
- Nested module: "torch.nn"
28+
- Class/attribute: "torch.nn.ReLU"
29+
30+
Note: The dots in the path determine the import behavior:
31+
32+
- No dots: Imports as a single module
33+
- One dot: Imports as a submodule
34+
- Multiple dots: Last part is treated as an attribute to import
35+
36+
pkg_name : str, default=None
37+
The name of the package to check for installation. This is useful when
38+
the import name differs from the package name, for example:
39+
40+
- import: "sklearn" -> pkg_name="scikit-learn"
41+
- import: "cv2" -> pkg_name="opencv-python"
42+
43+
If None, uses the first part of import_path before the dot.
44+
45+
Returns
46+
-------
47+
object
48+
One of the following:
49+
50+
- The imported module if import_path has no dots
51+
- The imported submodule if import_path has one dot
52+
- The imported class/function if import_path has multiple dots
53+
- A MagicMock object that returns an installation message if the
54+
package is not found
55+
56+
Examples
57+
--------
58+
>>> # Import a top-level module
59+
>>> torch = safe_import("torch")
60+
61+
>>> # Import a submodule
62+
>>> nn = safe_import("torch.nn")
63+
64+
>>> # Import a specific class
65+
>>> Linear = safe_import("torch.nn.Linear")
66+
67+
>>> # Import with different package name
68+
>>> cv2 = safe_import("cv2", pkg_name="opencv-python")
69+
"""
70+
if pkg_name is None:
71+
path_list = import_path.split(".")
72+
pkg_name = path_list[0]
73+
74+
if pkg_name in _get_installed_packages():
75+
try:
76+
if len(path_list) == 1:
77+
return importlib.import_module(pkg_name)
78+
module_name, attr_name = import_path.rsplit(".", 1)
79+
module = importlib.import_module(module_name)
80+
return getattr(module, attr_name)
81+
except (ImportError, AttributeError):
82+
return importlib.import_module(import_path)
83+
else:
84+
mock_obj = MagicMock()
85+
mock_obj.__call__ = MagicMock(
86+
return_value=f"Please install {pkg_name} to use this functionality."
87+
)
88+
return mock_obj

0 commit comments

Comments
 (0)