Skip to content

Commit 7037ea0

Browse files
committed
fix
1 parent 0eb978b commit 7037ea0

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_get_phi2(self):
2222
@ignore_warnings(UserWarning)
2323
@requires_transformers("4.54")
2424
@requires_torch("2.9.99")
25-
def test_export_phi2_1(self):
25+
def test_export_phi2_1_batch_size_1(self):
2626
# exporting vmap does not work
2727
data = get_phi2(num_hidden_layers=2)
2828
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
@@ -35,6 +35,22 @@ def test_export_phi2_1(self):
3535
)
3636
assert ep
3737

38+
@ignore_warnings(UserWarning)
39+
@requires_transformers("4.54")
40+
@requires_torch("2.9.99")
41+
def test_export_phi2_1_batch_size_2(self):
42+
# exporting vmap does not work
43+
data = get_phi2(num_hidden_layers=2, batch=2)
44+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
45+
self.assertEqual(
46+
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
47+
)
48+
with torch_export_patches(patch_transformers=True):
49+
ep = torch.export.export(
50+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
51+
)
52+
assert ep
53+
3854

3955
if __name__ == "__main__":
4056
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/untrained/llm_phi2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,14 @@ def get_phi2(
6262
n_layers = config["num_hidden_layers"]
6363
num_key_value_heads = config["num_key_value_heads"]
6464

65-
batch = torch.export.Dim("batch", min=1, max=1024)
66-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
67-
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
65+
if batch_size == 1:
66+
batch = torch.export.Dim("batch", min=1, max=1024)
67+
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
68+
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
69+
else:
70+
batch = "batch"
71+
seq_length = "seq_length"
72+
cache_length = "cache_length"
6873

6974
shapes = {
7075
"input_ids": {0: batch, 1: seq_length},

0 commit comments

Comments
 (0)