@@ -497,39 +497,13 @@ def get_lbfgs(
497497 Returns:
498498 optax.GradientTransformationExtraArgs:
499499 Configured L-BFGS optimizer.
500-
501- Example:
502- >>> from jaxkan.models.utils import get_lbfgs
503- >>> from jaxkan.models.KAN import KAN
504- >>> from flax import nnx
505- >>> import jax.numpy as jnp
506-
507- >>> # Create model
508- >>> model = KAN([2, 5, 1], 'spline', {'k': 3, 'G': 5}, 42)
509- >>>
510- >>> # Create L-BFGS optimizer
511- >>> optimizer_tx = get_lbfgs(memory_size=10)
512- >>> optimizer = nnx.Optimizer(model, optimizer_tx, wrt=nnx.Param)
513- >>>
514- >>> # Define loss function
515- >>> def loss_fn(model):
516- ... # Your loss computation here
517- ... return jnp.sum(model(x) ** 2)
518- >>>
519- >>> # Training step with L-BFGS
520- >>> def train_step(model, optimizer):
521- ... loss, grads = nnx.value_and_grad(loss_fn)(model)
522- ... # L-BFGS requires value and value_fn (model and grads are positional)
523- ... optimizer.update(model, grads, value=loss, value_fn=loss_fn)
524- ... return loss
525500 """
526- import optax
527501
528- optimizer = optax .lbfgs (
502+ tx = optax .lbfgs (
529503 learning_rate = learning_rate ,
530504 memory_size = memory_size ,
531505 scale_init_precond = scale_init_precond ,
532506 linesearch = linesearch
533507 )
534508
535- return optimizer
509+ return tx
0 commit comments