|
32 | 32 |
|
33 | 33 | import torch
|
34 | 34 |
|
35 |
| -# exit cleanly if we are on a device that doesn't support torch.compile |
36 |
| -if torch.cuda.get_device_capability() < (7, 0): |
37 |
| - print("Skipping because torch.compile is not supported on this device.") |
38 |
| -else: |
39 |
| - @torch.compile() |
40 |
| - def fn(x, y): |
41 |
| - z = x + y |
42 |
| - return z + 2 |
| 35 | +@torch.compile() |
| 36 | +def fn(x, y): |
| 37 | + z = x + y |
| 38 | + return z + 2 |
43 | 39 |
|
44 | 40 |
|
45 |
| - inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) |
| 41 | +inputs = (torch.ones(2, 2), torch.zeros(2, 2)) |
46 | 42 |
|
47 | 43 |
|
48 | 44 | # print separator and reset dynamo
|
49 | 45 | # between each example
|
50 |
| - def separator(name): |
51 |
| - print(f"==================={name}=========================") |
52 |
| - torch._dynamo.reset() |
| 46 | +def separator(name): |
| 47 | + print(f"==================={name}=========================") |
| 48 | + torch._dynamo.reset() |
53 | 49 |
|
54 | 50 |
|
55 |
| - separator("Dynamo Tracing") |
| 51 | +separator("Dynamo Tracing") |
56 | 52 | # View dynamo tracing
|
57 | 53 | # TORCH_LOGS="+dynamo"
|
58 |
| - torch._logging.set_logs(dynamo=logging.DEBUG) |
59 |
| - fn(*inputs) |
| 54 | +torch._logging.set_logs(dynamo=logging.DEBUG) |
| 55 | +fn(*inputs) |
60 | 56 |
|
61 |
| - separator("Traced Graph") |
| 57 | +separator("Traced Graph") |
62 | 58 | # View traced graph
|
63 | 59 | # TORCH_LOGS="graph"
|
64 |
| - torch._logging.set_logs(graph=True) |
65 |
| - fn(*inputs) |
| 60 | +torch._logging.set_logs(graph=True) |
| 61 | +fn(*inputs) |
66 | 62 |
|
67 |
| - separator("Fusion Decisions") |
| 63 | +separator("Fusion Decisions") |
68 | 64 | # View fusion decisions
|
69 | 65 | # TORCH_LOGS="fusion"
|
70 |
| - torch._logging.set_logs(fusion=True) |
71 |
| - fn(*inputs) |
| 66 | +torch._logging.set_logs(fusion=True) |
| 67 | +fn(*inputs) |
72 | 68 |
|
73 |
| - separator("Output Code") |
| 69 | +separator("Output Code") |
74 | 70 | # View output code generated by inductor
|
75 | 71 | # TORCH_LOGS="output_code"
|
76 |
| - torch._logging.set_logs(output_code=True) |
77 |
| - fn(*inputs) |
| 72 | +torch._logging.set_logs(output_code=True) |
| 73 | +fn(*inputs) |
78 | 74 |
|
79 |
| - separator("") |
| 75 | +separator("") |
80 | 76 |
|
81 | 77 | ######################################################################
|
82 | 78 | # Conclusion
|
|
0 commit comments