@@ -122,6 +122,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
122122 qnn_config : Optional [str ] = None ,
123123 config : Optional [AutoConfig ] = None ,
124124 img_size : Optional [int ] = None ,
125+ torch_dtype : Optional [int ] = torch .float32 ,
125126):
126127 """
127128 Unified function to test PyTorch model, PyTorch KV model, ONNX model, and Cloud AI 100 model.
@@ -280,22 +281,21 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
280281 # )
281282
282283 streamer = TextStreamer (processor .tokenizer )
283- LOAD_DTYPE = torch .float16
284284
285285 # ========== Export and Compile Model ==========
286286 if is_intern_model or is_molmo_model :
287287 qeff_model = QEFFAutoModelForCausalLM .from_pretrained (
288288 model_name ,
289289 kv_offload = kv_offload ,
290290 config = config ,
291- torch_dtype = LOAD_DTYPE ,
291+ torch_dtype = torch_dtype ,
292292 )
293293 else :
294294 qeff_model = QEFFAutoModelForImageTextToText .from_pretrained (
295295 model_name ,
296296 kv_offload = kv_offload ,
297297 config = config ,
298- torch_dtype = LOAD_DTYPE ,
298+ torch_dtype = torch_dtype ,
299299 )
300300
301301 qeff_model .export ()
@@ -376,6 +376,50 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_offload
376376 )
377377
378378
379+ ### Custom dtype Test ###
380+
381+
382+ @pytest .mark .on_qaic
383+ @pytest .mark .multimodal
384+ @pytest .mark .parametrize ("model_name" , test_mm_models )
385+ @pytest .mark .parametrize ("kv_offload" , [True , False ])
386+ @pytest .mark .parametrize ("torch_dtype" , [torch .float16 ])
387+ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype (model_name , kv_offload , torch_dtype ):
388+ """
389+ Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching.
390+ ``Mandatory`` Args:
391+ :model_name (str): Hugging Face Model Card name, Example: ``gpt2``
392+ """
393+ if model_name in [
394+ "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
395+ "allenai/Molmo-7B-D-0924" ,
396+ "meta-llama/Llama-3.2-11B-Vision-Instruct" ,
397+ ]:
398+ pytest .skip ("Test skipped for this model due to some issues." )
399+ if (
400+ model_name in ["OpenGVLab/InternVL2_5-1B" , "OpenGVLab/InternVL3_5-1B" , "Qwen/Qwen2.5-VL-3B-Instruct" ]
401+ and not kv_offload
402+ ):
403+ pytest .skip ("These models require kv_offload=True for testing." )
404+ # Get img_size for standard models, None for InternVL and Molmo
405+ img_size = model_config_dict [model_name ].get ("img_size" )
406+
407+ # TODO: Add custom dtype support in ORT and Pytorch_KV APIs
408+ check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100 (
409+ model_name = model_name ,
410+ prompt_len = model_config_dict [model_name ]["prompt_len" ],
411+ ctx_len = model_config_dict [model_name ]["ctx_len" ],
412+ max_gen_len = NEW_GENERATION_TOKENS ,
413+ img_size = img_size ,
414+ img_url = model_config_dict [model_name ]["img_url" ],
415+ query = model_config_dict [model_name ]["text_prompt" ],
416+ n_layer = model_config_dict [model_name ]["num_layers" ],
417+ batch_size = model_config_dict [model_name ]["batch_size" ],
418+ kv_offload = kv_offload ,
419+ torch_dtype = torch_dtype ,
420+ )
421+
422+
379423### QNN Tests ###
380424
381425
0 commit comments