Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented Sep 22, 2025

Previous to this PR, in text-generation task, we are exporting and testing LLMs with multi-turn conversation, where the model runs "prompt processing" --> [for loop for "token generation"] --> using the for loop output for prompt processing again, and the whole run is restricted to have batch_size=1 in GQA contrib op.

In this PR, we export LLMs with token generation setting, and test the model with both token generation and prompt processing scenario.

Cited @kunal-vaishnavi

These are the general shapes:

  • input_ids = (batch_size, sequence_length)
  • attn_mask = (batch_size, past_sequence_length + sequence_length)
  • pos_ids = (batch_size, sequence_length)
  • past_key_values = (batch_size, num_key_value_heads, past_sequence_length, head_dim)
  • present_key_values = (batch_size, num_key_value_heads, past_sequence_length + sequence_length, head_dim)

Prompt processing (aka prefill):

  • input_ids = (batch_size, prompt_length)
  • attn_mask = (batch_size, 0 + prompt_length) = (batch_size, prompt_length)
  • pos_ids = (batch_size, prompt_length)
  • past_key_values = (batch_size, num_key_value_heads, 0, head_dim)
  • present_key_values = (batch_size, num_key_value_heads, 0 + prompt_length, head_dim) = (batch_size, * num_key_value_heads, prompt_length, head_dim)

Token generation (aka decode):

  • input_ids = (batch_size, 1)
  • attn_mask = (batch_size, past_sequence_length + 1)
  • pos_ids = (batch_size, 1)
  • past_key_values = (batch_size, num_key_value_heads, past_sequence_length, head_dim)
  • present_key_values = (batch_size, num_key_value_heads, past_sequence_length + 1, head_dim)

@titaiwangms titaiwangms marked this pull request as ready for review September 23, 2025 22:01
@titaiwangms
Copy link
Collaborator Author

@sdpython @xadupre Is there a full benchmarking I can run before merging it?

@sdpython
Copy link
Owner

CI is not running. I wonder why so I can't tell if the tests are passing. Let me create a temporary PR.

@sdpython
Copy link
Owner

CI is running.

@titaiwangms
Copy link
Collaborator Author

#236

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants