@@ -780,7 +780,6 @@ def topfarm_sgd_solve(
780780 boundary : Boundary ,
781781 min_spacing : float ,
782782 settings : SGDSettings | None = None ,
783- jit : bool = True ,
784783 progress : Literal ["none" , "print" , "bar" ] = "none" ,
785784 progress_callback : Callable [[SGDRecord ], None ] | None = None ,
786785 record : bool = False ,
@@ -810,15 +809,14 @@ def topfarm_sgd_solve(
810809 multi-polygon boundaries.
811810 min_spacing: Minimum inter-turbine distance.
812811 settings: SGD configuration. Uses defaults if None.
813- jit: Whether to JIT compile the gradient computations.
814812 progress: Progress rendering mode: none, print, or bar.
815813 progress_callback: Optional host callback receiving per-iteration progress.
816814 record: Whether to record and return structured optimization diagnostics.
817815 wind: Optional :class:`~pixwake.wind_resource.SampledWind` sampler.
818- When provided, :meth:`~pixwake.wind_resource.SampledWind.inc_rng` is
819- called outside JIT before each gradient step so that every iteration
820- uses a fresh, independent sample batch. ``objective_fn`` must then
821- accept ``(x, y, wind) -> scalar``.
816+ When provided, a fresh PRNG subkey is split from the carry key on
817+ every gradient step so that every iteration uses an independent
818+ sample batch. ``objective_fn`` must then accept
819+ ``(x, y, wind) -> scalar``.
822820
823821 Returns:
824822 Tuple of (optimized_x, optimized_y), optionally with per-iteration `SGDRecord`.
@@ -854,9 +852,13 @@ def _constraint_penalty(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
854852 x , y , min_spacing
855853 )
856854
857- jit_wrap = jax .jit if jit else (lambda f : f )
858- grad_obj_fn = jit_wrap (jax .grad (objective_fn , argnums = (0 , 1 )))
859- grad_con_fn = jit_wrap (jax .grad (_constraint_penalty , argnums = (0 , 1 )))
855+ def _pen_per_turbine (x : jnp .ndarray , y : jnp .ndarray ) -> jnp .ndarray :
856+ return _boundary_penalty_per_turbine (
857+ x , y , boundary_polys
858+ ) + _spacing_penalty_per_turbine (x , y , min_spacing )
859+
860+ grad_obj_fn = jax .jit (jax .grad (objective_fn , argnums = (0 , 1 )))
861+ grad_con_fn = jax .jit (jax .grad (_constraint_penalty , argnums = (0 , 1 )))
860862
861863 if progress not in {"none" , "print" , "bar" }:
862864 raise ValueError (
@@ -865,74 +867,89 @@ def _constraint_penalty(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
865867 progress_enabled = progress != "none" or progress_callback is not None
866868 emit_progress = _make_emit_progress (progress , progress_callback , settings .max_iter )
867869
868- x , y = init_x , init_y
869- state = _init_sgd_state (settings )
870- current_wind = wind
871870 record_list : list [SGDRecord ] = []
872871
872+ def _side_effects (rec : SGDRecord ) -> None :
873+ if record :
874+ record_list .append (rec )
875+ if progress_enabled :
876+ emit_progress (rec )
877+
873878 def _compute_layout_change (
874879 x : jnp .ndarray , y : jnp .ndarray , prev_x : jnp .ndarray , prev_y : jnp .ndarray
875880 ) -> jnp .ndarray :
876881 return jnp .max (jnp .abs (x - prev_x )) + jnp .max (jnp .abs (y - prev_y ))
877882
878- def _is_done (change : jnp .ndarray , it : int | jnp .ndarray ) -> bool :
879- return float (change ) <= settings .tol or int (it ) >= settings .max_iter
883+ def _is_done (change : jnp .ndarray , it : jnp .ndarray ) -> jnp .ndarray :
884+ return (change <= settings .tol ) | (it >= settings .max_iter )
885+
886+ has_wind = wind is not None
887+ init_key = wind ._key if wind is not None else jax .random .PRNGKey (0 )
880888
881- done = _is_done (_compute_layout_change (x , y , x - 1.0 , y - 1.0 ), state .iteration )
882- while not done :
883- if current_wind is not None :
884- current_wind = current_wind .inc_rng ()
885- grad_obj_x , grad_obj_y = grad_obj_fn (x , y , current_wind )
889+ def body_fn (
890+ carry : tuple [jnp .ndarray , jnp .ndarray , SGDState , jnp .ndarray , jnp .ndarray ],
891+ ) -> tuple [jnp .ndarray , jnp .ndarray , SGDState , jnp .ndarray , jnp .ndarray ]:
892+ x , y , state , rng_key , _prev_change = carry
893+
894+ if has_wind :
895+ rng_key , subkey = jax .random .split (rng_key )
896+ assert wind is not None
897+ grad_obj_x , grad_obj_y = grad_obj_fn (x , y , wind .with_key (subkey ))
886898 else :
887899 grad_obj_x , grad_obj_y = grad_obj_fn (x , y )
888- grad_con_x , grad_con_y = grad_con_fn (x , y )
889900
901+ grad_con_x , grad_con_y = grad_con_fn (x , y )
890902 d_x , d_y = _project_gradient (grad_obj_x , grad_obj_y , grad_con_x , grad_con_y )
891903 delta_x , delta_y , new_state = _sgd_step (state , d_x , d_y , settings )
892904
893- # Feasibility-preserving step: per-turbine bisection.
894- # Each turbine keeps its own scale s_i, initialised to 1. On each
895- # bisection round only the turbines that are still infeasible have
896- # their scale halved, so a turbine far from any constraint is never
897- # held back by one that is close to the boundary.
898- # _project_gradient() already makes the gradient direction tangential
899- # to any active constraint at the current position.
900- step_x , step_y = delta_x , delta_y
901- s = jnp .ones (x .shape [0 ])
902- x_try = x - s * step_x
903- y_try = y - s * step_y
904- pen_i = _boundary_penalty_per_turbine (
905- x_try , y_try , boundary_polys
906- ) + _spacing_penalty_per_turbine (x_try , y_try , min_spacing )
907- while (
908- float (jnp .max (pen_i )) > settings .bisect_tol
909- and float (jnp .min (s )) > settings .bisect_s_min
910- ):
911- s = jnp .where (pen_i > settings .bisect_tol , s / 2.0 , s )
912- x_try = x - s * step_x
913- y_try = y - s * step_y
914- pen_i = _boundary_penalty_per_turbine (
915- x_try , y_try , boundary_polys
916- ) + _spacing_penalty_per_turbine (x_try , y_try , min_spacing )
917-
918- x_new , y_new = x_try , y_try
919-
920- iter_rec = SGDRecord (
921- x = x_new ,
922- y = y_new ,
923- iteration = new_state .iteration ,
924- penalty = _constraint_penalty (x_new , y_new ),
925- change = (change := _compute_layout_change (x_new , y_new , x , y )),
926- learning_rate = new_state .learning_rate ,
927- converged = float (change <= settings .tol ),
928- is_final = (done := _is_done (change , new_state .iteration )),
905+ # Feasibility-preserving per-turbine bisection inside compiled loop.
906+ def bisect_body (
907+ bc : tuple [jnp .ndarray , jnp .ndarray , jnp .ndarray , jnp .ndarray ],
908+ ) -> tuple [jnp .ndarray , jnp .ndarray , jnp .ndarray , jnp .ndarray ]:
909+ s , xt , yt , pen_i = bc
910+ s = jnp .where (pen_i > settings .bisect_tol , s / 2 , s )
911+ xt = x - s * delta_x
912+ yt = y - s * delta_y
913+ return s , xt , yt , _pen_per_turbine (xt , yt )
914+
915+ s0 = jnp .ones (x .shape [0 ])
916+ x0 = x - delta_x
917+ y0 = y - delta_y
918+ _ , x_new , y_new , _ = while_loop (
919+ lambda bc : (jnp .max (bc [3 ]) > settings .bisect_tol )
920+ & (jnp .min (bc [0 ]) > settings .bisect_s_min ),
921+ bisect_body ,
922+ (s0 , x0 , y0 , _pen_per_turbine (x0 , y0 )),
929923 )
930- if record :
931- record_list .append (iter_rec )
932- if progress_enabled :
933- emit_progress (iter_rec )
934924
935- x , y , state = x_new , y_new , new_state
925+ change = _compute_layout_change (x_new , y_new , x , y )
926+ is_final = _is_done (change , jnp .asarray (new_state .iteration ))
927+ if record or progress_enabled :
928+ iter_rec = SGDRecord (
929+ x = x_new ,
930+ y = y_new ,
931+ iteration = new_state .iteration ,
932+ penalty = _constraint_penalty (x_new , y_new ),
933+ change = change ,
934+ learning_rate = new_state .learning_rate ,
935+ converged = change <= settings .tol ,
936+ is_final = is_final ,
937+ )
938+ jax .debug .callback (_side_effects , iter_rec )
939+ return x_new , y_new , new_state , rng_key , change
940+
941+ def cond_fn (
942+ carry : tuple [jnp .ndarray , jnp .ndarray , SGDState , jnp .ndarray , jnp .ndarray ],
943+ ) -> jnp .ndarray :
944+ _ , _ , state , _ , change = carry
945+ return ~ _is_done (change , jnp .asarray (state .iteration ))
946+
947+ init_state = _init_sgd_state (settings )
948+ # Initial change > tol ensures the first iteration always runs.
949+ init_change = jnp .asarray (settings .tol + 1.0 )
950+ x , y , _ , _ , _ = while_loop (
951+ cond_fn , body_fn , (init_x , init_y , init_state , init_key , init_change )
952+ )
936953
937954 if record :
938955 return x , y , record_list
@@ -1150,13 +1167,13 @@ def create_layout_optimizer(
11501167 wind: Wind conditions — a :class:`~pixwake.wind_resource.WeibullWindRose`,
11511168 :class:`~pixwake.wind_resource.TimeSeriesWind`, or
11521169 :class:`~pixwake.wind_resource.SampledWind`. When a
1153- :class:`~pixwake.wind_resource.SampledWind` is provided the solver
1154- draws a fresh sample batch on every gradient step via
1155- :meth:`~pixwake.wind_resource.SampledWind.inc_rng` .
1170+ :class:`~pixwake.wind_resource.SampledWind` is provided a fresh PRNG
1171+ subkey is split on every gradient step, drawing an independent sample
1172+ batch each iteration .
11561173 settings: SGD configuration.
11571174 **kwargs: Additional keyword arguments forwarded to
11581175 :func:`topfarm_sgd_solve` (e.g. ``progress``, ``progress_callback``,
1159- ``record``, ``jit`` ).
1176+ ``record``).
11601177
11611178 Returns:
11621179 Function ``(init_x, init_y) -> (opt_x, opt_y)``. When ``record=True``
0 commit comments