Skip to content

Commit 58225df

Browse files
committed
[PROTON-DEV] Improve profile interface (#5793)
1 parent 9131f4e commit 58225df

File tree

3 files changed

+26
-29
lines changed

3 files changed

+26
-29
lines changed

third_party/proton/csrc/Proton.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@ void initProton(pybind11::module &&m) {
1313
using ret = pybind11::return_value_policy;
1414
using namespace pybind11::literals;
1515

16-
m.def("start",
17-
[](const std::string &path, const std::string &contextSourceName,
18-
const std::string &dataName, const std::string &profilerName,
19-
const std::string &profilerPath, const std::string &mode) {
20-
auto sessionId = SessionManager::instance().addSession(
21-
path, profilerName, profilerPath, contextSourceName, dataName,
22-
mode);
23-
SessionManager::instance().activateSession(sessionId);
24-
return sessionId;
25-
});
16+
m.def(
17+
"start",
18+
[](const std::string &path, const std::string &contextSourceName,
19+
const std::string &dataName, const std::string &profilerName,
20+
const std::string &mode, const std::string &profilerPath) {
21+
auto sessionId = SessionManager::instance().addSession(
22+
path, profilerName, profilerPath, contextSourceName, dataName,
23+
mode);
24+
SessionManager::instance().activateSession(sessionId);
25+
return sessionId;
26+
},
27+
pybind11::arg("path"), pybind11::arg("contextSourceName"),
28+
pybind11::arg("dataName"), pybind11::arg("profilerName"),
29+
pybind11::arg("mode") = "", pybind11::arg("profilerPath") = "");
2630

2731
m.def("activate", [](size_t sessionId) {
2832
SessionManager::instance().activateSession(sessionId);

third_party/proton/proton/profile.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def _check_env(backend: str) -> None:
4646
def _check_mode(backend: str, mode: Optional[str]) -> None:
4747
# TODO(Keren): Need a better mode registration mechanism
4848
backend_modes = {
49-
"cupti": [None, "pcsampling"],
50-
"roctracer": [None],
51-
"instrumentation": [None],
49+
"cupti": ["", "pcsampling"],
50+
"roctracer": [""],
51+
"instrumentation": [""],
5252
}
5353

5454
if mode not in backend_modes[backend]:
@@ -102,28 +102,21 @@ def start(
102102
# Ignore the start() call if the script is run from the command line.
103103
return
104104

105-
if name is None:
106-
name = DEFAULT_PROFILE_NAME
107-
108-
if backend is None:
109-
backend = _select_backend()
105+
name = DEFAULT_PROFILE_NAME if name is None else name
106+
backend = _select_backend() if backend is None else backend
107+
mode = "" if mode is None else mode
110108

111109
_check_env(backend)
112110
_check_mode(backend, mode)
113111

114112
backend_path = _get_backend_default_path(backend)
115113

116-
_check_mode(backend, mode)
117-
118-
if mode is None:
119-
mode = ""
120-
121-
backend_path = _get_backend_default_path(backend)
122-
123114
set_profiling_on()
124-
if hook and hook == "triton":
115+
116+
if hook == "triton":
125117
register_triton_hook()
126-
return libproton.start(name, context, data, backend, backend_path, mode)
118+
119+
return libproton.start(name, context, data, backend, mode, backend_path)
127120

128121

129122
def activate(session: Optional[int] = None) -> None:

third_party/proton/test/test_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_op():
3232

3333
def test_session(tmp_path: pathlib.Path):
3434
temp_file = tmp_path / "test_session.hatchet"
35-
session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "", "")
35+
session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend())
3636
libproton.deactivate(session_id)
3737
libproton.activate(session_id)
3838
libproton.finalize(session_id, "hatchet")
@@ -42,7 +42,7 @@ def test_session(tmp_path: pathlib.Path):
4242

4343
def test_add_metrics(tmp_path: pathlib.Path):
4444
temp_file = tmp_path / "test_add_metrics.hatchet"
45-
libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "", "")
45+
libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend())
4646
id1 = libproton.record_scope()
4747
libproton.enter_scope(id1, "one")
4848
libproton.add_metrics(id1, {"a": 1.0, "b": 2.0})

0 commit comments

Comments
 (0)