Skip to content
30 changes: 0 additions & 30 deletions onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,36 +1061,6 @@ def max_diff(
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)

if isinstance(got, (list, tuple)):
if len(got) != 1:
if verbose >= 6:
print(
f"[max_diff] list,tuple,2: {string_type(expected)} "
f"? {string_type(got)}"
)
if verbose > 2:
import torch

print(
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
f"len(got)={len(got)}, level={level}, _index={_index}"
)
for i, (a, b) in enumerate(zip(expected, got)):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
print(
f" i={i} expected {a.dtype}:{a.shape}, "
f"has {b.dtype}:{b.shape}, _index={_index}"
)
else:
print(
f" i={i} a is {type(a)}, "
f"b is {type(b)}, _index={_index}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if verbose >= 6:
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws)

if isinstance(expected, (tuple, list)):
if verbose >= 6:
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
Expand Down
126 changes: 79 additions & 47 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def get_inputs(
dummy_max_token_id: int,
num_hidden_layers: int,
batch_size: int = 2,
sequence_length: int = 30,
sequence_length2: int = 3,
past_sequence_length: int = 30,
sequence_length: int = 3,
dynamic_rope: bool = False,
num_key_value_heads: Optional[int] = None,
head_dim: Optional[int] = None,
Expand All @@ -76,17 +76,18 @@ def get_inputs(
:param head_dim: last dimension of the cache
:param dummy_max_token_id: dummy max token id
:param batch_size: batch size
:param sequence_length: sequence length
:param sequence_length2: new sequence length
:param past_sequence_length: past sequence length
:param sequence_length: new sequence length
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
:param cls_cache: cache class, by default it is
:class:`transformers.cache_utils.DynamicCache`
:return: dictionary
"""
batch = "batch"
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
seq_length = "seq_length"
past_seq_length = "past_seq_length"

# TODO(team): Is this code block still necessary?
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
try:
from transformers.models.mamba.modeling_mamba import MambaCache
Expand All @@ -98,23 +99,23 @@ def get_inputs(
MambaCache,
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
seq_length_multiple = 8
sequence_length = (
(sequence_length + seq_length_multiple)
past_sequence_length = (
(past_sequence_length + seq_length_multiple)
// seq_length_multiple
* seq_length_multiple
)
# sequence_inc = seq_length_multiple
sequence_length2 = seq_length_multiple
sequence_length = seq_length_multiple

shapes = {
"input_ids": {0: batch, 1: "sequence_length"},
"attention_mask": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: "cache+seq", # past_seq_length + seq_length
},
"cache_position": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: "cache+seq", # past_seq_length + seq_length
},
"cache_params": [
[{0: batch} for _ in range(num_hidden_layers)],
Expand All @@ -123,9 +124,9 @@ def get_inputs(
}
inputs = dict(
input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
0, dummy_max_token_id, (batch_size, past_sequence_length + sequence_length)
).to(torch.int64),
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
attention_mask=torch.ones((batch_size, past_sequence_length + sequence_length)).to(
torch.int64
),
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
Expand Down Expand Up @@ -167,46 +168,54 @@ def get_inputs(
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
is_static = cache_name == "StaticCache"

# TODO(team): Is this code block still necessary?
if is_static:
# static
shapes = {
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {0: batch, 2: "seq"},
"cache_position": {0: "seq"},
"attention_mask": {0: batch, 2: "sequence_length+past_sequence_length"},
"cache_position": {0: "sequence_length+past_sequence_length"},
"past_key_values": [
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
# [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
# [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch} for _ in range(num_hidden_layers)],
],
}
inputs = dict(
input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length2)
0, dummy_max_token_id, (batch_size, sequence_length)
).to(torch.int64),
attention_mask=torch.ones(
(batch_size, num_key_value_heads, sequence_length2, head_dim)
(
batch_size,
num_key_value_heads,
past_sequence_length + sequence_length,
head_dim,
)
).to(torch.bool),
cache_position=torch.arange(sequence_length2).to(torch.int64),
cache_position=torch.arange(past_sequence_length + sequence_length).to(
torch.int64
),
past_key_values=make_static_cache(
[
(
torch.randn(
batch_size,
num_key_value_heads,
sequence_length + sequence_length2,
past_sequence_length + sequence_length,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
sequence_length + sequence_length2,
sequence_length + past_sequence_length,
head_dim,
),
)
for i in range(num_hidden_layers)
],
max_cache_len=max(sequence_length + sequence_length2, head_dim),
max_cache_len=max(sequence_length + past_sequence_length, head_dim),
),
)
else:
Expand All @@ -215,53 +224,57 @@ def get_inputs(
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: "cache+seq", # past_seq_length + seq_length
},
"position_ids": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: seq_length,
},
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
],
}

inputs = dict(
input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length2)
0, dummy_max_token_id, (batch_size, sequence_length)
).to(torch.int64),
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
torch.int64
),
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
attention_mask=torch.ones(
(batch_size, sequence_length + past_sequence_length)
).to(torch.int64),
position_ids=torch.arange(
past_sequence_length, sequence_length + past_sequence_length
)
.to(torch.int64)
.expand((batch_size, -1)),
past_key_values=make_cache( # type: ignore[operator]
)
# Caches are involved
if past_sequence_length > 0:
inputs["past_key_values"] = make_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
batch_size, num_key_value_heads, past_sequence_length, head_dim
),
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
batch_size, num_key_value_heads, past_sequence_length, head_dim
),
)
for i in range(num_hidden_layers)
]
),
)
)
shapes["past_key_values"] = [
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
]
res = dict(inputs=inputs, dynamic_shapes=shapes)
if add_second_input:
# prompt processing (prefill) testing
res["inputs2"] = get_inputs(
model=model,
config=config,
dummy_max_token_id=dummy_max_token_id,
num_hidden_layers=num_hidden_layers,
batch_size=(batch_size + 1) if add_second_input > 0 else 1,
sequence_length=sequence_length + 1,
sequence_length2=sequence_length2
+ (add_second_input if add_second_input > 0 else -add_second_input),
batch_size=batch_size,
past_sequence_length=0,
sequence_length=32,
dynamic_rope=dynamic_rope,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
Expand All @@ -276,6 +289,23 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.

NOTE: We test two scenarios:
1. prompt processing (aka prefill):
input_ids=(batch_size, prompt_length)
attn_mask=(batch_size, 0+prompt_length) = (batch_size, prompt_length)
pos_ids=(batch_size, prompt_length)
past_key_values=(batch_size, num_key_value_heads, 0, head_dim)
present_key_values=(batch_size, num_key_value_heads, 0+prompt_length, head_dim)
2. token generation (aka decode).
input_ids=(batch_size, 1)
attn_mask=(batch_size, past_sequence_length+1)
pos_ids=(batch_size, 1)
past_key_values=(batch_size, num_key_value_heads, past_sequence_length,
head_dim)
present_key_values=(batch_size, num_key_value_heads,
past_sequence_length+1, head_dim)


If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
Expand All @@ -290,8 +320,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
past_sequence_length=30,
sequence_length=3,
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
num_hidden_layers=4 if config is None else config.num_hidden_layers,
intermediate_size=256 if config is None else config.intermediate_size,
Expand All @@ -300,10 +330,12 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
conv_kernel=8 if config is None else getattr(config, "conv_kernel", None),
)
else:
# Token generation (decode) testing
# NOTE: We have to export model in decode mode to preserve the cache
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
past_sequence_length=32,
sequence_length=1,
head_dim=(
16
if config is None
Expand Down
35 changes: 35 additions & 0 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,11 @@ def torch_export_patches(
except ImportError:
masking_utils = None

try:
import transformers.modeling_utils as modeling_utils
except ImportError:
modeling_utils = None

if verbose:
import transformers

Expand Down Expand Up @@ -509,6 +514,23 @@ def torch_export_patches(
patch_transformers_list.patched_sdpa_mask_recent_torch
)

if (
modeling_utils
and patch_transformers_list.patch_modeling_utils
and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS
):
if verbose:
print(
"[torch_export_patches] patches "
"transformers.modeling_utils.sdpa_attention_forward"
)
f_transformers_sdpa_attention_forward = modeling_utils.ALL_ATTENTION_FUNCTIONS[
"sdpa"
]
modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = (
patch_transformers_list.patched_sdpa_attention_forward
)

if custom_patches:
if verbose:
print("[torch_export_patches] applies custom patches")
Expand Down Expand Up @@ -688,6 +710,19 @@ def torch_export_patches(
"transformers.masking_utils.sdpa_mask "
"in ALL_MASK_ATTENTION_FUNCTIONS"
)
if (
modeling_utils
and patch_transformers_list.patch_modeling_utils
and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS
):
modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = (
f_transformers_sdpa_attention_forward
)
if verbose:
print(
"[torch_export_patches] restored "
"transformers.modeling_utils.sdpa_attention_forward"
)

########
# caches
Expand Down
Loading
Loading