1
1
import pytest
2
2
import warnings
3
3
import os
4
+ import pandas as pd
5
+ import tempfile
6
+ from pathlib import Path
7
+
4
8
from sasctl .pzmm import ModelParameters as mp
5
9
6
10
@@ -16,17 +20,7 @@ def bad_model():
16
20
@pytest .fixture
17
21
def train_data ():
18
22
"""Returns the Iris data set as (X, y)"""
19
-
20
- try :
21
- import pandas as pd
22
- except ImportError :
23
- pytest .skip ('Package `pandas` not found.' )
24
-
25
- try :
26
- from sklearn import datasets
27
- except ImportError :
28
- pytest .skip ('Package `sklearn` not found.' )
29
-
23
+ from sklearn import datasets
30
24
raw = datasets .load_iris ()
31
25
iris = pd .DataFrame (raw .data , columns = raw .feature_names )
32
26
iris = iris .join (pd .DataFrame (raw .target ))
@@ -39,12 +33,7 @@ def train_data():
39
33
@pytest .fixture
40
34
def sklearn_model (train_data ):
41
35
"""Returns a simple Scikit-Learn model"""
42
-
43
- try :
44
- from sklearn .linear_model import LogisticRegression
45
- except ImportError :
46
- pytest .skip ('Package `sklearn` not found.' )
47
-
36
+ from sklearn .linear_model import LogisticRegression
48
37
X , y = train_data
49
38
with warnings .catch_warnings ():
50
39
warnings .simplefilter ('ignore' )
@@ -57,13 +46,13 @@ def sklearn_model(train_data):
57
46
class TestSKLearnModel :
58
47
PROJECT_NAME = 'PZMM SKLearn Test Project'
59
48
MODEL_NAME = 'SKLearnModel'
60
- PATH = '.'
61
49
62
50
def test_generate_hyperparameters (self , sklearn_model ):
63
- mp . generate_hyperparameters ( sklearn_model , self . MODEL_NAME , self . PATH )
64
- assert os . path . exists ( './PythonModelHyperparameters.json' )
65
- os . remove ( './PythonModelHyperparameters. json' )
51
+ tmp_dir = tempfile . TemporaryDirectory ( )
52
+ mp . generate_hyperparameters ( sklearn_model , self . MODEL_NAME , Path ( tmp_dir . name ) )
53
+ assert Path ( Path ( tmp_dir . name ) / f './{ self . MODEL_NAME } Hyperparameters. json'). exists ( )
66
54
67
55
def test_bad_model_hyperparameters (self , bad_model ):
56
+ tmp_dir = tempfile .TemporaryDirectory ()
68
57
with pytest .raises (ValueError ):
69
- mp .generate_hyperparameters (bad_model , self .MODEL_NAME , self . PATH )
58
+ mp .generate_hyperparameters (bad_model , self .MODEL_NAME , Path ( tmp_dir . name ) )
0 commit comments