Ahmed Heakl, Martin Gubri, Salman Khan, Sangdoo Yun, Seong Joon Oh
Parameter Lab Β· MBZUAI Β· NAVER AI Lab Β· University of TΓΌbingen Β· TΓΌbingen AI Center
- π 24 April 2026: Full code released!
- π 3 February 2026: Paper accepted at TTT @ ICLR 2026!
- π 25 January 2026: Paper accepted at ICLR 2026!
- π’ 15 October 2025: Paper ArXived!
Large Language Models (LLMs) process every token through all layers of a transformer stack, wasting compute on simple queries and lacking flexibility for harder ones that need deeper reasoning.
Dr.LLM (Dynamic Routing of Layers for LLMs) is a retrofittable framework that adds lightweight per-layer routers to pretrained models.
Each router decides whether to skip, execute, or repeat a layer, enabling adaptive depth without retraining or architectural changes.
Routers are trained with explicit supervision from Monte Carlo Tree Search (MCTS), generating high-quality layer configurations that preserve or improve accuracy under a compute budget.
Stabilized with windowed pooling, focal loss, and bottleneck MLPs, Dr.LLM maintains robustness under class imbalance and long sequences.
π Results
- On ARC (logic) and DART (math), Dr.LLM improves accuracy by +3.4%p while saving ~5 layers per input.
- Routers generalize to MMLU, GSM8k, AIME, TruthfulQA, SQuADv2, GPQA, PIQA, and AGIEval with only 0.85% accuracy drop.
- Outperforms prior routing methods (LayerSkip, FlexiDepth, MindSkip) by up to +7.7%p.
π‘ Dr.LLM equips frozen LLMs for budget-aware, accuracy-driven inference β no base weight modification required.
Our layer routing based on hidden states. Dr.LLM augments a frozen decoder-only LLM with per-layer routers that decide to skip, execute, or repeat a block once. Routers read windowed summaries of hidden states and are trained from MCTS-derived targets.
Length-aware MCTS used to collect the supervised training dataset of per-layer routing configurations (skip/execute/repeat). For each input, MCTS explores modified layer paths and retains accuracy-preserving or improving ones under a compute budget.
We evaluate Dr.LLM using lm-eval-harness across in-domain and out-of-domain benchmarks.
Routers are trained and evaluated on ARC-Easy/Challenge (logic) and DART-Math (levels 1β5) (multi-step math reasoning), using 4K MCTS-derived execution paths.
| Dataset | Domain | Metric |
|---|---|---|
| ARC-Easy / Challenge | Logic Reasoning | Accuracy |
| DART (levels 1β5) | Math Reasoning | Accuracy |
We test zero-shot transfer on MMLU, GSM8k, AIME24, TruthfulQA, GPQA Diamond, AGIEval, SQuADv2, and PIQA.
All evaluations follow default lm-eval-harness settings (2048 max tokens, greedy decoding).
git clone https://github.com/parameterlab/dr-llm
cd dr-llm
pip install -r requirements.txtThe data generation pipeline uses length-aware MCTS to discover optimal per-layer routing configurations (skip/execute/repeat) for each training example.
Modified model files compatible with the data generation pipeline are provided in data_models/:
data_models/
βββ modeling_llama.py
βββ modeling_qwen2.py
βββ modeling_qwen3.py
βββ ...
These files expose a layer_indices attribute on the base model class, which the MCTS search manipulates at runtime to explore different execution paths β no weight modification required. See data_models/README.md for instructions on adapting a new model architecture.
python data_generation.py \
--model meta-llama/Llama-3.2-3B-Instruct \
--dataset arc,dart \
--output_dir data/mcts_paths \
--num_simulations 50 \
--budget 2| Argument | Default | Description |
|---|---|---|
--model |
β | HuggingFace model path or ID |
--dataset |
arc,dart |
Comma-separated list of datasets |
--num_simulations |
50 |
MCTS simulations per example |
--budget |
2 |
Max path length factor (cap at 2L) |
--output_dir |
data/ |
Where to save routing configurations |
Each output file contains MCTS-derived tuples (question, optimal_layer_config, answer) where optimal_layer_config is a vector of {0=skip, 1=execute, 2=repeat} labels of length L (number of layers). These are used directly as supervision targets for router training.
Training uses AdamW, 25 epochs, 1Γ10β»Β³ LR, bf16 precision, and a single A100 GPU (40GB) β taking under 4 hours with only 4K MCTS-derived examples.
Modified model files compatible with router training are provided in train_models/:
train_models/
βββ modeling_llama.py
βββ modeling_qwen2.py
βββ modeling_qwen3.py
βββ ...
These files insert a RouterBlock (Linear-GELU-Linear, hidden dim 128) after each transformer block and expose init_routers(), num_windows, and is_static_routing on the base model class. See train_models/README.md for instructions on adapting a new model architecture.
python train.py \
--model_id meta-llama/Llama-3.2-3B-Instruct \
--run_name drllm-llama-3b \
--num_epochs 25 \
--learning_rate 1e-3 \
--weight_decay 0.01 \
--warmup_steps 500 \
--gradient_accumulation 16| Argument | Default | Description |
|---|---|---|
--model_id |
β | HuggingFace model path or ID |
--run_name |
β | Run name for checkpoints and W&B logging |
--num_epochs |
15 |
Number of training epochs |
--learning_rate |
1e-3 |
Learning rate |
--weight_decay |
0.01 |
AdamW weight decay |
--warmup_steps |
500 |
LR warmup steps |
--gradient_accumulation |
16 |
Gradient accumulation steps |
--num_windows |
8 |
Number of pooling windows for router input |
--with_squad |
False |
Include SQuADv2 data in training |
--with_commonsense |
False |
Include commonsense data in training |
- Frozen base: all base model parameters are frozen; only
model.model.routersis trained (11M params for 3B models, 0.14% of total weights) - Loss: focal loss with effective-number class rebalancing (
Ξ²=0.999,Ξ³=2) to handle the heavy skip/execute/repeat class imbalance - Router input: windowed mean-pooled hidden states from the previous layer (default 8 windows)
- Teacher forcing: ground-truth routing labels are used during training to avoid inter-router dependency
- Optimizer: AdamW with cosine LR schedule
- Precision: bf16
- Logging: Weights & Biases (
--report_to wandb) - Checkpoints: saved to
checkpoints/{run_name}/
During training, the following metrics are logged per step:
| Metric | Description |
|---|---|
macro_f1 |
Macro F1 across skip/execute/repeat |
f1_skip / f1_execute / f1_repeat |
Per-class F1 scores |
acc_skip / acc_repeat |
Per-class accuracy for minority classes |
avg_layers |
Average number of layers executed per example |
routers_loss |
Focal loss on routing decisions |
We evaluate Dr.LLM using lm-evaluation-harness. The same modified model files from train_models/ are used for evaluation β no additional changes needed.
To evaluate the vanilla model without routing:
accelerate launch --multi_gpu --num_processes 4 lm_eval \
--model hf \
--model_args pretrained="meta-llama/Llama-3.2-3B-Instruct" \
--tasks arc_challenge,arc_easy,mmlu,aime24,truthfulqa,gsm8k,piqa \
--batch_size 1To evaluate a trained Dr.LLM checkpoint:
accelerate launch --num_processes 2 lm_eval \
--model hf \
--model_args pretrained=checkpoints/drllm-llama-3b-instruct,dtype=bfloat16,num_windows=8 \
--tasks arc_challenge,arc_easy,mmlu,aime24,truthfulqa,gsm8k,piqa \
--batch_size 1 \
--gen_kwargs max_new_tokens=256 \
--cache_requests true| Argument | Description |
|---|---|
pretrained |
Path to HuggingFace model ID or local Dr.LLM checkpoint |
dtype |
Model precision, use bfloat16 |
num_windows |
Number of pooling windows β must match the value used during training |
--tasks |
Comma-separated list of benchmarks |
--batch_size |
Batch size per device, use 1 for stability |
--gen_kwargs |
Generation kwargs, e.g. max_new_tokens=256 |
--cache_requests |
Cache tokenized requests to speed up repeated runs |
| Benchmark | Domain | Split |
|---|---|---|
arc_easy / arc_challenge |
Logic Reasoning | In-domain |
dart (levels 1β5) |
Math Reasoning | In-domain |
mmlu |
Factual Knowledge | Out-of-domain |
gsm8k |
Grade-school Math | Out-of-domain |
aime24 |
Competition Math | Out-of-domain |
truthfulqa |
Adversarial Factuality | Out-of-domain |
gpqa_diamond |
Graduate Reasoning | Out-of-domain |
agieval |
Exam Reasoning | Out-of-domain |
squadv2 |
Reading Comprehension | Out-of-domain |
piqa |
Commonsense Reasoning | Out-of-domain |
- All results in the paper use greedy decoding, 2048 max tokens, and default
lm-eval-harnesssettings. num_windowsmust match the value used during training (default:8). Mismatches will silently produce incorrect routing decisions.- Set
is_static_routing=Truein the model to force all decisions toexecuteβ useful for sanity-checking that the base model is loaded correctly before evaluating routing.
If you find this work useful, please cite:
@article{heakl2025drllm,
title={Dr.LLM: Dynamic Layer Routing in LLMs},
author={Ahmed Heakl and Martin Gubri and Salman Khan and Sangdoo Yun and Seong Joon Oh},
journal={arXiv preprint arXiv:2510.12773},
year={2025}
}

