Skip to content

Commit a0db0d6

Browse files
committed
Revert "Comments Addressed"
This reverts commit c658e0f.
1 parent c658e0f commit a0db0d6

File tree

5 files changed

+43
-76
lines changed

5 files changed

+43
-76
lines changed

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def forward(
536536
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
537537

538538
if self.config.torch_dtype == torch.float16:
539-
logger.warning("Accuracy might drop with float16 as torch_dtype")
539+
logger.warning("Accucary might drop with float16 as torch_dtype")
540540

541541
outputs = self.model(
542542
input_ids=input_ids,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from QEfficient.utils.logging_utils import logger
7777
from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs
7878

79-
CUSTOM_IO_DTYPE_MAP = {
79+
DTYPE_TO_STRING_MAP = {
8080
torch.float16: "float16",
8181
torch.bfloat16: "bfloat16",
8282
torch.float32: "float16", # Since compiler doesn't support fp32
@@ -463,7 +463,7 @@ def compile(
463463
compile_dir=compile_dir,
464464
compile_only=True,
465465
specializations=specializations,
466-
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
466+
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
467467
mxfp6_matmul=mxfp6_matmul,
468468
mdp_ts_num_devices=num_devices,
469469
aic_num_cores=num_cores,
@@ -804,7 +804,7 @@ def compile(
804804
compile_dir=compile_dir,
805805
compile_only=True,
806806
specializations=specializations,
807-
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
807+
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
808808
mxfp6_matmul=mxfp6_matmul,
809809
mdp_ts_num_devices=num_devices,
810810
aic_num_cores=num_cores,
@@ -1478,17 +1478,17 @@ def compile(
14781478

14791479
custom_io_vision = {}
14801480
needed_dtype = self.model.config.torch_dtype
1481-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[needed_dtype]
1481+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
14821482
molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo"
14831483
if molmo:
1484-
custom_io_vision["image_masks"] = CUSTOM_IO_DTYPE_MAP[needed_dtype]
1485-
custom_io_vision["pixel_values"] = CUSTOM_IO_DTYPE_MAP[needed_dtype]
1484+
custom_io_vision["image_masks"] = DTYPE_TO_STRING_MAP[needed_dtype]
1485+
custom_io_vision["pixel_values"] = DTYPE_TO_STRING_MAP[needed_dtype]
14861486

14871487
for output_name in output_names["vision"]:
14881488
if output_name.startswith("past_"):
14891489
custom_io_vision[output_name] = kv_cache_dtype
14901490
else:
1491-
custom_io_vision[output_name] = CUSTOM_IO_DTYPE_MAP[needed_dtype]
1491+
custom_io_vision[output_name] = DTYPE_TO_STRING_MAP[needed_dtype]
14921492

14931493
if vision_onnx_path:
14941494
self.vision_model.onnx_path = vision_onnx_path
@@ -1531,21 +1531,21 @@ def compile(
15311531
for output_name in output_names["lang"]:
15321532
if output_name.endswith("_RetainedState"):
15331533
custom_io_lang[output_name[: -len("_RetainedState")]] = (
1534-
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
1534+
DTYPE_TO_STRING_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
15351535
)
15361536

15371537
# outputs
15381538
for output_name in output_names["lang"]:
15391539
if output_name.endswith("_RetainedState"):
15401540
custom_io_lang[output_name] = (
1541-
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
1541+
DTYPE_TO_STRING_MAP[needed_dtype] if "vision_embeds" in output_name else kv_cache_dtype
15421542
)
15431543
self.lang_model._compile(
15441544
compile_dir=compile_dir,
15451545
compile_only=True,
15461546
retained_state=True,
15471547
specializations=specializations["lang"],
1548-
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
1548+
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
15491549
mxfp6_matmul=mxfp6_matmul,
15501550
mdp_ts_num_devices=num_devices,
15511551
aic_num_cores=num_cores,
@@ -2160,19 +2160,19 @@ def compile(
21602160

21612161
custom_io = {}
21622162
needed_dtype = self.model.config.torch_dtype
2163-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[needed_dtype]
2163+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
21642164
# inputs
21652165
for input_name in output_names:
21662166
if input_name.endswith("_RetainedState"):
21672167
custom_io[input_name[: -len("_RetainedState")]] = (
2168-
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "pixel_values" in input_name else kv_cache_dtype
2168+
DTYPE_TO_STRING_MAP[needed_dtype] if "pixel_values" in input_name else kv_cache_dtype
21692169
)
21702170

21712171
# outputs
21722172
for output_name in output_names:
21732173
if output_name.endswith("_RetainedState"):
21742174
custom_io[output_name] = (
2175-
CUSTOM_IO_DTYPE_MAP[needed_dtype] if "pixel_values" in output_name else kv_cache_dtype
2175+
DTYPE_TO_STRING_MAP[needed_dtype] if "pixel_values" in output_name else kv_cache_dtype
21762176
)
21772177

21782178
# TODO this hould be removed once the continous batching is supported for all the models.
@@ -2185,7 +2185,7 @@ def compile(
21852185
compile_only=True,
21862186
retained_state=True,
21872187
specializations=specializations,
2188-
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
2188+
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
21892189
mxfp6_matmul=mxfp6_matmul,
21902190
custom_io=custom_io,
21912191
mdp_ts_num_devices=num_devices,
@@ -3437,7 +3437,7 @@ def compile(
34373437
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
34383438
custom_io = {}
34393439
needed_dtype = self.model.config.torch_dtype
3440-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[needed_dtype]
3440+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
34413441

34423442
for suffix in ["", "_RetainedState"]:
34433443
for i in range(self.num_layers):
@@ -3449,7 +3449,7 @@ def compile(
34493449
compile_only=True,
34503450
retained_state=True,
34513451
specializations=specializations,
3452-
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
3452+
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
34533453
mxfp6_matmul=mxfp6_matmul,
34543454
custom_io=custom_io,
34553455
mdp_ts_num_devices=num_devices,
@@ -3795,7 +3795,7 @@ def compile(
37953795
output_names = self.model.get_output_names()
37963796

37973797
needed_dtype = self.model.config.torch_dtype
3798-
kv_cache_dtype = CUSTOM_IO_DTYPE_MAP[needed_dtype]
3798+
kv_cache_dtype = DTYPE_TO_STRING_MAP[needed_dtype]
37993799
custom_io = {}
38003800

38013801
custom_io["input_features"] = kv_cache_dtype
@@ -3816,7 +3816,7 @@ def compile(
38163816
compile_only=True,
38173817
retained_state=True,
38183818
specializations=specializations,
3819-
convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[needed_dtype] == "float16"),
3819+
convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"),
38203820
mxfp6_matmul=mxfp6_matmul,
38213821
mdp_ts_num_devices=num_devices,
38223822
aic_num_cores=num_cores,
@@ -4224,7 +4224,7 @@ def cloud_ai_100_feature_generate(
42244224
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
42254225
)
42264226
needed_dtype = self.model.config.torch_dtype
4227-
input_values = input_values.astype(CUSTOM_IO_DTYPE_MAP[needed_dtype])
4227+
input_values = input_values.astype(DTYPE_TO_STRING_MAP[needed_dtype])
42284228
inputs = dict(input_values=input_values)
42294229
outputs = self.qpc_session.run(inputs)
42304230

QEfficient/transformers/models/molmo/modeling_molmo.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
1919
from QEfficient.utils import constants
2020
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
21-
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
2221

2322

2423
def _non_meta_init_device(config) -> torch.device:
@@ -55,9 +54,7 @@ def eager_attention_forward(
5554

5655
attn_weights = torch.matmul(q, k.transpose(2, 3)) * scale_factor
5756
if attention_mask is not None:
58-
attn_weights = torch.where(
59-
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=k.dtype), attn_weights
60-
)
57+
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=k.dtype), attn_weights)
6158

6259
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
6360
attn_output = torch.matmul(attn_weights, v)

tests/configs/image_text_model_configs.json

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -204,43 +204,5 @@
204204
"num_layers": 1,
205205
"additional_params": {}
206206
}
207-
],
208-
"image_text_custom_dtype_models":[
209-
{
210-
"model_name": "OpenGVLab/InternVL2_5-1B",
211-
"model_type": "internvl_chat",
212-
"batch_size": 1,
213-
"prompt_len": 384,
214-
"ctx_len": 512,
215-
"img_size": null,
216-
"img_url": "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg",
217-
"text_prompt": "Please describe the image in detail.",
218-
"num_layers": 2,
219-
"additional_params": {}
220-
},
221-
{
222-
"model_name": "google/gemma-3-4b-it",
223-
"model_type": "gemma3",
224-
"batch_size": 1,
225-
"prompt_len": 128,
226-
"ctx_len": 3072,
227-
"img_size": 896,
228-
"img_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
229-
"text_prompt": "Can you describe the image in detail.",
230-
"num_layers": 6,
231-
"additional_params": {}
232-
},
233-
{
234-
"model_name": "llava-hf/llava-1.5-7b-hf",
235-
"model_type": "llava",
236-
"batch_size": 1,
237-
"prompt_len": 784,
238-
"ctx_len": 1024,
239-
"img_size": 336,
240-
"img_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
241-
"text_prompt": "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
242-
"num_layers": 1,
243-
"additional_params": {}
244-
}
245207
]
246208
}

tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,8 @@
3838
with open(CONFIG_PATH, "r") as f:
3939
config_data = json.load(f)
4040
multimodal_models = config_data["image_text_models"]
41-
custom_dtype_support_models = config_data["image_text_custom_dtype_models"]
4241
test_mm_models = [model_config["model_name"] for model_config in multimodal_models]
4342
model_config_dict = {model["model_name"]: model for model in multimodal_models}
44-
test_custom_dtype_support_models = [model_config["model_name"] for model_config in custom_dtype_support_models]
45-
custom_dtype_support_models_config_dict = {model["model_name"]: model for model in custom_dtype_support_models}
4643

4744

4845
def load_image_text_to_text_model(model_config):
@@ -125,7 +122,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
125122
qnn_config: Optional[str] = None,
126123
config: Optional[AutoConfig] = None,
127124
img_size: Optional[int] = None,
128-
torch_dtype: Optional[torch.dtype] = torch.float32,
125+
torch_dtype: Optional[int] = torch.float32,
129126
):
130127
"""
131128
Unified function to test PyTorch model, PyTorch KV model, ONNX model, and Cloud AI 100 model.
@@ -384,29 +381,40 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(model_name, kv_offload
384381

385382
@pytest.mark.on_qaic
386383
@pytest.mark.multimodal
387-
@pytest.mark.parametrize("model_name", test_custom_dtype_support_models)
388-
@pytest.mark.parametrize("kv_offload", [True])
384+
@pytest.mark.parametrize("model_name", test_mm_models)
385+
@pytest.mark.parametrize("kv_offload", [True, False])
389386
@pytest.mark.parametrize("torch_dtype", [torch.float16])
390387
def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype(model_name, kv_offload, torch_dtype):
391388
"""
392389
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching.
393390
``Mandatory`` Args:
394391
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
395392
"""
396-
# Get img_size for standard models, None for InternVL
397-
img_size = custom_dtype_support_models_config_dict[model_name].get("img_size")
393+
if model_name in [
394+
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
395+
"allenai/Molmo-7B-D-0924",
396+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
397+
]:
398+
pytest.skip("Test skipped for this model due to some issues.")
399+
if (
400+
model_name in ["OpenGVLab/InternVL2_5-1B", "OpenGVLab/InternVL3_5-1B", "Qwen/Qwen2.5-VL-3B-Instruct"]
401+
and not kv_offload
402+
):
403+
pytest.skip("These models require kv_offload=True for testing.")
404+
# Get img_size for standard models, None for InternVL and Molmo
405+
img_size = model_config_dict[model_name].get("img_size")
398406

399407
# TODO: Add custom dtype support in ORT and Pytorch_KV APIs
400408
check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
401409
model_name=model_name,
402-
prompt_len=custom_dtype_support_models_config_dict[model_name]["prompt_len"],
403-
ctx_len=custom_dtype_support_models_config_dict[model_name]["ctx_len"],
410+
prompt_len=model_config_dict[model_name]["prompt_len"],
411+
ctx_len=model_config_dict[model_name]["ctx_len"],
404412
max_gen_len=NEW_GENERATION_TOKENS,
405413
img_size=img_size,
406-
img_url=custom_dtype_support_models_config_dict[model_name]["img_url"],
407-
query=custom_dtype_support_models_config_dict[model_name]["text_prompt"],
408-
n_layer=custom_dtype_support_models_config_dict[model_name]["num_layers"],
409-
batch_size=custom_dtype_support_models_config_dict[model_name]["batch_size"],
414+
img_url=model_config_dict[model_name]["img_url"],
415+
query=model_config_dict[model_name]["text_prompt"],
416+
n_layer=model_config_dict[model_name]["num_layers"],
417+
batch_size=model_config_dict[model_name]["batch_size"],
410418
kv_offload=kv_offload,
411419
torch_dtype=torch_dtype,
412420
)

0 commit comments

Comments
 (0)