|
23 | 23 | from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
24 | 24 |
|
25 | 25 |
|
| 26 | +@contextmanager |
| 27 | +def _skip_nested_compile() -> Generator[None, None, None]: |
| 28 | + """Tell dynamo to skip torch.compile calls encountered during make_fx tracing. |
| 29 | +
|
| 30 | + make_fx cannot trace through torch.compile'd functions (e.g. compiled |
| 31 | + flex_attention in FlexAttentionWrapper). Setting error_on_nested_fx_trace |
| 32 | + to False makes dynamo silently inline the wrapped function instead of |
| 33 | + raising, so make_fx traces the underlying ops normally. |
| 34 | + """ |
| 35 | + prev = torch._dynamo.config.error_on_nested_fx_trace |
| 36 | + torch._dynamo.config.error_on_nested_fx_trace = False |
| 37 | + try: |
| 38 | + yield |
| 39 | + finally: |
| 40 | + torch._dynamo.config.error_on_nested_fx_trace = prev |
| 41 | + |
| 42 | + |
26 | 43 | @dataclass |
27 | 44 | class SubclassMeta: |
28 | 45 | cls: type |
@@ -219,7 +236,7 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: |
219 | 236 | seq_nr_to_fwd_node: dict[int, torch.fx.Node] = {} |
220 | 237 |
|
221 | 238 | for node in fx_g.graph.nodes: |
222 | | - if node.op != "call_function" or "seq_nr" not in node.meta: |
| 239 | + if node.op not in ("call_function", "get_attr") or "seq_nr" not in node.meta: |
223 | 240 | continue |
224 | 241 | seq_nr = node.meta["seq_nr"] |
225 | 242 | if seq_nr not in seq_nr_to_fwd_node: |
@@ -275,11 +292,14 @@ def functional_call(*all_args): |
275 | 292 | unwrapped_args.append(arg) |
276 | 293 | input_layouts.append(SubclassLayout(1, None)) |
277 | 294 |
|
278 | | - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) |
| 295 | + fake_mode = FakeTensorMode( |
| 296 | + allow_non_fake_inputs=True, |
| 297 | + shape_env=torch.fx.experimental.symbolic_shapes.ShapeEnv(), |
| 298 | + ) |
279 | 299 |
|
280 | 300 | def to_fake(t): |
281 | 301 | if isinstance(t, torch.Tensor): |
282 | | - return fake_mode.from_tensor(t) |
| 302 | + return fake_mode.from_tensor(t, static_shapes=True) |
283 | 303 | return t |
284 | 304 |
|
285 | 305 | fake_args = tuple(to_fake(a) for a in unwrapped_args) |
@@ -315,7 +335,7 @@ def fn_with_subclass_handling(*plain_args): |
315 | 335 | return unwrapped_outputs |
316 | 336 |
|
317 | 337 | # preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes |
318 | | - with fake_mode, preserve_node_meta(): |
| 338 | + with fake_mode, preserve_node_meta(), _skip_nested_compile(): |
319 | 339 | traced = make_fx( |
320 | 340 | fn_with_subclass_handling, |
321 | 341 | record_stack_traces=True, |
|
0 commit comments