Skip to content

Commit a58f421

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA][CUDNN][SDPA] Pass dropout seed and offset to cuDNN in int64 (pytorch#146734)
Workaround for limitation in cuDNN that does not accept dropout seed/offset in `int32` for SM 10.0 kernels. Pull Request resolved: pytorch#146734 Approved by: https://github.com/Skylion007
1 parent 281249b commit a58f421

File tree

1 file changed

+43
-12
lines changed

1 file changed

+43
-12
lines changed

aten/src/ATen/native/cudnn/MHA.cpp

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,18 @@ auto build_graph_and_tensors(
418418
.set_name("Seed")
419419
.set_dim({1, 1, 1, 1})
420420
.set_stride({1, 1, 1, 1})
421-
.set_data_type(fe::DataType_t::INT32));
421+
.set_data_type(
422+
dropoutseed.dtype() == kInt
423+
? fe::DataType_t::INT32
424+
: fe::DataType_t::INT64));
422425
auto offset = mha_graph->tensor(fe::graph::Tensor_attributes()
423426
.set_name("Offset")
424427
.set_dim({1, 1, 1, 1})
425428
.set_stride({1, 1, 1, 1})
426-
.set_data_type(fe::DataType_t::INT32));
429+
.set_data_type(
430+
dropoutoffset.dtype() == kInt
431+
? fe::DataType_t::INT32
432+
: fe::DataType_t::INT64));
427433
auto scaled_dot_product_flash_attention_options =
428434
fe::graph::SDPA_attributes()
429435
.set_name("CUDNN_SDPA")
@@ -564,12 +570,20 @@ auto build_graph_and_tensors_backward(
564570
.set_name("Seed")
565571
.set_dim({1, 1, 1, 1})
566572
.set_stride({1, 1, 1, 1})
567-
.set_data_type(fe::DataType_t::INT32));
573+
.set_data_type(
574+
dropoutseed.dtype() == kInt
575+
? fe::DataType_t::INT32
576+
: fe::DataType_t::INT64));
577+
568578
auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes()
569579
.set_name("Offset")
570580
.set_dim({1, 1, 1, 1})
571581
.set_stride({1, 1, 1, 1})
572-
.set_data_type(fe::DataType_t::INT32));
582+
.set_data_type(
583+
dropoutoffset.dtype() == kInt
584+
? fe::DataType_t::INT32
585+
: fe::DataType_t::INT64));
586+
573587
auto O = mha_graph->tensor(fe::graph::Tensor_attributes()
574588
.set_name("O")
575589
.set_dim(o.sizes().vec())
@@ -633,6 +647,15 @@ void run_cudnn_SDP_fprop(
633647
Tensor& o,
634648
Tensor& dropoutseed,
635649
Tensor& dropoutoffset) {
650+
const auto dprops = at::cuda::getCurrentDeviceProperties();
651+
auto _dropoutseed = dropoutseed;
652+
auto _dropoutoffset = dropoutoffset;
653+
// cuDNN dropout bug requires these to be in int64
654+
if (dprops->major == 10 && dprops->minor == 0) {
655+
_dropoutseed = dropoutseed.to(kLong);
656+
_dropoutoffset = dropoutoffset.to(kLong);
657+
}
658+
636659
cudnnHandle_t handle = getCudnnHandle();
637660
if (!o.defined()) {
638661
// q is passed to us in BHSD dim order
@@ -685,8 +708,8 @@ void run_cudnn_SDP_fprop(
685708
attn_bias,
686709
softmaxstats,
687710
o,
688-
dropoutseed,
689-
dropoutoffset,
711+
_dropoutseed,
712+
_dropoutoffset,
690713
handle);
691714
}
692715
auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] =
@@ -697,8 +720,8 @@ void run_cudnn_SDP_fprop(
697720
{K, k.data_ptr()},
698721
{V, v.data_ptr()},
699722
{attn_scale, &scaling_factor},
700-
{seed, dropoutseed.data_ptr()},
701-
{offset, dropoutoffset.data_ptr()},
723+
{seed, _dropoutseed.data_ptr()},
724+
{offset, _dropoutoffset.data_ptr()},
702725
{O, o.data_ptr()}};
703726
if (return_softmaxstats) {
704727
variant_pack[Stats] = softmaxstats.data_ptr();
@@ -741,6 +764,14 @@ void run_cudnn_SDP_bprop(
741764
!softmaxstats.numel()) {
742765
return;
743766
}
767+
auto dprops = at::cuda::getCurrentDeviceProperties();
768+
auto _dropoutseed = dropoutseed;
769+
auto _dropoutoffset = dropoutoffset;
770+
// cuDNN dropout bug requires these to be in int64
771+
if (dprops->major == 10 && dprops->minor == 0) {
772+
_dropoutseed = dropoutseed.to(kLong);
773+
_dropoutoffset = dropoutoffset.to(kLong);
774+
}
744775

745776
Tensor dO_ = dO;
746777
// cuDNN < 9.5.1 assumes gradOutput has same strides as Output
@@ -803,8 +834,8 @@ void run_cudnn_SDP_bprop(
803834
dQ,
804835
dK,
805836
dV,
806-
dropoutseed,
807-
dropoutoffset,
837+
_dropoutseed,
838+
_dropoutoffset,
808839
handle);
809840
}
810841
auto
@@ -837,8 +868,8 @@ void run_cudnn_SDP_bprop(
837868
// pass by value
838869
{attn_scale, &scaling_factor}};
839870
if (dropout_probability != 0.0f) {
840-
variant_pack[Seed] = dropoutseed.data_ptr();
841-
variant_pack[Offset] = dropoutoffset.data_ptr();
871+
variant_pack[Seed] = _dropoutseed.data_ptr();
872+
variant_pack[Offset] = _dropoutoffset.data_ptr();
842873
}
843874
if (attn_bias.has_value()) {
844875
variant_pack[bias.value()] = attn_bias.value().data_ptr();

0 commit comments

Comments
 (0)