Skip to content

Commit 1340f25

Browse files
bnaulclaude
andcommitted
PERF: vectorize _status aggregation in remove_interstitial_nodes
Instead of calling the _status callable per group via pandas' slow pure-Python dispatch, compute the result for all groups at once using fast groupby primitives (size, first, sum). Combined with the previous _first_non_null fix, Aleppo (78K edges) drops from ~508s to ~109s (4.7x). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ae34c35 commit 1340f25

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

neatnet/nodes.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,26 @@ def _status(x: pd.Series) -> str:
137137
return "changed"
138138

139139

140+
def _vectorized_status(status_col: pd.Series, labels) -> pd.Series:
141+
"""Vectorized equivalent of ``_status`` for groupby aggregation.
142+
143+
Instead of calling ``_status`` per group (slow Python dispatch),
144+
this computes the same result for all groups at once using fast
145+
groupby primitives (size, first, sum).
146+
"""
147+
g = status_col.groupby(labels)
148+
sizes = g.size()
149+
first_vals = g.first()
150+
new_counts = (status_col == "new").groupby(labels).sum()
151+
152+
result = pd.Series("changed", index=sizes.index)
153+
single = sizes == 1
154+
result.loc[single] = first_vals.loc[single]
155+
all_new_multi = (new_counts == sizes) & ~single
156+
result.loc[all_new_multi] = "new"
157+
return result
158+
159+
140160
def _first_non_null(x: pd.Series):
141161
"""Return first observation that is not missing, unless all are."""
142162
non_null = x[~x.isna()]
@@ -478,7 +498,23 @@ def merge_geometries(block: gpd.GeoSeries) -> shapely.LineString:
478498

479499
# Process non-spatial component
480500
data = gdf.drop(labels=gdf.geometry.name, axis=1)
481-
aggregated_data = data.groupby(by=labels).agg(aggfunc, **kwargs)
501+
if (
502+
isinstance(aggfunc, dict)
503+
and "_status" in aggfunc
504+
and aggfunc["_status"] is _status
505+
):
506+
# Vectorize _status separately to avoid the slow per-group callable path
507+
status_result = _vectorized_status(data["_status"], labels)
508+
rest_agg = {k: v for k, v in aggfunc.items() if k != "_status"}
509+
if rest_agg:
510+
aggregated_data = (
511+
data.drop(columns=["_status"]).groupby(by=labels).agg(rest_agg, **kwargs)
512+
)
513+
else:
514+
aggregated_data = pd.DataFrame(index=status_result.index)
515+
aggregated_data["_status"] = status_result
516+
else:
517+
aggregated_data = data.groupby(by=labels).agg(aggfunc, **kwargs)
482518
aggregated_data.columns = aggregated_data.columns.to_flat_index()
483519

484520
# Process spatial component

0 commit comments

Comments
 (0)