@@ -519,11 +519,9 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
519519 // DEFAULT_FLAGS = cudaEventDisableTiming.
520520 if (cudaEventCacheEnabled) {
521521 ncclStartEvent_ = enableTiming
522- ? ProcessGroupNCCL::CUDAEventCache::get (device.index ())
523- ->create (enableTiming)
522+ ? CUDAEventCache::get (device.index ())->create (enableTiming)
524523 : nullptr ;
525- ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get (device.index ())
526- ->create (enableTiming);
524+ ncclEndEvent_ = CUDAEventCache::get (device.index ())->create (enableTiming);
527525 } else {
528526 ncclStartEvent_ = enableTiming
529527 ? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault)
@@ -860,61 +858,6 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
860858 }
861859}
862860
863- ProcessGroupNCCL::CUDAEventCache::CUDAEventCache () = default;
864-
865- // CUDA event is used to record the start/end of one Work.
866- // Instead of let the CUDA event gets destroyed, we now reuse it after the Work
867- // has been erased from workMetaList_.
868- // This is to avoid the potential deadlock caused by CudaEventDestroy.
869- std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create (
870- bool timing) {
871- // Register the deleter as a callback when the WorkNCCL object is destroyed.
872- // Each deleter keeps a ref count to the cache object, so that even when
873- // the thread that creates the cache is gone, the cache object won't be
874- // destroyed until all the events in the cache are destroyed (ref number drops
875- // to zero).
876- auto deleter = [cache = shared_from_this (),
877- timing](at::cuda::CUDAEvent* event) {
878- std::lock_guard<std::mutex> lock (cache->cacheMutex_ );
879- // We put the event back to the cache deque once the WorkNCCL object is
880- // destroyed.
881- cache->eventsArray_ [timing ? 1 : 0 ].push_back (event);
882- };
883- at::cuda::CUDAEvent* event = nullptr ;
884- {
885- std::lock_guard<std::mutex> lock (cacheMutex_);
886- auto & events = eventsArray_[timing ? 1 : 0 ];
887- // If we still have events in the cache, we reuse it. Otherwise, we create a
888- // new one.
889- if (!events.empty ()) {
890- event = events.front ();
891- events.pop_front ();
892- } else {
893- event = new at::cuda::CUDAEvent (
894- timing ? cudaEventDefault : cudaEventDisableTiming);
895- }
896- }
897- return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move (deleter));
898- }
899-
900- std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> ProcessGroupNCCL::
901- CUDAEventCache::get (at::DeviceIndex device) {
902- // A per-thread singleton of device-to-CUDAEventCache map.
903- // Map is needed because events cannot be reused across devices.
904- // Per-thread ownership is needed to support multi-threaded case (instead of
905- // multi-process case).
906- static thread_local std::
907- map<at::DeviceIndex, std::shared_ptr<ProcessGroupNCCL::CUDAEventCache>>
908- cacheDeviceMap;
909- // Check if device has already been in the map, if not, add a new entry
910- auto it = cacheDeviceMap.find (device);
911- if (it == cacheDeviceMap.end ()) {
912- cacheDeviceMap.emplace (
913- device, std::make_shared<ProcessGroupNCCL::CUDAEventCache>());
914- }
915- return cacheDeviceMap[device];
916- }
917-
918861static std::atomic<size_t > process_group_id = 0 ;
919862
920863constexpr const char * MULTI_DEVICE_ERROR_MSG =
0 commit comments