Skip to content

Commit 2a065dc

Browse files
committed
pass tests
1 parent d8b2eca commit 2a065dc

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

tests/test_core.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,8 @@ def setup_class(cls):
159159

160160

161161

162-
def setup_zephyr(self, producer_step_name):
162+
def setup_zephyr(self, step_num):
163163
zephyr = Zephyr()
164-
step_num = zephyr.producer_to_step_map[producer_step_name]
165164

166165
for i, (setters, getters) in enumerate(zephyr.step_order):
167166
if i < step_num:
@@ -173,7 +172,7 @@ def setup_zephyr(self, producer_step_name):
173172
return zephyr
174173

175174
def test_initialize_class(self):
176-
zephyr = self.setup_zephyr(0)
175+
zephyr = self.setup_zephyr(1)
177176

178177
def test_create_entityset(self):
179178
zephyr = self.setup_zephyr(1)
@@ -200,11 +199,10 @@ def test_generate_train_test_split(self):
200199
train_test_split = zephyr.get_train_test_split()
201200
assert train_test_split is not None
202201

203-
def setup_zephyr_with_base_split(self, producer_step_name):
202+
def setup_zephyr_with_base_split(self, step_num):
204203
zephyr = self.setup_zephyr(4)
205-
zephyr.set_train_test_split(**self.base_train_test_split())
206-
final_step_num = zephyr.producer_to_step_map[producer_step_name]
207-
for i in range(4, final_step_num):
204+
zephyr.set_train_test_split(*self.base_train_test_split())
205+
for i in range(5, step_num):
208206
setters, getters = zephyr.step_order[i]
209207
setter = setters[0]
210208
kwargs = self.kwargs[setter.__name__]

tests/test_entityset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,14 @@ def test_missing_time_indices(pidata_dfs):
210210
def test_default_create_pidata_entityset(pidata_dfs):
211211
es = create_pidata_entityset(pidata_dfs)
212212

213-
assert es.id == 'PI data'
213+
assert es.id == 'pidata'
214214
assert set(es.dataframe_dict.keys()) == set(
215215
['alarms', 'turbines', 'stoppages', 'work_orders', 'notifications', 'pidata'])
216216

217217

218218
def test_default_create_scada_entityset(scada_dfs):
219219
es = create_scada_entityset(scada_dfs)
220220

221-
assert es.id == 'SCADA data'
221+
assert es.id == 'scada'
222222
assert set(es.dataframe_dict.keys()) == set(
223223
['alarms', 'turbines', 'stoppages', 'work_orders', 'notifications', 'scada'])

tests/test_feature_engineering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import pytest
33

4-
from zephyr_ml import create_pidata_entityset, create_scada_entityset
4+
from zephyr_ml import _create_entityset
55
from zephyr_ml.feature_engineering import process_signals
66

77

@@ -122,12 +122,12 @@ def scada_dfs(base_dfs):
122122

123123
@pytest.fixture
124124
def pidata_es(pidata_dfs):
125-
return create_pidata_entityset(pidata_dfs)
125+
return _create_entityset(pidata_dfs, "pidata")
126126

127127

128128
@pytest.fixture
129129
def scada_es(scada_dfs):
130-
return create_scada_entityset(scada_dfs)
130+
return _create_entityset(scada_dfs, "scada")
131131

132132

133133
@pytest.fixture

zephyr_ml/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ def set_and_fit_pipeline(
332332
if y is None:
333333
y = self.y_train
334334

335+
print(X)
336+
print(y)
337+
335338
if visual:
336339
outputs_spec, visual_names = self._get_outputs_spec(False)
337340
else:
@@ -359,7 +362,10 @@ def predict(self, X=None, visual=False, **kwargs):
359362
else:
360363
outputs_spec = "default"
361364

365+
print(X)
366+
362367
outputs = self.pipeline.predict(X, output_=outputs_spec, **kwargs)
368+
print(outputs)
363369

364370
if visual and visual_names:
365371
prediction = outputs[0]

0 commit comments

Comments
 (0)