Skip to content

Commit 9131f4e

Browse files
fywkevinYuanwei Fang
authored andcommitted
[Proton] Fixed pc sampling error (#5787)
Fixed PC sampling error in proton when we have `mode`. --------- Co-authored-by: Yuanwei Fang <[email protected]>
1 parent 48aed55 commit 9131f4e

File tree

5 files changed

+28
-16
lines changed

5 files changed

+28
-16
lines changed

third_party/proton/csrc/Proton.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ void initProton(pybind11::module &&m) {
1616
m.def("start",
1717
[](const std::string &path, const std::string &contextSourceName,
1818
const std::string &dataName, const std::string &profilerName,
19-
const std::string &profilerPath) {
19+
const std::string &profilerPath, const std::string &mode) {
2020
auto sessionId = SessionManager::instance().addSession(
21-
path, profilerName, profilerPath, contextSourceName, dataName);
21+
path, profilerName, profilerPath, contextSourceName, dataName,
22+
mode);
2223
SessionManager::instance().activateSession(sessionId);
2324
return sessionId;
2425
});

third_party/proton/csrc/include/Session/Session.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class SessionManager : public Singleton<SessionManager> {
7474
size_t addSession(const std::string &path, const std::string &profilerName,
7575
const std::string &profilerPath,
7676
const std::string &contextSourceName,
77-
const std::string &dataName);
77+
const std::string &dataName, const std::string &mode);
7878

7979
void finalizeSession(size_t sessionId, OutputFormat outputFormat);
8080

@@ -106,7 +106,8 @@ class SessionManager : public Singleton<SessionManager> {
106106
const std::string &profilerName,
107107
const std::string &profilerPath,
108108
const std::string &contextSourceName,
109-
const std::string &dataName);
109+
const std::string &dataName,
110+
const std::string &mode);
110111

111112
void activateSessionImpl(size_t sessionId);
112113

third_party/proton/csrc/lib/Session/Session.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
namespace proton {
1010

1111
namespace {
12-
Profiler *getProfiler(const std::string &name, const std::string &path) {
12+
Profiler *getProfiler(const std::string &name, const std::string &path,
13+
const std::string &mode) {
1314
if (proton::toLower(name) == "cupti") {
14-
return &CuptiProfiler::instance().setLibPath(path);
15-
}
16-
if (proton::toLower(name) == "cupti_pcsampling") {
17-
return &CuptiProfiler::instance().setLibPath(path).enablePCSampling();
15+
auto *profiler = &CuptiProfiler::instance();
16+
profiler->setLibPath(path);
17+
if (proton::toLower(mode) == "pcsampling")
18+
profiler->enablePCSampling();
19+
return profiler;
1820
}
1921
if (proton::toLower(name) == "roctracer") {
2022
return &RoctracerProfiler::instance();
@@ -72,8 +74,8 @@ void Session::finalize(OutputFormat outputFormat) {
7274
std::unique_ptr<Session> SessionManager::makeSession(
7375
size_t id, const std::string &path, const std::string &profilerName,
7476
const std::string &profilerPath, const std::string &contextSourceName,
75-
const std::string &dataName) {
76-
auto profiler = getProfiler(profilerName, profilerPath);
77+
const std::string &dataName, const std::string &mode) {
78+
auto profiler = getProfiler(profilerName, profilerPath, mode);
7779
auto contextSource = makeContextSource(contextSourceName);
7880
auto data = makeData(dataName, path, contextSource.get());
7981
auto *session = new Session(id, path, profiler, std::move(contextSource),
@@ -142,7 +144,8 @@ size_t SessionManager::addSession(const std::string &path,
142144
const std::string &profilerName,
143145
const std::string &profilerPath,
144146
const std::string &contextSourceName,
145-
const std::string &dataName) {
147+
const std::string &dataName,
148+
const std::string &mode) {
146149
std::lock_guard<std::mutex> lock(mutex);
147150
if (hasSession(path)) {
148151
auto sessionId = getSessionId(path);
@@ -152,7 +155,7 @@ size_t SessionManager::addSession(const std::string &path,
152155
auto sessionId = nextSessionId++;
153156
sessionPaths[path] = sessionId;
154157
sessions[sessionId] = makeSession(sessionId, path, profilerName, profilerPath,
155-
contextSourceName, dataName);
158+
contextSourceName, dataName, mode);
156159
return sessionId;
157160
}
158161

third_party/proton/proton/profile.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,17 @@ def start(
113113

114114
backend_path = _get_backend_default_path(backend)
115115

116+
_check_mode(backend, mode)
117+
118+
if mode is None:
119+
mode = ""
120+
121+
backend_path = _get_backend_default_path(backend)
122+
116123
set_profiling_on()
117124
if hook and hook == "triton":
118125
register_triton_hook()
119-
return libproton.start(name, context, data, backend, backend_path)
126+
return libproton.start(name, context, data, backend, backend_path, mode)
120127

121128

122129
def activate(session: Optional[int] = None) -> None:

third_party/proton/test/test_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_op():
3232

3333
def test_session(tmp_path: pathlib.Path):
3434
temp_file = tmp_path / "test_session.hatchet"
35-
session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "")
35+
session_id = libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "", "")
3636
libproton.deactivate(session_id)
3737
libproton.activate(session_id)
3838
libproton.finalize(session_id, "hatchet")
@@ -42,7 +42,7 @@ def test_session(tmp_path: pathlib.Path):
4242

4343
def test_add_metrics(tmp_path: pathlib.Path):
4444
temp_file = tmp_path / "test_add_metrics.hatchet"
45-
libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "")
45+
libproton.start(str(temp_file.with_suffix("")), "shadow", "tree", _select_backend(), "", "")
4646
id1 = libproton.record_scope()
4747
libproton.enter_scope(id1, "one")
4848
libproton.add_metrics(id1, {"a": 1.0, "b": 2.0})

0 commit comments

Comments
 (0)