@@ -85,6 +85,77 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
8585 v_broadcast = torch .as_strided (v , size = v .shape , stride = (0 , 0 , Ev , 1 ))
8686 yield SampleInput (q_broadcast , k_broadcast , v_broadcast , None , dropout_p = 0.0 , is_causal = True )
8787
88+ # Additional dimension test cases for different GPU architectures and cuDNN versions
89+ b , h , s_q , s_kv = 2 , 4 , 64 , 64
90+
91+ # Standard dimensions - should work on all GPUs (d_q <= 128, d_kv <= 128)
92+ standard_dims = [
93+ (64 , 64 ), # standard small
94+ (32 , 32 ), # divisible by 8 small
95+ (96 , 96 ), # divisible by 8 mid
96+ (120 , 88 ), # divisible by 8 asymmetric
97+ ]
98+ for d_q , d_kv in standard_dims :
99+ q = make (b , h , s_q , d_q )
100+ k = make (b , h , s_kv , d_q )
101+ v = make (b , h , s_kv , d_kv )
102+ yield SampleInput (q , k , v , None , dropout_p = 0.0 , is_causal = True )
103+
104+ # Larger dimensions - only supported on Hopper (SM90) with cuDNN 9.x
105+ hopper_dims = [
106+ (192 , 192 ), # hopper 192
107+ (256 , 256 ), # hopper max
108+ (256 , 128 ), # hopper asymmetric
109+ ]
110+ for d_q , d_kv in hopper_dims :
111+ q = make (b , h , s_q , d_q )
112+ k = make (b , h , s_kv , d_q )
113+ v = make (b , h , s_kv , d_kv )
114+ yield SampleInput (q , k , v , None , dropout_p = 0.0 , is_causal = True )
115+
116+ # DeepSeek-style dimensions (d_q=192, d_kv=128) - Hopper 9.11+ or Blackwell 9.11+
117+ d_q , d_kv = 192 , 128
118+ q = make (b , h , s_q , d_q )
119+ k = make (b , h , s_kv , d_q )
120+ v = make (b , h , s_kv , d_kv )
121+ yield SampleInput (q , k , v , None , dropout_p = 0.0 , is_causal = True )
122+
123+
124+ def _should_skip_sdpa_sample (sample ) -> str | None :
125+ """Return a skip reason if the SDPA sample dimensions are not supported on current GPU/cuDNN, else None."""
126+ q , k , v = sample .args [:3 ]
127+ d_q = q .shape [- 1 ]
128+ d_kv = v .shape [- 1 ]
129+
130+ cudnn_version = cudnn .backend_version ()
131+ cc_major = torch .cuda .get_device_capability ()[0 ]
132+
133+ # Standard dimensions (d_q <= 128, d_kv <= 128) - should work on all GPUs
134+ if d_q <= 128 and d_kv <= 128 :
135+ return None
136+
137+ # For dimensions > 128, need cuDNN 9.x
138+ if cudnn_version < 90000 :
139+ return f"cuDNN 9.x required for dimensions > 128 (d_q={ d_q } , d_kv={ d_kv } )"
140+
141+ # DeepSeek case (d_q=192, d_kv=128)
142+ if d_q == 192 and d_kv == 128 :
143+ if cudnn_version < 91100 :
144+ return "cuDNN 9.11+ required for DeepSeek dimensions (d_q=192, d_kv=128)"
145+ if cc_major not in (9 , 10 ):
146+ return "Hopper (SM90) or Blackwell (SM100) required for DeepSeek dimensions"
147+ return None
148+
149+ # Larger dimensions (128 < d <= 256) - only Hopper
150+ if d_q > 128 or d_kv > 128 :
151+ if cc_major != 9 :
152+ return f"Hopper GPU (SM90) required for dimensions > 128 (d_q={ d_q } , d_kv={ d_kv } )"
153+ if d_q > 256 or d_kv > 256 :
154+ return f"Dimensions exceed Hopper max of 256 (d_q={ d_q } , d_kv={ d_kv } )"
155+ return None
156+
157+ return None
158+
88159
89160grad_sdpa_cudnn_opinfo = OpInfo (
90161 thunder .torch .scaled_dot_product_attention ,
@@ -196,6 +267,11 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
196267 cfn = thunder .jit (op .op , executors = [cudnn_ex , cudnn_layernorm_ex ])
197268
198269 for sample in op .reference_inputs (device , dtype , requires_grad = False ):
270+ # Skip SDPA samples with unsupported dimensions for current GPU/cuDNN
271+ if op .name == "grad_forward_scaled_dot_product_attention" :
272+ if _should_skip_sdpa_sample (sample ):
273+ continue
274+
199275 result = run_snippet (
200276 snippet_torch_consistency ,
201277 op ,
@@ -225,6 +301,10 @@ def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
225301 _maybe_xfail ()
226302
227303 for sample in grad_sdpa_cudnn_opinfo .reference_inputs ("cuda" , dtype , requires_grad = True ):
304+ # Skip samples with unsupported dimensions for current GPU/cuDNN
305+ if _should_skip_sdpa_sample (sample ):
306+ continue
307+
228308 # Enforce tensor arguments are contiguous for torch reference
229309 contiguous_args = list (map (lambda a : a .contiguous () if isinstance (a , torch .Tensor ) else a , sample .args ))
230310
0 commit comments