Skip to content

Commit 51a997a

Browse files
authored
Merge pull request #25 from big-o/develop
column prefix bug fix in input passthroughs
2 parents a972932 + b470a8d commit 51a997a

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

skdag/dag/_dag.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@
3333
__all__ = ["DAG", "DAGStep"]
3434

3535

36-
def _get_columns(X, dep, cols, is_root, axis=1):
36+
def _get_columns(X, dep, cols, is_root, dep_is_passthrough, axis=1):
3737
if callable(cols):
3838
# sklearn.compose.make_column_selector
3939
cols = cols(X)
4040

41-
if not is_root:
41+
if not is_root and not dep_is_passthrough:
4242
# The DAG will prepend output columns with the step name, so add this in to any
4343
# dep columns if missing. This helps keep user-provided deps readable.
4444
if isinstance(cols, str):
@@ -60,7 +60,14 @@ def _stack_inputs(dag, X, node):
6060
deps = {node.name: None} if node.is_root else node.deps
6161

6262
cols = [
63-
_get_columns(X[dep], dep, cols, node.is_root, axis=1)
63+
_get_columns(
64+
X[dep],
65+
dep,
66+
cols,
67+
node.is_root,
68+
_is_passthrough(dag.graph_.nodes[dep]["step"].estimator),
69+
axis=1,
70+
)
6471
for dep, cols in deps.items()
6572
]
6673

skdag/dag/tests/test_dag.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,20 @@ def test_pandas(X, y, steps):
407407
assert np.allclose(y_pred_np, y_pred_pd)
408408

409409

410-
def test_pandas_indexing():
410+
@pytest.mark.parametrize("input_passthrough", [False, True])
411+
def test_pandas_indexing(input_passthrough):
411412
X, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
412413

413414
passcols = ["age", "sex", "bmi", "bp"]
415+
416+
builder = DAGBuilder(infer_dataframe=True)
417+
if input_passthrough:
418+
builder.add_step("inp", "passthrough")
419+
414420
preprocessing = (
415-
DAGBuilder(infer_dataframe=True)
416-
.add_step("imp", SimpleImputer())
421+
builder.add_step(
422+
"imp", SimpleImputer(), deps=["inp"] if input_passthrough else None
423+
)
417424
.add_step("vitals", "passthrough", deps={"imp": passcols})
418425
.add_step(
419426
"blood",

0 commit comments

Comments
 (0)