88import copy
99import json
1010
11- from typing import List , Optional , Tuple
11+ import logging
12+ import sys
13+
14+ from typing import List , Tuple
1215
1316import torch
1417import torch .nn as nn
18+ from executorch .backends .qualcomm .quantizer .custom_annotation import (
19+ annotate_linear_16a8w_in_affine_layer ,
20+ annotate_matmul_16a8w ,
21+ )
22+
23+ from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
24+ from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
1525
1626from executorch .examples .models .llama .eval_llama_lib import (
1727 build_args_parser ,
1828 GraphModuleEvalWrapper ,
1929)
2030
31+ from executorch .examples .models .llama .source_transformation .quantize import (
32+ get_quant_embedding_transform ,
33+ )
34+
35+ from executorch .examples .qualcomm .oss_scripts .llama .llama import calibrate
36+
2137from executorch .examples .qualcomm .oss_scripts .llama .model .static_llama import (
2238 LlamaModel ,
2339 ModelArgs ,
2440)
41+
42+ from executorch .examples .qualcomm .utils import make_quantizer
43+
2544from lm_eval .evaluator import simple_evaluate
2645
2746from pytorch_tokenizers import get_tokenizer
2847
48+ from torchao .quantization .pt2e import MinMaxObserver
49+ from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
50+
51+ sys .setrecursionlimit (4096 )
52+ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
53+ logging .basicConfig (level = logging .INFO , format = FORMAT )
54+ logging .getLogger ().setLevel (logging .INFO )
55+
2956
3057class WrappedLlamaModel (nn .Module ):
31- def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = "cuda" ):
58+ def __init__ (
59+ self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda"
60+ ):
3261 super (WrappedLlamaModel , self ).__init__ ()
3362 self .model = model
3463 self .max_seq_len = max_seq_len
3564 self .use_kv_cache = use_kv_cache
3665 self .device = device
66+ self .atten_mask = atten_mask
3767
3868 def forward (
3969 self ,
4070 tokens : torch .Tensor ,
41- input_pos : Optional [torch .Tensor ] = None ,
4271 * args ,
4372 ) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
4473 # Pad input if necessary, since LlamaModel requires static shape
4574 if tokens .shape [1 ] != self .max_seq_len :
4675 tokens = torch .nn .functional .pad (
47- tokens , (self .max_seq_len - tokens .shape [1 ], 0 )
76+ tokens , (0 , self .max_seq_len - tokens .shape [1 ])
4877 )
49- atten_mask = (
50- self .model .get_example_inputs (self .use_kv_cache )[1 ]
51- .to (device = self .device )
52- .to (dtype = torch .bfloat16 )
53- )
54- return self .model .forward (tokens , atten_mask , input_pos , * args )
78+ return self .model .forward (tokens , self .atten_mask )
5579
5680
5781def gen_eval_wrapper (model_name , args ):
@@ -119,14 +143,69 @@ def permute(w, heads):
119143 layer .feed_forward .prepare_feedfoward_conv ()
120144
121145 model .to (dtype = torch .bfloat16 )
122- model .to (args .device )
146+ model .to (device = args .device )
123147
124- wrapped_model = WrappedLlamaModel (
125- model , args .use_kv_cache , args .max_seq_length , args .device
148+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
149+ tokens = tokens .to (device = args .device )
150+ atten_mask = atten_mask .to (device = args .device )
151+ atten_mask = atten_mask .to (dtype = torch .bfloat16 )
152+ inputs = (tokens , atten_mask )
153+
154+ if args .embedding_quantize :
155+ model = get_quant_embedding_transform (
156+ embedding_quantize = args .embedding_quantize
157+ )(model )
158+
159+ model = convert_linear_to_conv2d (model )
160+
161+ if args .ptq :
162+ quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
163+
164+ custom_annotations = (annotate_matmul_16a8w ,)
165+ if args .llama_model == "stories110m" :
166+ custom_annotations = custom_annotations + (
167+ annotate_linear_16a8w_in_affine_layer ,
168+ )
169+ quantizer = make_quantizer (
170+ quant_dtype = quant_dtype ,
171+ per_channel_conv = True ,
172+ per_channel_linear = True ,
173+ act_observer = MinMaxObserver ,
174+ )
175+ quantizer .add_custom_quant_annotations (custom_annotations )
176+
177+ model .has_quant_io = True
178+
179+ with torch .no_grad ():
180+ model = torch .export .export (model , inputs , strict = True ).module ()
181+ if quant_dtype == QuantDtype .use_16a4w_block :
182+ conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
183+ block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
184+ quantizer .set_block_size_map (block_size_map )
185+
186+ model = prepare_pt2e (model , quantizer )
187+
188+ logging .info ("Quantizing the model..." )
189+
190+ calibrate (
191+ inputs ,
192+ "Once upon a time" ,
193+ model ,
194+ tokenizer = tokenizer ,
195+ ar_len = args .prefill_ar_len ,
196+ max_seq_len = args .max_seq_len ,
197+ kv_updater = None ,
198+ use_i64_token = use_i64_token ,
199+ )
200+
201+ model = convert_pt2e (model )
202+
203+ model = WrappedLlamaModel (
204+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
126205 )
127206
128207 return GraphModuleEvalWrapper (
129- model = wrapped_model ,
208+ model = model ,
130209 tokenizer = tokenizer ,
131210 max_seq_length = args .calibration_seq_length ,
132211 use_kv_cache = args .use_kv_cache ,
@@ -167,6 +246,7 @@ def main() -> None:
167246 modelname = "llama2"
168247 parser = build_args_parser ()
169248 args = parser .parse_args ()
249+ args .llama_model = "llama3_2"
170250 # Overrides this arg, because evaluation requires full logits.
171251 args .generate_full_logits = True
172252
@@ -177,7 +257,14 @@ def main() -> None:
177257 args .use_kv_cache = False
178258 args .prefill_ar_len = args .max_seq_length
179259
260+ # To do fewer samples for faster evaluation
261+ args .limit = 0.1
262+ # args.samples = {'wikitext': list(range(1))}
263+
180264 args .device = "cuda" if torch .cuda .is_available () else "cpu"
265+ torch .set_default_device (args .device )
266+
267+ args .ptq = "8a8w"
181268
182269 eval_llama (modelname , args )
183270
0 commit comments