Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9a276c3

Browse files
authored
[Hackability Refactor] Collapse quant_ops into calling sites (#1062)
* Remove unused quant code * Decompose qops to call sites * Fix import call in gguf_loader * Lint gguf imports
1 parent d0d1105 commit 9a276c3

File tree

3 files changed

+424
-548
lines changed

3 files changed

+424
-548
lines changed

build/gguf_loader.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77

88
import copy
99
import logging
10-
from typing import Any
10+
from typing import Any, Optional
1111

1212
import gguf
1313

1414
import torch
15+
import torch.nn.functional as F
1516

1617
from build.gguf_util import Q4_0, to_float
1718
from build.model import Model, ModelArgs, TransformerArgs
1819

1920
from gguf import GGUFValueType
20-
from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
2121
from quantization.quantize import pack_scales_and_zeros
2222

23+
from build.utils import find_multiple, get_precision
24+
25+
2326
logger: logging.Logger = logging.getLogger(__name__)
2427

2528

@@ -97,6 +100,143 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
97100
return metadata
98101

99102

103+
#########################################################################
104+
# Note: int4 quantization is migrated to torchao for general quantization.
105+
# TODO: GGUF workflow needs migration to torchao
106+
#########################################################################
107+
108+
109+
def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsize):
110+
origin_input_size = input.size()
111+
input = input.reshape(-1, origin_input_size[-1])
112+
113+
if "cuda" in str(input.device):
114+
c = torch.ops.aten._weight_int4pack_mm(
115+
input.to(torch.bfloat16),
116+
weight_int4pack,
117+
groupsize,
118+
scales_and_zeros.to(torch.bfloat16),
119+
).to(
120+
input.dtype
121+
) # cast back to input.dtype
122+
else:
123+
c = torch.ops.aten._weight_int4pack_mm(
124+
input,
125+
weight_int4pack,
126+
groupsize,
127+
scales_and_zeros,
128+
)
129+
new_shape = origin_input_size[:-1] + (out_features,)
130+
c = c.reshape(new_shape)
131+
return c
132+
133+
134+
class WeightOnlyInt4Linear(torch.nn.Module):
135+
__constants__ = ["in_features", "out_features"]
136+
in_features: int
137+
out_features: int
138+
weight: torch.Tensor
139+
scales_and_zeros: torch.Tensor
140+
141+
def __init__(
142+
self,
143+
in_features: int,
144+
out_features: int,
145+
bias=True,
146+
device=None,
147+
dtype=None,
148+
*,
149+
groupsize: int = 128,
150+
inner_k_tiles: int = 8,
151+
weight: Optional[torch.Tensor] = None,
152+
scales_and_zeros: Optional[torch.Tensor] = None,
153+
) -> None:
154+
super().__init__()
155+
self.padding = not self._check_k(
156+
k=in_features,
157+
groupsize=groupsize,
158+
inner_k_tiles=inner_k_tiles,
159+
)
160+
if self.padding:
161+
self.origin_in_features = in_features
162+
in_features = find_multiple(in_features, 1024)
163+
164+
self.in_features = in_features
165+
self.out_features = out_features
166+
assert not bias, "require bias=False"
167+
self.groupsize = groupsize
168+
self.inner_k_tiles = inner_k_tiles
169+
170+
assert out_features % 8 == 0, "require out_features % 8 == 0"
171+
assert (
172+
in_features % (inner_k_tiles * 16) == 0
173+
), "require in_features % (innerKTiles * 16) == 0"
174+
assert (weight is None) == bool(
175+
scales_and_zeros is None
176+
), "must specify both weights and scales_and_zeros, or neither"
177+
178+
if weight is None:
179+
weight = torch.empty(
180+
(
181+
out_features // 8,
182+
in_features // (inner_k_tiles * 16),
183+
32,
184+
inner_k_tiles // 2,
185+
),
186+
dtype=torch.int32,
187+
device=device,
188+
)
189+
scales_and_zeros = torch.empty(
190+
(in_features // groupsize, out_features, 2),
191+
dtype=get_precision(),
192+
device=device,
193+
)
194+
195+
self.register_buffer(
196+
"weight",
197+
weight,
198+
)
199+
self.register_buffer(
200+
"scales_and_zeros",
201+
scales_and_zeros,
202+
)
203+
204+
def forward(self, input: torch.Tensor) -> torch.Tensor:
205+
if self.padding:
206+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
207+
return linear_int4(
208+
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
209+
)
210+
211+
@classmethod
212+
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
213+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
214+
215+
@classmethod
216+
def _prepare_weight_and_scales_and_zeros(
217+
cls, weight_bf16, groupsize, inner_k_tiles
218+
):
219+
from quantization.quantize import group_quantize_tensor
220+
221+
weight_int32, scales_and_zeros = group_quantize_tensor(
222+
weight_bf16, n_bit=4, groupsize=groupsize
223+
)
224+
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to(
225+
torch.uint8
226+
)
227+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
228+
weight_uint8, inner_k_tiles
229+
)
230+
return weight_int4pack, scales_and_zeros
231+
232+
@classmethod
233+
def _calc_padded_size(cls, *, k, groupsize=1, innner_k_tiles=1):
234+
return find_multiple(k, 1024)
235+
236+
237+
#########################################################################
238+
239+
100240
def load_model(gguf_file: str) -> torch.nn.Module:
101241
"""
102242
Parses the GGUF file and returns an nn.Module on meta device.

0 commit comments

Comments
 (0)