Skip to content

Commit 32bb6f8

Browse files
exclamafortepytorchmergebot
authored andcommitted
Make sure that benchmark_harness is set before running (pytorch#145532)
Running torch compile with these options causes an error, because the benchmark code isn't generated but is still called. ``` options={'profile_bandwidth_output': 'foo', 'benchmark_harness': False} ``` Pull Request resolved: pytorch#145532 Approved by: https://github.com/eellison
1 parent 25ca05e commit 32bb6f8

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,6 +2089,28 @@ def forward(self, x):
20892089
x = torch.randn(5, device=self.device)
20902090
self.check_model(Model(self.device), (x,))
20912091

2092+
def test_profile_benchmark_harness(self):
2093+
batch_size = 32
2094+
seq_length = 50
2095+
hidden_size = 768
2096+
2097+
def create_test_fn():
2098+
def test_fn():
2099+
inp = torch.randn(
2100+
batch_size, seq_length, hidden_size, device=self.device
2101+
)
2102+
weight = torch.randn(hidden_size, hidden_size, device=self.device)
2103+
matmul_output = inp @ weight
2104+
torch.nn.LayerNorm(hidden_size, device=self.device)(matmul_output)
2105+
return True
2106+
2107+
return test_fn
2108+
2109+
fn = torch.compile(
2110+
options={"profile_bandwidth_output": "foo", "benchmark_harness": False}
2111+
)(create_test_fn())
2112+
fn()
2113+
20922114
def test_with_profiler(self):
20932115
class Model(torch.nn.Module):
20942116
def __init__(self) -> None:

torch/_inductor/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2044,7 +2044,7 @@ def _compile_to_module(self) -> ModuleType:
20442044
self.cache_path = path
20452045
self.cache_linemap = linemap # type: ignore[assignment]
20462046

2047-
if config.profile_bandwidth_output:
2047+
if config.benchmark_harness and config.profile_bandwidth_output:
20482048
# run the inputs code gen to get the bandwidth info
20492049
mod.benchmark_compiled_module(times=1, repeat=1)
20502050
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029

0 commit comments

Comments
 (0)