Skip to content

Commit ef94b5b

Browse files
Updated cudnn executor checker (Lightning-AI#2771)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f6f0496 commit ef94b5b

File tree

2 files changed

+110
-3
lines changed

2 files changed

+110
-3
lines changed

thunder/executors/cudnn_sdpa.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,36 @@ def _cudnn_sdpa_checker(
334334
return False
335335
_, _, _, d_kv = value.size()
336336

337-
# Bug in cudnn 8.9.5 and earlier where embedding dim support is missing
338-
for d in [d_q, d_kv]:
339-
if d % 8 != 0 or d > 128:
337+
# Embedding dim must be divisible by 8.
338+
# Max dimension depends on cuDNN version and GPU arch:
339+
# - Hopper (SM90) with cuDNN 9.x: max is 256
340+
# - Blackwell (SM100) with cuDNN 9.11+: max is 128 (special case: d_qk=192 with d_v=128)
341+
# - Other GPUs: max is 128
342+
# https://github.com/NVIDIA/cudnn-frontend/blob/v1.16.0/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h#L816
343+
if d_q % 8 != 0 or d_kv % 8 != 0:
344+
return False
345+
is_supported_dim = False
346+
if cudnn_backend_version >= 90000:
347+
cc_major = torch.cuda.get_device_capability(query.device.index)[0]
348+
349+
if cc_major == 9: # Hopper
350+
# Validate basic dimension requirements
351+
if 91100 <= cudnn_backend_version < 91300:
352+
if 128 < d_q <= 192 and 64 < d_kv <= 128:
353+
# DeepSeek case, 9.11 only supports 192 hidden dim
354+
if d_kv != 128 and d_q != 192:
355+
return False
356+
357+
is_supported_dim = d_q <= 256 and d_kv <= 256
358+
359+
elif cc_major == 10 and cudnn_backend_version >= 91100: # Blackwell with cuDNN 9.11+
360+
if d_q == 192:
361+
is_supported_dim = d_kv == 128
362+
else:
363+
is_supported_dim = d_q <= 128 and d_kv <= 128
364+
if not is_supported_dim:
365+
# Check fallback for older GPUs or unsupported newer configs
366+
if d_q > 128 or d_kv > 128:
340367
return False
341368

342369
dropout_p = pyval(dropout_p)

thunder/tests/test_cudnn_executor.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,77 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
8585
v_broadcast = torch.as_strided(v, size=v.shape, stride=(0, 0, Ev, 1))
8686
yield SampleInput(q_broadcast, k_broadcast, v_broadcast, None, dropout_p=0.0, is_causal=True)
8787

88+
# Additional dimension test cases for different GPU architectures and cuDNN versions
89+
b, h, s_q, s_kv = 2, 4, 64, 64
90+
91+
# Standard dimensions - should work on all GPUs (d_q <= 128, d_kv <= 128)
92+
standard_dims = [
93+
(64, 64), # standard small
94+
(32, 32), # divisible by 8 small
95+
(96, 96), # divisible by 8 mid
96+
(120, 88), # divisible by 8 asymmetric
97+
]
98+
for d_q, d_kv in standard_dims:
99+
q = make(b, h, s_q, d_q)
100+
k = make(b, h, s_kv, d_q)
101+
v = make(b, h, s_kv, d_kv)
102+
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)
103+
104+
# Larger dimensions - only supported on Hopper (SM90) with cuDNN 9.x
105+
hopper_dims = [
106+
(192, 192), # hopper 192
107+
(256, 256), # hopper max
108+
(256, 128), # hopper asymmetric
109+
]
110+
for d_q, d_kv in hopper_dims:
111+
q = make(b, h, s_q, d_q)
112+
k = make(b, h, s_kv, d_q)
113+
v = make(b, h, s_kv, d_kv)
114+
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)
115+
116+
# DeepSeek-style dimensions (d_q=192, d_kv=128) - Hopper 9.11+ or Blackwell 9.11+
117+
d_q, d_kv = 192, 128
118+
q = make(b, h, s_q, d_q)
119+
k = make(b, h, s_kv, d_q)
120+
v = make(b, h, s_kv, d_kv)
121+
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)
122+
123+
124+
def _should_skip_sdpa_sample(sample) -> str | None:
125+
"""Return a skip reason if the SDPA sample dimensions are not supported on current GPU/cuDNN, else None."""
126+
q, k, v = sample.args[:3]
127+
d_q = q.shape[-1]
128+
d_kv = v.shape[-1]
129+
130+
cudnn_version = cudnn.backend_version()
131+
cc_major = torch.cuda.get_device_capability()[0]
132+
133+
# Standard dimensions (d_q <= 128, d_kv <= 128) - should work on all GPUs
134+
if d_q <= 128 and d_kv <= 128:
135+
return None
136+
137+
# For dimensions > 128, need cuDNN 9.x
138+
if cudnn_version < 90000:
139+
return f"cuDNN 9.x required for dimensions > 128 (d_q={d_q}, d_kv={d_kv})"
140+
141+
# DeepSeek case (d_q=192, d_kv=128)
142+
if d_q == 192 and d_kv == 128:
143+
if cudnn_version < 91100:
144+
return "cuDNN 9.11+ required for DeepSeek dimensions (d_q=192, d_kv=128)"
145+
if cc_major not in (9, 10):
146+
return "Hopper (SM90) or Blackwell (SM100) required for DeepSeek dimensions"
147+
return None
148+
149+
# Larger dimensions (128 < d <= 256) - only Hopper
150+
if d_q > 128 or d_kv > 128:
151+
if cc_major != 9:
152+
return f"Hopper GPU (SM90) required for dimensions > 128 (d_q={d_q}, d_kv={d_kv})"
153+
if d_q > 256 or d_kv > 256:
154+
return f"Dimensions exceed Hopper max of 256 (d_q={d_q}, d_kv={d_kv})"
155+
return None
156+
157+
return None
158+
88159

89160
grad_sdpa_cudnn_opinfo = OpInfo(
90161
thunder.torch.scaled_dot_product_attention,
@@ -196,6 +267,11 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
196267
cfn = thunder.jit(op.op, executors=[cudnn_ex, cudnn_layernorm_ex])
197268

198269
for sample in op.reference_inputs(device, dtype, requires_grad=False):
270+
# Skip SDPA samples with unsupported dimensions for current GPU/cuDNN
271+
if op.name == "grad_forward_scaled_dot_product_attention":
272+
if _should_skip_sdpa_sample(sample):
273+
continue
274+
199275
result = run_snippet(
200276
snippet_torch_consistency,
201277
op,
@@ -225,6 +301,10 @@ def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
225301
_maybe_xfail()
226302

227303
for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True):
304+
# Skip samples with unsupported dimensions for current GPU/cuDNN
305+
if _should_skip_sdpa_sample(sample):
306+
continue
307+
228308
# Enforce tensor arguments are contiguous for torch reference
229309
contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args))
230310

0 commit comments

Comments
 (0)