|
31 | 31 | # variable setting is shown for each example.
|
32 | 32 |
|
33 | 33 | import torch
|
| 34 | +import sys |
34 | 35 |
|
35 |
| -# exit cleanly if we are on a device that doesn't support torch.compile |
36 |
| -if torch.cuda.get_device_capability() < (7, 0): |
37 |
| - print("Skipping because torch.compile is not supported on this device.") |
38 |
| -else: |
39 |
| - @torch.compile() |
40 |
| - def fn(x, y): |
41 |
| - z = x + y |
42 |
| - return z + 2 |
43 |
| - |
44 |
| - |
45 |
| - inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) |
46 | 36 |
|
| 37 | +def env_setup(): |
| 38 | + """Set up environment for running the example. Exit cleanly if CUDA is not available.""" |
| 39 | + if not torch.cuda.is_available(): |
| 40 | + print("CUDA is not available. Exiting.") |
| 41 | + sys.exit(0) |
| 42 | + |
| 43 | + if torch.cuda.get_device_capability() < (7, 0): |
| 44 | + print("Skipping because torch.compile is not supported on this device.") |
| 45 | + sys.exit(0) |
47 | 46 |
|
48 |
| -# print separator and reset dynamo |
49 |
| -# between each example |
50 |
| - def separator(name): |
51 |
| - print(f"==================={name}=========================") |
52 |
| - torch._dynamo.reset() |
53 | 47 |
|
| 48 | +def separator(name): |
| 49 | + """Print separator and reset dynamo between each example""" |
| 50 | + print(f"\n{'='*20} {name} {'='*20}") |
| 51 | + torch._dynamo.reset() |
54 | 52 |
|
55 |
| - separator("Dynamo Tracing") |
56 |
| -# View dynamo tracing |
57 |
| -# TORCH_LOGS="+dynamo" |
58 |
| - torch._logging.set_logs(dynamo=logging.DEBUG) |
59 |
| - fn(*inputs) |
60 | 53 |
|
61 |
| - separator("Traced Graph") |
62 |
| -# View traced graph |
63 |
| -# TORCH_LOGS="graph" |
64 |
| - torch._logging.set_logs(graph=True) |
65 |
| - fn(*inputs) |
| 54 | +def run_debugging_suite(): |
| 55 | + """Run the complete debugging suite with all logging options""" |
| 56 | + env_setup() |
66 | 57 |
|
67 |
| - separator("Fusion Decisions") |
68 |
| -# View fusion decisions |
69 |
| -# TORCH_LOGS="fusion" |
70 |
| - torch._logging.set_logs(fusion=True) |
71 |
| - fn(*inputs) |
72 |
| - |
73 |
| - separator("Output Code") |
74 |
| -# View output code generated by inductor |
75 |
| -# TORCH_LOGS="output_code" |
76 |
| - torch._logging.set_logs(output_code=True) |
77 |
| - fn(*inputs) |
| 58 | + @torch.compile() |
| 59 | + def fn(x, y): |
| 60 | + z = x + y |
| 61 | + return z + 2 |
78 | 62 |
|
79 |
| - separator("") |
| 63 | + inputs = ( |
| 64 | + torch.ones(2, 2, device="cuda"), |
| 65 | + torch.zeros(2, 2, device="cuda") |
| 66 | + ) |
| 67 | + |
| 68 | + logging_scenarios = [ |
| 69 | + # View dynamo tracing; TORCH_LOGS="+dynamo" |
| 70 | + ("Dynamo Tracing", {"dynamo": logging.DEBUG}), |
| 71 | + |
| 72 | + # View traced graph; TORCH_LOGS="graph" |
| 73 | + ("Traced Graph", {"graph": True}), |
| 74 | + |
| 75 | + # View fusion decisions; TORCH_LOGS="fusion" |
| 76 | + ("Fusion Decisions", {"fusion": True}), |
| 77 | + |
| 78 | + # View output code generated by inductor; TORCH_LOGS="output_code" |
| 79 | + ("Output Code", {"output_code": True}) |
| 80 | + ] |
| 81 | + |
| 82 | + for name, log_config in logging_scenarios: |
| 83 | + separator(name) |
| 84 | + torch._logging.set_logs(**log_config) |
| 85 | + try: |
| 86 | + result = fn(*inputs) |
| 87 | + print(f"Function output shape: {result.shape}") |
| 88 | + except Exception as e: |
| 89 | + print(f"Error during {name}: {str(e)}") |
| 90 | + |
| 91 | +if __name__ == "__main__": |
| 92 | + run_debugging_suite() |
80 | 93 |
|
81 | 94 | ######################################################################
|
82 | 95 | # Conclusion
|
|
0 commit comments