Skip to content

Commit fecdebe

Browse files
malfetpytorchmergebot
authored andcommitted
[CI][MPS] Fix compile benchmark correctness (pytorch#159731)
By passing `fullgraph=True` attribute and increasing cache size limit to 2**16 Otherwise, compiler might decide not to fall back to eager to avoid recompilations Pull Request resolved: pytorch#159731 Approved by: https://github.com/dcci
1 parent e136a91 commit fecdebe

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

test/bench_mps_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def f(t):
9090
return reduction_func(t, dim=0)
9191

9292
f.__name__ = reduction_func.__name__
93-
f_c = torch.compile(f, dynamic=False)
93+
f_c = torch.compile(f, dynamic=False, fullgraph=True)
9494

9595
for size in (512, 1024, 2048, 4096):
9696
x = torch.testing.make_tensor(size, size, device=device, dtype=dtype)
@@ -116,7 +116,7 @@ def bench_scan(
116116
def f(t):
117117
return scan_func(t, dim=dim)
118118

119-
f_c = torch.compile(f, dynamic=False)
119+
f_c = torch.compile(f, dynamic=False, fullgraph=True)
120120

121121
for size in (32, 128, 512, 1024):
122122
f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}"
@@ -135,7 +135,7 @@ def f(t):
135135
def f_1d(t):
136136
return scan_func(t, dim=0)
137137

138-
f_1d_c = torch.compile(f_1d, dynamic=False)
138+
f_1d_c = torch.compile(f_1d, dynamic=False, fullgraph=True)
139139

140140
for size in (100, 10000, 1000000):
141141
f_1d.__name__ = f"{scan_func.__name__}-1d-{size}"
@@ -204,4 +204,5 @@ def main() -> None:
204204

205205

206206
if __name__ == "__main__":
207+
torch._dynamo.config.cache_size_limit = 2**16
207208
main()

0 commit comments

Comments
 (0)