Skip to content

Commit 4677398

Browse files
committed
Add bundle permutation importance exports
1 parent f94f6a4 commit 4677398

File tree

4 files changed

+384
-0
lines changed

4 files changed

+384
-0
lines changed

src/geoluck/cli.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@
159159
"Prediction target to use. Available: income, life_expectancy, inequality, wealth."
160160
),
161161
)
162+
PERMUTATION_IMPORTANCE_OPTION = typer.Option(
163+
False,
164+
"--with-permutation-importance/--no-with-permutation-importance",
165+
help="Compute held-out latest-decade permutation importance. Heavier than standard exports.",
166+
)
162167

163168

164169
def _echo_train_level_result(result: object) -> None:
@@ -171,6 +176,7 @@ def _echo_train_level_result(result: object) -> None:
171176
typer.echo(f"feature_importance={result.feature_importance_path}")
172177
typer.echo(f"coefficients={result.coefficients_path}")
173178
typer.echo(f"contributions={result.contributions_path}")
179+
typer.echo(f"permutation_importance={result.permutation_importance_path}")
174180
typer.echo(f"feature_coverage={result.feature_coverage_path}")
175181
typer.echo(f"target_correlations={result.target_correlations_path}")
176182
typer.echo(f"prediction_rows={result.row_count}")
@@ -1268,6 +1274,7 @@ def train_level_models(
12681274
model_name: list[str] | None = MODEL_NAME_OPTION,
12691275
model_family: list[str] | None = MODEL_FAMILY_OPTION,
12701276
output_suffix: str | None = OUTPUT_SUFFIX_OPTION,
1277+
with_permutation_importance: bool = PERMUTATION_IMPORTANCE_OPTION,
12711278
) -> None:
12721279
"""Train baseline and ML level models by decade."""
12731280
result = export_level_model_outputs(
@@ -1277,6 +1284,7 @@ def train_level_models(
12771284
model_names=model_name,
12781285
model_families=model_family,
12791286
output_suffix=output_suffix,
1287+
with_permutation_importance=with_permutation_importance,
12801288
)
12811289
_echo_train_level_result(result)
12821290

@@ -1379,6 +1387,11 @@ def export_web_data() -> None:
13791387
typer.echo(f"bundle_summary={result.bundle_summary_path}")
13801388
if result.bundle_feature_effects_path is not None:
13811389
typer.echo(f"bundle_feature_effects={result.bundle_feature_effects_path}")
1390+
if result.bundle_permutation_importance_path is not None:
1391+
typer.echo(
1392+
"bundle_permutation_importance="
1393+
f"{result.bundle_permutation_importance_path}"
1394+
)
13821395
if result.bundle_country_contributions_index_path is not None:
13831396
typer.echo(
13841397
"bundle_country_contributions_index="

src/geoluck/models/train_levels.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,29 @@
164164
"abs_contribution",
165165
"contribution_rank",
166166
]
167+
PERMUTATION_IMPORTANCE_COLUMNS = [
168+
"decade",
169+
"target_name",
170+
"target_column",
171+
"spec_name",
172+
"model_name",
173+
"model_family",
174+
"feature_set",
175+
"feature_name",
176+
"feature_block",
177+
"repeat_count",
178+
"fold_count",
179+
"row_count",
180+
"delta_r2_mean",
181+
"delta_r2_std",
182+
"delta_rmse_mean",
183+
"delta_rmse_std",
184+
"delta_mae_mean",
185+
"delta_mae_std",
186+
"delta_spearman_mean",
187+
"delta_spearman_std",
188+
"importance_rank",
189+
]
167190
ROBUSTNESS_PREDICTION_COLUMNS = [
168191
*MODEL_OUTPUT_COLUMNS,
169192
"robustness_strategy",
@@ -253,6 +276,7 @@ class TrainLevelsResult:
253276
feature_importance_path: Path
254277
coefficients_path: Path
255278
contributions_path: Path
279+
permutation_importance_path: Path
256280
feature_coverage_path: Path
257281
target_correlations_path: Path
258282
row_count: int
@@ -1673,6 +1697,10 @@ def empty_contribution_frame() -> pd.DataFrame:
16731697
return pd.DataFrame(columns=CONTRIBUTION_COLUMNS)
16741698

16751699

1700+
def empty_permutation_importance_frame() -> pd.DataFrame:
1701+
return pd.DataFrame(columns=PERMUTATION_IMPORTANCE_COLUMNS)
1702+
1703+
16761704
def contribution_frame_for_rows(
16771705
rows_frame: pd.DataFrame,
16781706
*,
@@ -2025,6 +2053,135 @@ def build_latest_decade_country_contributions(
20252053
)
20262054

20272055

2056+
def build_latest_decade_permutation_importance(
2057+
frame: pd.DataFrame,
2058+
feature_sets: list[FeatureSetSpec],
2059+
model_specs: list[ModelSpec],
2060+
*,
2061+
target_spec: TargetSpec,
2062+
repeat_count: int = 2,
2063+
random_state: int = 42,
2064+
) -> pd.DataFrame:
2065+
diagnostic_decades = sorted(
2066+
int(decade)
2067+
for decade, decade_frame in frame.groupby("decade", sort=True)
2068+
if int(decade_frame[target_spec.target_column].notna().sum()) >= 8
2069+
)
2070+
if not diagnostic_decades:
2071+
return empty_permutation_importance_frame()
2072+
2073+
latest_decade = diagnostic_decades[-1]
2074+
valid = frame.loc[
2075+
(frame["decade"] == latest_decade) & frame[target_spec.target_column].notna()
2076+
].copy()
2077+
if len(valid) < 8:
2078+
return empty_permutation_importance_frame()
2079+
2080+
feature_lookup = {feature_set.feature_set: feature_set for feature_set in feature_sets}
2081+
splitter = KFold(n_splits=n_splits_for_rows(len(valid)), shuffle=True, random_state=42)
2082+
rng = np.random.default_rng(random_state)
2083+
rows: list[dict[str, object]] = []
2084+
2085+
for spec in model_specs:
2086+
if spec.is_baseline or spec.build_pipeline is None:
2087+
continue
2088+
feature_set = feature_lookup[spec.feature_set]
2089+
if feature_set.min_decade is not None and latest_decade < feature_set.min_decade:
2090+
continue
2091+
feature_columns = [*feature_set.numeric_columns, *feature_set.categorical_columns]
2092+
fold_records: list[dict[str, object]] = []
2093+
for train_idx, test_idx in splitter.split(valid):
2094+
train = valid.iloc[train_idx].copy()
2095+
test = valid.iloc[test_idx].copy()
2096+
model = spec.build_pipeline()
2097+
model = fit_pipeline(model, train[feature_columns], train[target_spec.target_column])
2098+
baseline_predictions = np.asarray(
2099+
model.predict(test[feature_columns]),
2100+
dtype="float64",
2101+
)
2102+
baseline_summary = metric_summary(
2103+
test[target_spec.target_column],
2104+
baseline_predictions,
2105+
)
2106+
for feature_name in feature_columns:
2107+
for _ in range(repeat_count):
2108+
permuted = test[feature_columns].copy()
2109+
shuffled = permuted[feature_name].to_numpy(copy=True)
2110+
rng.shuffle(shuffled)
2111+
permuted[feature_name] = shuffled
2112+
permuted_predictions = np.asarray(
2113+
model.predict(permuted),
2114+
dtype="float64",
2115+
)
2116+
permuted_summary = metric_summary(
2117+
test[target_spec.target_column],
2118+
permuted_predictions,
2119+
)
2120+
fold_records.append(
2121+
{
2122+
"feature_name": feature_name,
2123+
"feature_block": feature_block_name(feature_name),
2124+
"row_count": int(len(test)),
2125+
"delta_r2": float(
2126+
baseline_summary["r2"] - permuted_summary["r2"]
2127+
),
2128+
"delta_rmse": float(
2129+
permuted_summary["rmse"] - baseline_summary["rmse"]
2130+
),
2131+
"delta_mae": float(
2132+
permuted_summary["mae"] - baseline_summary["mae"]
2133+
),
2134+
"delta_spearman": float(
2135+
baseline_summary["spearman"]
2136+
- permuted_summary["spearman"]
2137+
),
2138+
}
2139+
)
2140+
if not fold_records:
2141+
continue
2142+
fold_frame = pd.DataFrame(fold_records)
2143+
aggregated = (
2144+
fold_frame.groupby(["feature_name", "feature_block"], as_index=False)
2145+
.agg(
2146+
repeat_count=("feature_name", "size"),
2147+
fold_count=("row_count", "count"),
2148+
row_count=("row_count", "sum"),
2149+
delta_r2_mean=("delta_r2", "mean"),
2150+
delta_r2_std=("delta_r2", "std"),
2151+
delta_rmse_mean=("delta_rmse", "mean"),
2152+
delta_rmse_std=("delta_rmse", "std"),
2153+
delta_mae_mean=("delta_mae", "mean"),
2154+
delta_mae_std=("delta_mae", "std"),
2155+
delta_spearman_mean=("delta_spearman", "mean"),
2156+
delta_spearman_std=("delta_spearman", "std"),
2157+
)
2158+
.fillna(0.0)
2159+
)
2160+
aggregated["decade"] = latest_decade
2161+
aggregated["target_name"] = target_spec.target_name
2162+
aggregated["target_column"] = target_spec.target_column
2163+
aggregated["spec_name"] = f"{spec.model_name}__{spec.feature_set}"
2164+
aggregated["model_name"] = spec.model_name
2165+
aggregated["model_family"] = spec.model_family
2166+
aggregated["feature_set"] = spec.feature_set
2167+
aggregated["importance_rank"] = (
2168+
aggregated["delta_r2_mean"]
2169+
.rank(method="dense", ascending=False)
2170+
.astype("int64")
2171+
)
2172+
rows.extend(
2173+
aggregated.loc[:, PERMUTATION_IMPORTANCE_COLUMNS].to_dict("records")
2174+
)
2175+
2176+
permutation_frame = pd.DataFrame(rows, columns=PERMUTATION_IMPORTANCE_COLUMNS)
2177+
if permutation_frame.empty:
2178+
return permutation_frame
2179+
return permutation_frame.sort_values(
2180+
["spec_name", "importance_rank", "feature_name"],
2181+
kind="stable",
2182+
).reset_index(drop=True)
2183+
2184+
20282185
def feature_block_name(feature_name: str) -> str:
20292186
if feature_name in BASE_FEATURE_COLUMNS_NUMERIC:
20302187
return "deep_geo"
@@ -2623,6 +2780,7 @@ def export_level_model_outputs(
26232780
model_families: Sequence[str] | None = None,
26242781
output_suffix: str | None = None,
26252782
allow_canonical_outputs: bool = False,
2783+
with_permutation_importance: bool = False,
26262784
) -> TrainLevelsResult:
26272785
resolved_paths = paths or get_paths()
26282786
budget = build_train_levels_budget(
@@ -2671,6 +2829,16 @@ def export_level_model_outputs(
26712829
selected_model_specs,
26722830
target_spec=target_spec,
26732831
)
2832+
permutation_importance_frame = (
2833+
build_latest_decade_permutation_importance(
2834+
training_frame,
2835+
selected_feature_sets,
2836+
selected_model_specs,
2837+
target_spec=target_spec,
2838+
)
2839+
if with_permutation_importance
2840+
else empty_permutation_importance_frame()
2841+
)
26742842
target_correlation_frame = build_target_correlation_frame(training_frame)
26752843
resolved_suffix = resolved_output_suffix(budget)
26762844

@@ -2702,6 +2870,10 @@ def export_level_model_outputs(
27022870
resolved_paths.data_final / "model_contributions.parquet",
27032871
resolved_suffix,
27042872
)
2873+
permutation_importance_path = output_path_for_budget(
2874+
resolved_paths.data_final / "model_permutation_importance.parquet",
2875+
resolved_suffix,
2876+
)
27052877
feature_coverage_path = output_path_for_budget(
27062878
resolved_paths.data_final / "feature_coverage.parquet",
27072879
resolved_suffix,
@@ -2717,6 +2889,7 @@ def export_level_model_outputs(
27172889
feature_importance_frame.to_parquet(feature_importance_path, index=False)
27182890
coefficients_frame.to_parquet(coefficients_path, index=False)
27192891
contributions_frame.to_parquet(contributions_path, index=False)
2892+
permutation_importance_frame.to_parquet(permutation_importance_path, index=False)
27202893
feature_coverage_frame.to_parquet(feature_coverage_path, index=False)
27212894
target_correlation_frame.to_parquet(target_correlations_path, index=False)
27222895
specs_path.write_text(
@@ -2734,6 +2907,7 @@ def export_level_model_outputs(
27342907
feature_importance_path=feature_importance_path,
27352908
coefficients_path=coefficients_path,
27362909
contributions_path=contributions_path,
2910+
permutation_importance_path=permutation_importance_path,
27372911
feature_coverage_path=feature_coverage_path,
27382912
target_correlations_path=target_correlations_path,
27392913
row_count=len(predictions_frame),

0 commit comments

Comments
 (0)