Skip to content

Commit b2393a2

Browse files
authored
docs: add Week 2 Day 2-3 Quantized Matmul chapter CPU part (#88)
* docs: add Week 2 Day 2-3 Quantized Matmul chapter - Add quantized matmul documentation (week2-02-quantized-matmul.md) Signed-off-by: Connor1996 <zbk602423539@gmail.com>
1 parent 0c95267 commit b2393a2

File tree

5 files changed

+284
-14
lines changed

5 files changed

+284
-14
lines changed

.cspell.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
"bfloat",
2424
"multihead",
2525
"vllm",
26-
"silu"
26+
"silu",
27+
"GFLOPS",
28+
"TFLOPS",
29+
"dequantized",
30+
"dequantization"
2731
],
2832
"ignoreRegExpList": [
2933
"`[^`]*`",
3034
]
31-
}
35+
}

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Week 1 is complete. Week 2 is in progress.
3838
| 1.6 | Generate Responses (aka Decoding) ||||
3939
| 1.7 | Sampling ||||
4040
| 2.1 | Key-Value Cache ||||
41-
| 2.2 | Quantized Matmul and Linear - CPU ||| 🚧 |
41+
| 2.2 | Quantized Matmul and Linear - CPU ||| |
4242
| 2.3 | Quantized Matmul and Linear - GPU ||| 🚧 |
4343
| 2.4 | Flash Attention 2 - CPU ||| 🚧 |
4444
| 2.5 | Flash Attention 2 - GPU ||| 🚧 |

book/src/SUMMARY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
- [Sampling and Preparing for Week 2](./week1-07-sampling-prepare.md)
1616
- [Week 2: Tiny vLLM](./week2-overview.md)
1717
- [Key-Value Cache](./week2-01-kv-cache.md)
18-
- [Quantized Matmul (2 Days)]()
18+
- [Quantized Matmul (2 Days)](./week2-02-quantized-matmul.md)
1919
- [Flash Attention (2 Days)]()
2020
- [Continuous Batching (2 Days)](./week2-06-prefill-and-batch.md)
2121
- [Week 3: Serving]()
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# Week 2 Day 2-3: Quantized Matmul
2+
3+
In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth.
4+
5+
## Readings
6+
7+
- [Model Compression and Quantization](https://huggingface.co/blog/hf-bitsandbytes-integration)
8+
- [MLX Extensions Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html)
9+
- [Quantized Matmul on CPU (Video)](https://www.youtube.com/watch?v=es6s6T1bTtI)
10+
- [Quantized Matmul on GPU (Video)](https://www.youtube.com/watch?v=jYCxVirq4d0)
11+
12+
## Why Quantization?
13+
14+
As we learned in the KV Cache chapter, the decode phase of LLM inference is **memory-bandwidth bound**. Let's revisit the arithmetic intensity calculation for the Qwen2-0.5B model:
15+
16+
```plain
17+
Per-token computation in decode phase:
18+
- Input: 1 token × 896 dimensions = 896 float16 values = 1.792 KB
19+
- MLP weights: 896 × 4864 × 3 matrices × 2 bytes = ~25 MB per layer
20+
- Attention weights: 896 × 896 × 4 matrices × 2 bytes = ~6 MB per layer
21+
- Total weights per layer: ~31 MB
22+
- Total for 24 layers: ~750 MB
23+
24+
FLOPs (2 per multiply-accumulate):
25+
- MLP per layer: 2 × 3 × 896 × 4864 ≈ 26M
26+
- Attention per layer: 2 × 4 × 896 × 896 ≈ 6.4M
27+
- 24 layers: ~780 million per token
28+
29+
Memory access: ~750 MB
30+
Arithmetic intensity: 780M FLOPs / 750 MB ≈ 1.0 FLOPs/Byte
31+
```
32+
33+
With M3 Max's 400 GB/s memory bandwidth and ~10 TFLOPS compute:
34+
35+
```plain
36+
Memory-bound throughput: 400 GB/s × 1.0 FLOPs/Byte = 400 GFLOPS
37+
Compute-bound throughput: 10 TFLOPS
38+
39+
We're using only ~4% of available compute!
40+
```
41+
42+
### The Solution: Quantization
43+
44+
By compressing weights from 16 bits (float16/bfloat16) to 4 bits (int4), we:
45+
46+
- **Reduce memory bandwidth by 4×**: 750 MB → ~190 MB per token
47+
- **Improve arithmetic intensity by 4×**: 1.0 → ~4.0 FLOPs/Byte
48+
- **Increase throughput by ~**: 400 GFLOPS → ~1.6 TFLOPS
49+
50+
The tradeoff is minimal accuracy loss with proper quantization techniques.
51+
52+
### Group-wise Quantization
53+
54+
Instead of quantizing all weights uniformly, we divide them into **groups** and quantize each group independently. This preserves more information about the weight distribution.
55+
56+
For a weight matrix $W$ of shape $(K, N)$, we divide each row into groups of size $G$ (typically 64 or 128):
57+
58+
```plain
59+
Original weight matrix W: K × N (float16/bfloat16)
60+
61+
Group size G = 64
62+
Number of groups per row = N / G
63+
64+
For each group of 64 consecutive values in a row:
65+
1. Find min and max values
66+
2. Compute scale and bias to map [min, max] → [0, 15] (4-bit range)
67+
3. Quantize each value using: quantized = round((value - bias) / scale)
68+
```
69+
70+
### Affine Quantization
71+
72+
We use **affine (asymmetric) quantization** which maps a floating-point range to the full integer range:
73+
74+
$$
75+
\text{quantized} = \text{round}\left(\frac{\text{value} - \text{bias}}{\text{scale}}\right)
76+
$$
77+
78+
$$
79+
\text{dequantized} = \text{quantized} \times \text{scale} + \text{bias}
80+
$$
81+
82+
For 4-bit quantization, the quantized values are in the range $[0, 15]$.
83+
84+
Given a group with minimum value $v_{min}$ and maximum value $v_{max}$:
85+
86+
$$
87+
\text{scale} = \frac{v_{max} - v_{min}}{2^{\text{bits}} - 1} = \frac{v_{max} - v_{min}}{15}
88+
$$
89+
90+
$$
91+
\text{bias} = v_{min}
92+
$$
93+
94+
**Example:**
95+
96+
```plain
97+
Group values: [-0.5, -0.3, 0.1, 0.4, 0.8]
98+
min = -0.5, max = 0.8
99+
100+
scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 ≈ 0.0867
101+
bias = -0.5
102+
103+
Quantization:
104+
-0.5 → round((-0.5 - (-0.5)) / 0.0867) = 0
105+
-0.3 → round((-0.3 - (-0.5)) / 0.0867) = 2
106+
0.1 → round((0.1 - (-0.5)) / 0.0867) = 7
107+
0.4 → round((0.4 - (-0.5)) / 0.0867) = 10
108+
0.8 → round((0.8 - (-0.5)) / 0.0867) = 15
109+
110+
Quantized: [0, 2, 7, 10, 15] (4 bits each)
111+
```
112+
113+
### Storage Format
114+
115+
For efficient storage and computation, quantized weights are packed:
116+
117+
```plain
118+
Original: K × N float16 (2 bytes each) = 2KN bytes
119+
Quantized: K × N int4 (0.5 bytes each) = 0.5KN bytes
120+
121+
Packing: 8 × 4-bit values fit in one uint32 (32 bits)
122+
123+
Weight matrix shape: K × N
124+
Quantized storage shape: K × (N / 8) uint32
125+
Scales shape: K × (N / 64) float16
126+
Biases shape: K × (N / 64) float16
127+
```
128+
129+
Example packing for 8 consecutive 4-bit values `[a, b, c, d, e, f, g, h]`:
130+
131+
```plain
132+
uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) |
133+
(d << 12) | (c << 8) | (b << 4) | a
134+
135+
Unpacking:
136+
a = (uint32_value >> 0) & 0xF
137+
b = (uint32_value >> 4) & 0xF
138+
c = (uint32_value >> 8) & 0xF
139+
...
140+
h = (uint32_value >> 28) & 0xF
141+
```
142+
143+
## Quantized Matrix Multiplication
144+
145+
### Mathematical Formulation
146+
147+
For standard matrix multiplication $C = AB^T$ where:
148+
149+
- $A$: shape $(M, N)$, float16/bfloat16 (activations)
150+
- $B$: shape $(K, N)$, **quantized** to int4 (weights)
151+
- $C$: shape $(M, K)$, float16/bfloat16 (output)
152+
153+
Each element $C[i, k]$ is computed as:
154+
155+
$$
156+
C[i, k] = \sum_{j=0}^{N-1} A[i, j] \times B[k, j]
157+
$$
158+
159+
With quantization, $B[k, j]$ is represented as:
160+
161+
$$
162+
B[k, j] = B_{\text{quantized}}[k, j] \times \text{scale}[k, g] + \text{bias}[k, g]
163+
$$
164+
165+
where $g = \lfloor j / G \rfloor$ is the group index.
166+
167+
Substituting:
168+
169+
$$
170+
C[i, k] = \sum_{g=0}^{N/G-1} \sum_{j'=0}^{G-1} A[i, g \times G + j'] \times (B_{\text{quantized}}[k, g \times G + j'] \times \text{scale}[k, g] + \text{bias}[k, g])
171+
$$
172+
173+
Rearranging:
174+
175+
$$
176+
C[i, k] = \sum_{g=0}^{N/G-1} \left( \text{scale}[k, g] \sum_{j'=0}^{G-1} A[i, g \times G + j'] \times B_{\text{quantized}}[k, g \times G + j'] + \text{bias}[k, g] \sum_{j'=0}^{G-1} A[i, g \times G + j'] \right)
177+
$$
178+
179+
This shows we can factor out the scale and bias per group, reducing the number of floating-point operations.
180+
181+
### Computation Flow
182+
183+
```plain
184+
Input:
185+
A: M × N (float16, activations)
186+
B_quantized: K × (N/8) (uint32, packed weights)
187+
scales: K × (N/64) (float16)
188+
biases: K × (N/64) (float16)
189+
190+
Output:
191+
C: M × K (float16)
192+
193+
For each output element C[i, k]:
194+
sum = 0
195+
for each group g in 0..(N/64 - 1):
196+
scale = scales[k, g]
197+
bias = biases[k, g]
198+
199+
# Process 64 values in the group (8 uint32 packs)
200+
for each pack p in 0..7:
201+
packed_value = B_quantized[k, g*8 + p]
202+
203+
# Unpack 8 × 4-bit values
204+
for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]:
205+
quantized = (packed_value >> bit_offset) & 0xF
206+
b_value = quantized * scale + bias
207+
a_value = A[i, g*64 + p*8 + bit_offset/4]
208+
sum += a_value * b_value
209+
210+
C[i, k] = sum
211+
```
212+
213+
## Task 1: Implement QuantizedWeights
214+
215+
```
216+
src/tiny_llm/quantize.py
217+
```
218+
219+
First, familiarize yourself with the `QuantizedWeights` class, which stores quantized weight information:
220+
221+
| Field | Shape | Description |
222+
|-------|-------|-------------|
223+
| `weight` | $(K, N/8)$ uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape $(K, N)$, and after packing, it becomes $(K, N/8)$. |
224+
| `scales` | $(K, N/G)$ float16 | Per-group scale factors for dequantization. Each group of $G$ consecutive values shares one scale. Recall: $\text{scale} = (v_{max} - v_{min}) / 15$ |
225+
| `biases` | $(K, N/G)$ float16 | Per-group bias (offset) for dequantization. Recall: $\text{bias} = v_{min}$ |
226+
| `group_size` | int | Number of consecutive values that share the same scale/bias (typically 64) |
227+
| `bits` | int | Quantization bit width (typically 4, meaning values are in range $[0, 15]$) |
228+
229+
The `from_mlx_layer` static method extracts these fields from MLX's quantized linear layers when loading the model.
230+
231+
Next, implement the `quantized_linear` function, which is a wrapper around `quantized_matmul` that mimics the standard `linear` function interface. And we'll implement `quantized_matmul` in the next task.
232+
233+
## Task 2: Implement `quantized_matmul` (CPU version)
234+
235+
In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing `axpby` example in the codebase — read through `axpby.h`, `axpby.cpp`, and the corresponding binding in `bindings.cpp` first as your reference.
236+
237+
```
238+
src/extensions/src/tiny_llm_ext.h
239+
src/extensions/bindings.cpp
240+
src/extensions/src/quantized_matmul.cpp
241+
src/extensions/CMakeLists.txt
242+
```
243+
244+
You need to touch three files, all within the `tiny_llm_ext` namespace:
245+
246+
- **`tiny_llm_ext.h`** — Declare the `quantized_matmul(...)` function signature and define a `QuantizedMatmul` primitive class (inheriting `mx::Primitive`). Store `group_size` and `bits` as private members.
247+
- **`bindings.cpp`** — Add an `m.def(...)` call to expose the function to Python.
248+
- **`quantized_matmul.cpp`** — Implement the `quantized_matmul(...)` function (validate inputs, compute output shape, return a lazy `mx::array`) and the `eval_cpu` method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).
249+
250+
The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, accumulate in `float` (fp32) to avoid precision loss, and cast the result back to `float16` when writing to the output.
251+
252+
Don't forget to add `src/quantized_matmul.cpp` to `target_sources` in `CMakeLists.txt`.
253+
254+
You can test your implementation by running:
255+
256+
```bash
257+
pdm run build-ext
258+
pdm run test --week 2 --day 2 -- -k task_2
259+
```
260+
261+
## Task 3: Implement Metal Kernel
262+
263+
TBD...
264+
265+
{{#include copyright.md}}
266+

tests_refsol/test_week_2_day_2.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,25 @@ def quantized_matmul_helper(
3535
assert_allclose(user_out, ref_out, precision)
3636

3737

38-
def test_task_1_quantized_matmul_simple_f16_cpu():
38+
def test_task_2_quantized_matmul_simple_f16_cpu():
3939
quantized_matmul_helper(mx.cpu, True, mx.float16)
4040

4141

42-
def test_task_1_quantized_matmul_complex_f16_cpu():
42+
def test_task_2_quantized_matmul_complex_f16_cpu():
4343
quantized_matmul_helper(mx.cpu, False, mx.float16)
4444

4545

46-
def test_task_2_quantized_matmul_simple_f16_gpu():
47-
quantized_matmul_helper(mx.gpu, True, mx.float16)
46+
def test_task_2_quantized_matmul_simple_f32_cpu():
47+
quantized_matmul_helper(mx.cpu, True, mx.float32)
4848

4949

50-
def test_task_2_quantized_matmul_complex_f16_gpu():
51-
quantized_matmul_helper(mx.gpu, False, mx.float16)
50+
def test_task_2_quantized_matmul_complex_f32_cpu():
51+
quantized_matmul_helper(mx.cpu, False, mx.float32)
5252

5353

54-
def test_task_1_quantized_matmul_simple_f32_cpu():
55-
quantized_matmul_helper(mx.cpu, True, mx.float32)
54+
def test_task_3_quantized_matmul_simple_f16_gpu():
55+
quantized_matmul_helper(mx.gpu, True, mx.float16)
5656

5757

58-
def test_task_1_quantized_matmul_complex_f32_cpu():
59-
quantized_matmul_helper(mx.cpu, False, mx.float32)
58+
def test_task_3_quantized_matmul_complex_f16_gpu():
59+
quantized_matmul_helper(mx.gpu, False, mx.float16)

0 commit comments

Comments
 (0)