Skip to content

Commit de1d28b

Browse files
sgd hot loop compiled
1 parent 390acff commit de1d28b

2 files changed

Lines changed: 489 additions & 65 deletions

File tree

src/pixwake/optim/sgd.py

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,6 @@ def topfarm_sgd_solve(
780780
boundary: Boundary,
781781
min_spacing: float,
782782
settings: SGDSettings | None = None,
783-
jit: bool = True,
784783
progress: Literal["none", "print", "bar"] = "none",
785784
progress_callback: Callable[[SGDRecord], None] | None = None,
786785
record: bool = False,
@@ -810,15 +809,14 @@ def topfarm_sgd_solve(
810809
multi-polygon boundaries.
811810
min_spacing: Minimum inter-turbine distance.
812811
settings: SGD configuration. Uses defaults if None.
813-
jit: Whether to JIT compile the gradient computations.
814812
progress: Progress rendering mode: none, print, or bar.
815813
progress_callback: Optional host callback receiving per-iteration progress.
816814
record: Whether to record and return structured optimization diagnostics.
817815
wind: Optional :class:`~pixwake.wind_resource.SampledWind` sampler.
818-
When provided, :meth:`~pixwake.wind_resource.SampledWind.inc_rng` is
819-
called outside JIT before each gradient step so that every iteration
820-
uses a fresh, independent sample batch. ``objective_fn`` must then
821-
accept ``(x, y, wind) -> scalar``.
816+
When provided, a fresh PRNG subkey is split from the carry key on
817+
every gradient step so that every iteration uses an independent
818+
sample batch. ``objective_fn`` must then accept
819+
``(x, y, wind) -> scalar``.
822820
823821
Returns:
824822
Tuple of (optimized_x, optimized_y), optionally with per-iteration `SGDRecord`.
@@ -854,9 +852,13 @@ def _constraint_penalty(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
854852
x, y, min_spacing
855853
)
856854

857-
jit_wrap = jax.jit if jit else (lambda f: f)
858-
grad_obj_fn = jit_wrap(jax.grad(objective_fn, argnums=(0, 1)))
859-
grad_con_fn = jit_wrap(jax.grad(_constraint_penalty, argnums=(0, 1)))
855+
def _pen_per_turbine(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
856+
return _boundary_penalty_per_turbine(
857+
x, y, boundary_polys
858+
) + _spacing_penalty_per_turbine(x, y, min_spacing)
859+
860+
grad_obj_fn = jax.jit(jax.grad(objective_fn, argnums=(0, 1)))
861+
grad_con_fn = jax.jit(jax.grad(_constraint_penalty, argnums=(0, 1)))
860862

861863
if progress not in {"none", "print", "bar"}:
862864
raise ValueError(
@@ -865,74 +867,89 @@ def _constraint_penalty(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
865867
progress_enabled = progress != "none" or progress_callback is not None
866868
emit_progress = _make_emit_progress(progress, progress_callback, settings.max_iter)
867869

868-
x, y = init_x, init_y
869-
state = _init_sgd_state(settings)
870-
current_wind = wind
871870
record_list: list[SGDRecord] = []
872871

872+
def _side_effects(rec: SGDRecord) -> None:
873+
if record:
874+
record_list.append(rec)
875+
if progress_enabled:
876+
emit_progress(rec)
877+
873878
def _compute_layout_change(
874879
x: jnp.ndarray, y: jnp.ndarray, prev_x: jnp.ndarray, prev_y: jnp.ndarray
875880
) -> jnp.ndarray:
876881
return jnp.max(jnp.abs(x - prev_x)) + jnp.max(jnp.abs(y - prev_y))
877882

878-
def _is_done(change: jnp.ndarray, it: int | jnp.ndarray) -> bool:
879-
return float(change) <= settings.tol or int(it) >= settings.max_iter
883+
def _is_done(change: jnp.ndarray, it: jnp.ndarray) -> jnp.ndarray:
884+
return (change <= settings.tol) | (it >= settings.max_iter)
885+
886+
has_wind = wind is not None
887+
init_key = wind._key if wind is not None else jax.random.PRNGKey(0)
880888

881-
done = _is_done(_compute_layout_change(x, y, x - 1.0, y - 1.0), state.iteration)
882-
while not done:
883-
if current_wind is not None:
884-
current_wind = current_wind.inc_rng()
885-
grad_obj_x, grad_obj_y = grad_obj_fn(x, y, current_wind)
889+
def body_fn(
890+
carry: tuple[jnp.ndarray, jnp.ndarray, SGDState, jnp.ndarray, jnp.ndarray],
891+
) -> tuple[jnp.ndarray, jnp.ndarray, SGDState, jnp.ndarray, jnp.ndarray]:
892+
x, y, state, rng_key, _prev_change = carry
893+
894+
if has_wind:
895+
rng_key, subkey = jax.random.split(rng_key)
896+
assert wind is not None
897+
grad_obj_x, grad_obj_y = grad_obj_fn(x, y, wind.with_key(subkey))
886898
else:
887899
grad_obj_x, grad_obj_y = grad_obj_fn(x, y)
888-
grad_con_x, grad_con_y = grad_con_fn(x, y)
889900

901+
grad_con_x, grad_con_y = grad_con_fn(x, y)
890902
d_x, d_y = _project_gradient(grad_obj_x, grad_obj_y, grad_con_x, grad_con_y)
891903
delta_x, delta_y, new_state = _sgd_step(state, d_x, d_y, settings)
892904

893-
# Feasibility-preserving step: per-turbine bisection.
894-
# Each turbine keeps its own scale s_i, initialised to 1. On each
895-
# bisection round only the turbines that are still infeasible have
896-
# their scale halved, so a turbine far from any constraint is never
897-
# held back by one that is close to the boundary.
898-
# _project_gradient() already makes the gradient direction tangential
899-
# to any active constraint at the current position.
900-
step_x, step_y = delta_x, delta_y
901-
s = jnp.ones(x.shape[0])
902-
x_try = x - s * step_x
903-
y_try = y - s * step_y
904-
pen_i = _boundary_penalty_per_turbine(
905-
x_try, y_try, boundary_polys
906-
) + _spacing_penalty_per_turbine(x_try, y_try, min_spacing)
907-
while (
908-
float(jnp.max(pen_i)) > settings.bisect_tol
909-
and float(jnp.min(s)) > settings.bisect_s_min
910-
):
911-
s = jnp.where(pen_i > settings.bisect_tol, s / 2.0, s)
912-
x_try = x - s * step_x
913-
y_try = y - s * step_y
914-
pen_i = _boundary_penalty_per_turbine(
915-
x_try, y_try, boundary_polys
916-
) + _spacing_penalty_per_turbine(x_try, y_try, min_spacing)
917-
918-
x_new, y_new = x_try, y_try
919-
920-
iter_rec = SGDRecord(
921-
x=x_new,
922-
y=y_new,
923-
iteration=new_state.iteration,
924-
penalty=_constraint_penalty(x_new, y_new),
925-
change=(change := _compute_layout_change(x_new, y_new, x, y)),
926-
learning_rate=new_state.learning_rate,
927-
converged=float(change <= settings.tol),
928-
is_final=(done := _is_done(change, new_state.iteration)),
905+
# Feasibility-preserving per-turbine bisection inside compiled loop.
906+
def bisect_body(
907+
bc: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
908+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
909+
s, xt, yt, pen_i = bc
910+
s = jnp.where(pen_i > settings.bisect_tol, s / 2, s)
911+
xt = x - s * delta_x
912+
yt = y - s * delta_y
913+
return s, xt, yt, _pen_per_turbine(xt, yt)
914+
915+
s0 = jnp.ones(x.shape[0])
916+
x0 = x - delta_x
917+
y0 = y - delta_y
918+
_, x_new, y_new, _ = while_loop(
919+
lambda bc: (jnp.max(bc[3]) > settings.bisect_tol)
920+
& (jnp.min(bc[0]) > settings.bisect_s_min),
921+
bisect_body,
922+
(s0, x0, y0, _pen_per_turbine(x0, y0)),
929923
)
930-
if record:
931-
record_list.append(iter_rec)
932-
if progress_enabled:
933-
emit_progress(iter_rec)
934924

935-
x, y, state = x_new, y_new, new_state
925+
change = _compute_layout_change(x_new, y_new, x, y)
926+
is_final = _is_done(change, jnp.asarray(new_state.iteration))
927+
if record or progress_enabled:
928+
iter_rec = SGDRecord(
929+
x=x_new,
930+
y=y_new,
931+
iteration=new_state.iteration,
932+
penalty=_constraint_penalty(x_new, y_new),
933+
change=change,
934+
learning_rate=new_state.learning_rate,
935+
converged=change <= settings.tol,
936+
is_final=is_final,
937+
)
938+
jax.debug.callback(_side_effects, iter_rec)
939+
return x_new, y_new, new_state, rng_key, change
940+
941+
def cond_fn(
942+
carry: tuple[jnp.ndarray, jnp.ndarray, SGDState, jnp.ndarray, jnp.ndarray],
943+
) -> jnp.ndarray:
944+
_, _, state, _, change = carry
945+
return ~_is_done(change, jnp.asarray(state.iteration))
946+
947+
init_state = _init_sgd_state(settings)
948+
# Initial change > tol ensures the first iteration always runs.
949+
init_change = jnp.asarray(settings.tol + 1.0)
950+
x, y, _, _, _ = while_loop(
951+
cond_fn, body_fn, (init_x, init_y, init_state, init_key, init_change)
952+
)
936953

937954
if record:
938955
return x, y, record_list
@@ -1150,13 +1167,13 @@ def create_layout_optimizer(
11501167
wind: Wind conditions — a :class:`~pixwake.wind_resource.WeibullWindRose`,
11511168
:class:`~pixwake.wind_resource.TimeSeriesWind`, or
11521169
:class:`~pixwake.wind_resource.SampledWind`. When a
1153-
:class:`~pixwake.wind_resource.SampledWind` is provided the solver
1154-
draws a fresh sample batch on every gradient step via
1155-
:meth:`~pixwake.wind_resource.SampledWind.inc_rng`.
1170+
:class:`~pixwake.wind_resource.SampledWind` is provided a fresh PRNG
1171+
subkey is split on every gradient step, drawing an independent sample
1172+
batch each iteration.
11561173
settings: SGD configuration.
11571174
**kwargs: Additional keyword arguments forwarded to
11581175
:func:`topfarm_sgd_solve` (e.g. ``progress``, ``progress_callback``,
1159-
``record``, ``jit``).
1176+
``record``).
11601177
11611178
Returns:
11621179
Function ``(init_x, init_y) -> (opt_x, opt_y)``. When ``record=True``

0 commit comments

Comments
 (0)