|
2 | 2 |
|
3 | 3 | In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth. |
4 | 4 |
|
5 | | -## Readings |
| 5 | +**📚 Readings** |
6 | 6 |
|
7 | 7 | - [Model Compression and Quantization](https://huggingface.co/blog/hf-bitsandbytes-integration) |
8 | 8 | - [MLX Extensions Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html) |
@@ -258,9 +258,61 @@ pdm run build-ext |
258 | 258 | pdm run test --week 2 --day 2 -- -k task_2 |
259 | 259 | ``` |
260 | 260 |
|
261 | | -## Task 3: Implement Metal Kernel |
| 261 | +## Task 3: Implement `quantized_matmul` (GPU version) |
262 | 262 |
|
263 | | -TBD... |
| 263 | +``` |
| 264 | +src/extensions/src/quantized_matmul.metal |
| 265 | +src/extensions/src/quantized_matmul.cpp |
| 266 | +``` |
264 | 267 |
|
265 | | -{{#include copyright.md}} |
| 268 | +In this task, you will write the Metal kernel for quantized matmul **and** wire up the `eval_gpu` method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes. |
| 269 | + |
| 270 | +### Metal Kernel |
| 271 | + |
| 272 | +You need to implement one kernel entry in `quantized_matmul.metal`: |
| 273 | + |
| 274 | +- Use a **one-thread-per-output-element** mapping: each thread computes `out[i, k]`. |
| 275 | +- The kernel should be templated on the data type (to support both `half` and `bfloat16_t`). |
| 276 | +- Apply the same group-wise dequantization loop as the CPU version: |
| 277 | + - Iterate over groups (`group_size = 64`) |
| 278 | + - Unpack int4 values from packed `uint32` |
| 279 | + - Dequantize with `q * scale + bias` |
| 280 | + - Accumulate in `float` and cast to the output dtype at the end |
| 281 | +- Add boundary checks (`i < M`, `k < K`) before writing output. |
| 282 | + |
| 283 | +### GPU Dispatch |
266 | 284 |
|
| 285 | +Complete the `eval_gpu` method in `quantized_matmul.cpp` to dispatch your Metal kernel. Follow the same pattern as `axpby`'s GPU dispatch: |
| 286 | + |
| 287 | +1. Get the Metal device and command encoder from the stream. |
| 288 | +2. Select the correct kernel name based on the activation dtype (`float16` → `half`, `bfloat16` → `bfloat16_t`). |
| 289 | +3. Set input/output buffers and dimension constants (`M`, `N`, `K`) on the encoder — make sure the buffer order matches your kernel signature. |
| 290 | +4. Calculate a 2D thread group configuration: use `kernel->maxTotalThreadsPerThreadgroup()` to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K). |
| 291 | +5. Dispatch with `dispatchThreadgroups`. |
| 292 | + |
| 293 | +You can test your implementation by running: |
| 294 | + |
| 295 | +```bash |
| 296 | +pdm run build-ext |
| 297 | +pdm run test --week 2 --day 2 -- -k task_3 |
| 298 | +``` |
| 299 | + |
| 300 | +## Task 4: Model Integration |
| 301 | + |
| 302 | +``` |
| 303 | +src/tiny_llm/qwen2_week2.py |
| 304 | +``` |
| 305 | + |
| 306 | +Integrate your quantized matmul into the Week 2 Qwen2 model so that inference runs on quantized weights end-to-end. |
| 307 | + |
| 308 | +Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full float16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize. |
| 309 | + |
| 310 | +Note that MLX loads quantized models with `scales` and `biases` stored in **bfloat16** by default, while the activation tensors are typically **float16**. Since we have not implemented bfloat16 support in our kernel, you will need to convert the scales and biases to float16 with `mx.astype` before calling the kernel. If you see `nan` or garbage output, a dtype mismatch is the most likely cause. |
| 311 | + |
| 312 | +You can test your implementation by running: |
| 313 | + |
| 314 | +```bash |
| 315 | +pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b |
| 316 | +``` |
| 317 | + |
| 318 | +{{#include copyright.md}} |
0 commit comments