Skip to content

Commit dbc93d4

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Fix overflow in attention workspace calculation (pytorch#4939)
Summary: X-link: facebookresearch/FBGEMM#1961 Pull Request resolved: pytorch#4939 During workspace calculation some of the additions risked overflowing int32 and returning an invalid value. This diff makes sure that we're doing higher precision arithmetic in the workspace to prevent this issue. Reviewed By: q10 Differential Revision: D83370379 fbshipit-source-id: b481d16ce4053a21ac19f1060eec6d57bbeb9eda
1 parent 7a1e337 commit dbc93d4

File tree

1 file changed

+5
-4
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device

1 file changed

+5
-4
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,15 @@ class Sm100FmhaBwd {
267267
auto [Q_, K, D, D_VO, HB] = args.problem_shape;
268268
auto [H, B] = product_each(HB);
269269
D = cutlass::round_up(D, 8); // Alignment
270-
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
270+
size_t Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
271271
size_t workspace_bytes = 0;
272+
size_t accum_size = sizeof(ElementAccumulator);
272273
// OdO vector
273-
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
274+
workspace_bytes += static_cast<size_t>(B)*static_cast<size_t>(H)*Q * accum_size;
274275
// scaled LSE vector
275-
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
276+
workspace_bytes += static_cast<size_t>(B)*static_cast<size_t>(H)*Q * accum_size;
276277
// FP32 versions of outputs that are churned (start off with Q only)
277-
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
278+
workspace_bytes += static_cast<size_t>(B)*static_cast<size_t>(H)*Q*static_cast<size_t>(D) * accum_size;
278279
return workspace_bytes;
279280
}
280281

0 commit comments

Comments
 (0)