@@ -79,6 +79,16 @@ def _make_sub_component_name(
7979 return f"{ instantiation_name } _{ part_index + 1 } _of_{ num_splits } "
8080
8181
82+ def _onnx_has_dynamic_shapes (onnx_path : str ) -> bool :
83+ """Check if an ONNX model has dynamic (symbolic) input dimensions."""
84+ model = onnx .load (onnx_path , load_external_data = False )
85+ for inp in model .graph .input :
86+ for dim in inp .type .tensor_type .shape .dim :
87+ if dim .dim_param :
88+ return True
89+ return False
90+
91+
8292# TODO(#12640): Unnecessary if we can pull this information directly.
8393def _infer_output_specs (
8494 instantiations : list [tuple [str , int , int ]],
@@ -301,7 +311,7 @@ def export_model(
301311 link_jobs : dict [str , hub .client .LinkJob ] = {}
302312 profile_options_per_subcomponent : dict [str , str ] = {}
303313 onnx_model_path_from_sub_component_name : dict [str , str ] = {}
304- llm_config : PretrainedConfig | None = None
314+ llm_config : PretrainedConfig
305315
306316 sub_component_names : dict [str , list [str ]] = {}
307317 component_from_sub_component_names = {}
@@ -316,94 +326,172 @@ def export_model(
316326 onnx_export_dir = tmpdir_handler .name
317327 Path (onnx_export_dir ).mkdir (parents = True , exist_ok = True )
318328
319- for instantiation_name , seq_len , ctx_len in instantiations :
320- full_name = f"{ model_name } _{ instantiation_name } "
321- model = model_cls .from_pretrained (
322- sequence_length = seq_len ,
323- context_length = ctx_len ,
324- precision = precision ,
325- ** model_params ,
329+ # Load the model once to determine checkpoint type and get llm_config.
330+ _first_instantiation_name , first_seq_len , first_ctx_len = instantiations [0 ]
331+ model = model_cls .from_pretrained (
332+ sequence_length = first_seq_len ,
333+ context_length = first_ctx_len ,
334+ precision = precision ,
335+ ** model_params ,
336+ )
337+ llm_config = model .llm_config
338+
339+ # Check if the checkpoint has a dynamic ONNX model.
340+ # Dynamic models are exported/split/uploaded once and compiled with
341+ # explicit input_specs for each (seq_len, ctx_len) combo.
342+ # Static models are exported/split/uploaded per combo.
343+ has_dynamic_onnx = (
344+ hasattr (model , "checkpoint" )
345+ and model .checkpoint is not None
346+ and os .path .exists (os .path .join (model .checkpoint , "model_dynamic.onnx" ))
347+ and _onnx_has_dynamic_shapes (
348+ os .path .join (model .checkpoint , "model_dynamic.onnx" )
326349 )
350+ )
327351
328- llm_config = model .llm_config
329- sub_component_names [instantiation_name ] = []
330-
331- input_spec = model .get_input_spec (
352+ def _export_split_upload (
353+ label : str ,
354+ inst_model : LLM_AIMETOnnx ,
355+ inst_seq_len : int ,
356+ inst_ctx_len : int ,
357+ ) -> tuple [list [Any ], list [list [str ]]]:
358+ """Export ONNX, split, upload parts. Returns (uploaded_models, split_input_names)."""
359+ inst_input_spec = inst_model .get_input_spec (
332360 ** {
333- ** get_input_spec_kwargs (model , additional_model_kwargs ),
334- "sequence_length" : seq_len ,
335- "context_length" : model . context_length ,
361+ ** get_input_spec_kwargs (inst_model , additional_model_kwargs ),
362+ "sequence_length" : inst_seq_len ,
363+ "context_length" : inst_ctx_len ,
336364 "llm_config" : llm_config .to_dict (),
337- "llm_io_type" : model .llm_io_type ,
365+ "llm_io_type" : inst_model .llm_io_type ,
338366 },
339367 )
340-
341- sub_output_path = Path (onnx_export_dir ) / instantiation_name
342- source_model_dir = model .convert_to_hub_source_model (
368+ sub_output_path = Path (onnx_export_dir ) / label
369+ source_model_dir = inst_model .convert_to_hub_source_model (
343370 target_runtime ,
344371 sub_output_path ,
345- input_spec ,
372+ inst_input_spec ,
346373 external_onnx_weights = True ,
347- output_names = model .get_output_names (),
374+ output_names = inst_model .get_output_names (),
348375 )
349376 assert source_model_dir is not None
350377 source_model_bundle = ONNXBundle .from_bundle_path (source_model_dir )
378+ nonlocal input_encodings_path
351379 input_encodings_path = str (source_model_bundle .aimet_encodings_path )
352- # Split encodings
353- model_artifact = Path (onnx_export_dir ) / instantiation_name
354- os .makedirs (model_artifact , exist_ok = True )
355380
381+ model_artifact = Path (onnx_export_dir ) / label
382+ os .makedirs (model_artifact , exist_ok = True )
356383 onnx .checker .check_model (source_model_bundle .onnx_graph_path , full_check = True )
357- subcomponent_onnx_bundles : list [ ONNXBundle ]
384+
358385 if num_splits == 1 :
359- subcomponent_onnx_bundles = [source_model_bundle ]
386+ bundles = [source_model_bundle ]
360387 else :
361- subcomponent_onnx_bundles = utils .split_onnx (
388+ bundles = utils .split_onnx (
362389 onnxfile = source_model_bundle ,
363- modelname = full_name ,
390+ modelname = f" { model_name } _ { label } " ,
364391 num_splits = num_splits ,
365392 num_layers_per_split = num_layers_per_split ,
366393 output_dir = model_artifact ,
367394 split_embedding = True ,
368395 using_qairt_workflow = True ,
369396 )
370397
371- # Submit the parts for compilation
372- for i , onnx_model_bundle in enumerate (subcomponent_onnx_bundles ):
373- # Sequence length (ar...) and context lenght (cl...) in graph name
374- # are semantically important to Genie
398+ uploaded : list [Any ] = []
399+ input_names : list [list [str ]] = []
400+ for i , bundle in enumerate (bundles ):
401+ onnx_path = bundle .onnx_graph_path .as_posix ()
402+ split_model = onnx .load (onnx_path , load_external_data = False )
403+ onnx .checker .check_model (onnx_path , full_check = True )
404+ input_names .append ([inp .name for inp in split_model .graph .input ])
405+
406+ cache_keys : dict [str , str ] = {"precision" : str (precision )}
407+ if not has_dynamic_onnx :
408+ cache_keys ["context_length" ] = str (inst_ctx_len )
409+ cache_keys ["sequence_length" ] = str (inst_seq_len )
410+
411+ uploaded .append (
412+ get_or_create_cached_model (
413+ model_name = model_name ,
414+ model_asset_version = model_asset_version ,
415+ cache_name = f"{ label } _part_{ i + 1 } _of_{ num_splits } " ,
416+ cache_mode = model_cache_mode ,
417+ model_path = bundle .bundle_path .as_posix (),
418+ additional_keys = cache_keys ,
419+ )
420+ )
421+ return uploaded , input_names
422+
423+ if has_dynamic_onnx :
424+ # Dynamic path: Export+Upload once per part (regardless of instantiations)
425+ uploaded_models , split_onnx_input_names = _export_split_upload (
426+ "dynamic" , model , first_seq_len , first_ctx_len
427+ )
428+ else :
429+ # Static path: Export+Upload once per instantiation and part
430+ uploaded_models = []
431+ split_onnx_input_names = []
432+
433+ # Submit compile jobs for each (seq_len, ctx_len) combo
434+ for inst_idx , (instantiation_name , seq_len , ctx_len ) in enumerate (instantiations ):
435+ sub_component_names [instantiation_name ] = []
436+
437+ if not has_dynamic_onnx :
438+ # Static path: export/split/upload per instantiation
439+ # Reuse the model from the initial load for the first instantiation.
440+ if inst_idx > 0 :
441+ model = model_cls .from_pretrained (
442+ sequence_length = seq_len ,
443+ context_length = ctx_len ,
444+ precision = precision ,
445+ ** model_params ,
446+ )
447+ llm_config = model .llm_config
448+ uploaded_models , split_onnx_input_names = _export_split_upload (
449+ instantiation_name , model , seq_len , ctx_len
450+ )
451+ else :
452+ model .sequence_length = seq_len
453+ model .context_length = ctx_len
454+
455+ input_spec = model .get_input_spec (
456+ ** {
457+ ** get_input_spec_kwargs (model , additional_model_kwargs ),
458+ "sequence_length" : seq_len ,
459+ "context_length" : ctx_len ,
460+ "llm_config" : llm_config .to_dict (),
461+ "llm_io_type" : model .llm_io_type ,
462+ },
463+ )
464+
465+ for i in range (len (uploaded_models )):
375466 sub_component_name = _make_sub_component_name (
376467 instantiation_name , i , num_splits
377468 )
378469 component_name = f"part_{ i + 1 } _of_{ num_splits } "
379470 sub_component_names [instantiation_name ].append (sub_component_name )
380471 full_name = f"{ model_name } _{ sub_component_name } "
381472
382- onnx_path = onnx_model_bundle .onnx_graph_path .as_posix ()
383- onnx .checker .check_model (onnx_path , full_check = True )
384-
385- onnx_model_path_from_sub_component_name [sub_component_name ] = str (onnx_path )
386473 model_compile_options = model .get_hub_compile_options (
387474 target_runtime ,
388475 precision ,
389476 compile_options ,
390477 context_graph_name = model .get_qnn_context_graph_name (i , num_splits ),
391478 )
392- current_model = get_or_create_cached_model (
393- model_name = model_name ,
394- model_asset_version = model_asset_version ,
395- cache_name = sub_component_name ,
396- cache_mode = model_cache_mode ,
397- model_path = onnx_model_bundle . bundle_path . as_posix (),
398- additional_keys = {
399- "context_length" : str ( model . context_length ),
400- "sequence_length" : str ( seq_len ),
401- "precision" : str ( precision ),
402- } ,
403- )
479+
480+ # Build per-split input spec from the split's ONNX input names.
481+ # Intermediate tensors (e.g. "embedding") from previous splits are float32.
482+ split_input_spec = {}
483+ for inp_name in split_onnx_input_names [ i ]:
484+ if inp_name in input_spec :
485+ split_input_spec [ inp_name ] = input_spec [ inp_name ]
486+ else :
487+ split_input_spec [ inp_name ] = (
488+ ( 1 , seq_len , llm_config . hidden_size ),
489+ "float32" ,
490+ )
404491
405492 submitted_compile_job = hub .submit_compile_job (
406- model = current_model ,
493+ model = uploaded_models [i ],
494+ input_specs = split_input_spec ,
407495 device = device ,
408496 name = full_name ,
409497 options = model_compile_options ,
0 commit comments