Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

Previous to this PR, in text-generation task, we are exporting and testing LLMs with multi-turn conversation, where the model runs "prompt processing" --> [for loop for "token generation"] --> using the for loop output for prompt processing again, and the whole run is restricted to have batch_size=1 in GQA contrib op.

In this PR, we export LLMs with token generation setting, and test the model with both token generation and prompt processing scenario.

Cited @kunal-vaishnavi

These are the general shapes:

input_ids = (batch_size, sequence_length)
attn_mask = (batch_size, past_sequence_length + sequence_length)
pos_ids = (batch_size, sequence_length)
past_key_values = (batch_size, num_key_value_heads, past_sequence_length, head_dim)
present_key_values = (batch_size, num_key_value_heads, past_sequence_length + sequence_length, head_dim)
Prompt processing (aka prefill):

input_ids = (batch_size, prompt_length)
attn_mask = (batch_size, 0 + prompt_length) = (batch_size, prompt_length)
pos_ids = (batch_size, prompt_length)
past_key_values = (batch_size, num_key_value_heads, 0, head_dim)
present_key_values = (batch_size, num_key_value_heads, 0 + prompt_length, head_dim) = (batch_size, * num_key_value_heads, prompt_length, head_dim)
Token generation (aka decode):

input_ids = (batch_size, 1)
attn_mask = (batch_size, past_sequence_length + 1)
pos_ids = (batch_size, 1)
past_key_values = (batch_size, num_key_value_heads, past_sequence_length, head_dim)
present_key_values = (batch_size, num_key_value_heads, past_sequence_length + 1, head_dim)

@titaiwangms
Copy link
Collaborator Author

This PR tried to export LLMs with multi-turn conversation (batch_size = 1 and sequence_length > 1 and past_sequence_length > 1) but blocked by torch.export.export dynamism with 0/1 specialization. The phi models complain 2 scenarios:

(1) input batch is dynamic, while output batch is 1

_unittests/ut_torch_models/test_validate_whole_models.py::TestValidateWholeModels::test_o_validate_phi35_4k_mini_instruct - onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Concat node. Name:'cat16' Status Message: concat.cc:154 PrepareForCompute Non concat axis dimensions must match: Axis 0 has mismatched dimensions of 1 and 2

(2) Shape env issue

FAILED _unittests/ut_torch_models/test_validate_models.py::TestValidateModel::test_validate_microsoft_phi4_reasoning - AssertionError: [patched_ShapeEnv] Ignored guard s60 + s70 <= s31 + s70 == True, this could result in accuracy problems
FAILED 

@titaiwangms
Copy link
Collaborator Author

Closing this as the comment above.

@titaiwangms titaiwangms closed this Oct 7, 2025
@xadupre xadupre deleted the titaiwang/fix_modelbuilder_discrepancy branch November 12, 2025 12:25
@xadupre xadupre restored the titaiwang/fix_modelbuilder_discrepancy branch November 12, 2025 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants