Skip to content

Commit 1feb254

Browse files
juntaoclaude
andcommitted
MLX: GPU-side argmax in decode loop, avoid 64KB logits transfer per step
decoder.step() now returns the argmax token ID (computed on GPU via mlx_argmax_axis) instead of the full 16,384-element logits vector. This eliminates a to_vec_f32() CPU round-trip per decode step, keeping the full decoder graph (8 layers + head + argmax) as one fused Metal dispatch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0522513 commit 1feb254

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

CLAUDE.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,25 @@ release zip so users need zero configuration:
463463
- **vocab.json** is generated once in a separate job and included in every platform zip,
464464
so users only need to copy it into their model directory.
465465

466-
### 18. MLX weight count differs from tch (2104 vs 2152)
466+
### 18. MLX eval() placement and GPU-side argmax
467+
468+
MLX lazy evaluation builds a computation graph that should be evaluated at outer loop
469+
boundaries, not per-layer. Our encoder correctly runs all 48 conformer layers as one lazy
470+
graph with a single `eval()` after. The decoder runs 8 layers per step — also fine.
471+
472+
The decode loop originally called `to_vec_f32()` on the logits (shape: 16,384) at every
473+
step to perform argmax on the CPU. This transferred 64 KB per token and broke the lazy graph.
474+
475+
**Fix:** use `Array::argmax_flat()` which calls `mlx_argmax_axis` on GPU and transfers a
476+
single i32 to CPU. The full graph (8 decoder layers + layer norm + linear head + argmax) is
477+
now evaluated as one fused Metal dispatch per step.
478+
479+
**Rule of thumb:**
480+
- `eval()` after encoder forward (1 call)
481+
- `argmax_flat()` after each decoder step (1 call per token, transfers 4 bytes not 64 KB)
482+
- Never `eval()` or `to_vec_f32()` per-layer or mid-graph
483+
484+
### 19. MLX weight count differs from tch (2104 vs 2152)
467485

468486
The MLX weight loader skips `num_batches_tracked` tensors (I64 dtype, used only during
469487
PyTorch training). This results in 2104 loaded tensors vs 2152 for the tch backend

src/mlx/decoder.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,16 @@ impl TransformerDecoder {
300300

301301
/// One greedy-decoding step.
302302
///
303-
/// Returns (logits: Vec<f32> of shape vocab_size, updated self_kv_cache).
303+
/// Returns (next_token_id, updated self_kv_cache).
304+
/// Argmax is computed on GPU — only a single i32 is transferred to CPU,
305+
/// avoiding a 16,384-element logits transfer per step.
304306
pub fn step(
305307
&self,
306308
token_id: i32,
307309
position: i32,
308310
self_kv_cache: &[(Option<Array>, Option<Array>)],
309311
cross_kv: &[(Array, Array)],
310-
) -> (Vec<f32>, Vec<(Option<Array>, Option<Array>)>) {
312+
) -> (i32, Vec<(Option<Array>, Option<Array>)>) {
311313
// Token embedding lookup
312314
let idx = Array::from_slice_i32(&[token_id]);
313315
let emb = ops::take(&self.token_emb, &idx, 0); // (1, hidden)
@@ -339,8 +341,10 @@ impl TransformerDecoder {
339341
let hidden = ops::squeeze(&hidden, &[1]);
340342
let logits = ops::linear(&hidden, &self.head_w, &self.head_b); // (1, vocab)
341343
let logits = ops::squeeze(&logits, &[0]); // (vocab,)
342-
let logits_vec = logits.to_vec_f32();
343344

344-
(logits_vec, new_kv)
345+
// Argmax on GPU — transfers a single i32 instead of 16,384 floats
346+
let next_token = logits.argmax_flat() as i32;
347+
348+
(next_token, new_kv)
345349
}
346350
}

src/mlx/inference.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,19 @@ pub fn transcribe(
3838
(0..decoder.layers.len()).map(|_| (None, None)).collect();
3939

4040
// 5. Prime decoder with prompt tokens
41-
let mut last_logits: Vec<f32> = Vec::new();
41+
let mut next_token = 0i32;
4242
for (i, &token_id) in prompt.iter().enumerate() {
43-
let (logits, new_kv) = decoder.step(token_id as i32, i as i32, &self_kv_cache, &cross_kv);
43+
let (token, new_kv) = decoder.step(token_id as i32, i as i32, &self_kv_cache, &cross_kv);
4444
self_kv_cache = new_kv;
45-
last_logits = logits;
45+
next_token = token;
4646
}
4747

4848
// 6. Greedy decode until EOS or max_new_tokens
49+
// Argmax is computed on GPU inside decoder.step() — only a single i32
50+
// is transferred per step instead of the full 16,384-element logits vector.
4951
let eos_id = tokenizer.special.eos as i32;
5052
let nospeech_id = tokenizer.special.nospeech as i32;
5153
let mut generated: Vec<i64> = Vec::new();
52-
53-
let mut next_token = argmax(&last_logits) as i32;
5454
let mut position = n_prompt as i32;
5555

5656
while generated.len() < max_new_tokens {
@@ -59,24 +59,14 @@ pub fn transcribe(
5959
}
6060
generated.push(next_token as i64);
6161

62-
let (logits, new_kv) = decoder.step(next_token, position, &self_kv_cache, &cross_kv);
62+
let (token, new_kv) = decoder.step(next_token, position, &self_kv_cache, &cross_kv);
6363
self_kv_cache = new_kv;
64-
last_logits = logits;
64+
next_token = token;
6565
position += 1;
66-
next_token = argmax(&last_logits) as i32;
6766
}
6867

6968
tracing::debug!("Generated token IDs: {:?}", generated);
7069

7170
// 7. Decode tokens to text
7271
Ok(tokenizer.decode(&generated))
7372
}
74-
75-
fn argmax(logits: &[f32]) -> usize {
76-
logits
77-
.iter()
78-
.enumerate()
79-
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
80-
.map(|(i, _)| i)
81-
.unwrap_or(0)
82-
}

0 commit comments

Comments
 (0)