44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import sys
78import argparse
89import copy
910import json
11+ import torch
12+ from functools import partial
13+
14+ from lm_eval .evaluator import simple_evaluate
1015
1116from typing import List , Optional , Tuple
1217
2631
2732from pytorch_tokenizers import get_tokenizer
2833
34+ from executorch .examples .qualcomm .oss_scripts .llama .llama import calibrate
35+
36+ from executorch .examples .qualcomm .utils import make_quantizer
37+
38+ from executorch .examples .models .llama .source_transformation .quantize import (
39+ get_quant_embedding_transform ,
40+ )
41+
42+ from torchao .quantization .pt2e import MinMaxObserver
43+ from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
44+
45+
46+ from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
47+ from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
48+ from executorch .backends .qualcomm .quantizer .custom_annotation import (
49+ annotate_linear_16a8w_in_affine_layer ,
50+ annotate_matmul_16a8w ,
51+ )
52+
53+
54+ import logging
55+ sys .setrecursionlimit (4096 )
56+ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
57+ logging .basicConfig (level = logging .INFO , format = FORMAT )
58+ logging .getLogger ().setLevel (logging .INFO )
59+
2960
3061class WrappedLlamaModel (nn .Module ):
31- def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = "cuda" ):
62+ def __init__ (self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda" ):
3263 super (WrappedLlamaModel , self ).__init__ ()
3364 self .model = model
3465 self .max_seq_len = max_seq_len
3566 self .use_kv_cache = use_kv_cache
3667 self .device = device
68+ self .atten_mask = atten_mask
3769
3870 def forward (
3971 self ,
4072 tokens : torch .Tensor ,
41- input_pos : Optional [torch .Tensor ] = None ,
4273 * args ,
4374 ) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
4475 # Pad input if necessary, since LlamaModel requires static shape
4576 if tokens .shape [1 ] != self .max_seq_len :
4677 tokens = torch .nn .functional .pad (
47- tokens , (self .max_seq_len - tokens .shape [1 ], 0 )
78+ tokens , (0 , self .max_seq_len - tokens .shape [1 ])
4879 )
49- atten_mask = (
50- self .model .get_example_inputs (self .use_kv_cache )[1 ]
51- .to (device = self .device )
52- .to (dtype = torch .bfloat16 )
53- )
54- return self .model .forward (tokens , atten_mask , input_pos , * args )
80+ return self .model .forward (tokens , self .atten_mask )
5581
5682
5783def gen_eval_wrapper (model_name , args ):
@@ -119,14 +145,73 @@ def permute(w, heads):
119145 layer .feed_forward .prepare_feedfoward_conv ()
120146
121147 model .to (dtype = torch .bfloat16 )
122- model .to (args .device )
148+ model .to (device = args .device )
149+
150+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
151+ tokens = tokens .to (device = args .device )
152+ atten_mask = atten_mask .to (device = args .device )
153+ atten_mask = atten_mask .to (dtype = torch .bfloat16 )
154+ inputs = (tokens , atten_mask )
155+
156+ if args .embedding_quantize :
157+ model = get_quant_embedding_transform (
158+ embedding_quantize = args .embedding_quantize
159+ )(model )
160+
161+ model = convert_linear_to_conv2d (model )
123162
124- wrapped_model = WrappedLlamaModel (
125- model , args .use_kv_cache , args .max_seq_length , args .device
163+ if args .ptq :
164+ quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
165+
166+ custom_annotations = (annotate_matmul_16a8w ,)
167+ if args .llama_model == "stories110m" :
168+ custom_annotations = custom_annotations + (
169+ annotate_linear_16a8w_in_affine_layer ,
170+ )
171+ quantizer = make_quantizer (
172+ quant_dtype = quant_dtype ,
173+ per_channel_conv = True ,
174+ per_channel_linear = True ,
175+ act_observer = MinMaxObserver ,
176+ )
177+ quantizer .add_custom_quant_annotations (custom_annotations )
178+
179+ model .has_quant_io = True
180+
181+ with torch .no_grad ():
182+ model = torch .export .export (
183+ model , inputs , strict = True
184+ ).module ()
185+ if quant_dtype == QuantDtype .use_16a4w_block :
186+ conv_nodes = [
187+ n for n in model .graph .nodes if "conv" in n .name
188+ ]
189+ block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
190+ quantizer .set_block_size_map (block_size_map )
191+
192+ model = prepare_pt2e (model , quantizer )
193+
194+ logging .info ("Quantizing the model..." )
195+
196+ calibrate (
197+ inputs ,
198+ 'Once upon a time' ,
199+ model ,
200+ tokenizer = tokenizer ,
201+ ar_len = args .prefill_ar_len ,
202+ max_seq_len = args .max_seq_len ,
203+ kv_updater = None ,
204+ use_i64_token = use_i64_token ,
205+ )
206+
207+ model = convert_pt2e (model )
208+
209+ model = WrappedLlamaModel (
210+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
126211 )
127212
128213 return GraphModuleEvalWrapper (
129- model = wrapped_model ,
214+ model = model ,
130215 tokenizer = tokenizer ,
131216 max_seq_length = args .calibration_seq_length ,
132217 use_kv_cache = args .use_kv_cache ,
@@ -167,6 +252,7 @@ def main() -> None:
167252 modelname = "llama2"
168253 parser = build_args_parser ()
169254 args = parser .parse_args ()
255+ args .llama_model = "llama3_2"
170256 # Overrides this arg, because evaluation requires full logits.
171257 args .generate_full_logits = True
172258
@@ -177,7 +263,15 @@ def main() -> None:
177263 args .use_kv_cache = False
178264 args .prefill_ar_len = args .max_seq_length
179265
266+ # To do fewer samples for faster evaluation
267+ args .limit = 0.1
268+ # args.samples = {'wikitext': list(range(1))}
269+
180270 args .device = "cuda" if torch .cuda .is_available () else "cpu"
271+ torch .set_default_device (args .device )
272+
273+ args .ptq = '8a8w'
274+
181275
182276 eval_llama (modelname , args )
183277
0 commit comments