Skip to content

Latest commit

 

History

History
168 lines (116 loc) · 6.46 KB

File metadata and controls

168 lines (116 loc) · 6.46 KB

What Is Gradient Estimation?

Gradient estimation (also called sensitivity analysis) ranks the attention modules inside the ACE-Step decoder by how much they respond to your specific dataset. Instead of blindly training every q_proj, k_proj, v_proj, and o_proj layer equally, estimation tells you which ones matter most.

Think of it like an X-ray of the model: it shows where the gradients concentrate when your audio is passed through the network.


Why It Matters

  • Targeted training: Focus the adapter on the layers that actually learn from your data.
  • Fewer wasted parameters: If layer 22 barely responds to your dataset, you don't need to train it.
  • Better results at lower rank: By selecting only the top-K most sensitive modules, a rank-32 adapter trained on 16 carefully chosen modules can outperform a rank-64 adapter spread across all 80+ modules.
  • Dataset comparison: Run estimation on two different datasets and compare -- you'll see where they differ.

How To Run Estimation

Via the Wizard (Recommended)

uv run python train.py
  1. From the main menu, select Estimate gradient sensitivity
  2. Point it to your checkpoint directory and preprocessed dataset
  3. Adjust the parameters (or press Enter for defaults)
  4. Review the results and save the JSON

Via CLI

uv run python train.py estimate \
    --checkpoint-dir ../ACE-Step-1.5/checkpoints \
    --model-variant base \
    --dataset-dir ./my_tensors \
    --estimate-batches 5 \
    --top-k 16 \
    --granularity module \
    --estimate-output ./estimate_results.json

Reading the Output

Estimation produces a JSON file with a ranked list:

[
  {"module": "decoder.layers.22.self_attn.q_proj", "sensitivity": 0.04231},
  {"module": "decoder.layers.22.self_attn.v_proj", "sensitivity": 0.03894},
  {"module": "decoder.layers.18.cross_attn.k_proj", "sensitivity": 0.03512},
  ...
]

What Each Field Means

Field Meaning
module Full dot-path name of the attention projection inside the decoder
sensitivity Average gradient norm across estimation batches (higher = more responsive)

Higher sensitivity = more important for your dataset. The modules at the top of the list are where the model "wants" to change the most when it sees your audio.


Understanding Module Names

ACE-Step's decoder is a stack of transformer layers. Each layer has attention blocks, and each attention block has four linear projections:

Projection Role
q_proj Query -- what the model is looking for
k_proj Key -- what each position offers
v_proj Value -- the actual content to read
o_proj Output -- projects the attention result back

Self-Attention vs Cross-Attention

Type Path Pattern What It Does
Self-attention decoder.layers.N.self_attn.* Relates audio positions to each other (rhythm, structure, patterns)
Cross-attention decoder.layers.N.cross_attn.* Connects audio to text conditioning (lyrics, genre, prompt)

Interpretation tips:

  • If self-attention modules rank high, your dataset has distinctive audio patterns (rhythms, timbres, structures) the model wants to learn.
  • If cross-attention modules rank high, the text/lyrics conditioning is strongly tied to the audio -- the model is learning text-to-audio alignment.
  • If a specific layer number dominates (e.g., layers 18-22), those are the layers where your dataset diverges most from the pre-trained weights.

Module-Level vs Layer-Level Granularity

Granularity --granularity What It Ranks When To Use
Module module (default) Individual projections (q_proj, k_proj, etc.) Fine-grained selection, small datasets, precise control
Layer layer Entire attention blocks (self_attn, cross_attn) Quick overview, large datasets, coarse selection

Module-level is almost always the better choice. It lets you pick exactly which projections to target. Layer-level is useful as a quick first pass to see which depth regions of the decoder are most active.


Using Results for Training

Selecting Target Modules

After estimation, the top-K modules tell you which projections to target. For example, if the top 8 modules are all q_proj and v_proj in layers 18-24:

  • You might set --target-modules "q_proj v_proj" (skip k_proj and o_proj)
  • Or focus rank on those specific layers

Practical Example

Suppose estimation returns:

#1  decoder.layers.22.self_attn.q_proj   0.042
#2  decoder.layers.22.self_attn.v_proj   0.039
#3  decoder.layers.18.cross_attn.k_proj  0.035
#4  decoder.layers.18.cross_attn.v_proj  0.033
#5  decoder.layers.20.self_attn.q_proj   0.031
...
#12 decoder.layers.5.self_attn.o_proj    0.008
#16 decoder.layers.2.cross_attn.k_proj   0.002

What this tells you:

  • Layers 18-22 are the most sensitive -- your dataset is "different" from the pre-trained model at those depths
  • Self-attention dominates -- the model wants to learn audio patterns more than text alignment
  • Layer 2 barely responds -- it's already general enough and doesn't need fine-tuning
  • q_proj and v_proj rank higher than k_proj and o_proj -- queries and values carry the signal

Action: You could train with --target-modules "q_proj v_proj" and expect strong results even at lower rank, since you're focusing on what matters.


Parameter Guide

Parameter Default Guidance
--estimate-batches 5 More batches = more stable ranking. 3-5 is enough for small datasets; 10+ for large/diverse ones.
--top-k 16 How many modules to highlight. 8-16 is a good range. Beyond 32 you're training most of the model anyway.
--granularity module Use module unless you want a quick layer-level overview first.

VRAM Considerations

Estimation loads the full model and runs forward + backward passes, similar to training. Budget the same VRAM you would for training:

GPU VRAM Recommended --estimate-batches
8 GB 3
12 GB 5
24 GB 10
48 GB 10-20

Estimation is fast -- typically 1-3 minutes regardless of batch count.


See Also

  • [[Training Guide]] -- Full training workflow and hyperparameter guide
  • [[Model Management]] -- Checkpoint structure and model selection