Skip to content

Commit 0f28d4f

Browse files
author
Gustav Larsson
authored
Add optional path for Dynamo and dynamic paths. (#2902)
- This is publicly exposed to the quantize.py just using `--use-dynamic-shapes`. This is mostly to get it in and make testing easy. - Note, if you use dynamic shapes I'm also using Dynamo. I'm tying these two changes together to prevent a 2x2 test matrix. - I imagine some new models we may force dynamic shape usage (those checks are not present right now). This is not included in this PR and currently no model would use this by default. - Eventually, the static path should be retired, but that will require upgrading all the grab-and-go encodings. (tracked here: https://github.com/qcom-ai-hub/tetracode/issues/18070) - Adds test to Llama 3.2 1B to exercise this path. - I'm not 100% sure this requires torch 2.10. I remember that being the case in an early version of the branch, but I haven't verified again. It is good that static path doesn't support 2.10, so that we can use the torch version to tweak adaptations (otherwise it gets pretty ugly to send this information into the adaptations). So, even if we can support lower versions, this actually works out pretty well.
1 parent 075154b commit 0f28d4f

File tree

27 files changed

+507
-133
lines changed

27 files changed

+507
-133
lines changed

qai_hub_models/models/_shared/llama3/model_adaptations.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LlamaMLP,
2121
)
2222

23+
from qai_hub_models.models._shared.llm.common import TORCH_SUPPORTS_DYNAMIC_SHAPE
2324
from qai_hub_models.models._shared.llm.model_adaptations import (
2425
ConvInplaceLinear,
2526
_apply_rope_single,
@@ -199,7 +200,12 @@ def forward_sha(
199200
past_key_value = past_key_values
200201
bsz, q_len, _ = hidden_states.size()
201202

202-
hidden_states = torch.reshape(hidden_states, (bsz, -1, 1, self.hidden_size_))
203+
if TORCH_SUPPORTS_DYNAMIC_SHAPE:
204+
hidden_states = hidden_states.unsqueeze(2)
205+
else:
206+
hidden_states = torch.reshape(
207+
hidden_states, (bsz, -1, 1, self.hidden_size_)
208+
)
203209
hidden_states = hidden_states.transpose(1, 3)
204210

205211
query_states: list[torch.Tensor] = [
@@ -311,7 +317,12 @@ def forward_sha(
311317
attn_output_return = attn_output_return.permute(0, 3, 1, 2)
312318
attn_output_return = self.o_proj_conv(attn_output_return)
313319
attn_output_return = attn_output_return.transpose(1, 3)
314-
attn_output_return = attn_output_return.reshape(bsz, q_len, self.hidden_size_)
320+
if TORCH_SUPPORTS_DYNAMIC_SHAPE:
321+
attn_output_return = attn_output_return.squeeze(2)
322+
else:
323+
attn_output_return = attn_output_return.reshape(
324+
bsz, q_len, self.hidden_size_
325+
)
315326

316327
attn_weights_return = attn_weights if output_attentions else None
317328

qai_hub_models/models/_shared/llm/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
from enum import Enum
77

88
import torch
9+
from packaging.version import Version
10+
11+
# Minimum torch version required for dynamic-shape ONNX export (dynamo export).
12+
# Note that earlier versions did support dynamic shapes in general, but did
13+
# not work well for LLMs until 2.10.
14+
TORCH_DYNAMIC_SHAPE_MIN_VERSION = "2.10"
15+
TORCH_SUPPORTS_DYNAMIC_SHAPE = Version(torch.__version__) >= Version(
16+
TORCH_DYNAMIC_SHAPE_MIN_VERSION
17+
)
918

1019

1120
def cleanup() -> None:

qai_hub_models/models/_shared/llm/export.py

Lines changed: 137 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ def _make_sub_component_name(
7979
return f"{instantiation_name}_{part_index + 1}_of_{num_splits}"
8080

8181

82+
def _onnx_has_dynamic_shapes(onnx_path: str) -> bool:
83+
"""Check if an ONNX model has dynamic (symbolic) input dimensions."""
84+
model = onnx.load(onnx_path, load_external_data=False)
85+
for inp in model.graph.input:
86+
for dim in inp.type.tensor_type.shape.dim:
87+
if dim.dim_param:
88+
return True
89+
return False
90+
91+
8292
# TODO(#12640): Unnecessary if we can pull this information directly.
8393
def _infer_output_specs(
8494
instantiations: list[tuple[str, int, int]],
@@ -301,7 +311,7 @@ def export_model(
301311
link_jobs: dict[str, hub.client.LinkJob] = {}
302312
profile_options_per_subcomponent: dict[str, str] = {}
303313
onnx_model_path_from_sub_component_name: dict[str, str] = {}
304-
llm_config: PretrainedConfig | None = None
314+
llm_config: PretrainedConfig
305315

306316
sub_component_names: dict[str, list[str]] = {}
307317
component_from_sub_component_names = {}
@@ -316,94 +326,172 @@ def export_model(
316326
onnx_export_dir = tmpdir_handler.name
317327
Path(onnx_export_dir).mkdir(parents=True, exist_ok=True)
318328

319-
for instantiation_name, seq_len, ctx_len in instantiations:
320-
full_name = f"{model_name}_{instantiation_name}"
321-
model = model_cls.from_pretrained(
322-
sequence_length=seq_len,
323-
context_length=ctx_len,
324-
precision=precision,
325-
**model_params,
329+
# Load the model once to determine checkpoint type and get llm_config.
330+
_first_instantiation_name, first_seq_len, first_ctx_len = instantiations[0]
331+
model = model_cls.from_pretrained(
332+
sequence_length=first_seq_len,
333+
context_length=first_ctx_len,
334+
precision=precision,
335+
**model_params,
336+
)
337+
llm_config = model.llm_config
338+
339+
# Check if the checkpoint has a dynamic ONNX model.
340+
# Dynamic models are exported/split/uploaded once and compiled with
341+
# explicit input_specs for each (seq_len, ctx_len) combo.
342+
# Static models are exported/split/uploaded per combo.
343+
has_dynamic_onnx = (
344+
hasattr(model, "checkpoint")
345+
and model.checkpoint is not None
346+
and os.path.exists(os.path.join(model.checkpoint, "model_dynamic.onnx"))
347+
and _onnx_has_dynamic_shapes(
348+
os.path.join(model.checkpoint, "model_dynamic.onnx")
326349
)
350+
)
327351

328-
llm_config = model.llm_config
329-
sub_component_names[instantiation_name] = []
330-
331-
input_spec = model.get_input_spec(
352+
def _export_split_upload(
353+
label: str,
354+
inst_model: LLM_AIMETOnnx,
355+
inst_seq_len: int,
356+
inst_ctx_len: int,
357+
) -> tuple[list[Any], list[list[str]]]:
358+
"""Export ONNX, split, upload parts. Returns (uploaded_models, split_input_names)."""
359+
inst_input_spec = inst_model.get_input_spec(
332360
**{
333-
**get_input_spec_kwargs(model, additional_model_kwargs),
334-
"sequence_length": seq_len,
335-
"context_length": model.context_length,
361+
**get_input_spec_kwargs(inst_model, additional_model_kwargs),
362+
"sequence_length": inst_seq_len,
363+
"context_length": inst_ctx_len,
336364
"llm_config": llm_config.to_dict(),
337-
"llm_io_type": model.llm_io_type,
365+
"llm_io_type": inst_model.llm_io_type,
338366
},
339367
)
340-
341-
sub_output_path = Path(onnx_export_dir) / instantiation_name
342-
source_model_dir = model.convert_to_hub_source_model(
368+
sub_output_path = Path(onnx_export_dir) / label
369+
source_model_dir = inst_model.convert_to_hub_source_model(
343370
target_runtime,
344371
sub_output_path,
345-
input_spec,
372+
inst_input_spec,
346373
external_onnx_weights=True,
347-
output_names=model.get_output_names(),
374+
output_names=inst_model.get_output_names(),
348375
)
349376
assert source_model_dir is not None
350377
source_model_bundle = ONNXBundle.from_bundle_path(source_model_dir)
378+
nonlocal input_encodings_path
351379
input_encodings_path = str(source_model_bundle.aimet_encodings_path)
352-
# Split encodings
353-
model_artifact = Path(onnx_export_dir) / instantiation_name
354-
os.makedirs(model_artifact, exist_ok=True)
355380

381+
model_artifact = Path(onnx_export_dir) / label
382+
os.makedirs(model_artifact, exist_ok=True)
356383
onnx.checker.check_model(source_model_bundle.onnx_graph_path, full_check=True)
357-
subcomponent_onnx_bundles: list[ONNXBundle]
384+
358385
if num_splits == 1:
359-
subcomponent_onnx_bundles = [source_model_bundle]
386+
bundles = [source_model_bundle]
360387
else:
361-
subcomponent_onnx_bundles = utils.split_onnx(
388+
bundles = utils.split_onnx(
362389
onnxfile=source_model_bundle,
363-
modelname=full_name,
390+
modelname=f"{model_name}_{label}",
364391
num_splits=num_splits,
365392
num_layers_per_split=num_layers_per_split,
366393
output_dir=model_artifact,
367394
split_embedding=True,
368395
using_qairt_workflow=True,
369396
)
370397

371-
# Submit the parts for compilation
372-
for i, onnx_model_bundle in enumerate(subcomponent_onnx_bundles):
373-
# Sequence length (ar...) and context lenght (cl...) in graph name
374-
# are semantically important to Genie
398+
uploaded: list[Any] = []
399+
input_names: list[list[str]] = []
400+
for i, bundle in enumerate(bundles):
401+
onnx_path = bundle.onnx_graph_path.as_posix()
402+
split_model = onnx.load(onnx_path, load_external_data=False)
403+
onnx.checker.check_model(onnx_path, full_check=True)
404+
input_names.append([inp.name for inp in split_model.graph.input])
405+
406+
cache_keys: dict[str, str] = {"precision": str(precision)}
407+
if not has_dynamic_onnx:
408+
cache_keys["context_length"] = str(inst_ctx_len)
409+
cache_keys["sequence_length"] = str(inst_seq_len)
410+
411+
uploaded.append(
412+
get_or_create_cached_model(
413+
model_name=model_name,
414+
model_asset_version=model_asset_version,
415+
cache_name=f"{label}_part_{i + 1}_of_{num_splits}",
416+
cache_mode=model_cache_mode,
417+
model_path=bundle.bundle_path.as_posix(),
418+
additional_keys=cache_keys,
419+
)
420+
)
421+
return uploaded, input_names
422+
423+
if has_dynamic_onnx:
424+
# Dynamic path: Export+Upload once per part (regardless of instantiations)
425+
uploaded_models, split_onnx_input_names = _export_split_upload(
426+
"dynamic", model, first_seq_len, first_ctx_len
427+
)
428+
else:
429+
# Static path: Export+Upload once per instantiation and part
430+
uploaded_models = []
431+
split_onnx_input_names = []
432+
433+
# Submit compile jobs for each (seq_len, ctx_len) combo
434+
for inst_idx, (instantiation_name, seq_len, ctx_len) in enumerate(instantiations):
435+
sub_component_names[instantiation_name] = []
436+
437+
if not has_dynamic_onnx:
438+
# Static path: export/split/upload per instantiation
439+
# Reuse the model from the initial load for the first instantiation.
440+
if inst_idx > 0:
441+
model = model_cls.from_pretrained(
442+
sequence_length=seq_len,
443+
context_length=ctx_len,
444+
precision=precision,
445+
**model_params,
446+
)
447+
llm_config = model.llm_config
448+
uploaded_models, split_onnx_input_names = _export_split_upload(
449+
instantiation_name, model, seq_len, ctx_len
450+
)
451+
else:
452+
model.sequence_length = seq_len
453+
model.context_length = ctx_len
454+
455+
input_spec = model.get_input_spec(
456+
**{
457+
**get_input_spec_kwargs(model, additional_model_kwargs),
458+
"sequence_length": seq_len,
459+
"context_length": ctx_len,
460+
"llm_config": llm_config.to_dict(),
461+
"llm_io_type": model.llm_io_type,
462+
},
463+
)
464+
465+
for i in range(len(uploaded_models)):
375466
sub_component_name = _make_sub_component_name(
376467
instantiation_name, i, num_splits
377468
)
378469
component_name = f"part_{i + 1}_of_{num_splits}"
379470
sub_component_names[instantiation_name].append(sub_component_name)
380471
full_name = f"{model_name}_{sub_component_name}"
381472

382-
onnx_path = onnx_model_bundle.onnx_graph_path.as_posix()
383-
onnx.checker.check_model(onnx_path, full_check=True)
384-
385-
onnx_model_path_from_sub_component_name[sub_component_name] = str(onnx_path)
386473
model_compile_options = model.get_hub_compile_options(
387474
target_runtime,
388475
precision,
389476
compile_options,
390477
context_graph_name=model.get_qnn_context_graph_name(i, num_splits),
391478
)
392-
current_model = get_or_create_cached_model(
393-
model_name=model_name,
394-
model_asset_version=model_asset_version,
395-
cache_name=sub_component_name,
396-
cache_mode=model_cache_mode,
397-
model_path=onnx_model_bundle.bundle_path.as_posix(),
398-
additional_keys={
399-
"context_length": str(model.context_length),
400-
"sequence_length": str(seq_len),
401-
"precision": str(precision),
402-
},
403-
)
479+
480+
# Build per-split input spec from the split's ONNX input names.
481+
# Intermediate tensors (e.g. "embedding") from previous splits are float32.
482+
split_input_spec = {}
483+
for inp_name in split_onnx_input_names[i]:
484+
if inp_name in input_spec:
485+
split_input_spec[inp_name] = input_spec[inp_name]
486+
else:
487+
split_input_spec[inp_name] = (
488+
(1, seq_len, llm_config.hidden_size),
489+
"float32",
490+
)
404491

405492
submitted_compile_job = hub.submit_compile_job(
406-
model=current_model,
493+
model=uploaded_models[i],
494+
input_specs=split_input_spec,
407495
device=device,
408496
name=full_name,
409497
options=model_compile_options,

0 commit comments

Comments
 (0)