|
2 | 2 | Test the DAG module. |
3 | 3 | """ |
4 | 4 | import re |
5 | | -import time |
6 | 5 |
|
7 | 6 | import numpy as np |
8 | 7 | import pandas as pd |
9 | 8 | import pytest |
10 | | -from skdag import DAG, DAGBuilder |
| 9 | +from skdag import DAGBuilder |
11 | 10 | from skdag.dag.tests.utils import FitParamT, Mult, NoFit, NoTrans, Transf |
12 | 11 | from sklearn import datasets |
13 | | -from sklearn import preprocessing |
14 | | -from sklearn.base import BaseEstimator, clone |
| 12 | +from sklearn.base import clone |
15 | 13 | from sklearn.compose import make_column_selector |
16 | 14 | from sklearn.decomposition import PCA |
17 | 15 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
@@ -294,9 +292,10 @@ def test_dag_draw(): |
294 | 292 | dag = ( |
295 | 293 | DAGBuilder() |
296 | 294 | .add_step("pca", pca) |
297 | | - .add_step("svc", svc, deps=["pca"]) |
298 | | - .add_step("rf", rf, deps=["pca"]) |
299 | | - .add_step("log", log, deps=["svc", "rf"]) |
| 295 | + .add_step("svc", svc, deps={"pca": slice(4)}) |
| 296 | + .add_step("rf1", rf, deps={"pca": [0, 1, 2]}) |
| 297 | + .add_step("rf2", rf, deps={"pca": make_column_selector(pattern="^pca.*")}) |
| 298 | + .add_step("log", log, deps=["svc", "rf1", "rf2"]) |
300 | 299 | .make_dag() |
301 | 300 | ) |
302 | 301 |
|
|
0 commit comments