Commit 1feb254
MLX: GPU-side argmax in decode loop, avoid 64KB logits transfer per step
decoder.step() now returns the argmax token ID (computed on GPU via
mlx_argmax_axis) instead of the full 16,384-element logits vector.
This eliminates a to_vec_f32() CPU round-trip per decode step, keeping
the full decoder graph (8 layers + head + argmax) as one fused Metal
dispatch.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent 0522513 commit 1feb254
3 files changed
+34
-22
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
463 | 463 | | |
464 | 464 | | |
465 | 465 | | |
466 | | - | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
467 | 485 | | |
468 | 486 | | |
469 | 487 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
300 | 300 | | |
301 | 301 | | |
302 | 302 | | |
303 | | - | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
304 | 306 | | |
305 | 307 | | |
306 | 308 | | |
307 | 309 | | |
308 | 310 | | |
309 | 311 | | |
310 | | - | |
| 312 | + | |
311 | 313 | | |
312 | 314 | | |
313 | 315 | | |
| |||
339 | 341 | | |
340 | 342 | | |
341 | 343 | | |
342 | | - | |
343 | 344 | | |
344 | | - | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
345 | 349 | | |
346 | 350 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
41 | | - | |
| 41 | + | |
42 | 42 | | |
43 | | - | |
| 43 | + | |
44 | 44 | | |
45 | | - | |
| 45 | + | |
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
| 49 | + | |
| 50 | + | |
49 | 51 | | |
50 | 52 | | |
51 | 53 | | |
52 | | - | |
53 | | - | |
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
62 | | - | |
| 62 | + | |
63 | 63 | | |
64 | | - | |
| 64 | + | |
65 | 65 | | |
66 | | - | |
67 | 66 | | |
68 | 67 | | |
69 | 68 | | |
70 | 69 | | |
71 | 70 | | |
72 | 71 | | |
73 | 72 | | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
0 commit comments