@@ -1473,13 +1473,19 @@ def compile(
14731473 for output_name in output_names ["lang" ]:
14741474 if output_name .endswith ("_RetainedState" ):
14751475 custom_io_lang [output_name [: - len ("_RetainedState" )]] = (
1476- "float16" if "vision_embeds" in output_name else kv_cache_dtype
1476+ "float16"
1477+ if ("vision_embeds" in output_name or "deepstack_features" in output_name )
1478+ else kv_cache_dtype
14771479 )
14781480
14791481 # outputs
14801482 for output_name in output_names ["lang" ]:
14811483 if output_name .endswith ("_RetainedState" ):
1482- custom_io_lang [output_name ] = "float16" if "vision_embeds" in output_name else kv_cache_dtype
1484+ custom_io_lang [output_name ] = (
1485+ "float16"
1486+ if ("vision_embeds" in output_name or "deepstack_features" in output_name )
1487+ else kv_cache_dtype
1488+ )
14831489 self .lang_model ._compile (
14841490 compile_dir = compile_dir ,
14851491 compile_only = True ,
@@ -1654,7 +1660,6 @@ def kv_offload_generate(
16541660 [x [lang_session .binding_index_map ["input_ids" ]][1 ][1 ] for x in lang_session .allowed_shapes ]
16551661 + [lang_session .bindings [lang_session .binding_index_map ["input_ids" ]].dims [1 ]]
16561662 )
1657- # breakpoint()
16581663 input_len = inputs ["attention_mask" ].sum (1 , keepdims = True )
16591664 input_ids_length = inputs ["input_ids" ].shape [1 ]
16601665 num_chunks = - (input_ids_length // - prefill_seq_len ) # ceil divide without float
@@ -1700,7 +1705,6 @@ def kv_offload_generate(
17001705 vision_end = perf_counter ()
17011706
17021707 lang_inputs = {k : v for k , v in inputs .items () if k not in vision_inputs }
1703- # breakpoint()
17041708 if "position_ids" in inputs :
17051709 lang_inputs ["position_ids" ] = inputs ["position_ids" ]
17061710 lang_inputs .pop ("attention_mask" )
@@ -1712,7 +1716,6 @@ def kv_offload_generate(
17121716 not_mllama = hasattr (self .model .config , "model_type" ) and self .model .config .model_type != "mllama"
17131717 if not_mllama :
17141718 lang_inputs ["image_idx" ] = np .array ([[0 ]])
1715- # breakpoint()
17161719 if self .vision_model .qpc_path :
17171720 vision_session .deactivate ()
17181721 lang_session .activate ()
@@ -1727,7 +1730,6 @@ def kv_offload_generate(
17271730 lang_inputs ["comp_ctx_lengths" ] = list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
17281731
17291732 lang_start = perf_counter ()
1730- # breakpoint()
17311733 # Run prefill
17321734 chunk_inputs = lang_inputs .copy ()
17331735 for i in range (num_chunks ):
@@ -1756,7 +1758,6 @@ def kv_offload_generate(
17561758 )
17571759 if not_mllama :
17581760 lang_session .skip_buffers (vision_outputs .keys ())
1759- # breakpoint()
17601761 # Get first token
17611762 lang_inputs ["input_ids" ] = outputs ["logits" ].argmax (2 )
17621763 lang_inputs ["position_ids" ] = np .max (lang_inputs ["position_ids" ], axis = - 1 , keepdims = True ) + 1
0 commit comments