Skip to content

Commit 134d92a

Browse files
committed
Modify example code to fix torch logs doc
1 parent 4ed884d commit 134d92a

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

recipes_source/torch_logs.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,52 +31,65 @@
3131
# variable setting is shown for each example.
3232

3333
import torch
34+
import sys
3435

35-
# exit cleanly if we are on a device that doesn't support torch.compile
36-
if torch.cuda.get_device_capability() < (7, 0):
37-
print("Skipping because torch.compile is not supported on this device.")
38-
else:
39-
@torch.compile()
40-
def fn(x, y):
41-
z = x + y
42-
return z + 2
43-
44-
45-
inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
4636

37+
def env_setup():
38+
"""Set up environment for running the example. Exit cleanly if CUDA is not available."""
39+
if not torch.cuda.is_available():
40+
print("CUDA is not available. Exiting.")
41+
sys.exit(0)
42+
43+
if torch.cuda.get_device_capability() < (7, 0):
44+
print("Skipping because torch.compile is not supported on this device.")
45+
sys.exit(0)
4746

48-
# print separator and reset dynamo
49-
# between each example
50-
def separator(name):
51-
print(f"==================={name}=========================")
52-
torch._dynamo.reset()
5347

48+
def separator(name):
49+
"""Print separator and reset dynamo between each example"""
50+
print(f"\n{'='*20} {name} {'='*20}")
51+
torch._dynamo.reset()
5452

55-
separator("Dynamo Tracing")
56-
# View dynamo tracing
57-
# TORCH_LOGS="+dynamo"
58-
torch._logging.set_logs(dynamo=logging.DEBUG)
59-
fn(*inputs)
6053

61-
separator("Traced Graph")
62-
# View traced graph
63-
# TORCH_LOGS="graph"
64-
torch._logging.set_logs(graph=True)
65-
fn(*inputs)
54+
def run_debugging_suite():
55+
"""Run the complete debugging suite with all logging options"""
56+
env_setup()
6657

67-
separator("Fusion Decisions")
68-
# View fusion decisions
69-
# TORCH_LOGS="fusion"
70-
torch._logging.set_logs(fusion=True)
71-
fn(*inputs)
72-
73-
separator("Output Code")
74-
# View output code generated by inductor
75-
# TORCH_LOGS="output_code"
76-
torch._logging.set_logs(output_code=True)
77-
fn(*inputs)
58+
@torch.compile()
59+
def fn(x, y):
60+
z = x + y
61+
return z + 2
7862

79-
separator("")
63+
inputs = (
64+
torch.ones(2, 2, device="cuda"),
65+
torch.zeros(2, 2, device="cuda")
66+
)
67+
68+
logging_scenarios = [
69+
# View dynamo tracing; TORCH_LOGS="+dynamo"
70+
("Dynamo Tracing", {"dynamo": logging.DEBUG}),
71+
72+
# View traced graph; TORCH_LOGS="graph"
73+
("Traced Graph", {"graph": True}),
74+
75+
# View fusion decisions; TORCH_LOGS="fusion"
76+
("Fusion Decisions", {"fusion": True}),
77+
78+
# View output code generated by inductor; TORCH_LOGS="output_code"
79+
("Output Code", {"output_code": True})
80+
]
81+
82+
for name, log_config in logging_scenarios:
83+
separator(name)
84+
torch._logging.set_logs(**log_config)
85+
try:
86+
result = fn(*inputs)
87+
print(f"Function output shape: {result.shape}")
88+
except Exception as e:
89+
print(f"Error during {name}: {str(e)}")
90+
91+
if __name__ == "__main__":
92+
run_debugging_suite()
8093

8194
######################################################################
8295
# Conclusion

0 commit comments

Comments
 (0)