Skip to content

Commit 46132dc

Browse files
mlazospytorchmergebot
authored andcommitted
[Dynamo] Refactor wrap_fx_proxy (pytorch#138933)
During the work to dedup graphs for hierarchical compilation I tried to tame the `wrap_fx_proxy_cls` mess by separating the wrapping into three distinct scenarios (vs a jumble of conditionals). These are: 1) wrapping a preexisting tensor (`_wrap_fx_preexisting_tensor` 2) wrapping and tracing a new op into the graph (`_wrap_fx_proxy`) 3) handling a value that is some other proxyable data structure See `wrap_fx_proxy_cls` for the conditional tree handling these three cases. Pull Request resolved: pytorch#138933 Approved by: https://github.com/williamwen42
1 parent 9ca749d commit 46132dc

File tree

1 file changed

+87
-40
lines changed

1 file changed

+87
-40
lines changed

torch/_dynamo/variables/builder.py

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,24 @@ def _dataclasses_fields_lambda(obj):
20262026
return TupleVariable(items)
20272027

20282028

2029+
def _clone_input(value, fake_mode):
2030+
if isinstance(value, torch.Tensor):
2031+
# tensor subclasses will not be converted to FakeTensors and need to be cloned
2032+
if not (
2033+
isinstance(value, FakeTensor)
2034+
or (
2035+
# Is functional tensor fakeified by this instance of Dynamo
2036+
torch._is_functional_tensor(value)
2037+
and maybe_get_fake_mode(value) is fake_mode
2038+
)
2039+
or value.is_nested
2040+
):
2041+
# NB: ensure strides are preserved
2042+
value = clone_input(value)
2043+
2044+
return value
2045+
2046+
20292047
def wrap_fx_proxy(
20302048
tx, proxy, example_value=None, subclass_type=None, **options
20312049
) -> VariableTracker:
@@ -2071,7 +2089,7 @@ def wrap_fx_proxy(
20712089
# instance of Dynamo.
20722090
#
20732091
# Upon closer inspection, you may notice that there are a slurry of non-Tensor
2074-
# output cases. What gives? Well, we sometimes trace operations into the
2092+
# output cases in handle_traced_output. What gives? Well, we sometimes trace operations into the
20752093
# graph that don't involve tensors.
20762094
#
20772095
# * Some operators return tuples; we need to recursively handle their
@@ -2090,54 +2108,53 @@ def wrap_fx_proxy(
20902108
# this function without a proxy.
20912109
def wrap_fx_proxy_cls(
20922110
target_cls, tx, proxy, example_value=None, subclass_type=None, **options
2111+
):
2112+
if example_value is None:
2113+
return _wrap_fx_proxy(
2114+
target_cls, tx, proxy, example_value, subclass_type, **options
2115+
)
2116+
elif isinstance(example_value, torch.Tensor):
2117+
return _wrap_fx_preexisting_tensor(
2118+
target_cls, tx, proxy, example_value, subclass_type, **options
2119+
)
2120+
else:
2121+
# This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported
2122+
# data structures. In essence this just handles tracing some other value which may
2123+
# contain Fake Tensors or is otherwise proxyable.
2124+
return handle_traced_output(
2125+
example_value, tx, proxy, options, subclass_type, target_cls
2126+
)
2127+
2128+
2129+
# This is 1 above (wrapping a preexisting tensor)
2130+
def _wrap_fx_preexisting_tensor(
2131+
target_cls, tx, proxy, tensor, subclass_type=None, **options
20932132
):
20942133
from ..symbolic_convert import InstructionTranslatorBase
20952134

2135+
assert isinstance(
2136+
tensor, torch.Tensor
2137+
), f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}"
2138+
20962139
assert isinstance(tx, InstructionTranslatorBase)
20972140
if "guards" in options and options["guards"] is not None:
20982141
tx.output.guards.update(options["guards"])
20992142

21002143
assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
21012144

2102-
initial_example_value = example_value
2103-
2104-
def _clone_input(value):
2105-
if isinstance(value, torch.Tensor):
2106-
# tensor subclasses will not be converted to FakeTensors and need to be cloned
2107-
if not (
2108-
isinstance(value, FakeTensor)
2109-
or (
2110-
# Is functional tensor fakeified by this instance of Dynamo
2111-
torch._is_functional_tensor(value)
2112-
and maybe_get_fake_mode(value) is tx.fake_mode
2113-
)
2114-
or value.is_nested
2115-
):
2116-
# NB: ensure strides are preserved
2117-
value = clone_input(value)
2118-
2119-
return value
2120-
21212145
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
21222146
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
2123-
# with preserve_rng_state():
2124-
if example_value is None:
2125-
# only allow_non_graph_fake in this instance because we handle the non-fake
2126-
# cases properly below.
2127-
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
2128-
21292147
# Handle recursive calls here
2130-
elif maybe_get_fake_mode(example_value) is tx.fake_mode:
2148+
if maybe_get_fake_mode(tensor) is tx.fake_mode:
21312149
pass
2132-
2133-
elif isinstance(example_value, torch.Tensor):
2150+
else:
21342151
if tx.export:
21352152
# The legacy behavior for real value cache with subclasses was
21362153
# to perform a clone WITHOUT preserving the subclass. It's
21372154
# not entirely clear this is what you actually want though.
21382155
with torch._C.DisableTorchFunctionSubclass():
21392156
proxy.tracer.real_value_cache[proxy.node] = _clone_input(
2140-
example_value
2157+
tensor, tx.fake_mode
21412158
)
21422159
# NB: If we're ignoring subclass, then the expectation is you will
21432160
# take the returned TensorVariable and wrap it into a more
@@ -2149,27 +2166,57 @@ def _clone_input(value):
21492166
}
21502167
assert "source" in options and options["source"] is not None
21512168
kwargs["source"] = options["source"]
2152-
example_value = wrap_to_fake_tensor_and_record(
2153-
example_value, tx=tx, **kwargs
2154-
)
2155-
if (
2156-
isinstance(example_value, torch.Tensor)
2157-
and example_value.device.type != "meta"
2158-
and (maybe_get_fake_mode(example_value) is not tx.fake_mode)
2169+
tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs)
2170+
2171+
if tensor.device.type != "meta" and (
2172+
maybe_get_fake_mode(tensor) is not tx.fake_mode
21592173
):
21602174
raise InternalTorchDynamoError(
2161-
"`example_value` needs to be a `FakeTensor`"
2162-
f"wrapped by this instance of Dynamo. Found: {example_value}"
2175+
"`tensor` needs to be a `FakeTensor`"
2176+
f"wrapped by this instance of Dynamo. Found: {tensor}"
21632177
)
21642178

2179+
return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls)
2180+
2181+
2182+
# This is 2 in the above comment (wrapping the output of a traced op)
2183+
def _wrap_fx_proxy(
2184+
target_cls, tx, proxy, example_value=None, subclass_type=None, **options
2185+
):
2186+
from ..symbolic_convert import InstructionTranslatorBase
2187+
2188+
assert isinstance(tx, InstructionTranslatorBase)
2189+
if "guards" in options and options["guards"] is not None:
2190+
tx.output.guards.update(options["guards"])
2191+
2192+
assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
2193+
2194+
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
2195+
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
2196+
# with preserve_rng_state():
2197+
# only allow_non_graph_fake in this instance because we handle the non-fake
2198+
# cases properly below.
2199+
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
2200+
2201+
return handle_traced_output(
2202+
example_value, tx, proxy, options, subclass_type, target_cls
2203+
)
2204+
2205+
2206+
# This handles wrapping of the output of an op traced into the graph
2207+
def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls):
2208+
import torch._functorch.vmap
2209+
import torch._subclasses.fake_tensor
2210+
import torch._utils
2211+
21652212
if isinstance(example_value, torch.Tensor):
21662213
is_parameter = isinstance(example_value, torch.nn.Parameter)
21672214
is_buffer = isinstance(example_value, torch.nn.Buffer)
21682215

21692216
# NB: In most (all?) cases, this does not actually do a clone.
21702217
# (WARNING: this means that if we mutate metadata on the fake
21712218
# tensor, the stored example value will update too!)
2172-
example_value = _clone_input(example_value)
2219+
example_value = _clone_input(example_value, tx.fake_mode)
21732220
set_example_value(proxy.node, example_value)
21742221
specialized_props = target_cls.specialize(example_value)
21752222
# TODO: not sure about this fake mode test

0 commit comments

Comments
 (0)