Skip to content

Commit be6482b

Browse files
committed
oblib=vious
1 parent f1e1f91 commit be6482b

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

_unittests/ut_torch_models/test_tiny_llms_bypassed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_export_phi2_2_bypassed(self):
6262
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
6363
)
6464
model(**torch_deepcopy(inputs))
65+
ds = use_dyn_not_str(ds)
6566
with torch_export_patches(patch_transformers=True, stop_if_static=1) as modificator:
6667
inputs = modificator(inputs)
6768
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)

onnx_diagnostic/torch_models/untrained/llm_phi2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def get_phi2(
99
sequence_length: int = 30,
1010
sequence_length2: int = 3,
1111
dynamic_rope: bool = False,
12+
use_dim_not_dynamic: bool = False,
1213
**kwargs,
1314
) -> Dict[str, Any]:
1415
"""
@@ -18,6 +19,8 @@ def get_phi2(
1819
:param sequence_length: sequence length
1920
:param sequence_length2: new sequence length
2021
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
22+
:param use_dim_not_dynamic: uses ``torch.export.Dim`` and not a string for the batch size,
23+
the sequence length and the cache length
2124
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
2225
:return: dictionary
2326
@@ -62,7 +65,7 @@ def get_phi2(
6265
n_layers = config["num_hidden_layers"]
6366
num_key_value_heads = config["num_key_value_heads"]
6467

65-
if batch_size == 1:
68+
if use_dim_not_dynamic:
6669
batch = torch.export.Dim("batch", min=1, max=1024)
6770
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
6871
cache_length = torch.export.Dim("cache_length", min=1, max=4096)

0 commit comments

Comments
 (0)