Skip to content

NaN Weights in Mamba Projector During Initialization with Mixed Precision #10

@aiden1020

Description

@aiden1020

Hello xinding-sys,

First, thank you for your great work on this project.

I am currently trying to fine-tune the DAMO-NLP-SG/VideoLLaMA2-7B model using the provided code. My setup uses clip-vit-large-patch14-336 as the vision_tower and the default Video_Mamba_seq module as the mm_projector.

When enabling mixed precision (bf16=True or fp16=True) for training, I've found that the loss becomes NaN on the very first forward pass (step 0).

After extensive debugging, I've traced the root cause to the weights of the Video_Mamba_seq module. The module's parameters already contain NaN values immediately after initialization, before any training begins. This appears to be caused by a numerical instability in the project's custom weight initialization logic when operating in a mixed-precision environment.

Below are the detailed environment settings, reproduction steps, and debugging logs to help locate the issue.

Environment

  • OS: Ubuntu 24.04
  • PyTorch: 2.10.0+cu130
  • Transformers: 4.40.0
  • mamba_ssm: 2.2.5
  • CUDA Version: 13.0
  • GPU: RTX PRO 6000

Steps to Reproduce

  1. Set up the training script with the following base models:
    --model_name_or_path DAMO-NLP-SG/VideoLLaMA2-7B \
    --vision_tower clip-vit-large-patch14-336 \
  2. Ensure the configuration uses the default mm_projector_type='mamba'.
  3. Enable mixed precision in TrainingArguments (e.g., bf16=True ).
  4. Place a breakpoint in the initialize_vision_modules function, immediately after the self.mm_projector = build_vision_projector(self.config) line is called.
  5. Launch the training script.
  6. At the PDB breakpoint, run the following command to inspect the mm_projector weights:
    for name, param in self.mm_projector.named_parameters():
        print(f"{name}: {torch.isnan(param).any()}")

Actual Behavior

1. High-Level Symptom: NaN Loss in Forward Pass
The training fails at step 0 with a NaN loss, as seen in the forward pass logs:

--- [FWD LOG] Received output from Core LLM ---
[FWD LOG]   - LLM Loss: nan
[FWD LOG]   - LLM Logits shape: torch.Size([1, 11105, 32000])

2. Root Cause: NaN Weights at Initialization
The debugging breakpoint reveals that multiple weights inside the mm_projector (Video_Mamba_seq) module are NaN immediately after initialization.

A. mamba_model's NaN Weights:
Key parameters within the core Mamba module are initialized as NaN. Here are some examples:

Checking mamba_model weight ssms.0.norm.bias: True
Checking mamba_model weight ssms.0.mixer.A_log: True
Checking mamba_model weight ssms.0.mixer.D: True
Checking mamba_model weight ssms.0.mixer.conv1d.bias: True
Checking mamba_model weight norm_fn.weight: True
Checking mamba_model weight norm_fn.bias: True
... (and others)

B. cls_net's NaN Weights:
Surprisingly, LayerNorm weights within the associated cls_net also exhibit the same issue.

Checking projector weight cls_model.model.layers.1.input_layernorm.weight: True
Checking projector weight cls_model.model.layers.1.post_attention_layernorm.weight: True
Checking projector weight cls_model.model.layers.3.post_attention_layernorm.weight: True
Checking projector weight cls_model.model.norm.weight: True
... (and others)

Expected Behavior

All model weights should be initialized with valid, non-NaN floating-point numbers. The training process should start with a valid, finite loss value.

Specific Questions

  1. Is this a known issue regarding the Mamba module's initialization under mixed precision?
  2. My analysis suggests the custom initialization logic (e.g., for A_log and LayerNorm layers) is the source of the NaN values. Could you please confirm if this diagnosis is correct?
  3. What is the recommended and most stable way to initialize this Mamba module for fine-tuning to avoid these numerical issues?

Thank you for your time and any guidance you can provide!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions