Skip to content

Commit a99e8ee

Browse files
angelayipytorchmergebot
authored andcommitted
Propagate real tensor tracing with torchbind + fixing side effects (pytorch#138797)
Summary: * Fixed real tensor tracing w/ torchbind objs by passing the cloned tensor obj. For now I just catch the exception and have an error message if the `_clone` fails, but up for discussion on what to do here * Separate question, should we require people to set up FakeScriptObjects and stuff for draft mode? * Prevent side effects from happening when we do the first pass of custom ops profiling by cloning/copying everything. Not sure if deepcopying the model will succeed in all cases... But also I guess this path can be removed once custom ops profiling turns into one pass. Test Plan: `buck2 run @//mode/dev-nosan //scripts/angelayi/draft_export:test_draft_export` Reviewed By: ydwu4 Differential Revision: D64124825 Pull Request resolved: pytorch#138797 Approved by: https://github.com/ydwu4
1 parent dd9ff9f commit a99e8ee

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

torch/_library/fake_class_registry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# mypy: allow-untyped-defs
2+
import copy
23
import logging
34
from typing import Any, Dict, Optional, Protocol, Tuple, Union
45

56
import torch
67
from torch._library.utils import parse_namespace
8+
from torch.utils._python_dispatch import _disable_current_modes
79

810

911
log = logging.getLogger(__name__)
@@ -15,7 +17,18 @@ def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObje
1517

1618
# The fully qualified name of the class of original script object
1719
self.script_class_name = script_class_name
18-
self.real_obj = x
20+
try:
21+
with _disable_current_modes():
22+
self.real_obj = copy.deepcopy(x)
23+
except RuntimeError:
24+
log.warning(
25+
"Unable to deepcopy the custom object %s. "
26+
"Defaulting to the user given object. This might be "
27+
"dangerous as side effects may be directly applied "
28+
"to the object.",
29+
script_class_name,
30+
)
31+
self.real_obj = x
1932

2033

2134
class FakeScriptMethod:

torch/_subclasses/fake_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import torch
3939
from torch import SymBool, SymFloat, SymInt, Tensor
4040
from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
41+
from torch._library.fake_class_registry import FakeScriptObject
4142
from torch._prims_common import suggest_memory_format
4243
from torch._subclasses.meta_utils import (
4344
assert_eq,
@@ -1947,7 +1948,9 @@ def _dispatch_impl(
19471948
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
19481949
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
19491950

1950-
def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
1951+
def maybe_to_real_tensor(
1952+
t: T,
1953+
) -> Optional[Union[T, Tensor, torch._C.ScriptObject]]:
19511954
if isinstance(t, FakeTensor):
19521955
return t.real_tensor
19531956
elif isinstance(t, py_sym_types):
@@ -1957,6 +1960,8 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
19571960
self.shape_env.unbacked_var_to_val
19581961
)
19591962
)
1963+
elif isinstance(t, FakeScriptObject):
1964+
return t.real_obj
19601965
else:
19611966
return t
19621967

0 commit comments

Comments
 (0)