diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index 67c978221dcb..e28f93efe7e0 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -165,6 +165,13 @@ static void initProton(pybind11::module &&m) { m.def("get_context_depth", [](size_t sessionId) { return SessionManager::instance().getContextDepth(sessionId); }); + + m.def( + "get_data", + [](size_t sessionId) { + return SessionManager::instance().getData(sessionId); + }, + pybind11::arg("sessionId")); } PYBIND11_MODULE(libproton, m) { diff --git a/third_party/proton/csrc/include/Data/Data.h b/third_party/proton/csrc/include/Data/Data.h index 740bd2043acc..6d85a6afe56c 100644 --- a/third_party/proton/csrc/include/Data/Data.h +++ b/third_party/proton/csrc/include/Data/Data.h @@ -43,6 +43,9 @@ class Data : public ScopeInterface { /// Clear all caching data. virtual void clear() = 0; + /// To Json + virtual std::string toJsonString() const = 0; + /// Dump the data to the given output format. void dump(const std::string &outputFormat); diff --git a/third_party/proton/csrc/include/Data/TraceData.h b/third_party/proton/csrc/include/Data/TraceData.h index 89f4c596b361..723f30baaa05 100644 --- a/third_party/proton/csrc/include/Data/TraceData.h +++ b/third_party/proton/csrc/include/Data/TraceData.h @@ -22,6 +22,8 @@ class TraceData : public Data { addMetrics(size_t scopeId, const std::map &metrics) override; + std::string toJsonString() const override; + void clear() override; class Trace; diff --git a/third_party/proton/csrc/include/Data/TreeData.h b/third_party/proton/csrc/include/Data/TreeData.h index 35ec617351b6..da50ed6e4515 100644 --- a/third_party/proton/csrc/include/Data/TreeData.h +++ b/third_party/proton/csrc/include/Data/TreeData.h @@ -3,9 +3,13 @@ #include "Context/Context.h" #include "Data.h" +#include "nlohmann/json.hpp" #include +#include #include +using json = nlohmann::json; + namespace proton { class TreeData : public Data { @@ -25,6 +29,8 @@ class TreeData : public Data { addMetrics(size_t scopeId, const std::map &metrics) override; + std::string toJsonString() const override; + void clear() override; protected: @@ -34,6 +40,12 @@ class TreeData : public Data { void exitScope(const Scope &scope) override; private: + // `tree` and `scopeIdToContextId` can be accessed by both the user thread and + // the background threads concurrently, so methods that access them should be + // protected by a (shared) mutex. + class Tree; + json buildHatchetJson(TreeData::Tree *tree) const; + void dumpHatchet(std::ostream &os) const; void doDump(std::ostream &os, OutputFormat outputFormat) const override; @@ -42,10 +54,6 @@ class TreeData : public Data { return OutputFormat::Hatchet; } - // `tree` and `scopeIdToContextId` can be accessed by both the user thread and - // the background threads concurrently, so methods that access them should be - // protected by a (shared) mutex. - class Tree; std::unique_ptr tree; // ScopeId -> ContextId std::unordered_map scopeIdToContextId; diff --git a/third_party/proton/csrc/include/Session/Session.h b/third_party/proton/csrc/include/Session/Session.h index 25303d3181d9..387cb4d4d2b7 100644 --- a/third_party/proton/csrc/include/Session/Session.h +++ b/third_party/proton/csrc/include/Session/Session.h @@ -93,6 +93,8 @@ class SessionManager : public Singleton { size_t getContextDepth(size_t sessionId); + std::string getData(size_t sessionId); + void enterScope(const Scope &scope); void exitScope(const Scope &scope); diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 38728127faa7..b10f5e94b15c 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -229,6 +229,8 @@ void TraceData::addMetrics( } } +std::string TraceData::toJsonString() const { throw NotImplemented(); } + void TraceData::clear() { std::unique_lock lock(mutex); scopeIdToContextId.clear(); diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index 0b62e21b7d6f..6666dfc39508 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -2,7 +2,6 @@ #include "Context/Context.h" #include "Data/Metric.h" #include "Device.h" -#include "nlohmann/json.hpp" #include #include @@ -10,8 +9,6 @@ #include #include -using json = nlohmann::json; - namespace proton { class TreeData::Tree { @@ -106,6 +103,134 @@ class TreeData::Tree { std::map treeNodeMap; }; +json TreeData::buildHatchetJson(TreeData::Tree *tree) const { + std::map jsonNodes; + json output = json::array(); + output.push_back(json::object()); + jsonNodes[TreeData::Tree::TreeNode::RootId] = &(output.back()); + std::set inclusiveValueNames; + std::map> deviceIds; + tree->template walk( + [&](TreeData::Tree::TreeNode &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + for (auto [metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + std::shared_ptr kernelMetric = + std::dynamic_pointer_cast(metric); + uint64_t duration = std::get( + kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = std::get( + kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + std::string deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Duration)] = + duration; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Invocations)] = + invocations; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceType)] = + deviceTypeName; + inclusiveValueNames.insert( + kernelMetric->getValueName(KernelMetric::Duration)); + inclusiveValueNames.insert( + kernelMetric->getValueName(KernelMetric::Invocations)); + deviceIds[deviceType].insert(deviceId); + } else if (metricKind == MetricKind::PCSampling) { + auto pcSamplingMetric = + std::dynamic_pointer_cast(metric); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + auto valueName = pcSamplingMetric->getValueName(i); + inclusiveValueNames.insert(valueName); + std::visit( + [&](auto &&value) { + (*jsonNode)["metrics"][valueName] = value; + }, + pcSamplingMetric->getValues()[i]); + } + } else if (metricKind == MetricKind::Cycle) { + auto cycleMetric = std::dynamic_pointer_cast(metric); + uint64_t duration = std::get( + cycleMetric->getValue(CycleMetric::Duration)); + double normalizedDuration = std::get( + cycleMetric->getValue(CycleMetric::NormalizedDuration)); + uint64_t deviceId = std::get( + cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceType = std::get( + cycleMetric->getValue(CycleMetric::DeviceType)); + (*jsonNode)["metrics"] + [cycleMetric->getValueName(CycleMetric::Duration)] = + duration; + (*jsonNode)["metrics"][cycleMetric->getValueName( + CycleMetric::NormalizedDuration)] = normalizedDuration; + (*jsonNode)["metrics"] + [cycleMetric->getValueName(CycleMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [cycleMetric->getValueName(CycleMetric::DeviceType)] = + std::to_string(deviceType); + deviceIds[deviceType].insert(deviceId); + } else if (metricKind == MetricKind::Flexible) { + // Flexible metrics are handled in a different way + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { + auto valueName = flexibleMetric.getValueName(0); + if (!flexibleMetric.isExclusive(0)) + inclusiveValueNames.insert(valueName); + std::visit( + [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + flexibleMetric.getValues()[0]); + } + (*jsonNode)["children"] = json::array(); + auto children = treeNode.children; + for (auto _ : children) { + (*jsonNode)["children"].push_back(json::object()); + } + auto idx = 0; + for (auto child : children) { + auto [index, childId] = child; + jsonNodes[childId] = &(*jsonNode)["children"][idx]; + idx++; + } + }); + for (auto valueName : inclusiveValueNames) { + output[TreeData::Tree::TreeNode::RootId]["metrics"][valueName] = 0; + } + output.push_back(json::object()); + auto &deviceJson = output.back(); + for (auto [deviceType, deviceIdSet] : deviceIds) { + auto deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + if (!deviceJson.contains(deviceTypeName)) + deviceJson[deviceTypeName] = json::object(); + for (auto deviceId : deviceIdSet) { + Device device = getDevice(static_cast(deviceType), deviceId); + deviceJson[deviceTypeName][std::to_string(deviceId)] = { + {"clock_rate", device.clockRate}, + {"memory_clock_rate", device.memoryClockRate}, + {"bus_width", device.busWidth}, + {"arch", device.arch}, + {"num_sms", device.numSms}}; + } + } + return output; +} + void TreeData::enterScope(const Scope &scope) { // enterOp and addMetric maybe called from different threads std::unique_lock lock(mutex); @@ -201,136 +326,16 @@ void TreeData::clear() { } void TreeData::dumpHatchet(std::ostream &os) const { - std::map jsonNodes; - json output = json::array(); - output.push_back(json::object()); - jsonNodes[Tree::TreeNode::RootId] = &(output.back()); - std::set inclusiveValueNames; - std::map> deviceIds; - this->tree->template walk([&](Tree::TreeNode - &treeNode) { - const auto contextName = treeNode.name; - auto contextId = treeNode.id; - json *jsonNode = jsonNodes[contextId]; - (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; - (*jsonNode)["metrics"] = json::object(); - for (auto [metricKind, metric] : treeNode.metrics) { - if (metricKind == MetricKind::Kernel) { - std::shared_ptr kernelMetric = - std::dynamic_pointer_cast(metric); - uint64_t duration = - std::get(kernelMetric->getValue(KernelMetric::Duration)); - uint64_t invocations = std::get( - kernelMetric->getValue(KernelMetric::Invocations)); - uint64_t deviceId = - std::get(kernelMetric->getValue(KernelMetric::DeviceId)); - uint64_t deviceType = std::get( - kernelMetric->getValue(KernelMetric::DeviceType)); - std::string deviceTypeName = - getDeviceTypeString(static_cast(deviceType)); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Duration)] = - duration; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Invocations)] = - invocations; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceId)] = - std::to_string(deviceId); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceType)] = - deviceTypeName; - inclusiveValueNames.insert( - kernelMetric->getValueName(KernelMetric::Duration)); - inclusiveValueNames.insert( - kernelMetric->getValueName(KernelMetric::Invocations)); - deviceIds[deviceType].insert(deviceId); - } else if (metricKind == MetricKind::PCSampling) { - auto pcSamplingMetric = - std::dynamic_pointer_cast(metric); - for (size_t i = 0; i < PCSamplingMetric::Count; i++) { - auto valueName = pcSamplingMetric->getValueName(i); - inclusiveValueNames.insert(valueName); - std::visit( - [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, - pcSamplingMetric->getValues()[i]); - } - } else if (metricKind == MetricKind::Cycle) { - auto cycleMetric = std::dynamic_pointer_cast(metric); - uint64_t duration = - std::get(cycleMetric->getValue(CycleMetric::Duration)); - double normalizedDuration = std::get( - cycleMetric->getValue(CycleMetric::NormalizedDuration)); - uint64_t deviceId = - std::get(cycleMetric->getValue(CycleMetric::DeviceId)); - uint64_t deviceType = - std::get(cycleMetric->getValue(CycleMetric::DeviceType)); - (*jsonNode)["metrics"] - [cycleMetric->getValueName(CycleMetric::Duration)] = - duration; - (*jsonNode)["metrics"][cycleMetric->getValueName( - CycleMetric::NormalizedDuration)] = normalizedDuration; - (*jsonNode)["metrics"] - [cycleMetric->getValueName(CycleMetric::DeviceId)] = - std::to_string(deviceId); - (*jsonNode)["metrics"] - [cycleMetric->getValueName(CycleMetric::DeviceType)] = - std::to_string(deviceType); - deviceIds[deviceType].insert(deviceId); - } else if (metricKind == MetricKind::Flexible) { - // Flexible metrics are handled in a different way - } else { - throw std::runtime_error("MetricKind not supported"); - } - } - for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { - auto valueName = flexibleMetric.getValueName(0); - if (!flexibleMetric.isExclusive(0)) - inclusiveValueNames.insert(valueName); - std::visit( - [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, - flexibleMetric.getValues()[0]); - } - (*jsonNode)["children"] = json::array(); - auto children = treeNode.children; - for (auto _ : children) { - (*jsonNode)["children"].push_back(json::object()); - } - auto idx = 0; - for (auto child : children) { - auto [index, childId] = child; - jsonNodes[childId] = &(*jsonNode)["children"][idx]; - idx++; - } - }); - // Hints for all inclusive metrics - for (auto valueName : inclusiveValueNames) { - output[Tree::TreeNode::RootId]["metrics"][valueName] = 0; - } - // Prepare the device information - // Note that this is done from the application thread, - // query device information from the tool thread (e.g., CUPTI) will have - // problems - output.push_back(json::object()); - auto &deviceJson = output.back(); - for (auto [deviceType, deviceIdSet] : deviceIds) { - auto deviceTypeName = - getDeviceTypeString(static_cast(deviceType)); - if (!deviceJson.contains(deviceTypeName)) - deviceJson[deviceTypeName] = json::object(); - for (auto deviceId : deviceIdSet) { - Device device = getDevice(static_cast(deviceType), deviceId); - deviceJson[deviceTypeName][std::to_string(deviceId)] = { - {"clock_rate", device.clockRate}, - {"memory_clock_rate", device.memoryClockRate}, - {"bus_width", device.busWidth}, - {"arch", device.arch}, - {"num_sms", device.numSms}}; - } - } + auto output = buildHatchetJson(tree.get()); os << std::endl << output.dump(4) << std::endl; } +std::string TreeData::toJsonString() const { + std::shared_lock lock(mutex); + auto output = buildHatchetJson(tree.get()); + return output.dump(); +} + void TreeData::doDump(std::ostream &os, OutputFormat outputFormat) const { if (outputFormat == OutputFormat::Hatchet) { dumpHatchet(os); diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 761e12c51a82..d57cff2a1241 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -304,4 +304,22 @@ size_t SessionManager::getContextDepth(size_t sessionId) { return sessions[sessionId]->getContextDepth(); } +std::string SessionManager::getData(size_t sessionId) { + std::lock_guard lock(mutex); + throwIfSessionNotInitialized(sessions, sessionId); + auto *profiler = sessions[sessionId]->getProfiler(); + auto dataSet = profiler->getDataSet(); + if (dataSet.find(sessions[sessionId]->data.get()) != dataSet.end()) { + throw std::runtime_error( + "Cannot get data while the session is active. Please deactivate the " + "session first."); + } + auto *treeData = dynamic_cast(sessions[sessionId]->data.get()); + if (!treeData) { + throw std::runtime_error( + "Only TreeData is supported for getData() for now"); + } + return treeData->toJsonString(); +} + } // namespace proton diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index f6685397ddce..a4e1fd1073a7 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -9,4 +9,5 @@ profile, DEFAULT_PROFILE_NAME, ) +from .data import get_data from . import context, specs, mode diff --git a/third_party/proton/proton/data.py b/third_party/proton/proton/data.py new file mode 100644 index 000000000000..b79337ddc221 --- /dev/null +++ b/third_party/proton/proton/data.py @@ -0,0 +1,15 @@ +from triton._C.libproton import proton as libproton # type: ignore +import json as json + + +def get_data(session: int): + """ + Retrieves profiling data for a given session. + + Args: + session (int): The session ID of the profiling session. + + Returns: + str: The profiling data in JSON string format. + """ + return json.loads(libproton.get_data(session)) diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index 500e76d9d174..839c6b3c0fea 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -64,8 +64,7 @@ def remove_frame_helper(node): return new_database -def get_raw_metrics(file): - database = json.load(file) +def get_raw_metrics(database) -> tuple[ht.GraphFrame, list[str], list[str], dict]: database = remove_frames(database) device_info = database.pop(1) gf = ht.GraphFrame.from_literal(database) @@ -259,7 +258,8 @@ def print_tree(gf, metrics, depth=100, format=None, print_sorted=False): def read(filename): with open(filename, "r") as f: - gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(f) + database = json.load(f) + gf, inclusive_metrics, exclusive_metrics, device_info = get_raw_metrics(database) assert len(inclusive_metrics + exclusive_metrics) > 0, "No metrics found in the input file" gf.update_inclusive_columns() return gf, inclusive_metrics, exclusive_metrics, device_info @@ -289,7 +289,8 @@ def apply_diff_profile(gf, derived_metrics, diff_file, metrics, include, exclude def show_metrics(file_name): with open(file_name, "r") as f: - _, inclusive_metrics, exclusive_metrics, _ = get_raw_metrics(f) + database = json.load(f) + _, inclusive_metrics, exclusive_metrics, _ = get_raw_metrics(database) print("Available inclusive metrics:") if inclusive_metrics: for raw_metric in inclusive_metrics: diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 4a6cb1603c6c..6843047a88e8 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -15,6 +15,7 @@ import triton.language as tl from triton.profiler.hooks.launch import COMPUTE_METADATA_SCOPE_NAME import triton.profiler.hooks.launch as proton_launch +import triton.profiler.viewer as viewer from triton._internal_testing import is_hip @@ -198,8 +199,39 @@ def test_cpu_timed_scope(tmp_path: pathlib.Path): assert test0_frame["metrics"]["cpu_time (ns)"] > 0 test1_frame = test0_frame["children"][0] assert test1_frame["metrics"]["cpu_time (ns)"] > 0 - kernel_frame = test1_frame["children"][0] - assert kernel_frame["metrics"]["time (ns)"] > 0 + + +def test_get_data(tmp_path: pathlib.Path): + temp_file = tmp_path / "test_tree_json.hatchet" + session = proton.start(str(temp_file.with_suffix("")), context="shadow") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + tl.store(y + offs, tl.load(x + offs)) + + with proton.scope("test"): + x = torch.ones((2, 2), device="cuda") + foo[(1, )](x, x, 4) + foo[(1, )](x, x, 4) + + try: + _ = proton.get_data(session) + except RuntimeError as e: + assert "Cannot get data while the session is active" in str(e) + + proton.deactivate(session) + + database = proton.get_data(session) + gf, _, _, _ = viewer.get_raw_metrics(database) + foo_frame = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*foo.*' AND c IS LEAF").dataframe + ones_frame = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*elementwise.*' AND c IS LEAF").dataframe + + proton.finalize() + assert len(foo_frame) == 1 + assert int(foo_frame["count"].values[0]) == 2 + assert len(ones_frame) == 1 + assert int(ones_frame["count"].values[0]) == 1 def test_hook_launch(tmp_path: pathlib.Path):