Skip to content

Commit 65f1ca0

Browse files
committed
draft improve llm random inputs
1 parent 0f9667b commit 65f1ca0

File tree

3 files changed

+78
-101
lines changed

3 files changed

+78
-101
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,36 +1061,6 @@ def max_diff(
10611061
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
10621062
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)
10631063

1064-
if isinstance(got, (list, tuple)):
1065-
if len(got) != 1:
1066-
if verbose >= 6:
1067-
print(
1068-
f"[max_diff] list,tuple,2: {string_type(expected)} "
1069-
f"? {string_type(got)}"
1070-
)
1071-
if verbose > 2:
1072-
import torch
1073-
1074-
print(
1075-
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
1076-
f"len(got)={len(got)}, level={level}, _index={_index}"
1077-
)
1078-
for i, (a, b) in enumerate(zip(expected, got)):
1079-
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
1080-
print(
1081-
f" i={i} expected {a.dtype}:{a.shape}, "
1082-
f"has {b.dtype}:{b.shape}, _index={_index}"
1083-
)
1084-
else:
1085-
print(
1086-
f" i={i} a is {type(a)}, "
1087-
f"b is {type(b)}, _index={_index}"
1088-
)
1089-
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1090-
if verbose >= 6:
1091-
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
1092-
return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws)
1093-
10941064
if isinstance(expected, (tuple, list)):
10951065
if verbose >= 6:
10961066
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")

onnx_diagnostic/tasks/text_generation.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def get_inputs(
5959
dummy_max_token_id: int,
6060
num_hidden_layers: int,
6161
batch_size: int = 2,
62-
sequence_length: int = 30,
63-
sequence_length2: int = 3,
62+
past_sequence_length: int = 30,
63+
sequence_length: int = 3,
6464
dynamic_rope: bool = False,
6565
num_key_value_heads: Optional[int] = None,
6666
head_dim: Optional[int] = None,
@@ -76,17 +76,18 @@ def get_inputs(
7676
:param head_dim: last dimension of the cache
7777
:param dummy_max_token_id: dummy max token id
7878
:param batch_size: batch size
79-
:param sequence_length: sequence length
80-
:param sequence_length2: new sequence length
79+
:param past_sequence_length: past sequence length
80+
:param sequence_length: new sequence length
8181
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
8282
:param cls_cache: cache class, by default it is
8383
:class:`transformers.cache_utils.DynamicCache`
8484
:return: dictionary
8585
"""
8686
batch = "batch"
87-
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
88-
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
87+
seq_length = "seq_length"
88+
past_seq_length = "past_seq_length"
8989

90+
# TODO(team): Is this code block still necessary?
9091
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
9192
try:
9293
from transformers.models.mamba.modeling_mamba import MambaCache
@@ -98,23 +99,23 @@ def get_inputs(
9899
MambaCache,
99100
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
100101
seq_length_multiple = 8
101-
sequence_length = (
102-
(sequence_length + seq_length_multiple)
102+
past_sequence_length = (
103+
(past_sequence_length + seq_length_multiple)
103104
// seq_length_multiple
104105
* seq_length_multiple
105106
)
106107
# sequence_inc = seq_length_multiple
107-
sequence_length2 = seq_length_multiple
108+
sequence_length = seq_length_multiple
108109

109110
shapes = {
110111
"input_ids": {0: batch, 1: "sequence_length"},
111112
"attention_mask": {
112113
0: batch,
113-
1: "cache+seq", # cache_length + seq_length
114+
1: "cache+seq", # past_seq_length + seq_length
114115
},
115116
"cache_position": {
116117
0: batch,
117-
1: "cache+seq", # cache_length + seq_length
118+
1: "cache+seq", # past_seq_length + seq_length
118119
},
119120
"cache_params": [
120121
[{0: batch} for _ in range(num_hidden_layers)],
@@ -123,9 +124,9 @@ def get_inputs(
123124
}
124125
inputs = dict(
125126
input_ids=torch.randint(
126-
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
127+
0, dummy_max_token_id, (batch_size, past_sequence_length + sequence_length)
127128
).to(torch.int64),
128-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
129+
attention_mask=torch.ones((batch_size, past_sequence_length + sequence_length)).to(
129130
torch.int64
130131
),
131132
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
@@ -167,46 +168,54 @@ def get_inputs(
167168
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
168169
is_static = cache_name == "StaticCache"
169170

171+
# TODO(team): Is this code block still necessary?
170172
if is_static:
171173
# static
172174
shapes = {
173175
"input_ids": {0: batch, 1: seq_length},
174-
"attention_mask": {0: batch, 2: "seq"},
175-
"cache_position": {0: "seq"},
176+
"attention_mask": {0: batch, 2: "sequence_length+past_sequence_length"},
177+
"cache_position": {0: "sequence_length+past_sequence_length"},
176178
"past_key_values": [
177-
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
178-
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
179+
# [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
180+
# [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
179181
[{0: batch} for _ in range(num_hidden_layers)],
180182
[{0: batch} for _ in range(num_hidden_layers)],
181183
],
182184
}
183185
inputs = dict(
184186
input_ids=torch.randint(
185-
0, dummy_max_token_id, (batch_size, sequence_length2)
187+
0, dummy_max_token_id, (batch_size, sequence_length)
186188
).to(torch.int64),
187189
attention_mask=torch.ones(
188-
(batch_size, num_key_value_heads, sequence_length2, head_dim)
190+
(
191+
batch_size,
192+
num_key_value_heads,
193+
past_sequence_length + sequence_length,
194+
head_dim,
195+
)
189196
).to(torch.bool),
190-
cache_position=torch.arange(sequence_length2).to(torch.int64),
197+
cache_position=torch.arange(past_sequence_length + sequence_length).to(
198+
torch.int64
199+
),
191200
past_key_values=make_static_cache(
192201
[
193202
(
194203
torch.randn(
195204
batch_size,
196205
num_key_value_heads,
197-
sequence_length + sequence_length2,
206+
past_sequence_length + sequence_length,
198207
head_dim,
199208
),
200209
torch.randn(
201210
batch_size,
202211
num_key_value_heads,
203-
sequence_length + sequence_length2,
212+
sequence_length + past_sequence_length,
204213
head_dim,
205214
),
206215
)
207216
for i in range(num_hidden_layers)
208217
],
209-
max_cache_len=max(sequence_length + sequence_length2, head_dim),
218+
max_cache_len=max(sequence_length + past_sequence_length, head_dim),
210219
),
211220
)
212221
else:
@@ -215,53 +224,56 @@ def get_inputs(
215224
"input_ids": {0: batch, 1: seq_length},
216225
"attention_mask": {
217226
0: batch,
218-
1: "cache+seq", # cache_length + seq_length
227+
1: "cache+seq", # past_seq_length + seq_length
219228
},
220229
"position_ids": {
221230
0: batch,
222-
1: "cache+seq", # cache_length + seq_length
231+
1: seq_length,
223232
},
224-
"past_key_values": [
225-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
226-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
227-
],
228233
}
229234

230235
inputs = dict(
231236
input_ids=torch.randint(
232-
0, dummy_max_token_id, (batch_size, sequence_length2)
237+
0, dummy_max_token_id, (batch_size, sequence_length)
233238
).to(torch.int64),
234-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
235-
torch.int64
236-
),
237-
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
239+
attention_mask=torch.ones(
240+
(batch_size, sequence_length + past_sequence_length)
241+
).to(torch.int64),
242+
position_ids=torch.arange(
243+
past_sequence_length, sequence_length + past_sequence_length
244+
)
238245
.to(torch.int64)
239246
.expand((batch_size, -1)),
240-
past_key_values=make_cache( # type: ignore[operator]
247+
)
248+
if past_sequence_length > 0:
249+
inputs["past_key_values"] = make_cache(
241250
[
242251
(
243252
torch.randn(
244-
batch_size, num_key_value_heads, sequence_length, head_dim
253+
batch_size, num_key_value_heads, past_sequence_length, head_dim
245254
),
246255
torch.randn(
247-
batch_size, num_key_value_heads, sequence_length, head_dim
256+
batch_size, num_key_value_heads, past_sequence_length, head_dim
248257
),
249258
)
250259
for i in range(num_hidden_layers)
251260
]
252-
),
253-
)
261+
)
262+
shapes["past_key_values"] = [
263+
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
264+
[{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)],
265+
]
254266
res = dict(inputs=inputs, dynamic_shapes=shapes)
255267
if add_second_input:
268+
# prompt processing (prefill) testing
256269
res["inputs2"] = get_inputs(
257270
model=model,
258271
config=config,
259272
dummy_max_token_id=dummy_max_token_id,
260273
num_hidden_layers=num_hidden_layers,
261-
batch_size=(batch_size + 1) if add_second_input > 0 else 1,
262-
sequence_length=sequence_length + 1,
263-
sequence_length2=sequence_length2
264-
+ (add_second_input if add_second_input > 0 else -add_second_input),
274+
batch_size=batch_size,
275+
past_sequence_length=0,
276+
sequence_length=32,
265277
dynamic_rope=dynamic_rope,
266278
num_key_value_heads=num_key_value_heads,
267279
head_dim=head_dim,
@@ -276,6 +288,23 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
276288
"""
277289
Inputs kwargs.
278290
291+
NOTE: We test two scenarios:
292+
1. prompt processing (aka prefill):
293+
input_ids=(batch_size, prompt_length)
294+
attn_mask=(batch_size, 0+prompt_length) = (batch_size, prompt_length)
295+
pos_ids=(batch_size, prompt_length)
296+
past_key_values=(batch_size, num_key_value_heads, 0, head_dim)
297+
present_key_values=(batch_size, num_key_value_heads, 0+prompt_length, head_dim)
298+
2. token generation (aka decode).
299+
input_ids=(batch_size, 1)
300+
attn_mask=(batch_size, past_sequence_length+1)
301+
pos_ids=(batch_size, 1)
302+
past_key_values=(batch_size, num_key_value_heads, past_sequence_length,
303+
head_dim)
304+
present_key_values=(batch_size, num_key_value_heads,
305+
past_sequence_length+1, head_dim)
306+
307+
279308
If the configuration is None, the function selects typical dimensions.
280309
"""
281310
if config is not None:
@@ -290,8 +319,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
290319
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
291320
kwargs = dict(
292321
batch_size=2,
293-
sequence_length=30,
294-
sequence_length2=3,
322+
past_sequence_length=30,
323+
sequence_length=3,
295324
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
296325
num_hidden_layers=4 if config is None else config.num_hidden_layers,
297326
intermediate_size=256 if config is None else config.intermediate_size,
@@ -300,10 +329,12 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
300329
conv_kernel=8 if config is None else getattr(config, "conv_kernel", None),
301330
)
302331
else:
332+
# Token generation (decode) testing
333+
# NOTE: We have to export model in decode mode to preserve the cache
303334
kwargs = dict(
304335
batch_size=2,
305-
sequence_length=30,
306-
sequence_length2=3,
336+
past_sequence_length=32,
337+
sequence_length=1,
307338
head_dim=(
308339
16
309340
if config is None

onnx_diagnostic/torch_models/validate.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -493,30 +493,6 @@ def validate_model(
493493
f.write(f"model_id: {model_id}\n------\n")
494494
f.write(pprint.pformat(dump_info))
495495

496-
if exporter == "modelbuilder":
497-
# Models used with ModelBuilder do not like batch size > 1.
498-
# Let's change that.
499-
for k in ["inputs", "inputs2"]:
500-
if k not in data:
501-
continue
502-
if verbose:
503-
print(f"[validate_model] set batch=1 for data[{k!r}]")
504-
print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
505-
cpl = CoupleInputsDynamicShapes(
506-
tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
507-
)
508-
if patch_kwargs.get("patch", False):
509-
with torch_export_patches(**patch_kwargs): # type: ignore[arg-type]
510-
data[k] = cpl.change_dynamic_dimensions(
511-
desired_values=dict(batch=1), only_desired=True
512-
)
513-
else:
514-
data[k] = cpl.change_dynamic_dimensions(
515-
desired_values=dict(batch=1), only_desired=True
516-
)
517-
if verbose:
518-
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
519-
520496
data["input_options"] = iop
521497
data["model_options"] = mop
522498
data["model_dump_folder"] = dump_folder

0 commit comments

Comments
 (0)