7676from QEfficient .utils .logging_utils import logger
7777from QEfficient .utils .sampler_utils import get_sampling_inputs_and_outputs
7878
79- CUSTOM_IO_DTYPE_MAP = {
79+ DTYPE_TO_STRING_MAP = {
8080 torch .float16 : "float16" ,
8181 torch .bfloat16 : "bfloat16" ,
8282 torch .float32 : "float16" , # Since compiler doesn't support fp32
@@ -463,7 +463,7 @@ def compile(
463463 compile_dir = compile_dir ,
464464 compile_only = True ,
465465 specializations = specializations ,
466- convert_to_fp16 = (CUSTOM_IO_DTYPE_MAP [needed_dtype ] == "float16" ),
466+ convert_to_fp16 = (DTYPE_TO_STRING_MAP [needed_dtype ] == "float16" ),
467467 mxfp6_matmul = mxfp6_matmul ,
468468 mdp_ts_num_devices = num_devices ,
469469 aic_num_cores = num_cores ,
@@ -804,7 +804,7 @@ def compile(
804804 compile_dir = compile_dir ,
805805 compile_only = True ,
806806 specializations = specializations ,
807- convert_to_fp16 = (CUSTOM_IO_DTYPE_MAP [needed_dtype ] == "float16" ),
807+ convert_to_fp16 = (DTYPE_TO_STRING_MAP [needed_dtype ] == "float16" ),
808808 mxfp6_matmul = mxfp6_matmul ,
809809 mdp_ts_num_devices = num_devices ,
810810 aic_num_cores = num_cores ,
@@ -1478,17 +1478,17 @@ def compile(
14781478
14791479 custom_io_vision = {}
14801480 needed_dtype = self .model .config .torch_dtype
1481- kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP [needed_dtype ]
1481+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP [needed_dtype ]
14821482 molmo = hasattr (self .model .config , "model_type" ) and self .model .config .model_type == "molmo"
14831483 if molmo :
1484- custom_io_vision ["image_masks" ] = CUSTOM_IO_DTYPE_MAP [needed_dtype ]
1485- custom_io_vision ["pixel_values" ] = CUSTOM_IO_DTYPE_MAP [needed_dtype ]
1484+ custom_io_vision ["image_masks" ] = DTYPE_TO_STRING_MAP [needed_dtype ]
1485+ custom_io_vision ["pixel_values" ] = DTYPE_TO_STRING_MAP [needed_dtype ]
14861486
14871487 for output_name in output_names ["vision" ]:
14881488 if output_name .startswith ("past_" ):
14891489 custom_io_vision [output_name ] = kv_cache_dtype
14901490 else :
1491- custom_io_vision [output_name ] = CUSTOM_IO_DTYPE_MAP [needed_dtype ]
1491+ custom_io_vision [output_name ] = DTYPE_TO_STRING_MAP [needed_dtype ]
14921492
14931493 if vision_onnx_path :
14941494 self .vision_model .onnx_path = vision_onnx_path
@@ -1531,21 +1531,21 @@ def compile(
15311531 for output_name in output_names ["lang" ]:
15321532 if output_name .endswith ("_RetainedState" ):
15331533 custom_io_lang [output_name [: - len ("_RetainedState" )]] = (
1534- CUSTOM_IO_DTYPE_MAP [needed_dtype ] if "vision_embeds" in output_name else kv_cache_dtype
1534+ DTYPE_TO_STRING_MAP [needed_dtype ] if "vision_embeds" in output_name else kv_cache_dtype
15351535 )
15361536
15371537 # outputs
15381538 for output_name in output_names ["lang" ]:
15391539 if output_name .endswith ("_RetainedState" ):
15401540 custom_io_lang [output_name ] = (
1541- CUSTOM_IO_DTYPE_MAP [needed_dtype ] if "vision_embeds" in output_name else kv_cache_dtype
1541+ DTYPE_TO_STRING_MAP [needed_dtype ] if "vision_embeds" in output_name else kv_cache_dtype
15421542 )
15431543 self .lang_model ._compile (
15441544 compile_dir = compile_dir ,
15451545 compile_only = True ,
15461546 retained_state = True ,
15471547 specializations = specializations ["lang" ],
1548- convert_to_fp16 = (CUSTOM_IO_DTYPE_MAP [needed_dtype ] == "float16" ),
1548+ convert_to_fp16 = (DTYPE_TO_STRING_MAP [needed_dtype ] == "float16" ),
15491549 mxfp6_matmul = mxfp6_matmul ,
15501550 mdp_ts_num_devices = num_devices ,
15511551 aic_num_cores = num_cores ,
@@ -2160,19 +2160,19 @@ def compile(
21602160
21612161 custom_io = {}
21622162 needed_dtype = self .model .config .torch_dtype
2163- kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP [needed_dtype ]
2163+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP [needed_dtype ]
21642164 # inputs
21652165 for input_name in output_names :
21662166 if input_name .endswith ("_RetainedState" ):
21672167 custom_io [input_name [: - len ("_RetainedState" )]] = (
2168- CUSTOM_IO_DTYPE_MAP [needed_dtype ] if "pixel_values" in input_name else kv_cache_dtype
2168+ DTYPE_TO_STRING_MAP [needed_dtype ] if "pixel_values" in input_name else kv_cache_dtype
21692169 )
21702170
21712171 # outputs
21722172 for output_name in output_names :
21732173 if output_name .endswith ("_RetainedState" ):
21742174 custom_io [output_name ] = (
2175- CUSTOM_IO_DTYPE_MAP [needed_dtype ] if "pixel_values" in output_name else kv_cache_dtype
2175+ DTYPE_TO_STRING_MAP [needed_dtype ] if "pixel_values" in output_name else kv_cache_dtype
21762176 )
21772177
21782178 # TODO this hould be removed once the continous batching is supported for all the models.
@@ -2185,7 +2185,7 @@ def compile(
21852185 compile_only = True ,
21862186 retained_state = True ,
21872187 specializations = specializations ,
2188- convert_to_fp16 = (CUSTOM_IO_DTYPE_MAP [needed_dtype ] == "float16" ),
2188+ convert_to_fp16 = (DTYPE_TO_STRING_MAP [needed_dtype ] == "float16" ),
21892189 mxfp6_matmul = mxfp6_matmul ,
21902190 custom_io = custom_io ,
21912191 mdp_ts_num_devices = num_devices ,
@@ -3437,7 +3437,7 @@ def compile(
34373437 kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
34383438 custom_io = {}
34393439 needed_dtype = self .model .config .torch_dtype
3440- kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP [needed_dtype ]
3440+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP [needed_dtype ]
34413441
34423442 for suffix in ["" , "_RetainedState" ]:
34433443 for i in range (self .num_layers ):
@@ -3449,7 +3449,7 @@ def compile(
34493449 compile_only = True ,
34503450 retained_state = True ,
34513451 specializations = specializations ,
3452- convert_to_fp16 = (CUSTOM_IO_DTYPE_MAP [needed_dtype ] == "float16" ),
3452+ convert_to_fp16 = (DTYPE_TO_STRING_MAP [needed_dtype ] == "float16" ),
34533453 mxfp6_matmul = mxfp6_matmul ,
34543454 custom_io = custom_io ,
34553455 mdp_ts_num_devices = num_devices ,
@@ -3795,7 +3795,7 @@ def compile(
37953795 output_names = self .model .get_output_names ()
37963796
37973797 needed_dtype = self .model .config .torch_dtype
3798- kv_cache_dtype = CUSTOM_IO_DTYPE_MAP [needed_dtype ]
3798+ kv_cache_dtype = DTYPE_TO_STRING_MAP [needed_dtype ]
37993799 custom_io = {}
38003800
38013801 custom_io ["input_features" ] = kv_cache_dtype
@@ -3816,7 +3816,7 @@ def compile(
38163816 compile_only = True ,
38173817 retained_state = True ,
38183818 specializations = specializations ,
3819- convert_to_fp16 = (CUSTOM_IO_DTYPE_MAP [needed_dtype ] == "float16" ),
3819+ convert_to_fp16 = (DTYPE_TO_STRING_MAP [needed_dtype ] == "float16" ),
38203820 mxfp6_matmul = mxfp6_matmul ,
38213821 mdp_ts_num_devices = num_devices ,
38223822 aic_num_cores = num_cores ,
@@ -4224,7 +4224,7 @@ def cloud_ai_100_feature_generate(
42244224 torch .nn .functional .pad (inputs ["input_values" ], (0 , self .seq_len - input_ids_len ), "constant" , 0 )
42254225 )
42264226 needed_dtype = self .model .config .torch_dtype
4227- input_values = input_values .astype (CUSTOM_IO_DTYPE_MAP [needed_dtype ])
4227+ input_values = input_values .astype (DTYPE_TO_STRING_MAP [needed_dtype ])
42284228 inputs = dict (input_values = input_values )
42294229 outputs = self .qpc_session .run (inputs )
42304230
0 commit comments