Skip to content

Commit a81f486

Browse files
committed
Updatd the test file
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
1 parent ef0d944 commit a81f486

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

examples/image_text_to_text/models/molmo/molmo_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# For faster execution user can run on 2 layers, This is only for testing purpose
2020
# config.num_hidden_layers = 2
2121

22-
2322
# load the model
2423
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, kv_offload=True, trust_remote_code=True, config=config)
2524
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
122122
qnn_config: Optional[str] = None,
123123
config: Optional[AutoConfig] = None,
124124
img_size: Optional[int] = None,
125+
torch_dtype: Optional[int] = torch.float32,
125126
):
126127
"""
127128
Unified function to test PyTorch model, PyTorch KV model, ONNX model, and Cloud AI 100 model.
@@ -280,22 +281,21 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
280281
# )
281282

282283
streamer = TextStreamer(processor.tokenizer)
283-
LOAD_DTYPE = torch.float16
284284

285285
# ========== Export and Compile Model ==========
286286
if is_intern_model or is_molmo_model:
287287
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
288288
model_name,
289289
kv_offload=kv_offload,
290290
config=config,
291-
torch_dtype=LOAD_DTYPE,
291+
torch_dtype=torch_dtype,
292292
)
293293
else:
294294
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
295295
model_name,
296296
kv_offload=kv_offload,
297297
config=config,
298-
torch_dtype=LOAD_DTYPE,
298+
torch_dtype=torch_dtype,
299299
)
300300

301301
qeff_model.export()
@@ -376,6 +376,50 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_offload
376376
)
377377

378378

379+
### Custom dtype Test ###
380+
381+
382+
@pytest.mark.on_qaic
383+
@pytest.mark.multimodal
384+
@pytest.mark.parametrize("model_name", test_mm_models)
385+
@pytest.mark.parametrize("kv_offload", [True, False])
386+
@pytest.mark.parametrize("torch_dtype", [torch.float16])
387+
def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype(model_name, kv_offload, torch_dtype):
388+
"""
389+
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching.
390+
``Mandatory`` Args:
391+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
392+
"""
393+
if model_name in [
394+
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
395+
"allenai/Molmo-7B-D-0924",
396+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
397+
]:
398+
pytest.skip("Test skipped for this model due to some issues.")
399+
if (
400+
model_name in ["OpenGVLab/InternVL2_5-1B", "OpenGVLab/InternVL3_5-1B", "Qwen/Qwen2.5-VL-3B-Instruct"]
401+
and not kv_offload
402+
):
403+
pytest.skip("These models require kv_offload=True for testing.")
404+
# Get img_size for standard models, None for InternVL and Molmo
405+
img_size = model_config_dict[model_name].get("img_size")
406+
407+
# TODO: Add custom dtype support in ORT and Pytorch_KV APIs
408+
check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
409+
model_name=model_name,
410+
prompt_len=model_config_dict[model_name]["prompt_len"],
411+
ctx_len=model_config_dict[model_name]["ctx_len"],
412+
max_gen_len=NEW_GENERATION_TOKENS,
413+
img_size=img_size,
414+
img_url=model_config_dict[model_name]["img_url"],
415+
query=model_config_dict[model_name]["text_prompt"],
416+
n_layer=model_config_dict[model_name]["num_layers"],
417+
batch_size=model_config_dict[model_name]["batch_size"],
418+
kv_offload=kv_offload,
419+
torch_dtype=torch_dtype,
420+
)
421+
422+
379423
### QNN Tests ###
380424

381425

0 commit comments

Comments
 (0)