Skip to content

Commit 81c251b

Browse files
committed
[PROTON-DEV] Add the instrumentation mode and clean up dependencies (#5742)
https://github.com/orgs/triton-lang/projects/6?pane=issue&itemId=95371680
1 parent ce156bf commit 81c251b

File tree

6 files changed

+44
-17
lines changed

6 files changed

+44
-17
lines changed

third_party/proton/proton/hook.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from .state import enter_state, exit_state
22
from .scope import enter_scope, exit_scope
3-
from triton.compiler import CompiledKernel, LazyDict
43

54
COMPUTE_METADATA_SCOPE_NAME = "__proton_launch_metadata"
65

76

87
class TritonHook:
8+
from triton.compiler import LazyDict
9+
910
flops_width = [8, 16, 32, 64]
10-
metrics = [f"flops{width}" for width in flops_width] + ["bytes"] + ["flops"]
11+
metrics = [f"flops{width}" for width in flops_width] + \
12+
["bytes"] + ["flops"]
1113

1214
@staticmethod
1315
def enter(lazy_dict: LazyDict) -> None:
@@ -23,12 +25,14 @@ def exit(lazy_dict: LazyDict) -> None:
2325

2426

2527
def register_triton_hook() -> None:
28+
from triton.compiler import CompiledKernel
2629
if CompiledKernel.launch_enter_hook is None:
2730
CompiledKernel.launch_enter_hook = TritonHook.enter
2831
CompiledKernel.launch_exit_hook = TritonHook.exit
2932

3033

3134
def unregister_triton_hook() -> None:
35+
from triton.compiler import CompiledKernel
3236
if CompiledKernel.launch_enter_hook == TritonHook.enter:
3337
CompiledKernel.launch_enter_hook = None
3438
CompiledKernel.launch_exit_hook = None

third_party/proton/proton/profile.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,25 @@ def _check_env(backend: str) -> None:
4343
)
4444

4545

46+
def _check_mode(backend: str, mode: Optional[str]) -> None:
47+
# TODO(Keren): Need a better mode registration mechanism
48+
backend_modes = {
49+
"cupti": [None, "pcsampling"],
50+
"roctracer": [None],
51+
"instrumentation": [None],
52+
}
53+
54+
if mode not in backend_modes[backend]:
55+
raise ValueError(f"Invalid mode {mode} for backend {backend}")
56+
57+
4658
def start(
4759
name: Optional[str] = None,
4860
*,
4961
context: Optional[str] = "shadow",
5062
data: Optional[str] = "tree",
5163
backend: Optional[str] = None,
64+
mode: Optional[str] = None,
5265
hook: Optional[str] = None,
5366
):
5467
"""
@@ -65,15 +78,20 @@ def start(
6578
Args:
6679
name (str, optional): The name (with path) of the profiling session.
6780
If not provided, the default name is "~/proton.hatchet".
68-
backend (str, optional): The backend to use for profiling.
69-
Available options are [None, "cupti", "cupti_pcsampling", "roctracer"].
70-
Defaults to None, which automatically selects the backend matching the current active runtime.
7181
context (str, optional): The context to use for profiling.
7282
Available options are ["shadow", "python"].
7383
Defaults to "shadow".
7484
data (str, optional): The data structure to use for profiling.
7585
Available options are ["tree"].
7686
Defaults to "tree".
87+
backend (str, optional): The backend to use for profiling.
88+
Available options are [None, "cupti", "roctracer", "instrumentation"].
89+
Defaults to None, which automatically selects the backend matching the current active runtime.
90+
mode (str, optional): The "mode" to use for profiling, which is specific to the backend.
91+
Defaults to None.
92+
For "cupti", available options are [None, "pcsampling"].
93+
For "roctracer", available options are [None].
94+
For "instrumentation", available options are [None].
7795
hook (str, optional): The hook to use for profiling.
7896
Available options are [None, "triton"].
7997
Defaults to None.
@@ -91,6 +109,7 @@ def start(
91109
backend = _select_backend()
92110

93111
_check_env(backend)
112+
_check_mode(backend, mode)
94113

95114
backend_path = _get_backend_default_path(backend)
96115

@@ -167,6 +186,7 @@ def _profiling(
167186
context: Optional[str] = "shadow",
168187
data: Optional[str] = "tree",
169188
backend: Optional[str] = None,
189+
mode: Optional[str] = None,
170190
hook: Optional[str] = None,
171191
):
172192
"""
@@ -181,7 +201,7 @@ def _profiling(
181201

182202
@functools.wraps(func)
183203
def wrapper(*args, **kwargs):
184-
session = start(name, context=context, data=data, backend=backend, hook=hook)
204+
session = start(name, context=context, data=data, backend=backend, mode=mode, hook=hook)
185205
ret = func(*args, **kwargs)
186206
deactivate(session)
187207
return ret
@@ -196,6 +216,7 @@ def profile(
196216
context: Optional[str] = "shadow",
197217
data: Optional[str] = "tree",
198218
backend: Optional[str] = None,
219+
mode: Optional[str] = None,
199220
hook: Optional[str] = None,
200221
):
201222
"""
@@ -218,9 +239,9 @@ def foo():
218239
if func is None:
219240
# It's being used with parentheses, so return a decorator
220241
def decorator(f):
221-
return _profiling(f, name=name, context=context, data=data, backend=backend, hook=hook)
242+
return _profiling(f, name=name, context=context, data=data, backend=backend, mode=mode, hook=hook)
222243

223244
return decorator
224245
else:
225246
# It's being used without parentheses, so apply the decorator directly
226-
return _profiling(func, name=name, context=context, data=data, backend=backend, hook=hook)
247+
return _profiling(func, name=name, context=context, data=data, backend=backend, mode=mode, hook=hook)

third_party/proton/proton/proton.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ def parse_arguments():
1616
""", formatter_class=argparse.RawTextHelpFormatter)
1717
parser.add_argument("-n", "--name", type=str, help="Name of the profiling session")
1818
parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None,
19-
choices=["cupti", "cupti_pcsampling", "roctracer"])
19+
choices=["cupti", "roctracer"])
2020
parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow",
2121
choices=["shadow", "python"])
22+
parser.add_argument("-m", "--mode", type=str, help="Profiling mode", default=None)
2223
parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"])
2324
parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"])
2425
parser.add_argument("-i", "--instrument", type=str, help="Instrumentation analysis type", default=None,

third_party/proton/test/test_profile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def foo(x, y, size: tl.constexpr):
256256
tl.store(y + offs, tl.load(x + offs))
257257

258258
temp_file = tmp_path / "test_pcsampling.hatchet"
259-
proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti_pcsampling")
259+
proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti", mode="pcsampling")
260260
with proton.scope("init"):
261261
x = torch.ones((1024, ), device="cuda", dtype=torch.float32)
262262
y = torch.zeros_like(x)

third_party/proton/tutorials/dynamic_net.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import triton.profiler as proton
66
import argparse
77

8-
mode = "torch"
8+
engine = "torch"
99

1010

1111
class DynamicNet(torch.nn.Module):
@@ -53,7 +53,7 @@ def run():
5353

5454
# Construct our model by instantiating the class defined above
5555
model = DynamicNet().to("cuda")
56-
if mode == "torchinductor":
56+
if engine == "torchinductor":
5757
model = torch.compile(model)
5858

5959
# Construct our loss function and an Optimizer. Training this strange model with
@@ -83,16 +83,17 @@ def run():
8383

8484
argparser = argparse.ArgumentParser()
8585
argparser.add_argument("--profile", action="store_true")
86-
argparser.add_argument("--mode", default="torch", choices=["torch", "torchinductor"])
86+
argparser.add_argument("--engine", default="torch", choices=["torch", "torchinductor"])
8787
argparser.add_argument("--context", default="shadow", choices=["shadow", "python"])
88-
argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer", "cupti_pcsampling"])
88+
argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer"])
89+
argparser.add_argument("--mode", default=None)
8990

9091
args = argparser.parse_args()
9192

92-
mode = args.mode
93+
engine = args.engine
9394

9495
if args.profile:
95-
func = proton.profile(run, name="dynamic_net", context=args.context, backend=args.backend)
96+
func = proton.profile(run, name="dynamic_net", context=args.context, backend=args.backend, mode=args.mode)
9697
else:
9798
func = run
9899

third_party/proton/tutorials/matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def perf(ms):
308308
if args.profile:
309309
if args.pcsampling:
310310
# proton-viewer -m num_samples/%,time/s ./matmul.hatchet
311-
proton.start("matmul", hook="triton", backend="cupti_pcsampling")
311+
proton.start("matmul", hook="triton", backend="cupti", mode="pcsampling")
312312
else:
313313
# proton-viewer -m tflop/s,time/s ./matmul.hatchet
314314
proton.start("matmul", hook="triton")

0 commit comments

Comments
 (0)