Skip to content

Commit 0688e96

Browse files
authored
docs: add Week 2 Day 2-3 Quantized Matmul chapter GPU part (#89)
* docs: add week2 quantized matmul GPU part Signed-off-by: Connor1996 <zbk602423539@gmail.com>
1 parent b2393a2 commit 0688e96

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

.cspell.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
"GFLOPS",
2828
"TFLOPS",
2929
"dequantized",
30-
"dequantization"
30+
"dequantization",
31+
"dequantizes",
32+
"dtype",
3133
],
3234
"ignoreRegExpList": [
3335
"`[^`]*`",

book/src/week2-02-quantized-matmul.md

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth.
44

5-
## Readings
5+
**📚 Readings**
66

77
- [Model Compression and Quantization](https://huggingface.co/blog/hf-bitsandbytes-integration)
88
- [MLX Extensions Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html)
@@ -258,9 +258,61 @@ pdm run build-ext
258258
pdm run test --week 2 --day 2 -- -k task_2
259259
```
260260

261-
## Task 3: Implement Metal Kernel
261+
## Task 3: Implement `quantized_matmul` (GPU version)
262262

263-
TBD...
263+
```
264+
src/extensions/src/quantized_matmul.metal
265+
src/extensions/src/quantized_matmul.cpp
266+
```
264267

265-
{{#include copyright.md}}
268+
In this task, you will write the Metal kernel for quantized matmul **and** wire up the `eval_gpu` method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes.
269+
270+
### Metal Kernel
271+
272+
You need to implement one kernel entry in `quantized_matmul.metal`:
273+
274+
- Use a **one-thread-per-output-element** mapping: each thread computes `out[i, k]`.
275+
- The kernel should be templated on the data type (to support both `half` and `bfloat16_t`).
276+
- Apply the same group-wise dequantization loop as the CPU version:
277+
- Iterate over groups (`group_size = 64`)
278+
- Unpack int4 values from packed `uint32`
279+
- Dequantize with `q * scale + bias`
280+
- Accumulate in `float` and cast to the output dtype at the end
281+
- Add boundary checks (`i < M`, `k < K`) before writing output.
282+
283+
### GPU Dispatch
266284

285+
Complete the `eval_gpu` method in `quantized_matmul.cpp` to dispatch your Metal kernel. Follow the same pattern as `axpby`'s GPU dispatch:
286+
287+
1. Get the Metal device and command encoder from the stream.
288+
2. Select the correct kernel name based on the activation dtype (`float16``half`, `bfloat16``bfloat16_t`).
289+
3. Set input/output buffers and dimension constants (`M`, `N`, `K`) on the encoder — make sure the buffer order matches your kernel signature.
290+
4. Calculate a 2D thread group configuration: use `kernel->maxTotalThreadsPerThreadgroup()` to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K).
291+
5. Dispatch with `dispatchThreadgroups`.
292+
293+
You can test your implementation by running:
294+
295+
```bash
296+
pdm run build-ext
297+
pdm run test --week 2 --day 2 -- -k task_3
298+
```
299+
300+
## Task 4: Model Integration
301+
302+
```
303+
src/tiny_llm/qwen2_week2.py
304+
```
305+
306+
Integrate your quantized matmul into the Week 2 Qwen2 model so that inference runs on quantized weights end-to-end.
307+
308+
Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full float16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize.
309+
310+
Note that MLX loads quantized models with `scales` and `biases` stored in **bfloat16** by default, while the activation tensors are typically **float16**. Since we have not implemented bfloat16 support in our kernel, you will need to convert the scales and biases to float16 with `mx.astype` before calling the kernel. If you see `nan` or garbage output, a dtype mismatch is the most likely cause.
311+
312+
You can test your implementation by running:
313+
314+
```bash
315+
pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b
316+
```
317+
318+
{{#include copyright.md}}

0 commit comments

Comments
 (0)