|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import sys |
7 | 8 | import argparse |
8 | 9 | import copy |
9 | 10 | import json |
| 11 | +import torch |
| 12 | +from functools import partial |
| 13 | + |
| 14 | +from lm_eval.evaluator import simple_evaluate |
10 | 15 |
|
11 | 16 | from typing import List, Optional, Tuple |
12 | 17 |
|
|
26 | 31 |
|
27 | 32 | from pytorch_tokenizers import get_tokenizer |
28 | 33 |
|
| 34 | +from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate |
| 35 | + |
| 36 | +from executorch.examples.qualcomm.utils import make_quantizer |
| 37 | + |
| 38 | +from executorch.examples.models.llama.source_transformation.quantize import ( |
| 39 | + get_quant_embedding_transform, |
| 40 | +) |
| 41 | + |
| 42 | +from torchao.quantization.pt2e import MinMaxObserver |
| 43 | +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 44 | + |
| 45 | + |
| 46 | +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype |
| 47 | +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
| 48 | + |
| 49 | + |
| 50 | +import logging |
| 51 | +sys.setrecursionlimit(4096) |
| 52 | +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| 53 | +logging.basicConfig(level=logging.INFO, format=FORMAT) |
| 54 | +logging.getLogger().setLevel(logging.INFO) |
| 55 | + |
29 | 56 |
|
30 | 57 | class WrappedLlamaModel(nn.Module): |
31 | | - def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"): |
| 58 | + def __init__(self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"): |
32 | 59 | super(WrappedLlamaModel, self).__init__() |
33 | 60 | self.model = model |
34 | 61 | self.max_seq_len = max_seq_len |
35 | 62 | self.use_kv_cache = use_kv_cache |
36 | 63 | self.device = device |
| 64 | + self.atten_mask = atten_mask |
37 | 65 |
|
38 | 66 | def forward( |
39 | 67 | self, |
40 | 68 | tokens: torch.Tensor, |
41 | | - input_pos: Optional[torch.Tensor] = None, |
42 | 69 | *args, |
43 | 70 | ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: |
44 | 71 | # Pad input if necessary, since LlamaModel requires static shape |
45 | 72 | if tokens.shape[1] != self.max_seq_len: |
46 | 73 | tokens = torch.nn.functional.pad( |
47 | | - tokens, (self.max_seq_len - tokens.shape[1], 0) |
| 74 | + tokens, (0, self.max_seq_len - tokens.shape[1]) |
48 | 75 | ) |
49 | | - atten_mask = ( |
50 | | - self.model.get_example_inputs(self.use_kv_cache)[1] |
51 | | - .to(device=self.device) |
52 | | - .to(dtype=torch.bfloat16) |
53 | | - ) |
54 | | - return self.model.forward(tokens, atten_mask, input_pos, *args) |
| 76 | + return self.model.forward(tokens, self.atten_mask) |
55 | 77 |
|
56 | 78 |
|
57 | 79 | def gen_eval_wrapper(model_name, args): |
@@ -119,14 +141,69 @@ def permute(w, heads): |
119 | 141 | layer.feed_forward.prepare_feedfoward_conv() |
120 | 142 |
|
121 | 143 | model.to(dtype=torch.bfloat16) |
122 | | - model.to(args.device) |
| 144 | + model.to(device=args.device) |
| 145 | + |
| 146 | + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) |
| 147 | + tokens = tokens.to(device=args.device) |
| 148 | + atten_mask = atten_mask.to(device=args.device) |
| 149 | + atten_mask = atten_mask.to(dtype=torch.bfloat16) |
| 150 | + inputs = (tokens, atten_mask) |
| 151 | + |
| 152 | + if args.embedding_quantize: |
| 153 | + model = get_quant_embedding_transform( |
| 154 | + embedding_quantize=args.embedding_quantize |
| 155 | + )(model) |
| 156 | + |
| 157 | + model = convert_linear_to_conv2d(model) |
| 158 | + |
| 159 | + if args.ptq: |
| 160 | + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") |
| 161 | + |
| 162 | + custom_annotations = () |
| 163 | + quantizer = make_quantizer( |
| 164 | + quant_dtype=quant_dtype, |
| 165 | + per_channel_conv=True, |
| 166 | + per_channel_linear=True, |
| 167 | + act_observer=MinMaxObserver, |
| 168 | + ) |
| 169 | + quantizer.add_custom_quant_annotations(custom_annotations) |
| 170 | + |
| 171 | + model.has_quant_io = True |
| 172 | + |
| 173 | + with torch.no_grad(): |
| 174 | + model = torch.export.export( |
| 175 | + model, inputs, strict=True |
| 176 | + ).module() |
| 177 | + if quant_dtype == QuantDtype.use_16a4w_block: |
| 178 | + conv_nodes = [ |
| 179 | + n for n in model.graph.nodes if "conv" in n.name |
| 180 | + ] |
| 181 | + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} |
| 182 | + quantizer.set_block_size_map(block_size_map) |
| 183 | + |
| 184 | + model = prepare_pt2e(model, quantizer) |
| 185 | + |
| 186 | + logging.info("Quantizing the model...") |
| 187 | + |
| 188 | + calibrate( |
| 189 | + inputs, |
| 190 | + 'Once upon a time', |
| 191 | + model, |
| 192 | + tokenizer=tokenizer, |
| 193 | + ar_len=args.prefill_ar_len, |
| 194 | + max_seq_len=args.max_seq_len, |
| 195 | + kv_updater=None, |
| 196 | + use_i64_token=use_i64_token, |
| 197 | + ) |
123 | 198 |
|
124 | | - wrapped_model = WrappedLlamaModel( |
125 | | - model, args.use_kv_cache, args.max_seq_length, args.device |
| 199 | + model = convert_pt2e(model) |
| 200 | + |
| 201 | + model = WrappedLlamaModel( |
| 202 | + model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device |
126 | 203 | ) |
127 | 204 |
|
128 | 205 | return GraphModuleEvalWrapper( |
129 | | - model=wrapped_model, |
| 206 | + model=model, |
130 | 207 | tokenizer=tokenizer, |
131 | 208 | max_seq_length=args.calibration_seq_length, |
132 | 209 | use_kv_cache=args.use_kv_cache, |
@@ -177,7 +254,15 @@ def main() -> None: |
177 | 254 | args.use_kv_cache = False |
178 | 255 | args.prefill_ar_len = args.max_seq_length |
179 | 256 |
|
| 257 | + # To do fewer samples for faster evaluation |
| 258 | + args.limit = 0.1 |
| 259 | + # args.samples = {'wikitext': list(range(1))} |
| 260 | + |
180 | 261 | args.device = "cuda" if torch.cuda.is_available() else "cpu" |
| 262 | + torch.set_default_device(args.device) |
| 263 | + |
| 264 | + args.ptq = '16a4w' |
| 265 | + |
181 | 266 |
|
182 | 267 | eval_llama(modelname, args) |
183 | 268 |
|
|
0 commit comments