diff --git a/recipes_source/torch_logs.py b/recipes_source/torch_logs.py index b5c3f0bd8ac..407921921a3 100644 --- a/recipes_source/torch_logs.py +++ b/recipes_source/torch_logs.py @@ -32,51 +32,47 @@ import torch -# exit cleanly if we are on a device that doesn't support torch.compile -if torch.cuda.get_device_capability() < (7, 0): - print("Skipping because torch.compile is not supported on this device.") -else: - @torch.compile() - def fn(x, y): - z = x + y - return z + 2 +@torch.compile() +def fn(x, y): + z = x + y + return z + 2 - inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) +inputs = (torch.ones(2, 2), torch.zeros(2, 2)) # print separator and reset dynamo # between each example - def separator(name): - print(f"==================={name}=========================") - torch._dynamo.reset() +def separator(name): + print(f"==================={name}=========================") + torch._dynamo.reset() - separator("Dynamo Tracing") +separator("Dynamo Tracing") # View dynamo tracing # TORCH_LOGS="+dynamo" - torch._logging.set_logs(dynamo=logging.DEBUG) - fn(*inputs) +torch._logging.set_logs(dynamo=logging.DEBUG) +fn(*inputs) - separator("Traced Graph") +separator("Traced Graph") # View traced graph # TORCH_LOGS="graph" - torch._logging.set_logs(graph=True) - fn(*inputs) +torch._logging.set_logs(graph=True) +fn(*inputs) - separator("Fusion Decisions") +separator("Fusion Decisions") # View fusion decisions # TORCH_LOGS="fusion" - torch._logging.set_logs(fusion=True) - fn(*inputs) +torch._logging.set_logs(fusion=True) +fn(*inputs) - separator("Output Code") +separator("Output Code") # View output code generated by inductor # TORCH_LOGS="output_code" - torch._logging.set_logs(output_code=True) - fn(*inputs) +torch._logging.set_logs(output_code=True) +fn(*inputs) - separator("") +separator("") ###################################################################### # Conclusion