|
| 1 | +# Week 2 Day 2-3: Quantized Matmul |
| 2 | + |
| 3 | +In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth. |
| 4 | + |
| 5 | +## Readings |
| 6 | + |
| 7 | +- [Model Compression and Quantization](https://huggingface.co/blog/hf-bitsandbytes-integration) |
| 8 | +- [MLX Extensions Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html) |
| 9 | +- [Quantized Matmul on CPU (Video)](https://www.youtube.com/watch?v=es6s6T1bTtI) |
| 10 | +- [Quantized Matmul on GPU (Video)](https://www.youtube.com/watch?v=jYCxVirq4d0) |
| 11 | + |
| 12 | +## Why Quantization? |
| 13 | + |
| 14 | +As we learned in the KV Cache chapter, the decode phase of LLM inference is **memory-bandwidth bound**. Let's revisit the arithmetic intensity calculation for the Qwen2-0.5B model: |
| 15 | + |
| 16 | +```plain |
| 17 | +Per-token computation in decode phase: |
| 18 | +- Input: 1 token × 896 dimensions = 896 float16 values = 1.792 KB |
| 19 | +- MLP weights: 896 × 4864 × 3 matrices × 2 bytes = ~25 MB per layer |
| 20 | +- Attention weights: 896 × 896 × 4 matrices × 2 bytes = ~6 MB per layer |
| 21 | +- Total weights per layer: ~31 MB |
| 22 | +- Total for 24 layers: ~750 MB |
| 23 | +
|
| 24 | +FLOPs (2 per multiply-accumulate): |
| 25 | +- MLP per layer: 2 × 3 × 896 × 4864 ≈ 26M |
| 26 | +- Attention per layer: 2 × 4 × 896 × 896 ≈ 6.4M |
| 27 | +- 24 layers: ~780 million per token |
| 28 | +
|
| 29 | +Memory access: ~750 MB |
| 30 | +Arithmetic intensity: 780M FLOPs / 750 MB ≈ 1.0 FLOPs/Byte |
| 31 | +``` |
| 32 | + |
| 33 | +With M3 Max's 400 GB/s memory bandwidth and ~10 TFLOPS compute: |
| 34 | + |
| 35 | +```plain |
| 36 | +Memory-bound throughput: 400 GB/s × 1.0 FLOPs/Byte = 400 GFLOPS |
| 37 | +Compute-bound throughput: 10 TFLOPS |
| 38 | +
|
| 39 | +We're using only ~4% of available compute! |
| 40 | +``` |
| 41 | + |
| 42 | +### The Solution: Quantization |
| 43 | + |
| 44 | +By compressing weights from 16 bits (float16/bfloat16) to 4 bits (int4), we: |
| 45 | + |
| 46 | +- **Reduce memory bandwidth by 4×**: 750 MB → ~190 MB per token |
| 47 | +- **Improve arithmetic intensity by 4×**: 1.0 → ~4.0 FLOPs/Byte |
| 48 | +- **Increase throughput by ~4×**: 400 GFLOPS → ~1.6 TFLOPS |
| 49 | + |
| 50 | +The tradeoff is minimal accuracy loss with proper quantization techniques. |
| 51 | + |
| 52 | +### Group-wise Quantization |
| 53 | + |
| 54 | +Instead of quantizing all weights uniformly, we divide them into **groups** and quantize each group independently. This preserves more information about the weight distribution. |
| 55 | + |
| 56 | +For a weight matrix $W$ of shape $(K, N)$, we divide each row into groups of size $G$ (typically 64 or 128): |
| 57 | + |
| 58 | +```plain |
| 59 | +Original weight matrix W: K × N (float16/bfloat16) |
| 60 | +
|
| 61 | +Group size G = 64 |
| 62 | +Number of groups per row = N / G |
| 63 | +
|
| 64 | +For each group of 64 consecutive values in a row: |
| 65 | + 1. Find min and max values |
| 66 | + 2. Compute scale and bias to map [min, max] → [0, 15] (4-bit range) |
| 67 | + 3. Quantize each value using: quantized = round((value - bias) / scale) |
| 68 | +``` |
| 69 | + |
| 70 | +### Affine Quantization |
| 71 | + |
| 72 | +We use **affine (asymmetric) quantization** which maps a floating-point range to the full integer range: |
| 73 | + |
| 74 | +$$ |
| 75 | +\text{quantized} = \text{round}\left(\frac{\text{value} - \text{bias}}{\text{scale}}\right) |
| 76 | +$$ |
| 77 | + |
| 78 | +$$ |
| 79 | +\text{dequantized} = \text{quantized} \times \text{scale} + \text{bias} |
| 80 | +$$ |
| 81 | + |
| 82 | +For 4-bit quantization, the quantized values are in the range $[0, 15]$. |
| 83 | + |
| 84 | +Given a group with minimum value $v_{min}$ and maximum value $v_{max}$: |
| 85 | + |
| 86 | +$$ |
| 87 | +\text{scale} = \frac{v_{max} - v_{min}}{2^{\text{bits}} - 1} = \frac{v_{max} - v_{min}}{15} |
| 88 | +$$ |
| 89 | + |
| 90 | +$$ |
| 91 | +\text{bias} = v_{min} |
| 92 | +$$ |
| 93 | + |
| 94 | +**Example:** |
| 95 | + |
| 96 | +```plain |
| 97 | +Group values: [-0.5, -0.3, 0.1, 0.4, 0.8] |
| 98 | +min = -0.5, max = 0.8 |
| 99 | +
|
| 100 | +scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 ≈ 0.0867 |
| 101 | +bias = -0.5 |
| 102 | +
|
| 103 | +Quantization: |
| 104 | + -0.5 → round((-0.5 - (-0.5)) / 0.0867) = 0 |
| 105 | + -0.3 → round((-0.3 - (-0.5)) / 0.0867) = 2 |
| 106 | + 0.1 → round((0.1 - (-0.5)) / 0.0867) = 7 |
| 107 | + 0.4 → round((0.4 - (-0.5)) / 0.0867) = 10 |
| 108 | + 0.8 → round((0.8 - (-0.5)) / 0.0867) = 15 |
| 109 | +
|
| 110 | +Quantized: [0, 2, 7, 10, 15] (4 bits each) |
| 111 | +``` |
| 112 | + |
| 113 | +### Storage Format |
| 114 | + |
| 115 | +For efficient storage and computation, quantized weights are packed: |
| 116 | + |
| 117 | +```plain |
| 118 | +Original: K × N float16 (2 bytes each) = 2KN bytes |
| 119 | +Quantized: K × N int4 (0.5 bytes each) = 0.5KN bytes |
| 120 | +
|
| 121 | +Packing: 8 × 4-bit values fit in one uint32 (32 bits) |
| 122 | +
|
| 123 | +Weight matrix shape: K × N |
| 124 | +Quantized storage shape: K × (N / 8) uint32 |
| 125 | +Scales shape: K × (N / 64) float16 |
| 126 | +Biases shape: K × (N / 64) float16 |
| 127 | +``` |
| 128 | + |
| 129 | +Example packing for 8 consecutive 4-bit values `[a, b, c, d, e, f, g, h]`: |
| 130 | + |
| 131 | +```plain |
| 132 | +uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) | |
| 133 | + (d << 12) | (c << 8) | (b << 4) | a |
| 134 | +
|
| 135 | +Unpacking: |
| 136 | + a = (uint32_value >> 0) & 0xF |
| 137 | + b = (uint32_value >> 4) & 0xF |
| 138 | + c = (uint32_value >> 8) & 0xF |
| 139 | + ... |
| 140 | + h = (uint32_value >> 28) & 0xF |
| 141 | +``` |
| 142 | + |
| 143 | +## Quantized Matrix Multiplication |
| 144 | + |
| 145 | +### Mathematical Formulation |
| 146 | + |
| 147 | +For standard matrix multiplication $C = AB^T$ where: |
| 148 | + |
| 149 | +- $A$: shape $(M, N)$, float16/bfloat16 (activations) |
| 150 | +- $B$: shape $(K, N)$, **quantized** to int4 (weights) |
| 151 | +- $C$: shape $(M, K)$, float16/bfloat16 (output) |
| 152 | + |
| 153 | +Each element $C[i, k]$ is computed as: |
| 154 | + |
| 155 | +$$ |
| 156 | +C[i, k] = \sum_{j=0}^{N-1} A[i, j] \times B[k, j] |
| 157 | +$$ |
| 158 | + |
| 159 | +With quantization, $B[k, j]$ is represented as: |
| 160 | + |
| 161 | +$$ |
| 162 | +B[k, j] = B_{\text{quantized}}[k, j] \times \text{scale}[k, g] + \text{bias}[k, g] |
| 163 | +$$ |
| 164 | + |
| 165 | +where $g = \lfloor j / G \rfloor$ is the group index. |
| 166 | + |
| 167 | +Substituting: |
| 168 | + |
| 169 | +$$ |
| 170 | +C[i, k] = \sum_{g=0}^{N/G-1} \sum_{j'=0}^{G-1} A[i, g \times G + j'] \times (B_{\text{quantized}}[k, g \times G + j'] \times \text{scale}[k, g] + \text{bias}[k, g]) |
| 171 | +$$ |
| 172 | + |
| 173 | +Rearranging: |
| 174 | + |
| 175 | +$$ |
| 176 | +C[i, k] = \sum_{g=0}^{N/G-1} \left( \text{scale}[k, g] \sum_{j'=0}^{G-1} A[i, g \times G + j'] \times B_{\text{quantized}}[k, g \times G + j'] + \text{bias}[k, g] \sum_{j'=0}^{G-1} A[i, g \times G + j'] \right) |
| 177 | +$$ |
| 178 | + |
| 179 | +This shows we can factor out the scale and bias per group, reducing the number of floating-point operations. |
| 180 | + |
| 181 | +### Computation Flow |
| 182 | + |
| 183 | +```plain |
| 184 | +Input: |
| 185 | + A: M × N (float16, activations) |
| 186 | + B_quantized: K × (N/8) (uint32, packed weights) |
| 187 | + scales: K × (N/64) (float16) |
| 188 | + biases: K × (N/64) (float16) |
| 189 | +
|
| 190 | +Output: |
| 191 | + C: M × K (float16) |
| 192 | +
|
| 193 | +For each output element C[i, k]: |
| 194 | + sum = 0 |
| 195 | + for each group g in 0..(N/64 - 1): |
| 196 | + scale = scales[k, g] |
| 197 | + bias = biases[k, g] |
| 198 | + |
| 199 | + # Process 64 values in the group (8 uint32 packs) |
| 200 | + for each pack p in 0..7: |
| 201 | + packed_value = B_quantized[k, g*8 + p] |
| 202 | + |
| 203 | + # Unpack 8 × 4-bit values |
| 204 | + for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]: |
| 205 | + quantized = (packed_value >> bit_offset) & 0xF |
| 206 | + b_value = quantized * scale + bias |
| 207 | + a_value = A[i, g*64 + p*8 + bit_offset/4] |
| 208 | + sum += a_value * b_value |
| 209 | + |
| 210 | + C[i, k] = sum |
| 211 | +``` |
| 212 | + |
| 213 | +## Task 1: Implement QuantizedWeights |
| 214 | + |
| 215 | +``` |
| 216 | +src/tiny_llm/quantize.py |
| 217 | +``` |
| 218 | + |
| 219 | +First, familiarize yourself with the `QuantizedWeights` class, which stores quantized weight information: |
| 220 | + |
| 221 | +| Field | Shape | Description | |
| 222 | +|-------|-------|-------------| |
| 223 | +| `weight` | $(K, N/8)$ uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape $(K, N)$, and after packing, it becomes $(K, N/8)$. | |
| 224 | +| `scales` | $(K, N/G)$ float16 | Per-group scale factors for dequantization. Each group of $G$ consecutive values shares one scale. Recall: $\text{scale} = (v_{max} - v_{min}) / 15$ | |
| 225 | +| `biases` | $(K, N/G)$ float16 | Per-group bias (offset) for dequantization. Recall: $\text{bias} = v_{min}$ | |
| 226 | +| `group_size` | int | Number of consecutive values that share the same scale/bias (typically 64) | |
| 227 | +| `bits` | int | Quantization bit width (typically 4, meaning values are in range $[0, 15]$) | |
| 228 | + |
| 229 | +The `from_mlx_layer` static method extracts these fields from MLX's quantized linear layers when loading the model. |
| 230 | + |
| 231 | +Next, implement the `quantized_linear` function, which is a wrapper around `quantized_matmul` that mimics the standard `linear` function interface. And we'll implement `quantized_matmul` in the next task. |
| 232 | + |
| 233 | +## Task 2: Implement `quantized_matmul` (CPU version) |
| 234 | + |
| 235 | +In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing `axpby` example in the codebase — read through `axpby.h`, `axpby.cpp`, and the corresponding binding in `bindings.cpp` first as your reference. |
| 236 | + |
| 237 | +``` |
| 238 | +src/extensions/src/tiny_llm_ext.h |
| 239 | +src/extensions/bindings.cpp |
| 240 | +src/extensions/src/quantized_matmul.cpp |
| 241 | +src/extensions/CMakeLists.txt |
| 242 | +``` |
| 243 | + |
| 244 | +You need to touch three files, all within the `tiny_llm_ext` namespace: |
| 245 | + |
| 246 | +- **`tiny_llm_ext.h`** — Declare the `quantized_matmul(...)` function signature and define a `QuantizedMatmul` primitive class (inheriting `mx::Primitive`). Store `group_size` and `bits` as private members. |
| 247 | +- **`bindings.cpp`** — Add an `m.def(...)` call to expose the function to Python. |
| 248 | +- **`quantized_matmul.cpp`** — Implement the `quantized_matmul(...)` function (validate inputs, compute output shape, return a lazy `mx::array`) and the `eval_cpu` method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel). |
| 249 | + |
| 250 | +The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, accumulate in `float` (fp32) to avoid precision loss, and cast the result back to `float16` when writing to the output. |
| 251 | + |
| 252 | +Don't forget to add `src/quantized_matmul.cpp` to `target_sources` in `CMakeLists.txt`. |
| 253 | + |
| 254 | +You can test your implementation by running: |
| 255 | + |
| 256 | +```bash |
| 257 | +pdm run build-ext |
| 258 | +pdm run test --week 2 --day 2 -- -k task_2 |
| 259 | +``` |
| 260 | + |
| 261 | +## Task 3: Implement Metal Kernel |
| 262 | + |
| 263 | +TBD... |
| 264 | + |
| 265 | +{{#include copyright.md}} |
| 266 | + |
0 commit comments