Skip to content

Commit e7ddae5

Browse files
committed
Update tests for model parameters to use temporary files and remove unneeded import checks
1 parent bbbf4e7 commit e7ddae5

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

tests/unit/test_model_parameters.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pytest
22
import warnings
33
import os
4+
import pandas as pd
5+
import tempfile
6+
from pathlib import Path
7+
48
from sasctl.pzmm import ModelParameters as mp
59

610

@@ -16,17 +20,7 @@ def bad_model():
1620
@pytest.fixture
1721
def train_data():
1822
"""Returns the Iris data set as (X, y)"""
19-
20-
try:
21-
import pandas as pd
22-
except ImportError:
23-
pytest.skip('Package `pandas` not found.')
24-
25-
try:
26-
from sklearn import datasets
27-
except ImportError:
28-
pytest.skip('Package `sklearn` not found.')
29-
23+
from sklearn import datasets
3024
raw = datasets.load_iris()
3125
iris = pd.DataFrame(raw.data, columns=raw.feature_names)
3226
iris = iris.join(pd.DataFrame(raw.target))
@@ -39,12 +33,7 @@ def train_data():
3933
@pytest.fixture
4034
def sklearn_model(train_data):
4135
"""Returns a simple Scikit-Learn model"""
42-
43-
try:
44-
from sklearn.linear_model import LogisticRegression
45-
except ImportError:
46-
pytest.skip('Package `sklearn` not found.')
47-
36+
from sklearn.linear_model import LogisticRegression
4837
X, y = train_data
4938
with warnings.catch_warnings():
5039
warnings.simplefilter('ignore')
@@ -57,13 +46,13 @@ def sklearn_model(train_data):
5746
class TestSKLearnModel:
5847
PROJECT_NAME = 'PZMM SKLearn Test Project'
5948
MODEL_NAME = 'SKLearnModel'
60-
PATH = '.'
6149

6250
def test_generate_hyperparameters(self, sklearn_model):
63-
mp.generate_hyperparameters(sklearn_model, self.MODEL_NAME, self.PATH)
64-
assert os.path.exists('./PythonModelHyperparameters.json')
65-
os.remove('./PythonModelHyperparameters.json')
51+
tmp_dir = tempfile.TemporaryDirectory()
52+
mp.generate_hyperparameters(sklearn_model, self.MODEL_NAME, Path(tmp_dir.name))
53+
assert Path(Path(tmp_dir.name) / f'./{self.MODEL_NAME}Hyperparameters.json').exists()
6654

6755
def test_bad_model_hyperparameters(self, bad_model):
56+
tmp_dir = tempfile.TemporaryDirectory()
6857
with pytest.raises(ValueError):
69-
mp.generate_hyperparameters(bad_model, self.MODEL_NAME, self.PATH)
58+
mp.generate_hyperparameters(bad_model, self.MODEL_NAME, Path(tmp_dir.name))

0 commit comments

Comments
 (0)