Gradient estimation (also called sensitivity analysis) ranks the attention modules inside the ACE-Step decoder by how much they respond to your specific dataset. Instead of blindly training every q_proj, k_proj, v_proj, and o_proj layer equally, estimation tells you which ones matter most.
Think of it like an X-ray of the model: it shows where the gradients concentrate when your audio is passed through the network.
- Targeted training: Focus the adapter on the layers that actually learn from your data.
- Fewer wasted parameters: If layer 22 barely responds to your dataset, you don't need to train it.
- Better results at lower rank: By selecting only the top-K most sensitive modules, a rank-32 adapter trained on 16 carefully chosen modules can outperform a rank-64 adapter spread across all 80+ modules.
- Dataset comparison: Run estimation on two different datasets and compare -- you'll see where they differ.
uv run python train.py- From the main menu, select Estimate gradient sensitivity
- Point it to your checkpoint directory and preprocessed dataset
- Adjust the parameters (or press Enter for defaults)
- Review the results and save the JSON
uv run python train.py estimate \
--checkpoint-dir ../ACE-Step-1.5/checkpoints \
--model-variant base \
--dataset-dir ./my_tensors \
--estimate-batches 5 \
--top-k 16 \
--granularity module \
--estimate-output ./estimate_results.jsonEstimation produces a JSON file with a ranked list:
[
{"module": "decoder.layers.22.self_attn.q_proj", "sensitivity": 0.04231},
{"module": "decoder.layers.22.self_attn.v_proj", "sensitivity": 0.03894},
{"module": "decoder.layers.18.cross_attn.k_proj", "sensitivity": 0.03512},
...
]| Field | Meaning |
|---|---|
module |
Full dot-path name of the attention projection inside the decoder |
sensitivity |
Average gradient norm across estimation batches (higher = more responsive) |
Higher sensitivity = more important for your dataset. The modules at the top of the list are where the model "wants" to change the most when it sees your audio.
ACE-Step's decoder is a stack of transformer layers. Each layer has attention blocks, and each attention block has four linear projections:
| Projection | Role |
|---|---|
q_proj |
Query -- what the model is looking for |
k_proj |
Key -- what each position offers |
v_proj |
Value -- the actual content to read |
o_proj |
Output -- projects the attention result back |
| Type | Path Pattern | What It Does |
|---|---|---|
| Self-attention | decoder.layers.N.self_attn.* |
Relates audio positions to each other (rhythm, structure, patterns) |
| Cross-attention | decoder.layers.N.cross_attn.* |
Connects audio to text conditioning (lyrics, genre, prompt) |
Interpretation tips:
- If self-attention modules rank high, your dataset has distinctive audio patterns (rhythms, timbres, structures) the model wants to learn.
- If cross-attention modules rank high, the text/lyrics conditioning is strongly tied to the audio -- the model is learning text-to-audio alignment.
- If a specific layer number dominates (e.g., layers 18-22), those are the layers where your dataset diverges most from the pre-trained weights.
| Granularity | --granularity |
What It Ranks | When To Use |
|---|---|---|---|
| Module | module (default) |
Individual projections (q_proj, k_proj, etc.) |
Fine-grained selection, small datasets, precise control |
| Layer | layer |
Entire attention blocks (self_attn, cross_attn) |
Quick overview, large datasets, coarse selection |
Module-level is almost always the better choice. It lets you pick exactly which projections to target. Layer-level is useful as a quick first pass to see which depth regions of the decoder are most active.
After estimation, the top-K modules tell you which projections to target. For example, if the top 8 modules are all q_proj and v_proj in layers 18-24:
- You might set
--target-modules "q_proj v_proj"(skipk_projando_proj) - Or focus rank on those specific layers
Suppose estimation returns:
#1 decoder.layers.22.self_attn.q_proj 0.042
#2 decoder.layers.22.self_attn.v_proj 0.039
#3 decoder.layers.18.cross_attn.k_proj 0.035
#4 decoder.layers.18.cross_attn.v_proj 0.033
#5 decoder.layers.20.self_attn.q_proj 0.031
...
#12 decoder.layers.5.self_attn.o_proj 0.008
#16 decoder.layers.2.cross_attn.k_proj 0.002
What this tells you:
- Layers 18-22 are the most sensitive -- your dataset is "different" from the pre-trained model at those depths
- Self-attention dominates -- the model wants to learn audio patterns more than text alignment
- Layer 2 barely responds -- it's already general enough and doesn't need fine-tuning
q_projandv_projrank higher thank_projando_proj-- queries and values carry the signal
Action: You could train with --target-modules "q_proj v_proj" and expect strong results even at lower rank, since you're focusing on what matters.
| Parameter | Default | Guidance |
|---|---|---|
--estimate-batches |
5 | More batches = more stable ranking. 3-5 is enough for small datasets; 10+ for large/diverse ones. |
--top-k |
16 | How many modules to highlight. 8-16 is a good range. Beyond 32 you're training most of the model anyway. |
--granularity |
module |
Use module unless you want a quick layer-level overview first. |
Estimation loads the full model and runs forward + backward passes, similar to training. Budget the same VRAM you would for training:
| GPU VRAM | Recommended --estimate-batches |
|---|---|
| 8 GB | 3 |
| 12 GB | 5 |
| 24 GB | 10 |
| 48 GB | 10-20 |
Estimation is fast -- typically 1-3 minutes regardless of batch count.
- [[Training Guide]] -- Full training workflow and hyperparameter guide
- [[Model Management]] -- Checkpoint structure and model selection