@@ -418,12 +418,18 @@ auto build_graph_and_tensors(
418418 .set_name (" Seed" )
419419 .set_dim ({1 , 1 , 1 , 1 })
420420 .set_stride ({1 , 1 , 1 , 1 })
421- .set_data_type (fe::DataType_t::INT32));
421+ .set_data_type (
422+ dropoutseed.dtype () == kInt
423+ ? fe::DataType_t::INT32
424+ : fe::DataType_t::INT64));
422425 auto offset = mha_graph->tensor (fe::graph::Tensor_attributes ()
423426 .set_name (" Offset" )
424427 .set_dim ({1 , 1 , 1 , 1 })
425428 .set_stride ({1 , 1 , 1 , 1 })
426- .set_data_type (fe::DataType_t::INT32));
429+ .set_data_type (
430+ dropoutoffset.dtype () == kInt
431+ ? fe::DataType_t::INT32
432+ : fe::DataType_t::INT64));
427433 auto scaled_dot_product_flash_attention_options =
428434 fe::graph::SDPA_attributes ()
429435 .set_name (" CUDNN_SDPA" )
@@ -564,12 +570,20 @@ auto build_graph_and_tensors_backward(
564570 .set_name (" Seed" )
565571 .set_dim ({1 , 1 , 1 , 1 })
566572 .set_stride ({1 , 1 , 1 , 1 })
567- .set_data_type (fe::DataType_t::INT32));
573+ .set_data_type (
574+ dropoutseed.dtype () == kInt
575+ ? fe::DataType_t::INT32
576+ : fe::DataType_t::INT64));
577+
568578 auto Offset = mha_graph->tensor (fe::graph::Tensor_attributes ()
569579 .set_name (" Offset" )
570580 .set_dim ({1 , 1 , 1 , 1 })
571581 .set_stride ({1 , 1 , 1 , 1 })
572- .set_data_type (fe::DataType_t::INT32));
582+ .set_data_type (
583+ dropoutoffset.dtype () == kInt
584+ ? fe::DataType_t::INT32
585+ : fe::DataType_t::INT64));
586+
573587 auto O = mha_graph->tensor (fe::graph::Tensor_attributes ()
574588 .set_name (" O" )
575589 .set_dim (o.sizes ().vec ())
@@ -633,6 +647,15 @@ void run_cudnn_SDP_fprop(
633647 Tensor& o,
634648 Tensor& dropoutseed,
635649 Tensor& dropoutoffset) {
650+ const auto dprops = at::cuda::getCurrentDeviceProperties ();
651+ auto _dropoutseed = dropoutseed;
652+ auto _dropoutoffset = dropoutoffset;
653+ // cuDNN dropout bug requires these to be in int64
654+ if (dprops->major == 10 && dprops->minor == 0 ) {
655+ _dropoutseed = dropoutseed.to (kLong );
656+ _dropoutoffset = dropoutoffset.to (kLong );
657+ }
658+
636659 cudnnHandle_t handle = getCudnnHandle ();
637660 if (!o.defined ()) {
638661 // q is passed to us in BHSD dim order
@@ -685,8 +708,8 @@ void run_cudnn_SDP_fprop(
685708 attn_bias,
686709 softmaxstats,
687710 o,
688- dropoutseed ,
689- dropoutoffset ,
711+ _dropoutseed ,
712+ _dropoutoffset ,
690713 handle);
691714 }
692715 auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] =
@@ -697,8 +720,8 @@ void run_cudnn_SDP_fprop(
697720 {K, k.data_ptr ()},
698721 {V, v.data_ptr ()},
699722 {attn_scale, &scaling_factor},
700- {seed, dropoutseed .data_ptr ()},
701- {offset, dropoutoffset .data_ptr ()},
723+ {seed, _dropoutseed .data_ptr ()},
724+ {offset, _dropoutoffset .data_ptr ()},
702725 {O, o.data_ptr ()}};
703726 if (return_softmaxstats) {
704727 variant_pack[Stats] = softmaxstats.data_ptr ();
@@ -741,6 +764,14 @@ void run_cudnn_SDP_bprop(
741764 !softmaxstats.numel ()) {
742765 return ;
743766 }
767+ auto dprops = at::cuda::getCurrentDeviceProperties ();
768+ auto _dropoutseed = dropoutseed;
769+ auto _dropoutoffset = dropoutoffset;
770+ // cuDNN dropout bug requires these to be in int64
771+ if (dprops->major == 10 && dprops->minor == 0 ) {
772+ _dropoutseed = dropoutseed.to (kLong );
773+ _dropoutoffset = dropoutoffset.to (kLong );
774+ }
744775
745776 Tensor dO_ = dO;
746777// cuDNN < 9.5.1 assumes gradOutput has same strides as Output
@@ -803,8 +834,8 @@ void run_cudnn_SDP_bprop(
803834 dQ,
804835 dK,
805836 dV,
806- dropoutseed ,
807- dropoutoffset ,
837+ _dropoutseed ,
838+ _dropoutoffset ,
808839 handle);
809840 }
810841 auto
@@ -837,8 +868,8 @@ void run_cudnn_SDP_bprop(
837868 // pass by value
838869 {attn_scale, &scaling_factor}};
839870 if (dropout_probability != 0 .0f ) {
840- variant_pack[Seed] = dropoutseed .data_ptr ();
841- variant_pack[Offset] = dropoutoffset .data_ptr ();
871+ variant_pack[Seed] = _dropoutseed .data_ptr ();
872+ variant_pack[Offset] = _dropoutoffset .data_ptr ();
842873 }
843874 if (attn_bias.has_value ()) {
844875 variant_pack[bias.value ()] = attn_bias.value ().data_ptr ();
0 commit comments