Skip to content

Commit 222d32b

Browse files
committed
Conditionally importing torch at top
1 parent b5b98d2 commit 222d32b

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

sdmetrics/single_table/bayesian_network.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from sdmetrics.goal import Goal
99
from sdmetrics.single_table.base import SingleTableMetric
1010

11+
try:
12+
import torch
13+
except ModuleNotFoundError:
14+
torch = None
15+
1116
LOGGER = logging.getLogger(__name__)
1217

1318

@@ -17,8 +22,9 @@ class BNLikelihoodBase(SingleTableMetric):
1722
@classmethod
1823
def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
1924
try:
20-
import torch
2125
from pomegranate.bayesian_network import BayesianNetwork
26+
if torch is None:
27+
raise ImportError
2228
except ImportError:
2329
raise ImportError(
2430
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'

tests/unit/single_table/test_bayesian_network.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,22 @@ def metadata():
4444

4545

4646
class TestBNLikelihood:
47-
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None, 'torch': None})
47+
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
4848
def test_compute_error(self):
49-
"""Test that an `ImportError` is raised."""
49+
"""Test that an `ImportError` is raised when pomegranate isn't installed."""
50+
# Setup
51+
metric = BNLikelihood()
52+
53+
# Run and Assert
54+
expected_message = re.escape(
55+
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
56+
)
57+
with pytest.raises(ImportError, match=expected_message):
58+
metric.compute(Mock(), Mock())
59+
60+
@patch.dict('sys.modules', {'torch': None})
61+
def test_compute_error_torch_is_none(self):
62+
"""Test that an `ImportError` is raised when torch isn't installed."""
5063
# Setup
5164
metric = BNLikelihood()
5265

@@ -71,7 +84,7 @@ def test_compute(self, real_data, synthetic_data, metadata):
7184

7285

7386
class TestBNLogLikelihood:
74-
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None, 'torch': None})
87+
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
7588
def test_compute_error(self):
7689
"""Test that an `ImportError` is raised."""
7790
# Setup

0 commit comments

Comments
 (0)