Skip to content

Commit dde1de3

Browse files
committed
consistency change
1 parent e32950d commit dde1de3

File tree

1 file changed

+2
-28
lines changed

1 file changed

+2
-28
lines changed

jaxkan/models/utils.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -497,39 +497,13 @@ def get_lbfgs(
497497
Returns:
498498
optax.GradientTransformationExtraArgs:
499499
Configured L-BFGS optimizer.
500-
501-
Example:
502-
>>> from jaxkan.models.utils import get_lbfgs
503-
>>> from jaxkan.models.KAN import KAN
504-
>>> from flax import nnx
505-
>>> import jax.numpy as jnp
506-
507-
>>> # Create model
508-
>>> model = KAN([2, 5, 1], 'spline', {'k': 3, 'G': 5}, 42)
509-
>>>
510-
>>> # Create L-BFGS optimizer
511-
>>> optimizer_tx = get_lbfgs(memory_size=10)
512-
>>> optimizer = nnx.Optimizer(model, optimizer_tx, wrt=nnx.Param)
513-
>>>
514-
>>> # Define loss function
515-
>>> def loss_fn(model):
516-
... # Your loss computation here
517-
... return jnp.sum(model(x) ** 2)
518-
>>>
519-
>>> # Training step with L-BFGS
520-
>>> def train_step(model, optimizer):
521-
... loss, grads = nnx.value_and_grad(loss_fn)(model)
522-
... # L-BFGS requires value and value_fn (model and grads are positional)
523-
... optimizer.update(model, grads, value=loss, value_fn=loss_fn)
524-
... return loss
525500
"""
526-
import optax
527501

528-
optimizer = optax.lbfgs(
502+
tx = optax.lbfgs(
529503
learning_rate=learning_rate,
530504
memory_size=memory_size,
531505
scale_init_precond=scale_init_precond,
532506
linesearch=linesearch
533507
)
534508

535-
return optimizer
509+
return tx

0 commit comments

Comments
 (0)