Skip to content

Commit c658e0f

Browse files
asmigoswquic-dhirajku
authored andcommitted
Comments Addressed
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
1 parent f7c88be commit c658e0f

File tree

5 files changed

+76
-43
lines changed

5 files changed

+76
-43
lines changed

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def forward(
536536
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
537537

538538
if self.config.torch_dtype == torch.float16:
539-
logger.warning("Accucary might drop with float16 as torch_dtype")
539+
logger.warning("Accuracy might drop with float16 as torch_dtype")
540540

541541
outputs = self.model(
542542
input_ids=input_ids,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from QEfficient.utils.logging_utils import logger
7777
from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs
7878

79-
DTYPE_TO_STRING_MAP = {
79+
CUSTOM_IO_DTYPE_MAP = {
8080
torch.float16: "float16",
8181
torch.bfloat16: "bfloat16",
8282
torch.float32: "float16", # Since compiler doesn't support fp32
@@ -463,7 +463,7 @@ def compile(
463463
compile_dir=compile_dir,
464464
compile_only=True,
465465
specializations=specializations,
466-
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
466+
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
467467
mxfp6_matmul=mxfp6_matmul,
468468
mdp_ts_num_devices=num_devices,
469469
aic_num_cores=num_cores,
@@ -804,7 +804,7 @@ def compile(
804804
compile_dir=compile_dir,
805805
compile_only=True,
806806
specializations=specializations,
807-
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
807+
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
808808
mxfp6_matmul=mxfp6_matmul,
809809
mdp_ts_num_devices=num_devices,
810810
aic_num_cores=num_cores,
@@ -1478,17 +1478,17 @@ def compile(
14781478

14791479
custom_io_vision = {}
14801480
needed_dtype = self.model.config.torch_dtype
1481-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
1481+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[needed_dtype]
14821482
molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo"
14831483
if molmo:
1484-
custom_io_vision["image_masks"] = DTYPE_TO_STRING_MAP[needed_dtype]
1485-
custom_io_vision["pixel_values"] = DTYPE_TO_STRING_MAP[needed_dtype]
1484+
custom_io_vision["image_masks"] = CUSTOM_IO_DTYPE_MAP[needed_dtype]
1485+
custom_io_vision["pixel_values"] = CUSTOM_IO_DTYPE_MAP[needed_dtype]
14861486

14871487
for output_name in output_names["vision"]:
14881488
if output_name.startswith("past_"):
14891489
custom_io_vision[output_name] = kv_cache_dtype
14901490
else:
1491-
custom_io_vision[output_name] = DTYPE_TO_STRING_MAP[needed_dtype]
1491+
custom_io_vision[output_name] = CUSTOM_IO_DTYPE_MAP[needed_dtype]
14921492

14931493
if vision_onnx_path:
14941494
self.vision_model.onnx_path = vision_onnx_path
@@ -1531,21 +1531,21 @@ def compile(
15311531
for output_name in output_names["lang"]:
15321532
if output_name.endswith("_RetainedState"):
15331533
custom_io_lang[output_name[: -len("_RetainedState")]] = (
1534-
DTYPE_TO_STRING_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
1534+
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
15351535
)
15361536

15371537
# outputs
15381538
for output_name in output_names["lang"]:
15391539
if output_name.endswith("_RetainedState"):
15401540
custom_io_lang[output_name] = (
1541-
DTYPE_TO_STRING_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
1541+
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
15421542
)
15431543
self.lang_model._compile(
15441544
compile_dir=compile_dir,
15451545
compile_only=True,
15461546
retained_state=True,
15471547
specializations=specializations["lang"],
1548-
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
1548+
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
15491549
mxfp6_matmul=mxfp6_matmul,
15501550
mdp_ts_num_devices=num_devices,
15511551
aic_num_cores=num_cores,
@@ -2160,19 +2160,19 @@ def compile(
21602160

21612161
custom_io = {}
21622162
needed_dtype = self.model.config.torch_dtype
2163-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
2163+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[needed_dtype]
21642164
# inputs
21652165
for input_name in output_names:
21662166
if input_name.endswith("_RetainedState"):
21672167
custom_io[input_name[: -len("_RetainedState")]] = (
2168-
DTYPE_TO_STRING_MAP[needed_dtype] if "pixel_values" in input_name else kv_cache_dtype
2168+
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "pixel_values" in input_name else kv_cache_dtype
21692169
)
21702170

21712171
# outputs
21722172
for output_name in output_names:
21732173
if output_name.endswith("_RetainedState"):
21742174
custom_io[output_name] = (
2175-
DTYPE_TO_STRING_MAP[needed_dtype] if "pixel_values" in output_name else kv_cache_dtype
2175+
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "pixel_values" in output_name else kv_cache_dtype
21762176
)
21772177

21782178
# TODO this hould be removed once the continous batching is supported for all the models.
@@ -2185,7 +2185,7 @@ def compile(
21852185
compile_only=True,
21862186
retained_state=True,
21872187
specializations=specializations,
2188-
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
2188+
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
21892189
mxfp6_matmul=mxfp6_matmul,
21902190
custom_io=custom_io,
21912191
mdp_ts_num_devices=num_devices,
@@ -3437,7 +3437,7 @@ def compile(
34373437
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
34383438
custom_io = {}
34393439
needed_dtype = self.model.config.torch_dtype
3440-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
3440+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[needed_dtype]
34413441

34423442
for suffix in ["", "_RetainedState"]:
34433443
for i in range(self.num_layers):
@@ -3449,7 +3449,7 @@ def compile(
34493449
compile_only=True,
34503450
retained_state=True,
34513451
specializations=specializations,
3452-
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
3452+
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
34533453
mxfp6_matmul=mxfp6_matmul,
34543454
custom_io=custom_io,
34553455
mdp_ts_num_devices=num_devices,
@@ -3795,7 +3795,7 @@ def compile(
37953795
output_names = self.model.get_output_names()
37963796

37973797
needed_dtype = self.model.config.torch_dtype
3798-
kv_cache_dtype = DTYPE_TO_STRING_MAP[needed_dtype]
3798+
kv_cache_dtype = CUSTOM_IO_DTYPE_MAP[needed_dtype]
37993799
custom_io = {}
38003800

38013801
custom_io["input_features"] = kv_cache_dtype
@@ -3816,7 +3816,7 @@ def compile(
38163816
compile_only=True,
38173817
retained_state=True,
38183818
specializations=specializations,
3819-
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
3819+
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
38203820
mxfp6_matmul=mxfp6_matmul,
38213821
mdp_ts_num_devices=num_devices,
38223822
aic_num_cores=num_cores,
@@ -4224,7 +4224,7 @@ def cloud_ai_100_feature_generate(
42244224
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
42254225
)
42264226
needed_dtype = self.model.config.torch_dtype
4227-
input_values = input_values.astype(DTYPE_TO_STRING_MAP[needed_dtype])
4227+
input_values = input_values.astype(CUSTOM_IO_DTYPE_MAP[needed_dtype])
42284228
inputs = dict(input_values=input_values)
42294229
outputs = self.qpc_session.run(inputs)
42304230

QEfficient/transformers/models/molmo/modeling_molmo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
1919
from QEfficient.utils import constants
2020
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
21+
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
2122

2223

2324
def _non_meta_init_device(config) -> torch.device:
@@ -54,7 +55,9 @@ def eager_attention_forward(
5455

5556
attn_weights = torch.matmul(q, k.transpose(2, 3)) * scale_factor
5657
if attention_mask is not None:
57-
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=k.dtype), attn_weights)
58+
attn_weights = torch.where(
59+
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=k.dtype), attn_weights
60+
)
5861

5962
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
6063
attn_output = torch.matmul(attn_weights, v)

tests/configs/image_text_model_configs.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,43 @@
204204
"num_layers": 1,
205205
"additional_params": {}
206206
}
207+
],
208+
"image_text_custom_dtype_models":[
209+
{
210+
"model_name": "OpenGVLab/InternVL2_5-1B",
211+
"model_type": "internvl_chat",
212+
"batch_size": 1,
213+
"prompt_len": 384,
214+
"ctx_len": 512,
215+
"img_size": null,
216+
"img_url": "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg",
217+
"text_prompt": "Please describe the image in detail.",
218+
"num_layers": 2,
219+
"additional_params": {}
220+
},
221+
{
222+
"model_name": "google/gemma-3-4b-it",
223+
"model_type": "gemma3",
224+
"batch_size": 1,
225+
"prompt_len": 128,
226+
"ctx_len": 3072,
227+
"img_size": 896,
228+
"img_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
229+
"text_prompt": "Can you describe the image in detail.",
230+
"num_layers": 6,
231+
"additional_params": {}
232+
},
233+
{
234+
"model_name": "llava-hf/llava-1.5-7b-hf",
235+
"model_type": "llava",
236+
"batch_size": 1,
237+
"prompt_len": 784,
238+
"ctx_len": 1024,
239+
"img_size": 336,
240+
"img_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
241+
"text_prompt": "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
242+
"num_layers": 1,
243+
"additional_params": {}
244+
}
207245
]
208246
}

tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@
3838
with open(CONFIG_PATH, "r") as f:
3939
config_data = json.load(f)
4040
multimodal_models = config_data["image_text_models"]
41+
custom_dtype_support_models = config_data["image_text_custom_dtype_models"]
4142
test_mm_models = [model_config["model_name"] for model_config in multimodal_models]
4243
model_config_dict = {model["model_name"]: model for model in multimodal_models}
44+
test_custom_dtype_support_models = [model_config["model_name"] for model_config in custom_dtype_support_models]
45+
custom_dtype_support_models_config_dict = {model["model_name"]: model for model in custom_dtype_support_models}
4346

4447

4548
def load_image_text_to_text_model(model_config):
@@ -122,7 +125,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
122125
qnn_config: Optional[str] = None,
123126
config: Optional[AutoConfig] = None,
124127
img_size: Optional[int] = None,
125-
torch_dtype: Optional[int] = torch.float32,
128+
torch_dtype: Optional[torch.dtype] = torch.float32,
126129
):
127130
"""
128131
Unified function to test PyTorch model, PyTorch KV model, ONNX model, and Cloud AI 100 model.
@@ -381,40 +384,29 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_offload
381384

382385
@pytest.mark.on_qaic
383386
@pytest.mark.multimodal
384-
@pytest.mark.parametrize("model_name", test_mm_models)
385-
@pytest.mark.parametrize("kv_offload", [True, False])
387+
@pytest.mark.parametrize("model_name", test_custom_dtype_support_models)
388+
@pytest.mark.parametrize("kv_offload", [True])
386389
@pytest.mark.parametrize("torch_dtype", [torch.float16])
387390
def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype(model_name, kv_offload, torch_dtype):
388391
"""
389392
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching.
390393
``Mandatory`` Args:
391394
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
392395
"""
393-
if model_name in [
394-
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
395-
"allenai/Molmo-7B-D-0924",
396-
"meta-llama/Llama-3.2-11B-Vision-Instruct",
397-
]:
398-
pytest.skip("Test skipped for this model due to some issues.")
399-
if (
400-
model_name in ["OpenGVLab/InternVL2_5-1B", "OpenGVLab/InternVL3_5-1B", "Qwen/Qwen2.5-VL-3B-Instruct"]
401-
and not kv_offload
402-
):
403-
pytest.skip("These models require kv_offload=True for testing.")
404-
# Get img_size for standard models, None for InternVL and Molmo
405-
img_size = model_config_dict[model_name].get("img_size")
396+
# Get img_size for standard models, None for InternVL
397+
img_size = custom_dtype_support_models_config_dict[model_name].get("img_size")
406398

407399
# TODO: Add custom dtype support in ORT and Pytorch_KV APIs
408400
check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
409401
model_name=model_name,
410-
prompt_len=model_config_dict[model_name]["prompt_len"],
411-
ctx_len=model_config_dict[model_name]["ctx_len"],
402+
prompt_len=custom_dtype_support_models_config_dict[model_name]["prompt_len"],
403+
ctx_len=custom_dtype_support_models_config_dict[model_name]["ctx_len"],
412404
max_gen_len=NEW_GENERATION_TOKENS,
413405
img_size=img_size,
414-
img_url=model_config_dict[model_name]["img_url"],
415-
query=model_config_dict[model_name]["text_prompt"],
416-
n_layer=model_config_dict[model_name]["num_layers"],
417-
batch_size=model_config_dict[model_name]["batch_size"],
406+
img_url=custom_dtype_support_models_config_dict[model_name]["img_url"],
407+
query=custom_dtype_support_models_config_dict[model_name]["text_prompt"],
408+
n_layer=custom_dtype_support_models_config_dict[model_name]["num_layers"],
409+
batch_size=custom_dtype_support_models_config_dict[model_name]["batch_size"],
418410
kv_offload=kv_offload,
419411
torch_dtype=torch_dtype,
420412
)

0 commit comments

Comments
 (0)