@@ -268,53 +268,73 @@ struct GraphState {
268268 size_t numInstances{1 };
269269};
270270
271- // FIXME: it should be a per-stream queue in case we capture graphs from
272- // different streams or different devices
271+ // Track pending graphs per device so flushing a single device won't drain
272+ // graphs from other devices.
273273class PendingGraphQueue {
274274public:
275+ explicit PendingGraphQueue (Runtime *runtime) : runtime(runtime) {}
276+
275277 struct PendingGraph {
276278 size_t externId;
277279 std::map<Data *, std::vector<std::pair<bool , size_t >>> dataToScopeIds;
278280 size_t numMetricNodes;
279281 };
280282 using PopResult = std::pair<size_t , std::vector<PendingGraph>>;
281283
282- PendingGraphQueue () = default ;
283-
284284 void push (size_t externId,
285285 const std::map<Data *, std::vector<std::pair<bool , size_t >>>
286286 &dataToScopeIds,
287287 size_t numNodes) {
288288 std::lock_guard<std::mutex> lock (mutex);
289- pendingGraphs.push_back (PendingGraph{externId, dataToScopeIds, numNodes});
290- this ->totalNumNodes += numNodes;
289+ auto device = runtime->getDevice ();
290+ auto &queue = deviceQueues[device];
291+ queue.pendingGraphs .push_back (
292+ PendingGraph{externId, dataToScopeIds, numNodes});
293+ queue.totalNumNodes += numNodes;
291294 }
292295
293- PopResult popAllIfReachCapacity (size_t numNewNodes, size_t capacity) {
296+ PopResult pop (size_t numNewNodes, size_t capacity) {
294297 std::lock_guard<std::mutex> lock (mutex);
295- if ((this ->totalNumNodes + numNewNodes) * 2 * sizeof (uint64_t ) <=
298+ if (deviceQueues.empty ()) {
299+ return {0 , {}};
300+ }
301+ auto device = runtime->getDevice ();
302+ auto &queue = deviceQueues[device];
303+ if ((queue.totalNumNodes + numNewNodes) * 2 * sizeof (uint64_t ) <=
296304 capacity) {
297305 return {0 , {}};
298306 }
299- return popAllLocked ( );
307+ return popLocked (queue );
300308 }
301309
302- PopResult popAll () {
310+ std::vector< PopResult> popAll () {
303311 std::lock_guard<std::mutex> lock (mutex);
304- return popAllLocked ();
312+ if (deviceQueues.empty ()) {
313+ return {{0 , {}}};
314+ }
315+ std::vector<PopResult> results;
316+ for (auto &[device, queue] : deviceQueues) {
317+ results.emplace_back (popLocked (queue));
318+ }
319+ return results;
305320 }
306321
307322private:
308- PopResult popAllLocked () {
323+ struct Queue {
324+ size_t totalNumNodes{};
325+ std::vector<PendingGraph> pendingGraphs;
326+ };
327+
328+ PopResult popLocked (Queue &queue) {
309329 std::vector<PendingGraph> items;
310- items.swap (pendingGraphs);
311- size_t numNodes = totalNumNodes;
312- totalNumNodes = 0 ;
330+ items.swap (queue. pendingGraphs );
331+ size_t numNodes = queue. totalNumNodes ;
332+ queue. totalNumNodes = 0 ;
313333 return {numNodes, items};
314334 }
315335
316- size_t totalNumNodes {};
317- std::vector<PendingGraph> pendingGraphs ;
336+ Runtime *runtime {};
337+ std::map< void *, Queue> deviceQueues ;
318338 mutable std::mutex mutex;
319339};
320340
@@ -323,7 +343,8 @@ class PendingGraphQueue {
323343struct CuptiProfiler ::CuptiProfilerPimpl
324344 : public GPUProfiler<CuptiProfiler>::GPUProfilerPimplInterface {
325345 CuptiProfilerPimpl (CuptiProfiler &profiler)
326- : GPUProfiler<CuptiProfiler>::GPUProfilerPimplInterface(profiler) {
346+ : GPUProfiler<CuptiProfiler>::GPUProfilerPimplInterface(profiler),
347+ pendingGraphQueue (&CudaRuntime::instance ()) {
327348 runtime = &CudaRuntime::instance ();
328349 metricBuffer = std::make_unique<MetricBuffer>(1024 * 1024 * 64 , runtime);
329350 }
@@ -619,8 +640,33 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
619640 }
620641 }
621642 }
643+ }
644+ }
645+ profiler.correlation .correlate (callbackData->correlationId , numInstances);
646+ if (profiler.pcSamplingEnabled && isDriverAPILaunch (cbId)) {
647+ pImpl->pcSampling .start (callbackData->context );
648+ }
649+ } else if (callbackData->callbackSite == CUPTI_API_EXIT) {
650+ auto externId = profiler.correlation .externIdQueue .back ();
651+ if (profiler.pcSamplingEnabled && isDriverAPILaunch (cbId)) {
652+ // XXX: Conservatively stop every GPU kernel for now
653+ pImpl->pcSampling .stop (
654+ callbackData->context , externId,
655+ profiler.correlation .apiExternIds .contain (externId));
656+ }
657+ if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch ||
658+ cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz) {
659+ // Cuda context can be lazily initialized, so we need to call device get
660+ // here after the first kernel is launched
661+ auto graphExec = static_cast <const cuGraphLaunch_params *>(
662+ callbackData->functionParams )
663+ ->hGraph ;
664+ uint32_t graphExecId = 0 ;
665+ cupti::getGraphExecId<true >(graphExec, &graphExecId);
666+ if (pImpl->graphStates .contain (graphExecId)) {
622667 std::map<Data *, std::vector<std::pair<bool , size_t >>>
623668 metricNodeScopes;
669+ auto dataSet = profiler.getDataSet ();
624670 for (auto *data : dataSet) {
625671 auto &nodeToScopeId =
626672 profiler.correlation .externIdToGraphNodeScopeId [externId][data];
@@ -639,8 +685,8 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
639685 pImpl->metricBuffer ->getCapacity (); // bytes
640686 auto metricNodeCount =
641687 pImpl->graphStates [graphExecId].metricKernelNodeIds .size ();
642- auto drained = pImpl->pendingGraphQueue .popAllIfReachCapacity (
643- metricNodeCount, metricBufferCapacity);
688+ auto drained = pImpl->pendingGraphQueue .pop (metricNodeCount,
689+ metricBufferCapacity);
644690 if (drained.first != 0 ) { // Reached capacity
645691 pImpl->metricBuffer ->flush ([&](uint8_t *data, size_t dataSize) {
646692 auto *recordPtr = reinterpret_cast <uint64_t *>(data);
@@ -651,18 +697,6 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData,
651697 metricNodeCount);
652698 }
653699 }
654- profiler.correlation .correlate (callbackData->correlationId , numInstances);
655- if (profiler.pcSamplingEnabled && isDriverAPILaunch (cbId)) {
656- pImpl->pcSampling .start (callbackData->context );
657- }
658- } else if (callbackData->callbackSite == CUPTI_API_EXIT) {
659- if (profiler.pcSamplingEnabled && isDriverAPILaunch (cbId)) {
660- // XXX: Conservatively stop every GPU kernel for now
661- auto scopeId = profiler.correlation .externIdQueue .back ();
662- pImpl->pcSampling .stop (
663- callbackData->context , scopeId,
664- profiler.correlation .apiExternIds .contain (scopeId));
665- }
666700 threadState.exitOp ();
667701 profiler.correlation .submit (callbackData->correlationId );
668702 }
@@ -713,14 +747,17 @@ void CuptiProfiler::CuptiProfilerPimpl::doFlush() {
713747 // new activities.
714748 cupti::activityFlushAll<true >(/* flag=*/ CUPTI_ACTIVITY_FLAG_FLUSH_FORCED);
715749 // Flush the tensor metric buffer
716- auto dataSet = profiler.getDataSet ();
717750 auto popResult = pendingGraphQueue.popAll ();
718- metricBuffer->flush (
719- [&](uint8_t *data, size_t dataSize) {
720- auto *recordPtr = reinterpret_cast <uint64_t *>(data);
721- emitMetricRecords (recordPtr, popResult.second );
722- },
723- /* flushAll=*/ true );
751+ if (!popResult.empty ()) {
752+ auto resultIdx = 0 ;
753+ metricBuffer->flush (
754+ [&](uint8_t *data, size_t dataSize) {
755+ auto *recordPtr = reinterpret_cast <uint64_t *>(data);
756+ emitMetricRecords (recordPtr, popResult[resultIdx].second );
757+ resultIdx++;
758+ },
759+ /* flushAll=*/ true );
760+ }
724761}
725762
726763void CuptiProfiler::CuptiProfilerPimpl::doStop () {
0 commit comments