Skip to content

Commit f1e1f91

Browse files
committed
another try
1 parent 7037ea0 commit f1e1f91

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@ def test_get_phi2(self):
2424
@requires_torch("2.9.99")
2525
def test_export_phi2_1_batch_size_1(self):
2626
# exporting vmap does not work
27-
data = get_phi2(num_hidden_layers=2)
27+
data = get_phi2(num_hidden_layers=2, batch_size=1)
2828
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
29+
self.assertEqual(inputs["input_ids"].shape[0], 1)
2930
self.assertEqual(
3031
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3132
)
32-
with torch_export_patches(patch_transformers=True):
33+
with torch.fx.experimental._config.patch(
34+
backed_size_oblivious=True
35+
), torch_export_patches(patch_transformers=True):
3336
ep = torch.export.export(
3437
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
3538
)
@@ -40,8 +43,9 @@ def test_export_phi2_1_batch_size_1(self):
4043
@requires_torch("2.9.99")
4144
def test_export_phi2_1_batch_size_2(self):
4245
# exporting vmap does not work
43-
data = get_phi2(num_hidden_layers=2, batch=2)
46+
data = get_phi2(num_hidden_layers=2, batch_size=2)
4447
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
48+
self.assertEqual(inputs["input_ids"].shape[0], 2)
4549
self.assertEqual(
4650
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
4751
)

0 commit comments

Comments
 (0)