Skip to content

Commit fdf532c

Browse files
t-vikshitij12345
andauthored
Skip sections with vmap for thunderfx (Lightning-AI#2504)
Co-authored-by: Kshiteej K <kshitijkalambarkar@gmail.com>
1 parent 5564291 commit fdf532c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

thunder/dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
338338
nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set()
339339
ctx_cnt = 0 # Count of we have seen till now
340340

341-
UNSUPPORTED_THUNDER_CTX = ()
341+
UNSUPPORTED_THUNDER_CTX = (torch._C._functorch._vmap_increment_nesting, torch._C._functorch._vmap_decrement_nesting)
342342
for node in gm.graph.nodes:
343343
if node.op == "call_function" and node.target in UNSUPPORTED_THUNDER_CTX:
344344
ctx_cnt += 1

0 commit comments

Comments
 (0)