@@ -2026,6 +2026,24 @@ def _dataclasses_fields_lambda(obj):
20262026 return TupleVariable (items )
20272027
20282028
2029+ def _clone_input (value , fake_mode ):
2030+ if isinstance (value , torch .Tensor ):
2031+ # tensor subclasses will not be converted to FakeTensors and need to be cloned
2032+ if not (
2033+ isinstance (value , FakeTensor )
2034+ or (
2035+ # Is functional tensor fakeified by this instance of Dynamo
2036+ torch ._is_functional_tensor (value )
2037+ and maybe_get_fake_mode (value ) is fake_mode
2038+ )
2039+ or value .is_nested
2040+ ):
2041+ # NB: ensure strides are preserved
2042+ value = clone_input (value )
2043+
2044+ return value
2045+
2046+
20292047def wrap_fx_proxy (
20302048 tx , proxy , example_value = None , subclass_type = None , ** options
20312049) -> VariableTracker :
@@ -2071,7 +2089,7 @@ def wrap_fx_proxy(
20712089# instance of Dynamo.
20722090#
20732091# Upon closer inspection, you may notice that there are a slurry of non-Tensor
2074- # output cases. What gives? Well, we sometimes trace operations into the
2092+ # output cases in handle_traced_output . What gives? Well, we sometimes trace operations into the
20752093# graph that don't involve tensors.
20762094#
20772095# * Some operators return tuples; we need to recursively handle their
@@ -2090,54 +2108,53 @@ def wrap_fx_proxy(
20902108# this function without a proxy.
20912109def wrap_fx_proxy_cls (
20922110 target_cls , tx , proxy , example_value = None , subclass_type = None , ** options
2111+ ):
2112+ if example_value is None :
2113+ return _wrap_fx_proxy (
2114+ target_cls , tx , proxy , example_value , subclass_type , ** options
2115+ )
2116+ elif isinstance (example_value , torch .Tensor ):
2117+ return _wrap_fx_preexisting_tensor (
2118+ target_cls , tx , proxy , example_value , subclass_type , ** options
2119+ )
2120+ else :
2121+ # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported
2122+ # data structures. In essence this just handles tracing some other value which may
2123+ # contain Fake Tensors or is otherwise proxyable.
2124+ return handle_traced_output (
2125+ example_value , tx , proxy , options , subclass_type , target_cls
2126+ )
2127+
2128+
2129+ # This is 1 above (wrapping a preexisting tensor)
2130+ def _wrap_fx_preexisting_tensor (
2131+ target_cls , tx , proxy , tensor , subclass_type = None , ** options
20932132):
20942133 from ..symbolic_convert import InstructionTranslatorBase
20952134
2135+ assert isinstance (
2136+ tensor , torch .Tensor
2137+ ), f"_wrap_fx_preexisting_tensor expected tensor, got { type (tensor )} "
2138+
20962139 assert isinstance (tx , InstructionTranslatorBase )
20972140 if "guards" in options and options ["guards" ] is not None :
20982141 tx .output .guards .update (options ["guards" ])
20992142
21002143 assert "example_value" not in proxy .node .meta , f"{ proxy .node .meta ['example_value' ]} "
21012144
2102- initial_example_value = example_value
2103-
2104- def _clone_input (value ):
2105- if isinstance (value , torch .Tensor ):
2106- # tensor subclasses will not be converted to FakeTensors and need to be cloned
2107- if not (
2108- isinstance (value , FakeTensor )
2109- or (
2110- # Is functional tensor fakeified by this instance of Dynamo
2111- torch ._is_functional_tensor (value )
2112- and maybe_get_fake_mode (value ) is tx .fake_mode
2113- )
2114- or value .is_nested
2115- ):
2116- # NB: ensure strides are preserved
2117- value = clone_input (value )
2118-
2119- return value
2120-
21212145 # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
21222146 with torch ._dynamo .utils ._disable_saved_tensors_hooks_during_tracing ():
2123- # with preserve_rng_state():
2124- if example_value is None :
2125- # only allow_non_graph_fake in this instance because we handle the non-fake
2126- # cases properly below.
2127- example_value = get_fake_value (proxy .node , tx , allow_non_graph_fake = True )
2128-
21292147 # Handle recursive calls here
2130- elif maybe_get_fake_mode (example_value ) is tx .fake_mode :
2148+ if maybe_get_fake_mode (tensor ) is tx .fake_mode :
21312149 pass
2132-
2133- elif isinstance (example_value , torch .Tensor ):
2150+ else :
21342151 if tx .export :
21352152 # The legacy behavior for real value cache with subclasses was
21362153 # to perform a clone WITHOUT preserving the subclass. It's
21372154 # not entirely clear this is what you actually want though.
21382155 with torch ._C .DisableTorchFunctionSubclass ():
21392156 proxy .tracer .real_value_cache [proxy .node ] = _clone_input (
2140- example_value
2157+ tensor , tx . fake_mode
21412158 )
21422159 # NB: If we're ignoring subclass, then the expectation is you will
21432160 # take the returned TensorVariable and wrap it into a more
@@ -2149,27 +2166,57 @@ def _clone_input(value):
21492166 }
21502167 assert "source" in options and options ["source" ] is not None
21512168 kwargs ["source" ] = options ["source" ]
2152- example_value = wrap_to_fake_tensor_and_record (
2153- example_value , tx = tx , ** kwargs
2154- )
2155- if (
2156- isinstance (example_value , torch .Tensor )
2157- and example_value .device .type != "meta"
2158- and (maybe_get_fake_mode (example_value ) is not tx .fake_mode )
2169+ tensor = wrap_to_fake_tensor_and_record (tensor , tx = tx , ** kwargs )
2170+
2171+ if tensor .device .type != "meta" and (
2172+ maybe_get_fake_mode (tensor ) is not tx .fake_mode
21592173 ):
21602174 raise InternalTorchDynamoError (
2161- "`example_value ` needs to be a `FakeTensor`"
2162- f"wrapped by this instance of Dynamo. Found: { example_value } "
2175+ "`tensor ` needs to be a `FakeTensor`"
2176+ f"wrapped by this instance of Dynamo. Found: { tensor } "
21632177 )
21642178
2179+ return handle_traced_output (tensor , tx , proxy , options , subclass_type , target_cls )
2180+
2181+
2182+ # This is 2 in the above comment (wrapping the output of a traced op)
2183+ def _wrap_fx_proxy (
2184+ target_cls , tx , proxy , example_value = None , subclass_type = None , ** options
2185+ ):
2186+ from ..symbolic_convert import InstructionTranslatorBase
2187+
2188+ assert isinstance (tx , InstructionTranslatorBase )
2189+ if "guards" in options and options ["guards" ] is not None :
2190+ tx .output .guards .update (options ["guards" ])
2191+
2192+ assert "example_value" not in proxy .node .meta , f"{ proxy .node .meta ['example_value' ]} "
2193+
2194+ # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
2195+ with torch ._dynamo .utils ._disable_saved_tensors_hooks_during_tracing ():
2196+ # with preserve_rng_state():
2197+ # only allow_non_graph_fake in this instance because we handle the non-fake
2198+ # cases properly below.
2199+ example_value = get_fake_value (proxy .node , tx , allow_non_graph_fake = True )
2200+
2201+ return handle_traced_output (
2202+ example_value , tx , proxy , options , subclass_type , target_cls
2203+ )
2204+
2205+
2206+ # This handles wrapping of the output of an op traced into the graph
2207+ def handle_traced_output (example_value , tx , proxy , options , subclass_type , target_cls ):
2208+ import torch ._functorch .vmap
2209+ import torch ._subclasses .fake_tensor
2210+ import torch ._utils
2211+
21652212 if isinstance (example_value , torch .Tensor ):
21662213 is_parameter = isinstance (example_value , torch .nn .Parameter )
21672214 is_buffer = isinstance (example_value , torch .nn .Buffer )
21682215
21692216 # NB: In most (all?) cases, this does not actually do a clone.
21702217 # (WARNING: this means that if we mutate metadata on the fake
21712218 # tensor, the stored example value will update too!)
2172- example_value = _clone_input (example_value )
2219+ example_value = _clone_input (example_value , tx . fake_mode )
21732220 set_example_value (proxy .node , example_value )
21742221 specialized_props = target_cls .specialize (example_value )
21752222 # TODO: not sure about this fake mode test
0 commit comments