Skip to content

Commit be81f0a

Browse files
authored
[PROTON] Fix proton's support for multiple profiling sessions (#5140)
Also change the default behavior or `activate` and `deactivate` to apply for all sessions but not only session 0
1 parent 32b0fce commit be81f0a

File tree

7 files changed

+90
-15
lines changed

7 files changed

+90
-15
lines changed

third_party/proton/csrc/Proton.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ void initProton(pybind11::module &&m) {
2626
SessionManager::instance().activateSession(sessionId);
2727
});
2828

29+
m.def("activate_all",
30+
[]() { SessionManager::instance().activateAllSessions(); });
31+
2932
m.def("deactivate", [](size_t sessionId) {
3033
SessionManager::instance().deactivateSession(sessionId);
3134
});
3235

36+
m.def("deactivate_all",
37+
[]() { SessionManager::instance().deactivateAllSessions(); });
38+
3339
m.def("finalize", [](size_t sessionId, const std::string &outputFormat) {
3440
auto outputFormatEnum = parseOutputFormat(outputFormat);
3541
SessionManager::instance().finalizeSession(sessionId, outputFormatEnum);

third_party/proton/csrc/include/Profiler/Profiler.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ class Profiler {
2727
/// If the profiler is already started, this function does nothing.
2828
Profiler *start() {
2929
std::unique_lock<std::shared_mutex> lock(mutex);
30-
if (this->isInitialized)
31-
return this;
32-
this->doStart();
33-
this->isInitialized = true;
30+
if (this->initializedCount == 0)
31+
this->doStart();
32+
this->initializedCount++;
3433
return this;
3534
}
3635

@@ -45,10 +44,11 @@ class Profiler {
4544
/// Stop the profiler.
4645
Profiler *stop() {
4746
std::unique_lock<std::shared_mutex> lock(mutex);
48-
if (!this->isInitialized)
47+
if (this->initializedCount == 0)
4948
return this;
50-
this->doStop();
51-
this->isInitialized = false;
49+
this->initializedCount--;
50+
if (this->initializedCount == 0)
51+
this->doStop();
5252
return this;
5353
}
5454

@@ -80,7 +80,9 @@ class Profiler {
8080

8181
mutable std::shared_mutex mutex;
8282
std::set<Data *> dataSet;
83-
bool isInitialized{false};
83+
84+
private:
85+
int initializedCount{};
8486
};
8587

8688
} // namespace proton

third_party/proton/csrc/include/Session/Session.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,12 @@ class SessionManager : public Singleton<SessionManager> {
7777

7878
void activateSession(size_t sessionId);
7979

80+
void activateAllSessions();
81+
8082
void deactivateSession(size_t sessionId);
8183

84+
void deactivateAllSessions();
85+
8286
void enterScope(const Scope &scope);
8387

8488
void exitScope(const Scope &scope);

third_party/proton/csrc/lib/Session/Session.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,25 @@ void SessionManager::activateSession(size_t sessionId) {
8484
activateSessionImpl(sessionId);
8585
}
8686

87+
void SessionManager::activateAllSessions() {
88+
std::unique_lock<std::shared_mutex> lock(mutex);
89+
for (auto iter : sessionActive) {
90+
activateSessionImpl(iter.first);
91+
}
92+
}
93+
8794
void SessionManager::deactivateSession(size_t sessionId) {
8895
std::unique_lock<std::shared_mutex> lock(mutex);
8996
deActivateSessionImpl(sessionId);
9097
}
9198

99+
void SessionManager::deactivateAllSessions() {
100+
std::unique_lock<std::shared_mutex> lock(mutex);
101+
for (auto iter : sessionActive) {
102+
deActivateSessionImpl(iter.first);
103+
}
104+
}
105+
92106
void SessionManager::activateSessionImpl(size_t sessionId) {
93107
throwIfSessionNotInitialized(sessions, sessionId);
94108
if (sessionActive[sessionId])
@@ -116,6 +130,7 @@ void SessionManager::removeSession(size_t sessionId) {
116130
}
117131
auto path = sessions[sessionId]->path;
118132
sessionPaths.erase(path);
133+
sessionActive.erase(sessionId);
119134
sessions.erase(sessionId);
120135
}
121136

third_party/proton/proton/profile.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,36 +85,42 @@ def start(
8585
return libproton.start(name, context, data, backend)
8686

8787

88-
def activate(session: Optional[int] = 0) -> None:
88+
def activate(session: Optional[int] = None) -> None:
8989
"""
9090
Activate the specified session.
9191
The profiling session will be active and data will be recorded.
9292
9393
Args:
94-
session (int): The session ID of the profiling session. Defaults to 0 (the first session started.)
94+
session (int): The session ID of the profiling session. Defaults to None (all sessions)
9595
9696
Returns:
9797
None
9898
"""
9999
if is_command_line() and session != 0:
100100
raise ValueError("Only one session can be activated when running from the command line.")
101-
libproton.activate(session)
101+
if session is None:
102+
libproton.activate_all()
103+
else:
104+
libproton.activate(session)
102105

103106

104-
def deactivate(session: Optional[int] = 0) -> None:
107+
def deactivate(session: Optional[int] = None) -> None:
105108
"""
106109
Stop the specified session.
107110
The profiling session's data will still be in the memory, but no more data will be recorded.
108111
109112
Args:
110-
session (int): The session ID of the profiling session. Defaults to 0 (the first session started.)
113+
session (int): The session ID of the profiling session. Defaults to None (all sessions)
111114
112115
Returns:
113116
None
114117
"""
115118
if is_command_line() and session != 0:
116119
raise ValueError("Only one session can be deactivated when running from the command line.")
117-
libproton.deactivate(session)
120+
if session is None:
121+
libproton.deactivate_all()
122+
else:
123+
libproton.deactivate(session)
118124

119125

120126
def finalize(session: Optional[int] = None, output_format: str = "hatchet") -> None:

third_party/proton/test/test_api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pathlib
44

55

6-
def test_profile(tmp_path: pathlib.Path):
6+
def test_profile_single_session(tmp_path: pathlib.Path):
77
temp_file0 = tmp_path / "test_profile0.hatchet"
88
session_id0 = proton.start(str(temp_file0.with_suffix("")))
99
proton.activate()
@@ -29,6 +29,28 @@ def test_profile(tmp_path: pathlib.Path):
2929
pathlib.Path("test.hatchet").unlink()
3030

3131

32+
def test_profile_multiple_sessions(tmp_path: pathlib.Path):
33+
temp_file0 = tmp_path / "test_profile0.hatchet"
34+
proton.start(str(temp_file0.with_suffix("")))
35+
temp_file1 = tmp_path / "test_profile1.hatchet"
36+
proton.start(str(temp_file1.with_suffix("")))
37+
proton.activate()
38+
proton.deactivate()
39+
proton.finalize()
40+
assert temp_file0.exists()
41+
assert temp_file1.exists()
42+
43+
temp_file2 = tmp_path / "test_profile2.hatchet"
44+
session_id2 = proton.start(str(temp_file2.with_suffix("")))
45+
temp_file3 = tmp_path / "test_profile3.hatchet"
46+
session_id3 = proton.start(str(temp_file3.with_suffix("")))
47+
proton.deactivate(session_id2)
48+
proton.deactivate(session_id3)
49+
proton.finalize()
50+
assert temp_file2.exists()
51+
assert temp_file3.exists()
52+
53+
3254
def test_profile_decorator(tmp_path: pathlib.Path):
3355
temp_file = tmp_path / "test_profile_decorator.hatchet"
3456

third_party/proton/test/test_profile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,23 @@ def test_deactivate(tmp_path: pathlib.Path):
257257
assert "device_id" not in data[0]["metrics"]
258258
assert len(data[0]["children"]) == 1
259259
assert "device_id" in data[0]["children"][0]["metrics"]
260+
261+
262+
def test_multiple_sessions(tmp_path: pathlib.Path):
263+
temp_file0 = tmp_path / "test_multiple_sessions0.hatchet"
264+
temp_file1 = tmp_path / "test_multiple_sessions1.hatchet"
265+
session_id0 = proton.start(str(temp_file0.with_suffix("")))
266+
session_id1 = proton.start(str(temp_file1.with_suffix("")))
267+
torch.randn((10, 10), device="cuda")
268+
torch.randn((10, 10), device="cuda")
269+
proton.deactivate(session_id0)
270+
proton.finalize(session_id0)
271+
torch.randn((10, 10), device="cuda")
272+
proton.finalize(session_id1)
273+
# kernel has been invokved twice in session 0 and three times in session 1
274+
with temp_file0.open() as f:
275+
data = json.load(f)
276+
assert int(data[0]["children"][0]["metrics"]["count"]) == 2
277+
with temp_file1.open() as f:
278+
data = json.load(f)
279+
assert int(data[0]["children"][0]["metrics"]["count"]) == 3

0 commit comments

Comments
 (0)