Skip to content
Closed
4 changes: 3 additions & 1 deletion _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def test_text_generation(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
with torch_export_patches(
patch_transformers=True, verbose=10
), torch.fx.experimental._config.patch(backed_size_oblivious=True):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
Expand Down
39 changes: 39 additions & 0 deletions _unittests/ut_tasks/test_tasks_text_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.torch_models.validate import validate_model


class TestTasksMaskGeneration(ExtTestCase):
@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
def test_text_generation(self):
mid = "microsoft/phi-2"
summary, data = validate_model(
mid,
do_run=True,
verbose=10,
exporter="onnx-dynamo",
dump_folder="dump_test/microsoft_phi-2",
inputs2=True,
patch=True,
)
self.assertIsInstance(summary, dict)
# token generation
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
# prompt processing
self.assertLess(summary["disc_onnx_ort_run2_abs"], 3e-2)
# multi-turn conversation
self.assertLess(summary["disc_onnx_ort_run3_abs"], 3e-2)
self.assertIsInstance(data, dict)
onnx_filename = data["onnx_filename"]
self.assertExists(onnx_filename)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 6 additions & 2 deletions _unittests/ut_torch_export_patches/test_dynamic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def test_phi2_export_module(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)

with torch_export_patches(patch_transformers=True):
with torch_export_patches(
patch_transformers=True
), torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = torch.export.export(
model,
(),
Expand Down Expand Up @@ -346,7 +348,9 @@ def test_phi2_export_interpreter(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)

with torch_export_patches(patch_transformers=True, verbose=1):
with torch_export_patches(
patch_transformers=True, verbose=1
), torch.fx.experimental._config.patch(backed_size_oblivious=True):
if masking_utils is not None:
self.assertEqual(
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
Expand Down
30 changes: 0 additions & 30 deletions onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,36 +1061,6 @@ def max_diff(
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)

if isinstance(got, (list, tuple)):
if len(got) != 1:
if verbose >= 6:
print(
f"[max_diff] list,tuple,2: {string_type(expected)} "
f"? {string_type(got)}"
)
if verbose > 2:
import torch

print(
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
f"len(got)={len(got)}, level={level}, _index={_index}"
)
for i, (a, b) in enumerate(zip(expected, got)):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
print(
f" i={i} expected {a.dtype}:{a.shape}, "
f"has {b.dtype}:{b.shape}, _index={_index}"
)
else:
print(
f" i={i} a is {type(a)}, "
f"b is {type(b)}, _index={_index}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if verbose >= 6:
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws)

if isinstance(expected, (tuple, list)):
if verbose >= 6:
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
Expand Down
132 changes: 89 additions & 43 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def get_inputs(
dummy_max_token_id: int,
num_hidden_layers: int,
batch_size: int = 2,
sequence_length: int = 30,
sequence_length2: int = 3,
past_sequence_length: int = 30,
sequence_length: int = 3,
dynamic_rope: bool = False,
num_key_value_heads: Optional[int] = None,
head_dim: Optional[int] = None,
Expand All @@ -76,17 +76,18 @@ def get_inputs(
:param head_dim: last dimension of the cache
:param dummy_max_token_id: dummy max token id
:param batch_size: batch size
:param sequence_length: sequence length
:param sequence_length2: new sequence length
:param past_sequence_length: past sequence length
:param sequence_length: new sequence length
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
:param cls_cache: cache class, by default it is
:class:`transformers.cache_utils.DynamicCache`
:return: dictionary
"""
batch = "batch"
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
seq_length = "seq_length"
past_seq_length = "past_seq_length"

# TODO(team): Is this code block still necessary?
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
try:
from transformers.models.mamba.modeling_mamba import MambaCache
Expand All @@ -98,23 +99,23 @@ def get_inputs(
MambaCache,
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
seq_length_multiple = 8
sequence_length = (
(sequence_length + seq_length_multiple)
past_sequence_length = (
(past_sequence_length + seq_length_multiple)
// seq_length_multiple
* seq_length_multiple
)
# sequence_inc = seq_length_multiple
sequence_length2 = seq_length_multiple
sequence_length = seq_length_multiple

shapes = {
"input_ids": {0: batch, 1: "sequence_length"},
"attention_mask": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: "cache+seq", # past_seq_length + seq_length
},
"cache_position": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: "cache+seq", # past_seq_length + seq_length
},
"cache_params": [
[{0: batch} for _ in range(num_hidden_layers)],
Expand All @@ -123,9 +124,9 @@ def get_inputs(
}
inputs = dict(
input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
0, dummy_max_token_id, (batch_size, past_sequence_length + sequence_length)
).to(torch.int64),
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
attention_mask=torch.ones((batch_size, past_sequence_length + sequence_length)).to(
torch.int64
),
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
Expand Down Expand Up @@ -167,46 +168,51 @@ def get_inputs(
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
is_static = cache_name == "StaticCache"

# TODO(team): Is this code block still necessary?
if is_static:
# static
shapes = {
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {0: batch, 2: "seq"},
"cache_position": {0: "seq"},
"attention_mask": {0: batch, 2: "past_sequence_length"},
"cache_position": {0: "past_sequence_length"},
"past_key_values": [
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
# past_sequence_length is now static
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch} for _ in range(num_hidden_layers)],
],
}
inputs = dict(
input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length2)
0, dummy_max_token_id, (batch_size, past_sequence_length)
).to(torch.int64),
attention_mask=torch.ones(
(batch_size, num_key_value_heads, sequence_length2, head_dim)
(
batch_size,
num_key_value_heads,
past_sequence_length,
head_dim,
)
).to(torch.bool),
cache_position=torch.arange(sequence_length2).to(torch.int64),
cache_position=torch.arange(past_sequence_length).to(torch.int64),
past_key_values=make_static_cache(
[
(
torch.randn(
batch_size,
num_key_value_heads,
sequence_length + sequence_length2,
sequence_length + past_sequence_length,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
sequence_length + sequence_length2,
sequence_length + past_sequence_length,
head_dim,
),
)
for i in range(num_hidden_layers)
],
max_cache_len=max(sequence_length + sequence_length2, head_dim),
max_cache_len=max(past_sequence_length, head_dim),
),
)
else:
Expand All @@ -215,53 +221,74 @@ def get_inputs(
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: "past_seq_length+seq_length", # past_seq_length + seq_length
},
"position_ids": {
0: batch,
1: "cache+seq", # cache_length + seq_length
1: seq_length,
},
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
],
}

inputs = dict(
input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length2)
0, dummy_max_token_id, (batch_size, sequence_length)
).to(torch.int64),
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
torch.int64
),
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
attention_mask=torch.ones(
(batch_size, sequence_length + past_sequence_length)
).to(torch.int64),
position_ids=torch.arange(
past_sequence_length, sequence_length + past_sequence_length
)
.to(torch.int64)
.expand((batch_size, -1)),
past_key_values=make_cache( # type: ignore[operator]
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
batch_size, num_key_value_heads, past_sequence_length, head_dim
),
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
batch_size, num_key_value_heads, past_sequence_length, head_dim
),
)
for i in range(num_hidden_layers)
]
),
)
# NOTE: past_sequence_length can be 0 when testing prompt processing,
# which it becomes an empty tensor
res = dict(inputs=inputs, dynamic_shapes=shapes)
if add_second_input:
# prompt processing (prefill) testing
res["inputs2"] = get_inputs(
model=model,
config=config,
dummy_max_token_id=dummy_max_token_id,
num_hidden_layers=num_hidden_layers,
batch_size=(batch_size + 1) if add_second_input > 0 else 1,
sequence_length=sequence_length + 1,
sequence_length2=sequence_length2
+ (add_second_input if add_second_input > 0 else -add_second_input),
batch_size=batch_size,
past_sequence_length=0,
sequence_length=32,
dynamic_rope=dynamic_rope,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
cls_cache=cls_cache,
add_second_input=0,
**kwargs,
)["inputs"]
# multi-turn conversation
# prompt-processing -> token-generation(loop output) ->
# prompt-processing from the loop output
res["inputs3"] = get_inputs(
model=model,
config=config,
dummy_max_token_id=dummy_max_token_id,
num_hidden_layers=num_hidden_layers,
batch_size=1,
past_sequence_length=32,
sequence_length=8,
dynamic_rope=dynamic_rope,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
Expand All @@ -276,6 +303,23 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.

NOTE: We test two scenarios:
1. prompt processing (aka prefill):
input_ids=(batch_size, prompt_length)
attn_mask=(batch_size, 0+prompt_length) = (batch_size, prompt_length)
pos_ids=(batch_size, prompt_length)
past_key_values=(batch_size, num_key_value_heads, 0, head_dim)
present_key_values=(batch_size, num_key_value_heads, 0+prompt_length, head_dim)
2. token generation (aka decode).
input_ids=(batch_size, 1)
attn_mask=(batch_size, past_sequence_length+1)
pos_ids=(batch_size, 1)
past_key_values=(batch_size, num_key_value_heads, past_sequence_length,
head_dim)
present_key_values=(batch_size, num_key_value_heads,
past_sequence_length+1, head_dim)


If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
Expand All @@ -290,8 +334,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
past_sequence_length=30,
sequence_length=3,
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
num_hidden_layers=4 if config is None else config.num_hidden_layers,
intermediate_size=256 if config is None else config.intermediate_size,
Expand All @@ -300,10 +344,12 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
conv_kernel=8 if config is None else getattr(config, "conv_kernel", None),
)
else:
# Token generation (decode) testing
# NOTE: We have to export model in decode mode to preserve the cache
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
past_sequence_length=32,
sequence_length=1,
head_dim=(
16
if config is None
Expand Down
Loading
Loading