Skip to content

Commit d65493b

Browse files
committed
support prompt processing and token generation
1 parent 8bd2fa1 commit d65493b

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

onnx_diagnostic/tasks/text_generation.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,11 @@ def get_inputs(
230230
0: batch,
231231
1: seq_length,
232232
},
233+
"past_key_values": [
234+
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
235+
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
236+
],
233237
}
234-
235238
inputs = dict(
236239
input_ids=torch.randint(
237240
0, dummy_max_token_id, (batch_size, sequence_length)
@@ -244,10 +247,7 @@ def get_inputs(
244247
)
245248
.to(torch.int64)
246249
.expand((batch_size, -1)),
247-
)
248-
# Caches are involved
249-
if past_sequence_length > 0:
250-
inputs["past_key_values"] = make_cache(
250+
past_key_values=make_cache(
251251
[
252252
(
253253
torch.randn(
@@ -259,11 +259,10 @@ def get_inputs(
259259
)
260260
for i in range(num_hidden_layers)
261261
]
262-
)
263-
shapes["past_key_values"] = [
264-
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
265-
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
266-
]
262+
),
263+
)
264+
# NOTE: past_sequence_length can be 0 when testing prompt processing,
265+
# which it becomes an empty tensor
267266
res = dict(inputs=inputs, dynamic_shapes=shapes)
268267
if add_second_input:
269268
# prompt processing (prefill) testing

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,7 @@ def patched_sdpa_attention_forward(
16571657
is_causal: Optional[bool] = None,
16581658
**kwargs,
16591659
) -> tuple[torch.Tensor, None]:
1660-
"""manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```."""
1660+
"""manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```.""" # noqa: E501
16611661
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
16621662
logger.warning_once(
16631663
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
@@ -1674,18 +1674,18 @@ def patched_sdpa_attention_forward(
16741674
if attention_mask is not None and attention_mask.ndim == 4:
16751675
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
16761676

1677-
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1678-
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1679-
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
1677+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # noqa: E501
1678+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # noqa: E501
1679+
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` # noqa: E501
16801680
if is_causal is None:
1681-
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
1682-
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
1683-
# is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
1681+
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # noqa: E501
1682+
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns # noqa: E501
1683+
# is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # noqa: E501
16841684
# NOTE: query.shape[2] == 1 or > 1 should have the same output for causal attention
16851685
# so we simplify the condition to:
16861686
is_causal = attention_mask is None and getattr(module, "is_causal", True)
16871687

1688-
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
1688+
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # noqa: E501
16891689
# We convert it to a bool for the SDPA kernel that only accepts bools.
16901690
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
16911691
is_causal = is_causal.item()

0 commit comments

Comments
 (0)