Skip to content

Commit 6d3f3da

Browse files
authored
[PROTON] Implement the get_data api to export profile directly in Python (#8928)
1 parent 3d33f74 commit 6d3f3da

File tree

12 files changed

+236
-140
lines changed

12 files changed

+236
-140
lines changed

third_party/proton/csrc/Proton.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ static void initProton(pybind11::module &&m) {
165165
m.def("get_context_depth", [](size_t sessionId) {
166166
return SessionManager::instance().getContextDepth(sessionId);
167167
});
168+
169+
m.def(
170+
"get_data",
171+
[](size_t sessionId) {
172+
return SessionManager::instance().getData(sessionId);
173+
},
174+
pybind11::arg("sessionId"));
168175
}
169176

170177
PYBIND11_MODULE(libproton, m) {

third_party/proton/csrc/include/Data/Data.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class Data : public ScopeInterface {
4343
/// Clear all caching data.
4444
virtual void clear() = 0;
4545

46+
/// To Json
47+
virtual std::string toJsonString() const = 0;
48+
4649
/// Dump the data to the given output format.
4750
void dump(const std::string &outputFormat);
4851

third_party/proton/csrc/include/Data/TraceData.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class TraceData : public Data {
2222
addMetrics(size_t scopeId,
2323
const std::map<std::string, MetricValueType> &metrics) override;
2424

25+
std::string toJsonString() const override;
26+
2527
void clear() override;
2628

2729
class Trace;

third_party/proton/csrc/include/Data/TreeData.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33

44
#include "Context/Context.h"
55
#include "Data.h"
6+
#include "nlohmann/json.hpp"
67
#include <stdexcept>
8+
#include <string>
79
#include <unordered_map>
810

11+
using json = nlohmann::json;
12+
913
namespace proton {
1014

1115
class TreeData : public Data {
@@ -25,6 +29,8 @@ class TreeData : public Data {
2529
addMetrics(size_t scopeId,
2630
const std::map<std::string, MetricValueType> &metrics) override;
2731

32+
std::string toJsonString() const override;
33+
2834
void clear() override;
2935

3036
protected:
@@ -34,6 +40,12 @@ class TreeData : public Data {
3440
void exitScope(const Scope &scope) override;
3541

3642
private:
43+
// `tree` and `scopeIdToContextId` can be accessed by both the user thread and
44+
// the background threads concurrently, so methods that access them should be
45+
// protected by a (shared) mutex.
46+
class Tree;
47+
json buildHatchetJson(TreeData::Tree *tree) const;
48+
3749
void dumpHatchet(std::ostream &os) const;
3850

3951
void doDump(std::ostream &os, OutputFormat outputFormat) const override;
@@ -42,10 +54,6 @@ class TreeData : public Data {
4254
return OutputFormat::Hatchet;
4355
}
4456

45-
// `tree` and `scopeIdToContextId` can be accessed by both the user thread and
46-
// the background threads concurrently, so methods that access them should be
47-
// protected by a (shared) mutex.
48-
class Tree;
4957
std::unique_ptr<Tree> tree;
5058
// ScopeId -> ContextId
5159
std::unordered_map<size_t, size_t> scopeIdToContextId;

third_party/proton/csrc/include/Session/Session.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class SessionManager : public Singleton<SessionManager> {
9393

9494
size_t getContextDepth(size_t sessionId);
9595

96+
std::string getData(size_t sessionId);
97+
9698
void enterScope(const Scope &scope);
9799

98100
void exitScope(const Scope &scope);

third_party/proton/csrc/lib/Data/TraceData.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ void TraceData::addMetrics(
229229
}
230230
}
231231

232+
std::string TraceData::toJsonString() const { throw NotImplemented(); }
233+
232234
void TraceData::clear() {
233235
std::unique_lock<std::shared_mutex> lock(mutex);
234236
scopeIdToContextId.clear();

third_party/proton/csrc/lib/Data/TreeData.cpp

Lines changed: 135 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22
#include "Context/Context.h"
33
#include "Data/Metric.h"
44
#include "Device.h"
5-
#include "nlohmann/json.hpp"
65

76
#include <limits>
87
#include <map>
98
#include <mutex>
109
#include <set>
1110
#include <stdexcept>
1211

13-
using json = nlohmann::json;
14-
1512
namespace proton {
1613

1714
class TreeData::Tree {
@@ -106,6 +103,134 @@ class TreeData::Tree {
106103
std::map<size_t, TreeNode> treeNodeMap;
107104
};
108105

106+
json TreeData::buildHatchetJson(TreeData::Tree *tree) const {
107+
std::map<size_t, json *> jsonNodes;
108+
json output = json::array();
109+
output.push_back(json::object());
110+
jsonNodes[TreeData::Tree::TreeNode::RootId] = &(output.back());
111+
std::set<std::string> inclusiveValueNames;
112+
std::map<uint64_t, std::set<uint64_t>> deviceIds;
113+
tree->template walk<TreeData::Tree::WalkPolicy::PreOrder>(
114+
[&](TreeData::Tree::TreeNode &treeNode) {
115+
const auto contextName = treeNode.name;
116+
auto contextId = treeNode.id;
117+
json *jsonNode = jsonNodes[contextId];
118+
(*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}};
119+
(*jsonNode)["metrics"] = json::object();
120+
for (auto [metricKind, metric] : treeNode.metrics) {
121+
if (metricKind == MetricKind::Kernel) {
122+
std::shared_ptr<KernelMetric> kernelMetric =
123+
std::dynamic_pointer_cast<KernelMetric>(metric);
124+
uint64_t duration = std::get<uint64_t>(
125+
kernelMetric->getValue(KernelMetric::Duration));
126+
uint64_t invocations = std::get<uint64_t>(
127+
kernelMetric->getValue(KernelMetric::Invocations));
128+
uint64_t deviceId = std::get<uint64_t>(
129+
kernelMetric->getValue(KernelMetric::DeviceId));
130+
uint64_t deviceType = std::get<uint64_t>(
131+
kernelMetric->getValue(KernelMetric::DeviceType));
132+
std::string deviceTypeName =
133+
getDeviceTypeString(static_cast<DeviceType>(deviceType));
134+
(*jsonNode)["metrics"]
135+
[kernelMetric->getValueName(KernelMetric::Duration)] =
136+
duration;
137+
(*jsonNode)["metrics"]
138+
[kernelMetric->getValueName(KernelMetric::Invocations)] =
139+
invocations;
140+
(*jsonNode)["metrics"]
141+
[kernelMetric->getValueName(KernelMetric::DeviceId)] =
142+
std::to_string(deviceId);
143+
(*jsonNode)["metrics"]
144+
[kernelMetric->getValueName(KernelMetric::DeviceType)] =
145+
deviceTypeName;
146+
inclusiveValueNames.insert(
147+
kernelMetric->getValueName(KernelMetric::Duration));
148+
inclusiveValueNames.insert(
149+
kernelMetric->getValueName(KernelMetric::Invocations));
150+
deviceIds[deviceType].insert(deviceId);
151+
} else if (metricKind == MetricKind::PCSampling) {
152+
auto pcSamplingMetric =
153+
std::dynamic_pointer_cast<PCSamplingMetric>(metric);
154+
for (size_t i = 0; i < PCSamplingMetric::Count; i++) {
155+
auto valueName = pcSamplingMetric->getValueName(i);
156+
inclusiveValueNames.insert(valueName);
157+
std::visit(
158+
[&](auto &&value) {
159+
(*jsonNode)["metrics"][valueName] = value;
160+
},
161+
pcSamplingMetric->getValues()[i]);
162+
}
163+
} else if (metricKind == MetricKind::Cycle) {
164+
auto cycleMetric = std::dynamic_pointer_cast<CycleMetric>(metric);
165+
uint64_t duration = std::get<uint64_t>(
166+
cycleMetric->getValue(CycleMetric::Duration));
167+
double normalizedDuration = std::get<double>(
168+
cycleMetric->getValue(CycleMetric::NormalizedDuration));
169+
uint64_t deviceId = std::get<uint64_t>(
170+
cycleMetric->getValue(CycleMetric::DeviceId));
171+
uint64_t deviceType = std::get<uint64_t>(
172+
cycleMetric->getValue(CycleMetric::DeviceType));
173+
(*jsonNode)["metrics"]
174+
[cycleMetric->getValueName(CycleMetric::Duration)] =
175+
duration;
176+
(*jsonNode)["metrics"][cycleMetric->getValueName(
177+
CycleMetric::NormalizedDuration)] = normalizedDuration;
178+
(*jsonNode)["metrics"]
179+
[cycleMetric->getValueName(CycleMetric::DeviceId)] =
180+
std::to_string(deviceId);
181+
(*jsonNode)["metrics"]
182+
[cycleMetric->getValueName(CycleMetric::DeviceType)] =
183+
std::to_string(deviceType);
184+
deviceIds[deviceType].insert(deviceId);
185+
} else if (metricKind == MetricKind::Flexible) {
186+
// Flexible metrics are handled in a different way
187+
} else {
188+
throw std::runtime_error("MetricKind not supported");
189+
}
190+
}
191+
for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) {
192+
auto valueName = flexibleMetric.getValueName(0);
193+
if (!flexibleMetric.isExclusive(0))
194+
inclusiveValueNames.insert(valueName);
195+
std::visit(
196+
[&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; },
197+
flexibleMetric.getValues()[0]);
198+
}
199+
(*jsonNode)["children"] = json::array();
200+
auto children = treeNode.children;
201+
for (auto _ : children) {
202+
(*jsonNode)["children"].push_back(json::object());
203+
}
204+
auto idx = 0;
205+
for (auto child : children) {
206+
auto [index, childId] = child;
207+
jsonNodes[childId] = &(*jsonNode)["children"][idx];
208+
idx++;
209+
}
210+
});
211+
for (auto valueName : inclusiveValueNames) {
212+
output[TreeData::Tree::TreeNode::RootId]["metrics"][valueName] = 0;
213+
}
214+
output.push_back(json::object());
215+
auto &deviceJson = output.back();
216+
for (auto [deviceType, deviceIdSet] : deviceIds) {
217+
auto deviceTypeName =
218+
getDeviceTypeString(static_cast<DeviceType>(deviceType));
219+
if (!deviceJson.contains(deviceTypeName))
220+
deviceJson[deviceTypeName] = json::object();
221+
for (auto deviceId : deviceIdSet) {
222+
Device device = getDevice(static_cast<DeviceType>(deviceType), deviceId);
223+
deviceJson[deviceTypeName][std::to_string(deviceId)] = {
224+
{"clock_rate", device.clockRate},
225+
{"memory_clock_rate", device.memoryClockRate},
226+
{"bus_width", device.busWidth},
227+
{"arch", device.arch},
228+
{"num_sms", device.numSms}};
229+
}
230+
}
231+
return output;
232+
}
233+
109234
void TreeData::enterScope(const Scope &scope) {
110235
// enterOp and addMetric maybe called from different threads
111236
std::unique_lock<std::shared_mutex> lock(mutex);
@@ -201,136 +326,16 @@ void TreeData::clear() {
201326
}
202327

203328
void TreeData::dumpHatchet(std::ostream &os) const {
204-
std::map<size_t, json *> jsonNodes;
205-
json output = json::array();
206-
output.push_back(json::object());
207-
jsonNodes[Tree::TreeNode::RootId] = &(output.back());
208-
std::set<std::string> inclusiveValueNames;
209-
std::map<uint64_t, std::set<uint64_t>> deviceIds;
210-
this->tree->template walk<Tree::WalkPolicy::PreOrder>([&](Tree::TreeNode
211-
&treeNode) {
212-
const auto contextName = treeNode.name;
213-
auto contextId = treeNode.id;
214-
json *jsonNode = jsonNodes[contextId];
215-
(*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}};
216-
(*jsonNode)["metrics"] = json::object();
217-
for (auto [metricKind, metric] : treeNode.metrics) {
218-
if (metricKind == MetricKind::Kernel) {
219-
std::shared_ptr<KernelMetric> kernelMetric =
220-
std::dynamic_pointer_cast<KernelMetric>(metric);
221-
uint64_t duration =
222-
std::get<uint64_t>(kernelMetric->getValue(KernelMetric::Duration));
223-
uint64_t invocations = std::get<uint64_t>(
224-
kernelMetric->getValue(KernelMetric::Invocations));
225-
uint64_t deviceId =
226-
std::get<uint64_t>(kernelMetric->getValue(KernelMetric::DeviceId));
227-
uint64_t deviceType = std::get<uint64_t>(
228-
kernelMetric->getValue(KernelMetric::DeviceType));
229-
std::string deviceTypeName =
230-
getDeviceTypeString(static_cast<DeviceType>(deviceType));
231-
(*jsonNode)["metrics"]
232-
[kernelMetric->getValueName(KernelMetric::Duration)] =
233-
duration;
234-
(*jsonNode)["metrics"]
235-
[kernelMetric->getValueName(KernelMetric::Invocations)] =
236-
invocations;
237-
(*jsonNode)["metrics"]
238-
[kernelMetric->getValueName(KernelMetric::DeviceId)] =
239-
std::to_string(deviceId);
240-
(*jsonNode)["metrics"]
241-
[kernelMetric->getValueName(KernelMetric::DeviceType)] =
242-
deviceTypeName;
243-
inclusiveValueNames.insert(
244-
kernelMetric->getValueName(KernelMetric::Duration));
245-
inclusiveValueNames.insert(
246-
kernelMetric->getValueName(KernelMetric::Invocations));
247-
deviceIds[deviceType].insert(deviceId);
248-
} else if (metricKind == MetricKind::PCSampling) {
249-
auto pcSamplingMetric =
250-
std::dynamic_pointer_cast<PCSamplingMetric>(metric);
251-
for (size_t i = 0; i < PCSamplingMetric::Count; i++) {
252-
auto valueName = pcSamplingMetric->getValueName(i);
253-
inclusiveValueNames.insert(valueName);
254-
std::visit(
255-
[&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; },
256-
pcSamplingMetric->getValues()[i]);
257-
}
258-
} else if (metricKind == MetricKind::Cycle) {
259-
auto cycleMetric = std::dynamic_pointer_cast<CycleMetric>(metric);
260-
uint64_t duration =
261-
std::get<uint64_t>(cycleMetric->getValue(CycleMetric::Duration));
262-
double normalizedDuration = std::get<double>(
263-
cycleMetric->getValue(CycleMetric::NormalizedDuration));
264-
uint64_t deviceId =
265-
std::get<uint64_t>(cycleMetric->getValue(CycleMetric::DeviceId));
266-
uint64_t deviceType =
267-
std::get<uint64_t>(cycleMetric->getValue(CycleMetric::DeviceType));
268-
(*jsonNode)["metrics"]
269-
[cycleMetric->getValueName(CycleMetric::Duration)] =
270-
duration;
271-
(*jsonNode)["metrics"][cycleMetric->getValueName(
272-
CycleMetric::NormalizedDuration)] = normalizedDuration;
273-
(*jsonNode)["metrics"]
274-
[cycleMetric->getValueName(CycleMetric::DeviceId)] =
275-
std::to_string(deviceId);
276-
(*jsonNode)["metrics"]
277-
[cycleMetric->getValueName(CycleMetric::DeviceType)] =
278-
std::to_string(deviceType);
279-
deviceIds[deviceType].insert(deviceId);
280-
} else if (metricKind == MetricKind::Flexible) {
281-
// Flexible metrics are handled in a different way
282-
} else {
283-
throw std::runtime_error("MetricKind not supported");
284-
}
285-
}
286-
for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) {
287-
auto valueName = flexibleMetric.getValueName(0);
288-
if (!flexibleMetric.isExclusive(0))
289-
inclusiveValueNames.insert(valueName);
290-
std::visit(
291-
[&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; },
292-
flexibleMetric.getValues()[0]);
293-
}
294-
(*jsonNode)["children"] = json::array();
295-
auto children = treeNode.children;
296-
for (auto _ : children) {
297-
(*jsonNode)["children"].push_back(json::object());
298-
}
299-
auto idx = 0;
300-
for (auto child : children) {
301-
auto [index, childId] = child;
302-
jsonNodes[childId] = &(*jsonNode)["children"][idx];
303-
idx++;
304-
}
305-
});
306-
// Hints for all inclusive metrics
307-
for (auto valueName : inclusiveValueNames) {
308-
output[Tree::TreeNode::RootId]["metrics"][valueName] = 0;
309-
}
310-
// Prepare the device information
311-
// Note that this is done from the application thread,
312-
// query device information from the tool thread (e.g., CUPTI) will have
313-
// problems
314-
output.push_back(json::object());
315-
auto &deviceJson = output.back();
316-
for (auto [deviceType, deviceIdSet] : deviceIds) {
317-
auto deviceTypeName =
318-
getDeviceTypeString(static_cast<DeviceType>(deviceType));
319-
if (!deviceJson.contains(deviceTypeName))
320-
deviceJson[deviceTypeName] = json::object();
321-
for (auto deviceId : deviceIdSet) {
322-
Device device = getDevice(static_cast<DeviceType>(deviceType), deviceId);
323-
deviceJson[deviceTypeName][std::to_string(deviceId)] = {
324-
{"clock_rate", device.clockRate},
325-
{"memory_clock_rate", device.memoryClockRate},
326-
{"bus_width", device.busWidth},
327-
{"arch", device.arch},
328-
{"num_sms", device.numSms}};
329-
}
330-
}
329+
auto output = buildHatchetJson(tree.get());
331330
os << std::endl << output.dump(4) << std::endl;
332331
}
333332

333+
std::string TreeData::toJsonString() const {
334+
std::shared_lock<std::shared_mutex> lock(mutex);
335+
auto output = buildHatchetJson(tree.get());
336+
return output.dump();
337+
}
338+
334339
void TreeData::doDump(std::ostream &os, OutputFormat outputFormat) const {
335340
if (outputFormat == OutputFormat::Hatchet) {
336341
dumpHatchet(os);

0 commit comments

Comments
 (0)