From 4965d66f4f62a74a698f1a14841be300367b3e18 Mon Sep 17 00:00:00 2001 From: CedricHwong Date: Fri, 26 Dec 2025 05:36:28 +0000 Subject: [PATCH 1/2] feat: add SOAP optimizer support --- docs/examples/config.rst | 3 + .../config/test_build_optimizer_soap.py | 40 ++ verl/optimizers/__init__.py | 17 + verl/optimizers/soap.py | 479 ++++++++++++++++++ verl/trainer/config/optim/fsdp.yaml | 11 + verl/trainer/config/optim/soap.yaml | 72 +++ verl/workers/config/optimizer.py | 2 +- 7 files changed, 623 insertions(+), 1 deletion(-) create mode 100644 tests/workers/config/test_build_optimizer_soap.py create mode 100644 verl/optimizers/__init__.py create mode 100644 verl/optimizers/soap.py create mode 100644 verl/trainer/config/optim/soap.yaml diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 906270aa04e..14847afd90c 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -680,6 +680,9 @@ Optim - ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases. - ``override_optimizer_config``: Dictionary of additional optimizer-specific keyword arguments. For example, to use ``torchao.optim``'s ``_AdamW`` with BF16 stochastic rounding: ``{"bf16_stochastic_round": true}`` +- SOAP example (paper defaults): set ``optimizer_impl: verl.optimizers.soap``, ``optimizer: SOAP``, ``lr: 3e-3``, and + ``betas: [0.95, 0.95]``; pass SOAP-specific arguments (e.g., ``precondition_frequency``, ``max_precond_dim``) via + ``override_optimizer_config``. See ``trainer/config/optim/soap.yaml`` for a full example. Model ~~~~~~~~~~~~ diff --git a/tests/workers/config/test_build_optimizer_soap.py b/tests/workers/config/test_build_optimizer_soap.py new file mode 100644 index 00000000000..0662389230d --- /dev/null +++ b/tests/workers/config/test_build_optimizer_soap.py @@ -0,0 +1,40 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from verl.workers.config.optimizer import FSDPOptimizerConfig, build_optimizer + + +def test_build_optimizer_with_soap(): + model = torch.nn.Linear(2, 2, bias=False) + config = FSDPOptimizerConfig( + lr=3e-3, + betas=(0.95, 0.95), + optimizer="SOAP", + optimizer_impl="verl.optimizers.soap", + override_optimizer_config={ + "precondition_frequency": 2, + "max_precond_dim": 8, + "merge_dims": False, + "precondition_1d": True, + }, + ) + + optimizer = build_optimizer(model.parameters(), config) + assert optimizer.__class__.__name__ == "SOAP" + + loss = model(torch.randn(4, 2)).sum() + loss.backward() + optimizer.step() diff --git a/verl/optimizers/__init__.py b/verl/optimizers/__init__.py new file mode 100644 index 00000000000..a0eb60966e1 --- /dev/null +++ b/verl/optimizers/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .soap import SOAP + +__all__ = ["SOAP"] diff --git a/verl/optimizers/soap.py b/verl/optimizers/soap.py new file mode 100644 index 00000000000..2bfee6625fb --- /dev/null +++ b/verl/optimizers/soap.py @@ -0,0 +1,479 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT +# +# This file is adapted from the SOAP optimizer implementation: +# https://github.com/nikhilvyas/SOAP (soap.py) +# +# MIT License +# +# Copyright (c) 2024 Nikhil Vyas +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from itertools import chain + +import torch +import torch.optim as optim + + +# Parts of the code are modifications of Pytorch's AdamW optimizer +# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py +class SOAP(optim.Optimizer): + """ + Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.003): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): + Adam's betas parameters (b1, b2). + shampoo_beta (`float`, *optional*, defaults to -1): + If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead + of betas[1]. + eps (`float`, *optional*, defaults to 1e-08): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. + precondition_frequency (`int`, *optional*, defaults to 10): + How often to update the preconditioner. + max_precond_dim (`int`, *optional*, defaults to 10000): + Maximum dimension of the preconditioner. + Set to 10000, so that we exclude most common vocab sizes while including layers. + merge_dims (`bool`, *optional*, defaults to `False`): + Whether or not to merge dimensions of the preconditioner. + precondition_1d (`bool`, *optional*, defaults to `False`): + Whether or not to precondition 1D gradients. + normalize_grads (`bool`, *optional*, defaults to `False`): + Whether or not to normalize gradients per layer. + Helps at large precondition_frequency (~100 in our experiments), + but hurts performance at small precondition_frequency (~10 in our experiments). + data_format (`str`, *optional*, defaults to `channels_first`): + Data format of the input for convolutional layers. + Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias correction in Adam. + """ + + def __init__( + self, + params, + lr: float = 3e-3, + betas=(0.95, 0.95), + shampoo_beta: float = -1, + eps: float = 1e-8, + weight_decay: float = 0.01, + precondition_frequency: int = 10, + max_precond_dim: int = 10000, + merge_dims: bool = False, + precondition_1d: bool = False, + normalize_grads: bool = False, + data_format: str = "channels_first", + correct_bias: bool = True, + ): + defaults = { + "lr": lr, + "betas": betas, + "shampoo_beta": shampoo_beta, + "eps": eps, + "weight_decay": weight_decay, + "precondition_frequency": precondition_frequency, + "max_precond_dim": max_precond_dim, + "merge_dims": merge_dims, + "precondition_1d": precondition_1d, + "normalize_grads": normalize_grads, + "correct_bias": correct_bias, + } + super().__init__(params, defaults) + self._data_format = data_format + + def merge_dims(self, grad, max_precond_dim): + """ + Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to + max_precond_dim. + """ + assert self._data_format in ["channels_first", "channels_last"] + if self._data_format == "channels_last" and grad.dim() == 4: + grad = grad.permute(0, 3, 1, 2) + shape = grad.shape + new_shape = [] + + curr_shape = 1 + for sh in shape: + temp_shape = curr_shape * sh + if temp_shape > max_precond_dim: + if curr_shape > 1: + new_shape.append(curr_shape) + curr_shape = sh + else: + new_shape.append(sh) + curr_shape = 1 + else: + curr_shape = temp_shape + + if curr_shape > 1 or len(new_shape) == 0: + new_shape.append(curr_shape) + + new_grad = grad.reshape(new_shape) + return new_grad + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + if closure is None: + loss = None + else: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + if "Q" not in state: + self.init_preconditioner( + grad, + state, + precondition_frequency=group["precondition_frequency"], + precondition_1d=group["precondition_1d"], + shampoo_beta=(group["shampoo_beta"] if group["shampoo_beta"] >= 0 else group["betas"][1]), + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + ) + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + continue # First step is skipped so that we never use the current gradients in the projection. + + # Projecting gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + grad_projected = self.project( + grad, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2)) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner + # i.e. projecting to the eigenbases of matrices in state['GG'] + # exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"], + # max_precond_dim=group['max_precond_dim']) + exp_avg_projected = exp_avg + + step_size = group["lr"] + if group["correct_bias"]: + bias_correction1 = 1.0 - beta1 ** (state["step"]) + bias_correction2 = 1.0 - beta2 ** (state["step"]) + step_size = step_size * (bias_correction2**0.5) / bias_correction1 + + # Projecting back the preconditioned (by Adam) exponential moving average of gradients + # to the original space + norm_grad = self.project_back( + exp_avg_projected / denom, + state, + merge_dims=group["merge_dims"], + max_precond_dim=group["max_precond_dim"], + ) + + if group["normalize_grads"]: + norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5) + + p.add_(norm_grad, alpha=-step_size) + + # From AdamW code: Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + # Update is done after the gradient step to avoid using current gradients in the projection. + self.update_preconditioner( + grad, + state, + max_precond_dim=group["max_precond_dim"], + merge_dims=group["merge_dims"], + precondition_1d=group["precondition_1d"], + ) + + return loss + + def init_preconditioner( + self, + grad, + state, + precondition_frequency=10, + shampoo_beta=0.95, + max_precond_dim=10000, + precondition_1d=False, + merge_dims=False, + ): + """ + Initializes the preconditioner matrices (L and R in the paper). + """ + state["GG"] = [] # Will hold all the preconditioner matrices (L and R in the paper). + if grad.dim() == 1: + if not precondition_1d or grad.shape[0] > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)) + else: + if merge_dims: + grad = self.merge_dims(grad, max_precond_dim) + + for sh in grad.shape: + if sh > max_precond_dim: + state["GG"].append([]) + else: + state["GG"].append(torch.zeros(sh, sh, device=grad.device)) + + state["Q"] = None # Will hold all the eigenbases of the preconditioner. + state["precondition_frequency"] = precondition_frequency + state["shampoo_beta"] = shampoo_beta + + def project(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient to the eigenbases of the preconditioner. + """ + original_shape = grad.shape + if merge_dims: + if grad.dim() == 4 and self._data_format == "channels_last": + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [0]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def update_preconditioner(self, grad, state, max_precond_dim=10000, merge_dims=False, precondition_1d=False): + """ + Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). + """ + if state["Q"] is not None: + state["exp_avg"] = self.project_back( + state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim + ) + if grad.dim() == 1: + if precondition_1d and grad.shape[0] <= max_precond_dim: + state["GG"][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]) + else: + if merge_dims: + new_grad = self.merge_dims(grad, max_precond_dim) + for idx, sh in enumerate(new_grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + new_grad, + new_grad, + dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + else: + for idx, sh in enumerate(grad.shape): + if sh <= max_precond_dim: + outer_product = torch.tensordot( + grad, + grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, + ) + state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"]) + + if state["Q"] is None: + state["Q"] = self.get_orthogonal_matrix(state["GG"]) + if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0: + state["Q"] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims) + # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims) + + if state["step"] > 0: + state["exp_avg"] = self.project( + state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim + ) + + def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): + """ + Projects the gradient back to the original space. + """ + original_shape = grad.shape + if merge_dims: + if self._data_format == "channels_last" and grad.dim() == 4: + permuted_shape = grad.permute(0, 3, 1, 2).shape + grad = self.merge_dims(grad, max_precond_dim) + for mat in state["Q"]: + if len(mat) > 0: + grad = torch.tensordot( + grad, + mat, + dims=[[0], [1]], + ) + else: + permute_order = list(range(1, len(grad.shape))) + [0] + grad = grad.permute(permute_order) + + if merge_dims: + if self._data_format == "channels_last" and len(original_shape) == 4: + grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + grad = grad.reshape(original_shape) + return grad + + def get_orthogonal_matrix(self, mat): + """ + Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. + """ + matrix = [] + for m in mat: + if len(m) == 0: + matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + else: + float_data = True + matrix.append(m.data) + + final = [] + for m in matrix: + if len(m) == 0: + final.append([]) + continue + try: + _, Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device)) + except: + _, Q = torch.linalg.eigh(m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)) + Q = Q.to(m.dtype) + Q = torch.flip(Q, [1]) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + return final + + def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False): + """ + Computes the eigenbases of the preconditioner using one round of power iteration + followed by torch.linalg.qr decomposition. + """ + precond_list = state["GG"] + orth_list = state["Q"] + + matrix = [] + orth_matrix = [] + for m, o in zip(precond_list, orth_list): + if len(m) == 0: + matrix.append([]) + orth_matrix.append([]) + continue + if m.data.dtype != torch.float: + float_data = False + original_type = m.data.dtype + original_device = m.data.device + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + else: + float_data = True + matrix.append(m.data.float()) + orth_matrix.append(o.data.float()) + + orig_shape = state["exp_avg_sq"].shape + if self._data_format == "channels_last" and len(orig_shape) == 4: + permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape + if merge_dims: + exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim) + else: + exp_avg_sq = state["exp_avg_sq"] + + final = [] + for ind, (m, o) in enumerate(zip(matrix, orth_matrix)): + if len(m) == 0: + final.append([]) + continue + est_eig = torch.diag(o.T @ m @ o) + sort_idx = torch.argsort(est_eig, descending=True) + exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) + o = o[:, sort_idx] + power_iter = m @ o + Q, _ = torch.linalg.qr(power_iter) + + if not float_data: + Q = Q.to(original_device).type(original_type) + final.append(Q) + + if merge_dims: + if self._data_format == "channels_last" and len(orig_shape) == 4: + exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1) + else: + exp_avg_sq = exp_avg_sq.reshape(orig_shape) + + state["exp_avg_sq"] = exp_avg_sq + return final diff --git a/verl/trainer/config/optim/fsdp.yaml b/verl/trainer/config/optim/fsdp.yaml index a7dd99b1ee2..f7dbba4bca4 100644 --- a/verl/trainer/config/optim/fsdp.yaml +++ b/verl/trainer/config/optim/fsdp.yaml @@ -47,4 +47,15 @@ warmup_style: null # optimizer: _AdamW # override_optimizer_config: # bf16_stochastic_round: true +# Example for SOAP optimizer (paper defaults): +# optimizer_impl: verl.optimizers.soap +# optimizer: SOAP +# lr: 3e-3 +# betas: [0.95, 0.95] +# override_optimizer_config: +# precondition_frequency: 10 +# max_precond_dim: 10000 +# merge_dims: false +# precondition_1d: false +# For a full SOAP config, see trainer/config/optim/soap.yaml. override_optimizer_config: null diff --git a/verl/trainer/config/optim/soap.yaml b/verl/trainer/config/optim/soap.yaml new file mode 100644 index 00000000000..8f3f5d77517 --- /dev/null +++ b/verl/trainer/config/optim/soap.yaml @@ -0,0 +1,72 @@ +# Target class for this configuration +_target_: verl.workers.config.FSDPOptimizerConfig + +# Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW", "Adam") +optimizer: SOAP + +# Module path to import optimizer +# Examples: "torch.optim", "torchao.optim", "bitsandbytes.optim" +optimizer_impl: verl.optimizers.soap + +# Learning rate (paper default) +lr: 3e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay (paper default) +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer (paper default) +betas: [0.95, 0.95] + +# Clip gradient +clip_grad: 1.0 + +# Minimum LR ratio for cosine schedule +min_lr_ratio: 0.0 + +# Number of cosine cycles in LR schedule +num_cycles: 0.5 + +# LR scheduler type: "constant" or "cosine" +lr_scheduler_type: constant + +# deprecated +warmup_style: null + +# Additional optimizer-specific keyword arguments (paper defaults) +override_optimizer_config: + + # Use betas[1] when < 0; otherwise use the explicit value + shampoo_beta: -1 + + # Adam epsilon for numerical stability + eps: 1e-8 + + # How often to update the preconditioner + precondition_frequency: 10 + + # Maximum dimension of the preconditioner + max_precond_dim: 10000 + + # Merge dimensions when product is <= max_precond_dim + merge_dims: false + + # Whether to precondition 1D gradients + precondition_1d: false + + # Normalize gradients per layer + normalize_grads: false + + # Data format for convolutional layers: channels_first | channels_last + data_format: channels_first + + # Whether to use bias correction in Adam + correct_bias: true diff --git a/verl/workers/config/optimizer.py b/verl/workers/config/optimizer.py index bdb87667c25..ec1bffddf03 100644 --- a/verl/workers/config/optimizer.py +++ b/verl/workers/config/optimizer.py @@ -177,7 +177,7 @@ def build_optimizer(parameters, config: FSDPOptimizerConfig): } optimizer_name_lower = config.optimizer.lower() - if "adam" in optimizer_name_lower or "ademamix" in optimizer_name_lower: + if "adam" in optimizer_name_lower or "ademamix" in optimizer_name_lower or optimizer_name_lower == "soap": optimizer_args["betas"] = config.betas if config.override_optimizer_config is not None: From 44682a3c00ad9a40b98f05fe31567231318faa56 Mon Sep 17 00:00:00 2001 From: CedricHwong Date: Fri, 26 Dec 2025 16:06:08 +0000 Subject: [PATCH 2/2] Fix SOAP eigh fallback exception --- verl/optimizers/soap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verl/optimizers/soap.py b/verl/optimizers/soap.py index 2bfee6625fb..4e889efebe4 100644 --- a/verl/optimizers/soap.py +++ b/verl/optimizers/soap.py @@ -409,7 +409,7 @@ def get_orthogonal_matrix(self, mat): continue try: _, Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device)) - except: + except Exception: _, Q = torch.linalg.eigh(m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)) Q = Q.to(m.dtype) Q = torch.flip(Q, [1]) @@ -429,7 +429,7 @@ def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=Fals matrix = [] orth_matrix = [] - for m, o in zip(precond_list, orth_list): + for m, o in zip(precond_list, orth_list, strict=False): if len(m) == 0: matrix.append([]) orth_matrix.append([]) @@ -454,7 +454,7 @@ def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=Fals exp_avg_sq = state["exp_avg_sq"] final = [] - for ind, (m, o) in enumerate(zip(matrix, orth_matrix)): + for ind, (m, o) in enumerate(zip(matrix, orth_matrix, strict=False)): if len(m) == 0: final.append([]) continue