1717)
1818
1919from executorch .examples .models .model_base import EagerModelBase
20+ from executorch .extension .llm .modules .attention import replace_mha_with_inference_mha
2021from torchtune .models .llama3_2_vision ._component_builders import llama3_2_vision_decoder
2122from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
2223
@@ -53,7 +54,7 @@ def __init__(self, **kwargs):
5354 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
5455 self .verbose = kwargs .get ("verbose" , False )
5556 self .args = kwargs .get ("args" , None )
56- self .dtype = None
57+ self .dtype = kwargs . get ( "dtype" , torch . float16 )
5758 self .use_checkpoint = False
5859
5960 ckpt_dir = get_default_model_resource_dir (__file__ )
@@ -72,7 +73,7 @@ def __init__(self, **kwargs):
7273 dtype = torch .bool ,
7374 )
7475 )
75- self .input_pos = torch .arange (self .max_seq_len )
76+ self .input_pos = torch .arange (self .max_seq_len , dtype = torch . int64 )
7677
7778 # Load checkpoint and params.
7879 device = "cpu"
@@ -107,6 +108,9 @@ def __init__(self, **kwargs):
107108 rope_base = params ["rope_theta" ],
108109 intermediate_dim = params ["intermediate_dim" ],
109110 )
111+
112+ # Source transformation for MultiHeadAttention
113+ self .model_ = replace_mha_with_inference_mha (self .model_ )
110114 # Save params for future use.
111115 for param_name , param_val in params .items ():
112116 setattr (self .model_ , param_name , param_val )
@@ -147,39 +151,46 @@ def __init__(self, **kwargs):
147151 self .model_ .setup_caches (
148152 batch_size = 1 ,
149153 dtype = self .dtype ,
154+ encoder_max_seq_len = self .encoder_max_seq_len ,
150155 decoder_max_seq_len = self .max_seq_len ,
151156 )
157+ # number of tokens for example input
158+ self .n_tokens = 34
159+ self .model_ .to (self .dtype )
152160
153161 def get_eager_model (self ) -> torch .nn .Module :
154- if self .dtype :
155- return self .model_ .to (self .dtype )
156- else :
157- return self .model_ .to (torch .float16 )
162+ return self .model_
158163
159164 def get_example_inputs (self ):
160- return (torch .ones (1 , 32 , dtype = torch .long ),)
165+ return (torch .ones (1 , self . n_tokens , dtype = torch .int64 ),)
161166
162167 def get_example_kwarg_inputs (self ):
163168 # For export we must use the prefill versions of the
164169 # causal mask and input_pos.
170+ # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
165171 if self .use_kv_cache :
166172 return {
167- "input_pos" : self .input_pos [None , :32 ],
168- "mask" : self .causal_mask [None , :32 ],
169- # "encoder_input": None,
170- # "encoder_mask": None,
173+ "input_pos" : self .input_pos [None , : self .n_tokens ],
174+ "mask" : self .causal_mask [None , : self .n_tokens ],
175+ "encoder_input" : torch .randn (
176+ 1 , self .encoder_max_seq_len , self .model_ .dim , dtype = self .dtype
177+ ),
178+ "encoder_mask" : torch .ones (
179+ [1 , self .n_tokens , self .encoder_max_seq_len ], dtype = torch .bool
180+ ),
171181 }
172182 else :
173183 return None
174184
175185 def get_dynamic_shapes (self ):
176186 batch_size = 1
177187 dim_seq_len = torch .export .Dim ("token_dim" , min = 1 , max = self .max_seq_len )
188+ # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
178189 if self .use_kv_cache :
179190 dynamic_shapes = {
180191 "tokens" : {0 : batch_size , 1 : dim_seq_len },
181- # "encoder_input": {0: 1, 1: dim_enc, 2: 4096} ,
182- # "encoder_mask": {0: 1, 1: dim , 2: dim_enc },
192+ "encoder_input" : None ,
193+ "encoder_mask" : {0 : 1 , 1 : dim_seq_len , 2 : None },
183194 "mask" : {0 : batch_size , 1 : dim_seq_len , 2 : None },
184195 "input_pos" : {0 : batch_size , 1 : dim_seq_len },
185196 }
0 commit comments