Skip to content

Commit 7dbc196

Browse files
committed
unit
1 parent dbc6950 commit 7dbc196

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,20 @@ def test_text_generation_tiny_llm(self):
5858
expected = model(**torch_deepcopy(inputs))
5959
model(**data["inputs2"])
6060
fake = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes=ds)[0]
61-
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
61+
with torch_export_patches(patch_transformers=True, verbose=1, patch_torch=False):
6262
ep = torch.export.export(
6363
model, (), kwargs=fake, dynamic_shapes=use_dyn_not_str(ds), strict=False
6464
)
6565
# print(ep)
66-
got = ep.module()(**inputs_copied)
66+
rem = []
67+
for node in ep.graph.nodes:
68+
if "_assert" in str(node.target):
69+
rem.append(node)
70+
for node in rem:
71+
ep.graph.erase_node(node)
72+
ep.graph.lint()
73+
mod = ep.module(check_guards=False)
74+
got = mod(**inputs_copied)
6775
self.assertEqualAny(expected.past_key_values, got.past_key_values)
6876
self.assertEqualArray(expected.logits, got.logits)
6977

0 commit comments

Comments
 (0)