Skip to content

Commit b5b98d2

Browse files
committed
Conditionally importing pomegranate
1 parent fa2580b commit b5b98d2

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

sdmetrics/single_table/bayesian_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import pandas as pd
7-
import torch
87

98
from sdmetrics.goal import Goal
109
from sdmetrics.single_table.base import SingleTableMetric
@@ -18,6 +17,7 @@ class BNLikelihoodBase(SingleTableMetric):
1817
@classmethod
1918
def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
2019
try:
20+
import torch
2121
from pomegranate.bayesian_network import BayesianNetwork
2222
except ImportError:
2323
raise ImportError(

tests/unit/single_table/test_bayesian_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def metadata():
4444

4545

4646
class TestBNLikelihood:
47-
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
47+
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None, 'torch': None})
4848
def test_compute_error(self):
4949
"""Test that an `ImportError` is raised."""
5050
# Setup
@@ -71,7 +71,7 @@ def test_compute(self, real_data, synthetic_data, metadata):
7171

7272

7373
class TestBNLogLikelihood:
74-
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
74+
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None, 'torch': None})
7575
def test_compute_error(self):
7676
"""Test that an `ImportError` is raised."""
7777
# Setup

0 commit comments

Comments
 (0)