Skip to content

Commit c49315e

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Improve attr mismatch msg (pytorch#149576)
Differential Revision: [D71513041](https://our.internmc.facebook.com/intern/diff/D71513041) Pull Request resolved: pytorch#149576 Approved by: https://github.com/avikchaudhuri
1 parent fdc4394 commit c49315e

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

test/export/test_export.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11046,6 +11046,21 @@ def forward(self, x, y, div="floor"):
1104611046
self.assertEqual(div_spec.arg.name, "div")
1104711047
self.assertEqual(div_spec.arg.value, "floor")
1104811048

11049+
def test_attr_assignment_extra(self):
11050+
class Foo(torch.nn.Module):
11051+
def __init__(self):
11052+
super().__init__()
11053+
11054+
def forward(self, x):
11055+
self.bar = x.sum()
11056+
return x + 2
11057+
11058+
with self.assertRaisesRegex(
11059+
ValueError,
11060+
"During torch.export, following attrs were created in the model.forward:",
11061+
):
11062+
_ = export(Foo(), (torch.randn(4, 4),), strict=False)
11063+
1104911064
def test_unbacked_deferred_runtime_retrace(self):
1105011065
class Foo(torch.nn.Module):
1105111066
def forward(self, x, y):

torch/_functorch/aot_autograd.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1659,8 +1659,29 @@ def _collect_assigned_tensor_attributes(kp, v, _v):
16591659
# TODO(avik): Assigning all other types are allowed right now.
16601660
# Maybe in the future we want to limit this to primitive types?
16611661

1662+
new_attrs = _get_attributes(mod)
1663+
if len(new_attrs) != len(snapshot):
1664+
added_attrs = new_attrs.keys() - snapshot.keys()
1665+
deleted_attrs = snapshot.keys() - new_attrs.keys()
1666+
1667+
if len(added_attrs) > 0:
1668+
raise ValueError(
1669+
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
1670+
f"Such attributes must be registered as buffers using the `register_buffer` "
1671+
f"API and must be initialized at model.__init__ "
1672+
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
1673+
)
1674+
1675+
if len(deleted_attrs) > 0:
1676+
raise ValueError(
1677+
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
1678+
f"Such attributes must be registered as buffers using the `register_buffer` "
1679+
f"API and must be initialized at model.__init__ "
1680+
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
1681+
)
1682+
16621683
pytree.tree_map_with_path(
1663-
_collect_assigned_tensor_attributes, snapshot, _get_attributes(mod)
1684+
_collect_assigned_tensor_attributes, snapshot, new_attrs
16641685
)
16651686
# restore state of all attributes (including, e.g., of primitive types)
16661687
mod.__dict__.update(snapshot)

0 commit comments

Comments
 (0)