-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Hello xinding-sys,
First, thank you for your great work on this project.
I am currently trying to fine-tune the DAMO-NLP-SG/VideoLLaMA2-7B model using the provided code. My setup uses clip-vit-large-patch14-336 as the vision_tower and the default Video_Mamba_seq module as the mm_projector.
When enabling mixed precision (bf16=True or fp16=True) for training, I've found that the loss becomes NaN on the very first forward pass (step 0).
After extensive debugging, I've traced the root cause to the weights of the Video_Mamba_seq module. The module's parameters already contain NaN values immediately after initialization, before any training begins. This appears to be caused by a numerical instability in the project's custom weight initialization logic when operating in a mixed-precision environment.
Below are the detailed environment settings, reproduction steps, and debugging logs to help locate the issue.
Environment
- OS:
Ubuntu 24.04 - PyTorch:
2.10.0+cu130 - Transformers:
4.40.0 - mamba_ssm:
2.2.5 - CUDA Version:
13.0 - GPU:
RTX PRO 6000
Steps to Reproduce
- Set up the training script with the following base models:
--model_name_or_path DAMO-NLP-SG/VideoLLaMA2-7B \ --vision_tower clip-vit-large-patch14-336 \
- Ensure the configuration uses the default
mm_projector_type='mamba'. - Enable mixed precision in
TrainingArguments(e.g.,bf16=True). - Place a breakpoint in the
initialize_vision_modulesfunction, immediately after theself.mm_projector = build_vision_projector(self.config)line is called. - Launch the training script.
- At the PDB breakpoint, run the following command to inspect the
mm_projectorweights:for name, param in self.mm_projector.named_parameters(): print(f"{name}: {torch.isnan(param).any()}")
Actual Behavior
1. High-Level Symptom: NaN Loss in Forward Pass
The training fails at step 0 with a NaN loss, as seen in the forward pass logs:
--- [FWD LOG] Received output from Core LLM ---
[FWD LOG] - LLM Loss: nan
[FWD LOG] - LLM Logits shape: torch.Size([1, 11105, 32000])
2. Root Cause: NaN Weights at Initialization
The debugging breakpoint reveals that multiple weights inside the mm_projector (Video_Mamba_seq) module are NaN immediately after initialization.
A. mamba_model's NaN Weights:
Key parameters within the core Mamba module are initialized as NaN. Here are some examples:
Checking mamba_model weight ssms.0.norm.bias: True
Checking mamba_model weight ssms.0.mixer.A_log: True
Checking mamba_model weight ssms.0.mixer.D: True
Checking mamba_model weight ssms.0.mixer.conv1d.bias: True
Checking mamba_model weight norm_fn.weight: True
Checking mamba_model weight norm_fn.bias: True
... (and others)
B. cls_net's NaN Weights:
Surprisingly, LayerNorm weights within the associated cls_net also exhibit the same issue.
Checking projector weight cls_model.model.layers.1.input_layernorm.weight: True
Checking projector weight cls_model.model.layers.1.post_attention_layernorm.weight: True
Checking projector weight cls_model.model.layers.3.post_attention_layernorm.weight: True
Checking projector weight cls_model.model.norm.weight: True
... (and others)
Expected Behavior
All model weights should be initialized with valid, non-NaN floating-point numbers. The training process should start with a valid, finite loss value.
Specific Questions
- Is this a known issue regarding the
Mambamodule's initialization under mixed precision? - My analysis suggests the custom initialization logic (e.g., for
A_logandLayerNormlayers) is the source of theNaNvalues. Could you please confirm if this diagnosis is correct? - What is the recommended and most stable way to initialize this Mamba module for fine-tuning to avoid these numerical issues?
Thank you for your time and any guidance you can provide!