Skip to content

Commit 1c15a29

Browse files
authored
[PROTON] Fix metric buffer deadlock and support multi-device metric profiling (#8943)
1 parent 49fd500 commit 1c15a29

File tree

4 files changed

+172
-44
lines changed

4 files changed

+172
-44
lines changed

third_party/proton/csrc/include/Data/Metric.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,9 @@ class MetricBuffer {
407407
void reserve() { getOrCreateBuffer(); }
408408

409409
template <typename Func> void flush(Func callback, bool flushAll = false) {
410-
std::lock_guard<std::mutex> lock(bufferMutex);
411410
std::vector<DeviceBuffer> buffersToFlush;
412411
if (flushAll) {
412+
std::lock_guard<std::mutex> lock(bufferMutex);
413413
for (auto &[device, buffer] : deviceBuffers) {
414414
buffersToFlush.push_back(buffer);
415415
}

third_party/proton/csrc/lib/Data/TreeData.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ void TreeData::dumpHatchet(std::ostream &os) const {
244244
kernelMetric->getValueName(KernelMetric::Duration));
245245
inclusiveValueNames.insert(
246246
kernelMetric->getValueName(KernelMetric::Invocations));
247-
deviceIds.insert({deviceType, {deviceId}});
247+
deviceIds[deviceType].insert(deviceId);
248248
} else if (metricKind == MetricKind::PCSampling) {
249249
auto pcSamplingMetric =
250250
std::dynamic_pointer_cast<PCSamplingMetric>(metric);
@@ -276,7 +276,7 @@ void TreeData::dumpHatchet(std::ostream &os) const {
276276
(*jsonNode)["metrics"]
277277
[cycleMetric->getValueName(CycleMetric::DeviceType)] =
278278
std::to_string(deviceType);
279-
deviceIds.insert({deviceType, {deviceId}});
279+
deviceIds[deviceType].insert(deviceId);
280280
} else if (metricKind == MetricKind::Flexible) {
281281
// Flexible metrics are handled in a different way
282282
} else {
@@ -313,12 +313,12 @@ void TreeData::dumpHatchet(std::ostream &os) const {
313313
// problems
314314
output.push_back(json::object());
315315
auto &deviceJson = output.back();
316-
for (auto [deviceType, deviceIds] : deviceIds) {
316+
for (auto [deviceType, deviceIdSet] : deviceIds) {
317317
auto deviceTypeName =
318318
getDeviceTypeString(static_cast<DeviceType>(deviceType));
319319
if (!deviceJson.contains(deviceTypeName))
320320
deviceJson[deviceTypeName] = json::object();
321-
for (auto deviceId : deviceIds) {
321+
for (auto deviceId : deviceIdSet) {
322322
Device device = getDevice(static_cast<DeviceType>(deviceType), deviceId);
323323
deviceJson[deviceTypeName][std::to_string(deviceId)] = {
324324
{"clock_rate", device.clockRate},

third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -268,53 +268,73 @@ struct GraphState {
268268
size_t numInstances{1};
269269
};
270270

271-
// FIXME: it should be a per-stream queue in case we capture graphs from
272-
// different streams or different devices
271+
// Track pending graphs per device so flushing a single device won't drain
272+
// graphs from other devices.
273273
class PendingGraphQueue {
274274
public:
275+
explicit PendingGraphQueue(Runtime *runtime) : runtime(runtime) {}
276+
275277
struct PendingGraph {
276278
size_t externId;
277279
std::map<Data *, std::vector<std::pair<bool, size_t>>> dataToScopeIds;
278280
size_t numMetricNodes;
279281
};
280282
using PopResult = std::pair<size_t, std::vector<PendingGraph>>;
281283

282-
PendingGraphQueue() = default;
283-
284284
void push(size_t externId,
285285
const std::map<Data *, std::vector<std::pair<bool, size_t>>>
286286
&dataToScopeIds,
287287
size_t numNodes) {
288288
std::lock_guard<std::mutex> lock(mutex);
289-
pendingGraphs.push_back(PendingGraph{externId, dataToScopeIds, numNodes});
290-
this->totalNumNodes += numNodes;
289+
auto device = runtime->getDevice();
290+
auto &queue = deviceQueues[device];
291+
queue.pendingGraphs.push_back(
292+
PendingGraph{externId, dataToScopeIds, numNodes});
293+
queue.totalNumNodes += numNodes;
291294
}
292295

293-
PopResult popAllIfReachCapacity(size_t numNewNodes, size_t capacity) {
296+
PopResult pop(size_t numNewNodes, size_t capacity) {
294297
std::lock_guard<std::mutex> lock(mutex);
295-
if ((this->totalNumNodes + numNewNodes) * 2 * sizeof(uint64_t) <=
298+
if (deviceQueues.empty()) {
299+
return {0, {}};
300+
}
301+
auto device = runtime->getDevice();
302+
auto &queue = deviceQueues[device];
303+
if ((queue.totalNumNodes + numNewNodes) * 2 * sizeof(uint64_t) <=
296304
capacity) {
297305
return {0, {}};
298306
}
299-
return popAllLocked();
307+
return popLocked(queue);
300308
}
301309

302-
PopResult popAll() {
310+
std::vector<PopResult> popAll() {
303311
std::lock_guard<std::mutex> lock(mutex);
304-
return popAllLocked();
312+
if (deviceQueues.empty()) {
313+
return {{0, {}}};
314+
}
315+
std::vector<PopResult> results;
316+
for (auto &[device, queue] : deviceQueues) {
317+
results.emplace_back(popLocked(queue));
318+
}
319+
return results;
305320
}
306321

307322
private:
308-
PopResult popAllLocked() {
323+
struct Queue {
324+
size_t totalNumNodes{};
325+
std::vector<PendingGraph> pendingGraphs;
326+
};
327+
328+
PopResult popLocked(Queue &queue) {
309329
std::vector<PendingGraph> items;
310-
items.swap(pendingGraphs);
311-
size_t numNodes = totalNumNodes;
312-
totalNumNodes = 0;
330+
items.swap(queue.pendingGraphs);
331+
size_t numNodes = queue.totalNumNodes;
332+
queue.totalNumNodes = 0;
313333
return {numNodes, items};
314334
}
315335

316-
size_t totalNumNodes{};
317-
std::vector<PendingGraph> pendingGraphs;
336+
Runtime *runtime{};
337+
std::map<void *, Queue> deviceQueues;
318338
mutable std::mutex mutex;
319339
};
320340

@@ -323,7 +343,8 @@ class PendingGraphQueue {
323343
struct CuptiProfiler::CuptiProfilerPimpl
324344
: public GPUProfiler<CuptiProfiler>::GPUProfilerPimplInterface {
325345
CuptiProfilerPimpl(CuptiProfiler &profiler)
326-
: GPUProfiler<CuptiProfiler>::GPUProfilerPimplInterface(profiler) {
346+
: GPUProfiler<CuptiProfiler>::GPUProfilerPimplInterface(profiler),
347+
pendingGraphQueue(&CudaRuntime::instance()) {
327348
runtime = &CudaRuntime::instance();
328349
metricBuffer = std::make_unique<MetricBuffer>(1024 * 1024 * 64, runtime);
329350
}
@@ -619,8 +640,33 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
619640
}
620641
}
621642
}
643+
}
644+
}
645+
profiler.correlation.correlate(callbackData->correlationId, numInstances);
646+
if (profiler.pcSamplingEnabled && isDriverAPILaunch(cbId)) {
647+
pImpl->pcSampling.start(callbackData->context);
648+
}
649+
} else if (callbackData->callbackSite == CUPTI_API_EXIT) {
650+
auto externId = profiler.correlation.externIdQueue.back();
651+
if (profiler.pcSamplingEnabled && isDriverAPILaunch(cbId)) {
652+
// XXX: Conservatively stop every GPU kernel for now
653+
pImpl->pcSampling.stop(
654+
callbackData->context, externId,
655+
profiler.correlation.apiExternIds.contain(externId));
656+
}
657+
if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch ||
658+
cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz) {
659+
// Cuda context can be lazily initialized, so we need to call device get
660+
// here after the first kernel is launched
661+
auto graphExec = static_cast<const cuGraphLaunch_params *>(
662+
callbackData->functionParams)
663+
->hGraph;
664+
uint32_t graphExecId = 0;
665+
cupti::getGraphExecId<true>(graphExec, &graphExecId);
666+
if (pImpl->graphStates.contain(graphExecId)) {
622667
std::map<Data *, std::vector<std::pair<bool, size_t>>>
623668
metricNodeScopes;
669+
auto dataSet = profiler.getDataSet();
624670
for (auto *data : dataSet) {
625671
auto &nodeToScopeId =
626672
profiler.correlation.externIdToGraphNodeScopeId[externId][data];
@@ -639,8 +685,8 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
639685
pImpl->metricBuffer->getCapacity(); // bytes
640686
auto metricNodeCount =
641687
pImpl->graphStates[graphExecId].metricKernelNodeIds.size();
642-
auto drained = pImpl->pendingGraphQueue.popAllIfReachCapacity(
643-
metricNodeCount, metricBufferCapacity);
688+
auto drained = pImpl->pendingGraphQueue.pop(metricNodeCount,
689+
metricBufferCapacity);
644690
if (drained.first != 0) { // Reached capacity
645691
pImpl->metricBuffer->flush([&](uint8_t *data, size_t dataSize) {
646692
auto *recordPtr = reinterpret_cast<uint64_t *>(data);
@@ -651,18 +697,6 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
651697
metricNodeCount);
652698
}
653699
}
654-
profiler.correlation.correlate(callbackData->correlationId, numInstances);
655-
if (profiler.pcSamplingEnabled && isDriverAPILaunch(cbId)) {
656-
pImpl->pcSampling.start(callbackData->context);
657-
}
658-
} else if (callbackData->callbackSite == CUPTI_API_EXIT) {
659-
if (profiler.pcSamplingEnabled && isDriverAPILaunch(cbId)) {
660-
// XXX: Conservatively stop every GPU kernel for now
661-
auto scopeId = profiler.correlation.externIdQueue.back();
662-
pImpl->pcSampling.stop(
663-
callbackData->context, scopeId,
664-
profiler.correlation.apiExternIds.contain(scopeId));
665-
}
666700
threadState.exitOp();
667701
profiler.correlation.submit(callbackData->correlationId);
668702
}
@@ -713,14 +747,17 @@ void CuptiProfiler::CuptiProfilerPimpl::doFlush() {
713747
// new activities.
714748
cupti::activityFlushAll<true>(/*flag=*/CUPTI_ACTIVITY_FLAG_FLUSH_FORCED);
715749
// Flush the tensor metric buffer
716-
auto dataSet = profiler.getDataSet();
717750
auto popResult = pendingGraphQueue.popAll();
718-
metricBuffer->flush(
719-
[&](uint8_t *data, size_t dataSize) {
720-
auto *recordPtr = reinterpret_cast<uint64_t *>(data);
721-
emitMetricRecords(recordPtr, popResult.second);
722-
},
723-
/*flushAll=*/true);
751+
if (!popResult.empty()) {
752+
auto resultIdx = 0;
753+
metricBuffer->flush(
754+
[&](uint8_t *data, size_t dataSize) {
755+
auto *recordPtr = reinterpret_cast<uint64_t *>(data);
756+
emitMetricRecords(recordPtr, popResult[resultIdx].second);
757+
resultIdx++;
758+
},
759+
/*flushAll=*/true);
760+
}
724761
}
725762

726763
void CuptiProfiler::CuptiProfilerPimpl::doStop() {

third_party/proton/test/test_profile.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,94 @@ def fn():
669669
assert scope_a_frame["metrics"]["bytes"] == 160
670670
assert scope_b_frame is not None
671671
assert scope_b_frame["metrics"]["sum"] == 40.0
672+
673+
674+
@pytest.mark.skipif(is_hip(), reason="HIP backend does not support metrics profiling in cudagraphs")
675+
def test_tensor_metrics_multi_device_cudagraph(tmp_path: pathlib.Path):
676+
if torch.cuda.device_count() < 2:
677+
pytest.skip("Requires at least two CUDA devices")
678+
679+
devices = [torch.device(f"cuda:{i}") for i in range(2)]
680+
streams = []
681+
for device in devices:
682+
with torch.cuda.device(device):
683+
streams.append(torch.cuda.Stream(device=device))
684+
685+
def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict):
686+
x = args["x"]
687+
x_sum = x.sum()
688+
device_idx = x.device.index
689+
return {"name": f"foo_test_{device_idx}", "bytes": x.numel() * x.element_size(), "flops": x_sum}
690+
691+
@triton.jit(launch_metadata=metadata_fn)
692+
def foo(x, y, z):
693+
tl.store(z, tl.load(y) + tl.load(x))
694+
695+
def run_on_device(device_id):
696+
with proton.scope(f"scope_a_{device_id}", metrics={"bytes": 4 * 4}):
697+
a = torch.ones((2, 2), device=f"cuda:{device_id}")
698+
with proton.metadata_state():
699+
a_sum = a.sum()
700+
with proton.scope(f"scope_b_{device_id}", metrics={"sum": a_sum}):
701+
b = torch.ones((2, 2), device=f"cuda:{device_id}")
702+
c = a + b
703+
foo[(1, )](a, b, c)
704+
705+
temp_file = tmp_path / "test_tensor_metrics_multi_device_cudagraph.hatchet"
706+
proton.start(str(temp_file.with_suffix("")), context="shadow", hook="triton")
707+
708+
graphs = []
709+
for device, stream in zip(devices, streams):
710+
with torch.cuda.device(device):
711+
torch.cuda.set_stream(stream)
712+
# warmup
713+
run_on_device(device.index)
714+
# graph capture
715+
g = torch.cuda.CUDAGraph()
716+
with torch.cuda.graph(g, stream=stream):
717+
for _ in range(10):
718+
run_on_device(device.index)
719+
graphs.append((device, stream, g))
720+
721+
for device, stream, graph in graphs:
722+
with torch.cuda.device(device):
723+
torch.cuda.set_stream(stream)
724+
with proton.scope(f"test_device_{device.index}"):
725+
graph.replay()
726+
727+
proton.finalize()
728+
729+
with temp_file.open() as f:
730+
data = json.load(f)
731+
732+
children = data[0]["children"]
733+
for device in devices:
734+
device_name = f"test_device_{device.index}"
735+
launch_frame = next((child for child in children if child["frame"]["name"] == device_name), None)
736+
assert launch_frame is not None
737+
capture_at_frame = launch_frame["children"][0]
738+
assert capture_at_frame["frame"]["name"] == "<captured_at>"
739+
740+
foo_frame = None
741+
scope_a_frame = None
742+
scope_b_frame = None
743+
for child in capture_at_frame["children"]:
744+
if child["frame"]["name"] == f"foo_test_{device.index}":
745+
foo_frame = child
746+
if child["frame"]["name"] == f"scope_a_{device.index}":
747+
scope_a_frame = child
748+
if child["frame"]["name"] == f"scope_b_{device.index}":
749+
scope_b_frame = child
750+
751+
assert foo_frame is not None
752+
assert scope_a_frame is not None
753+
assert scope_b_frame is not None
754+
assert foo_frame["metrics"]["bytes"] == 160
755+
assert foo_frame["metrics"]["flops"] == 40
756+
assert foo_frame["metrics"]["device_id"] == str(device.index)
757+
assert scope_a_frame["metrics"]["bytes"] == 160
758+
assert scope_b_frame["metrics"]["sum"] == 40.0
759+
760+
assert len(data) > 1
761+
cuda_devices = data[1].get("CUDA", {})
762+
assert len(cuda_devices) >= 2

0 commit comments

Comments
 (0)