18
18
from llmcompressor .observers .base import Observer
19
19
from llmcompressor .pytorch .utils .helpers import tensor_sparsity
20
20
21
- torch ._dynamo .config .capture_scalar_outputs = True
22
- torch ._inductor .config .triton .tile_reductions = True
23
- torch .set_float32_matmul_precision ("high" )
24
-
25
21
GPTQ_PRECISION = torch .float32
26
22
27
23
__all__ = ["make_empty_hessian" , "accumulate_hessian" , "quantize_weight" ]
@@ -296,6 +292,7 @@ def _process_block(
296
292
quant_max : int ,
297
293
sym : bool ,
298
294
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
295
+ """Process a single block of weight columns using with torch.compile support."""
299
296
count = W1 .shape [1 ]
300
297
Q1 = torch .zeros_like (W1 )
301
298
Err1 = torch .zeros_like (W1 )
@@ -349,11 +346,13 @@ def _quantize_core(
349
346
num_rows : int ,
350
347
num_columns : int ,
351
348
) -> Tuple [torch .Tensor , torch .Tensor ]:
349
+ """Core GPTQ quantization loop processing weights in blocks."""
352
350
losses = torch .zeros (num_rows , device = W .device , dtype = W .dtype )
353
351
354
352
for i1 in range (0 , num_columns , blocksize ):
355
353
i2 = min (i1 + blocksize , num_columns )
356
354
355
+ # Extract current block and corresponding Hessian/quantization params
357
356
W1 = W [:, i1 :i2 ].clone ()
358
357
Hinv1 = Hinv [i1 :i2 , i1 :i2 ].contiguous ()
359
358
scale_slice = scale_map [:, i1 :i2 ]
@@ -362,13 +361,15 @@ def _quantize_core(
362
361
if W_nz_mask is not None :
363
362
mask_slice = W_nz_mask [:, i1 :i2 ]
364
363
364
+ # Quantize the current block
365
365
Q1 , Err1 , losses1 = _process_block (
366
366
W1 , Hinv1 , scale_slice , zero_slice , mask_slice , quant_min , quant_max , sym
367
367
)
368
368
369
369
W [:, i1 :i2 ] = Q1
370
370
losses += losses1 .sum (dim = 1 ) / 2
371
371
372
+ # Propagate block error to remaining unprocessed columns
372
373
w_err = Err1 @ Hinv [i1 :i2 , i2 :]
373
374
if W_nz_mask is not None :
374
375
mask_rest = W_nz_mask [:, i2 :]
0 commit comments