From 1b41db78305f30fa857e3dc59061091701872bb2 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 7 Dec 2025 19:03:32 -0500 Subject: [PATCH 1/6] Update --- third_party/proton/csrc/Proton.cpp | 7 + third_party/proton/csrc/include/Data/Data.h | 3 + .../proton/csrc/include/Data/TraceData.h | 2 + .../proton/csrc/include/Data/TreeData.h | 16 +- .../proton/csrc/include/Session/Session.h | 2 + .../proton/csrc/lib/Data/TraceData.cpp | 4 + third_party/proton/csrc/lib/Data/TreeData.cpp | 263 +++++++++--------- .../proton/csrc/lib/Session/Session.cpp | 17 ++ third_party/proton/proton/__init__.py | 1 + third_party/proton/proton/data.py | 14 + third_party/proton/test/test_profile.py | 28 +- 11 files changed, 221 insertions(+), 136 deletions(-) create mode 100644 third_party/proton/proton/data.py diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index 67c978221dcb..e28f93efe7e0 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -165,6 +165,13 @@ static void initProton(pybind11::module &&m) { m.def("get_context_depth", [](size_t sessionId) { return SessionManager::instance().getContextDepth(sessionId); }); + + m.def( + "get_data", + [](size_t sessionId) { + return SessionManager::instance().getData(sessionId); + }, + pybind11::arg("sessionId")); } PYBIND11_MODULE(libproton, m) { diff --git a/third_party/proton/csrc/include/Data/Data.h b/third_party/proton/csrc/include/Data/Data.h index 740bd2043acc..6d85a6afe56c 100644 --- a/third_party/proton/csrc/include/Data/Data.h +++ b/third_party/proton/csrc/include/Data/Data.h @@ -43,6 +43,9 @@ class Data : public ScopeInterface { /// Clear all caching data. virtual void clear() = 0; + /// To Json + virtual std::string toJsonString() const = 0; + /// Dump the data to the given output format. void dump(const std::string &outputFormat); diff --git a/third_party/proton/csrc/include/Data/TraceData.h b/third_party/proton/csrc/include/Data/TraceData.h index 89f4c596b361..723f30baaa05 100644 --- a/third_party/proton/csrc/include/Data/TraceData.h +++ b/third_party/proton/csrc/include/Data/TraceData.h @@ -22,6 +22,8 @@ class TraceData : public Data { addMetrics(size_t scopeId, const std::map &metrics) override; + std::string toJsonString() const override; + void clear() override; class Trace; diff --git a/third_party/proton/csrc/include/Data/TreeData.h b/third_party/proton/csrc/include/Data/TreeData.h index 35ec617351b6..fae55a33a668 100644 --- a/third_party/proton/csrc/include/Data/TreeData.h +++ b/third_party/proton/csrc/include/Data/TreeData.h @@ -3,8 +3,12 @@ #include "Context/Context.h" #include "Data.h" +#include "nlohmann/json.hpp" #include #include +#include + +using json = nlohmann::json; namespace proton { @@ -25,6 +29,8 @@ class TreeData : public Data { addMetrics(size_t scopeId, const std::map &metrics) override; + std::string toJsonString() const override; + void clear() override; protected: @@ -34,6 +40,12 @@ class TreeData : public Data { void exitScope(const Scope &scope) override; private: + // `tree` and `scopeIdToContextId` can be accessed by both the user thread and + // the background threads concurrently, so methods that access them should be + // protected by a (shared) mutex. + class Tree; + json buildHatchetJson(TreeData::Tree *tree) const; + void dumpHatchet(std::ostream &os) const; void doDump(std::ostream &os, OutputFormat outputFormat) const override; @@ -42,10 +54,6 @@ class TreeData : public Data { return OutputFormat::Hatchet; } - // `tree` and `scopeIdToContextId` can be accessed by both the user thread and - // the background threads concurrently, so methods that access them should be - // protected by a (shared) mutex. - class Tree; std::unique_ptr tree; // ScopeId -> ContextId std::unordered_map scopeIdToContextId; diff --git a/third_party/proton/csrc/include/Session/Session.h b/third_party/proton/csrc/include/Session/Session.h index 25303d3181d9..387cb4d4d2b7 100644 --- a/third_party/proton/csrc/include/Session/Session.h +++ b/third_party/proton/csrc/include/Session/Session.h @@ -93,6 +93,8 @@ class SessionManager : public Singleton { size_t getContextDepth(size_t sessionId); + std::string getData(size_t sessionId); + void enterScope(const Scope &scope); void exitScope(const Scope &scope); diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 38728127faa7..14a7aab4e9c6 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -229,6 +229,10 @@ void TraceData::addMetrics( } } +std::string TraceData::toJsonString() const { + throw NotImplemented(); +} + void TraceData::clear() { std::unique_lock lock(mutex); scopeIdToContextId.clear(); diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index 946bcece10f7..ea676cc9fb45 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -2,7 +2,6 @@ #include "Context/Context.h" #include "Data/Metric.h" #include "Device.h" -#include "nlohmann/json.hpp" #include #include @@ -10,8 +9,6 @@ #include #include -using json = nlohmann::json; - namespace proton { class TreeData::Tree { @@ -106,6 +103,132 @@ class TreeData::Tree { std::map treeNodeMap; }; +json TreeData::buildHatchetJson(TreeData::Tree *tree) const { + std::map jsonNodes; + json output = json::array(); + output.push_back(json::object()); + jsonNodes[TreeData::Tree::TreeNode::RootId] = &(output.back()); + std::set inclusiveValueNames; + std::map> deviceIds; + tree->template walk( + [&](TreeData::Tree::TreeNode &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + for (auto [metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + std::shared_ptr kernelMetric = + std::dynamic_pointer_cast(metric); + uint64_t duration = + std::get(kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = + std::get(kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + std::string deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Duration)] = + duration; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Invocations)] = + invocations; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceType)] = + deviceTypeName; + inclusiveValueNames.insert( + kernelMetric->getValueName(KernelMetric::Duration)); + inclusiveValueNames.insert( + kernelMetric->getValueName(KernelMetric::Invocations)); + deviceIds.insert({deviceType, {deviceId}}); + } else if (metricKind == MetricKind::PCSampling) { + auto pcSamplingMetric = + std::dynamic_pointer_cast(metric); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + auto valueName = pcSamplingMetric->getValueName(i); + inclusiveValueNames.insert(valueName); + std::visit( + [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + pcSamplingMetric->getValues()[i]); + } + } else if (metricKind == MetricKind::Cycle) { + auto cycleMetric = std::dynamic_pointer_cast(metric); + uint64_t duration = + std::get(cycleMetric->getValue(CycleMetric::Duration)); + double normalizedDuration = std::get( + cycleMetric->getValue(CycleMetric::NormalizedDuration)); + uint64_t deviceId = + std::get(cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceType = std::get( + cycleMetric->getValue(CycleMetric::DeviceType)); + (*jsonNode)["metrics"] + [cycleMetric->getValueName(CycleMetric::Duration)] = + duration; + (*jsonNode)["metrics"][cycleMetric->getValueName( + CycleMetric::NormalizedDuration)] = normalizedDuration; + (*jsonNode)["metrics"] + [cycleMetric->getValueName(CycleMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [cycleMetric->getValueName(CycleMetric::DeviceType)] = + std::to_string(deviceType); + deviceIds.insert({deviceType, {deviceId}}); + } else if (metricKind == MetricKind::Flexible) { + // Flexible metrics are handled in a different way + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { + auto valueName = flexibleMetric.getValueName(0); + if (!flexibleMetric.isExclusive(0)) + inclusiveValueNames.insert(valueName); + std::visit( + [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + flexibleMetric.getValues()[0]); + } + (*jsonNode)["children"] = json::array(); + auto children = treeNode.children; + for (auto _ : children) { + (*jsonNode)["children"].push_back(json::object()); + } + auto idx = 0; + for (auto child : children) { + auto [index, childId] = child; + jsonNodes[childId] = &(*jsonNode)["children"][idx]; + idx++; + } + }); + for (auto valueName : inclusiveValueNames) { + output[TreeData::Tree::TreeNode::RootId]["metrics"][valueName] = 0; + } + output.push_back(json::object()); + auto &deviceJson = output.back(); + for (auto [deviceType, deviceIds] : deviceIds) { + auto deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + if (!deviceJson.contains(deviceTypeName)) + deviceJson[deviceTypeName] = json::object(); + for (auto deviceId : deviceIds) { + Device device = getDevice(static_cast(deviceType), deviceId); + deviceJson[deviceTypeName][std::to_string(deviceId)] = { + {"clock_rate", device.clockRate}, + {"memory_clock_rate", device.memoryClockRate}, + {"bus_width", device.busWidth}, + {"arch", device.arch}, + {"num_sms", device.numSms}}; + } + } + return output; +} + void TreeData::enterScope(const Scope &scope) { // enterOp and addMetric maybe called from different threads std::unique_lock lock(mutex); @@ -201,136 +324,16 @@ void TreeData::clear() { } void TreeData::dumpHatchet(std::ostream &os) const { - std::map jsonNodes; - json output = json::array(); - output.push_back(json::object()); - jsonNodes[Tree::TreeNode::RootId] = &(output.back()); - std::set inclusiveValueNames; - std::map> deviceIds; - this->tree->template walk([&](Tree::TreeNode - &treeNode) { - const auto contextName = treeNode.name; - auto contextId = treeNode.id; - json *jsonNode = jsonNodes[contextId]; - (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; - (*jsonNode)["metrics"] = json::object(); - for (auto [metricKind, metric] : treeNode.metrics) { - if (metricKind == MetricKind::Kernel) { - std::shared_ptr kernelMetric = - std::dynamic_pointer_cast(metric); - uint64_t duration = - std::get(kernelMetric->getValue(KernelMetric::Duration)); - uint64_t invocations = std::get( - kernelMetric->getValue(KernelMetric::Invocations)); - uint64_t deviceId = - std::get(kernelMetric->getValue(KernelMetric::DeviceId)); - uint64_t deviceType = std::get( - kernelMetric->getValue(KernelMetric::DeviceType)); - std::string deviceTypeName = - getDeviceTypeString(static_cast(deviceType)); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Duration)] = - duration; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Invocations)] = - invocations; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceId)] = - std::to_string(deviceId); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceType)] = - deviceTypeName; - inclusiveValueNames.insert( - kernelMetric->getValueName(KernelMetric::Duration)); - inclusiveValueNames.insert( - kernelMetric->getValueName(KernelMetric::Invocations)); - deviceIds.insert({deviceType, {deviceId}}); - } else if (metricKind == MetricKind::PCSampling) { - auto pcSamplingMetric = - std::dynamic_pointer_cast(metric); - for (size_t i = 0; i < PCSamplingMetric::Count; i++) { - auto valueName = pcSamplingMetric->getValueName(i); - inclusiveValueNames.insert(valueName); - std::visit( - [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, - pcSamplingMetric->getValues()[i]); - } - } else if (metricKind == MetricKind::Cycle) { - auto cycleMetric = std::dynamic_pointer_cast(metric); - uint64_t duration = - std::get(cycleMetric->getValue(CycleMetric::Duration)); - double normalizedDuration = std::get( - cycleMetric->getValue(CycleMetric::NormalizedDuration)); - uint64_t deviceId = - std::get(cycleMetric->getValue(CycleMetric::DeviceId)); - uint64_t deviceType = - std::get(cycleMetric->getValue(CycleMetric::DeviceType)); - (*jsonNode)["metrics"] - [cycleMetric->getValueName(CycleMetric::Duration)] = - duration; - (*jsonNode)["metrics"][cycleMetric->getValueName( - CycleMetric::NormalizedDuration)] = normalizedDuration; - (*jsonNode)["metrics"] - [cycleMetric->getValueName(CycleMetric::DeviceId)] = - std::to_string(deviceId); - (*jsonNode)["metrics"] - [cycleMetric->getValueName(CycleMetric::DeviceType)] = - std::to_string(deviceType); - deviceIds.insert({deviceType, {deviceId}}); - } else if (metricKind == MetricKind::Flexible) { - // Flexible metrics are handled in a different way - } else { - throw std::runtime_error("MetricKind not supported"); - } - } - for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { - auto valueName = flexibleMetric.getValueName(0); - if (!flexibleMetric.isExclusive(0)) - inclusiveValueNames.insert(valueName); - std::visit( - [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, - flexibleMetric.getValues()[0]); - } - (*jsonNode)["children"] = json::array(); - auto children = treeNode.children; - for (auto _ : children) { - (*jsonNode)["children"].push_back(json::object()); - } - auto idx = 0; - for (auto child : children) { - auto [index, childId] = child; - jsonNodes[childId] = &(*jsonNode)["children"][idx]; - idx++; - } - }); - // Hints for all inclusive metrics - for (auto valueName : inclusiveValueNames) { - output[Tree::TreeNode::RootId]["metrics"][valueName] = 0; - } - // Prepare the device information - // Note that this is done from the application thread, - // query device information from the tool thread (e.g., CUPTI) will have - // problems - output.push_back(json::object()); - auto &deviceJson = output.back(); - for (auto [deviceType, deviceIds] : deviceIds) { - auto deviceTypeName = - getDeviceTypeString(static_cast(deviceType)); - if (!deviceJson.contains(deviceTypeName)) - deviceJson[deviceTypeName] = json::object(); - for (auto deviceId : deviceIds) { - Device device = getDevice(static_cast(deviceType), deviceId); - deviceJson[deviceTypeName][std::to_string(deviceId)] = { - {"clock_rate", device.clockRate}, - {"memory_clock_rate", device.memoryClockRate}, - {"bus_width", device.busWidth}, - {"arch", device.arch}, - {"num_sms", device.numSms}}; - } - } + auto output = buildHatchetJson(tree.get()); os << std::endl << output.dump(4) << std::endl; } +std::string TreeData::toJsonString() const { + std::shared_lock lock(mutex); + auto output = buildHatchetJson(tree.get()); + return output.dump(); +} + void TreeData::doDump(std::ostream &os, OutputFormat outputFormat) const { if (outputFormat == OutputFormat::Hatchet) { dumpHatchet(os); diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 761e12c51a82..03db95adf6dd 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -304,4 +304,21 @@ size_t SessionManager::getContextDepth(size_t sessionId) { return sessions[sessionId]->getContextDepth(); } +std::string SessionManager::getData(size_t sessionId) { + std::lock_guard lock(mutex); + throwIfSessionNotInitialized(sessions, sessionId); + auto *profiler = sessions[sessionId]->getProfiler(); + auto dataSet = profiler->getDataSet(); + if (dataSet.find(sessions[sessionId]->data.get()) != dataSet.end()) { + throw std::runtime_error( + "Cannot get data while the session is active. Please deactivate the " + "session first."); + } + auto *treeData = dynamic_cast(sessions[sessionId]->data.get()); + if (!treeData) { + throw std::runtime_error("Only TreeData is supported for getData() for now"); + } + return treeData->toJsonString(); +} + } // namespace proton diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index f6685397ddce..a4e1fd1073a7 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -9,4 +9,5 @@ profile, DEFAULT_PROFILE_NAME, ) +from .data import get_data from . import context, specs, mode diff --git a/third_party/proton/proton/data.py b/third_party/proton/proton/data.py new file mode 100644 index 000000000000..223a103c2568 --- /dev/null +++ b/third_party/proton/proton/data.py @@ -0,0 +1,14 @@ +from triton._C.libproton import proton as libproton # type: ignore + +def get_data(session: int) -> str: + """ + Retrieves profiling data for a given session. + + Args: + session (int): The session ID of the profiling session. + + Returns: + str: The profiling data in JSON string format. + """ + return libproton.get_data(session) + \ No newline at end of file diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 36008550651c..2ff71a3936e2 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -15,6 +15,8 @@ import triton.language as tl from triton.profiler.hooks.launch import COMPUTE_METADATA_SCOPE_NAME import triton.profiler.hooks.launch as proton_launch +from triton.profiler import get_data +import triton.profiler.viewer as viewer from triton._internal_testing import is_hip @@ -198,8 +200,30 @@ def test_cpu_timed_scope(tmp_path: pathlib.Path): assert test0_frame["metrics"]["cpu_time (ns)"] > 0 test1_frame = test0_frame["children"][0] assert test1_frame["metrics"]["cpu_time (ns)"] > 0 - kernel_frame = test1_frame["children"][0] - assert kernel_frame["metrics"]["time (ns)"] > 0 + + +def test_get_data(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_tree_json.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow") + + @triton.jit + def foo(x, size: tl.constexpr, y): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + with proton.scope("test"): + torch.ones((2, 2), device="cuda") + foo[(1, )](x, 4) + foo[(1, )](x, 4) + + proton.deactivate(session) + + json_str = proton.get_data(session) + gf, _, _, _ = viewer.get_raw_metrics(json_str) + useful = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe + + proton.finalize() + print(useful) def test_hook_launch(tmp_path: pathlib.Path): From d3e73816a0366fd1d658589fcea4a6167a470948 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 7 Dec 2025 19:06:07 -0500 Subject: [PATCH 2/6] Update --- third_party/proton/test/test_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 2ff71a3936e2..7800accc50c7 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -212,7 +212,7 @@ def foo(x, size: tl.constexpr, y): tl.store(y + offs, tl.load(x + offs)) with proton.scope("test"): - torch.ones((2, 2), device="cuda") + x = torch.ones((2, 2), device="cuda") foo[(1, )](x, 4) foo[(1, )](x, 4) From 3d7dfca8d856b46b0360d20fb7ef4a615d8e5b98 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 7 Dec 2025 19:06:38 -0500 Subject: [PATCH 3/6] Update --- third_party/proton/test/test_profile.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 7800accc50c7..b5c9dd8cd1e0 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -207,14 +207,14 @@ def test_get_data(tmp_path: pathlib.Path): session = proton.start(str(temp_file.with_suffix("")), context="shadow") @triton.jit - def foo(x, size: tl.constexpr, y): + def foo(x, y, size: tl.constexpr): offs = tl.arange(0, size) tl.store(y + offs, tl.load(x + offs)) with proton.scope("test"): x = torch.ones((2, 2), device="cuda") - foo[(1, )](x, 4) - foo[(1, )](x, 4) + foo[(1, )](x, x, 4) + foo[(1, )](x, x, 4) proton.deactivate(session) From f470b6fbe22af80e2c81b0821853f6f85d459dbd Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 7 Dec 2025 19:52:57 -0500 Subject: [PATCH 4/6] Update --- third_party/proton/proton/data.py | 5 +++-- third_party/proton/proton/viewer.py | 9 +++++---- third_party/proton/test/test_profile.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/third_party/proton/proton/data.py b/third_party/proton/proton/data.py index 223a103c2568..55b2f14b116e 100644 --- a/third_party/proton/proton/data.py +++ b/third_party/proton/proton/data.py @@ -1,6 +1,7 @@ from triton._C.libproton import proton as libproton # type: ignore +import json as json -def get_data(session: int) -> str: +def get_data(session: int): """ Retrieves profiling data for a given session. @@ -10,5 +11,5 @@ def get_data(session: int) -> str: Returns: str: The profiling data in JSON string format. """ - return libproton.get_data(session) + return json.loads(libproton.get_data(session)) \ No newline at end of file diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index 500e76d9d174..839c6b3c0fea 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -64,8 +64,7 @@ def remove_frame_helper(node): return new_database -def get_raw_metrics(file): - database = json.load(file) +def get_raw_metrics(database) -> tuple[ht.GraphFrame, list[str], list[str], dict]: database = remove_frames(database) device_info = database.pop(1) gf = ht.GraphFrame.from_literal(database) @@ -259,7 +258,8 @@ def print_tree(gf, metrics, depth=100, format=None, print_sorted=False): def read(filename): with open(filename, "r") as f: - gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f) + database = json.load(f) + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(database) assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file" gf.update_inclusive_columns() return gf, inclusive_metrics, exclusive_metrics, device_info @@ -289,7 +289,8 @@ def apply_diff_profile(gf, derived_metrics, diff_file, metrics, include, exclude def show_metrics(file_name): with open(file_name, "r") as f: - _, inclusive_metrics, exclusive_metrics, _ = get_raw_metrics(f) + database = json.load(f) + _, inclusive_metrics, exclusive_metrics, _ = get_raw_metrics(database) print("Available inclusive metrics:") if inclusive_metrics: for raw_metric in inclusive_metrics: diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index b5c9dd8cd1e0..561363992890 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -218,8 +218,8 @@ def foo(x, y, size: tl.constexpr): proton.deactivate(session) - json_str = proton.get_data(session) - gf, _, _, _ = viewer.get_raw_metrics(json_str) + database = proton.get_data(session) + gf, _, _, _ = viewer.get_raw_metrics(database) useful = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe proton.finalize() From df23e759006eff8c804a9b11f397dec754635252 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 7 Dec 2025 19:55:31 -0500 Subject: [PATCH 5/6] Update --- third_party/proton/test/test_profile.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 561363992890..d1635758b54c 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -216,14 +216,23 @@ def foo(x, y, size: tl.constexpr): foo[(1, )](x, x, 4) foo[(1, )](x, x, 4) + try: + _ = proton.get_data(session) + except RuntimeError as e: + assert "Cannot get data while the session is active" in str(e) + proton.deactivate(session) database = proton.get_data(session) gf, _, _, _ = viewer.get_raw_metrics(database) - useful = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe + foo_frame = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe + ones_frame = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*elementwise.*' AND c IS LEAF").dataframe proton.finalize() - print(useful) + assert len(foo_frame) == 1 + assert int(foo_frame["count"].values[0]) == 2 + assert len(ones_frame) == 1 + assert int(ones_frame["count"].values[0]) == 1 def test_hook_launch(tmp_path: pathlib.Path): From a348f29d414a83b9e2ad1c9efdfc47111885ff9c Mon Sep 17 00:00:00 2001 From: Jokeren Date: Sun, 7 Dec 2025 20:06:35 -0500 Subject: [PATCH 6/6] Fix --- .../proton/csrc/include/Data/TreeData.h | 2 +- .../proton/csrc/lib/Data/TraceData.cpp | 4 +--- third_party/proton/csrc/lib/Data/TreeData.cpp | 20 ++++++++++--------- .../proton/csrc/lib/Session/Session.cpp | 3 ++- third_party/proton/proton/data.py | 2 +- third_party/proton/test/test_profile.py | 5 ++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/third_party/proton/csrc/include/Data/TreeData.h b/third_party/proton/csrc/include/Data/TreeData.h index fae55a33a668..da50ed6e4515 100644 --- a/third_party/proton/csrc/include/Data/TreeData.h +++ b/third_party/proton/csrc/include/Data/TreeData.h @@ -5,8 +5,8 @@ #include "Data.h" #include "nlohmann/json.hpp" #include -#include #include +#include using json = nlohmann::json; diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 14a7aab4e9c6..b10f5e94b15c 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -229,9 +229,7 @@ void TraceData::addMetrics( } } -std::string TraceData::toJsonString() const { - throw NotImplemented(); -} +std::string TraceData::toJsonString() const { throw NotImplemented(); } void TraceData::clear() { std::unique_lock lock(mutex); diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index ea676cc9fb45..9c863c080d24 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -121,12 +121,12 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { if (metricKind == MetricKind::Kernel) { std::shared_ptr kernelMetric = std::dynamic_pointer_cast(metric); - uint64_t duration = - std::get(kernelMetric->getValue(KernelMetric::Duration)); + uint64_t duration = std::get( + kernelMetric->getValue(KernelMetric::Duration)); uint64_t invocations = std::get( kernelMetric->getValue(KernelMetric::Invocations)); - uint64_t deviceId = - std::get(kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceId = std::get( + kernelMetric->getValue(KernelMetric::DeviceId)); uint64_t deviceType = std::get( kernelMetric->getValue(KernelMetric::DeviceType)); std::string deviceTypeName = @@ -155,17 +155,19 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { auto valueName = pcSamplingMetric->getValueName(i); inclusiveValueNames.insert(valueName); std::visit( - [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + [&](auto &&value) { + (*jsonNode)["metrics"][valueName] = value; + }, pcSamplingMetric->getValues()[i]); } } else if (metricKind == MetricKind::Cycle) { auto cycleMetric = std::dynamic_pointer_cast(metric); - uint64_t duration = - std::get(cycleMetric->getValue(CycleMetric::Duration)); + uint64_t duration = std::get( + cycleMetric->getValue(CycleMetric::Duration)); double normalizedDuration = std::get( cycleMetric->getValue(CycleMetric::NormalizedDuration)); - uint64_t deviceId = - std::get(cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceId = std::get( + cycleMetric->getValue(CycleMetric::DeviceId)); uint64_t deviceType = std::get( cycleMetric->getValue(CycleMetric::DeviceType)); (*jsonNode)["metrics"] diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 03db95adf6dd..d57cff2a1241 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -316,7 +316,8 @@ std::string SessionManager::getData(size_t sessionId) { } auto *treeData = dynamic_cast(sessions[sessionId]->data.get()); if (!treeData) { - throw std::runtime_error("Only TreeData is supported for getData() for now"); + throw std::runtime_error( + "Only TreeData is supported for getData() for now"); } return treeData->toJsonString(); } diff --git a/third_party/proton/proton/data.py b/third_party/proton/proton/data.py index 55b2f14b116e..b79337ddc221 100644 --- a/third_party/proton/proton/data.py +++ b/third_party/proton/proton/data.py @@ -1,6 +1,7 @@ from triton._C.libproton import proton as libproton # type: ignore import json as json + def get_data(session: int): """ Retrieves profiling data for a given session. @@ -12,4 +13,3 @@ def get_data(session: int): str: The profiling data in JSON string format. """ return json.loads(libproton.get_data(session)) - \ No newline at end of file diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index d1635758b54c..f74a2bf5cb1f 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -15,7 +15,6 @@ import triton.language as tl from triton.profiler.hooks.launch import COMPUTE_METADATA_SCOPE_NAME import triton.profiler.hooks.launch as proton_launch -from triton.profiler import get_data import triton.profiler.viewer as viewer from triton._internal_testing import is_hip @@ -225,8 +224,8 @@ def foo(x, y, size: tl.constexpr): database = proton.get_data(session) gf, _, _, _ = viewer.get_raw_metrics(database) - foo_frame = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe - ones_frame = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '.*elementwise.*' AND c IS LEAF").dataframe + foo_frame = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe + ones_frame = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*elementwise.*' AND c IS LEAF").dataframe proton.finalize() assert len(foo_frame) == 1