@@ -1401,6 +1401,18 @@ def patched_sdpa_attention_forward(
14011401 is_causal = attention_mask is None and is_causal
14021402
14031403 if not is_causal :
1404+ torch ._check (query .shape [0 ] > 0 )
1405+ torch ._check (query .shape [1 ] > 0 )
1406+ torch ._check (query .shape [2 ] > 0 )
1407+ torch ._check (query .shape [3 ] > 0 )
1408+ torch ._check (key .shape [0 ] > 0 )
1409+ torch ._check (key .shape [1 ] > 0 )
1410+ torch ._check (key .shape [2 ] > 0 )
1411+ torch ._check (key .shape [3 ] > 0 )
1412+ torch ._check (value .shape [0 ] > 0 )
1413+ torch ._check (value .shape [1 ] > 0 )
1414+ torch ._check (value .shape [2 ] > 0 )
1415+ torch ._check (value .shape [3 ] > 0 )
14041416 return (
14051417 torch .nn .functional .scaled_dot_product_attention (
14061418 query ,
@@ -2342,25 +2354,29 @@ def forward(
23422354 ** kwargs ,
23432355 )
23442356 elif _is_torchdynamo_exporting ():
2357+ if (
2358+ attention_interface
2359+ is transformers .integrations .sdpa_attention .sdpa_attention_forward
2360+ ):
2361+ attention_interface = patched_sdpa_attention_forward
23452362
23462363 def _iteration (start_end , query_states , key_states , value_states ):
2347- a , b = start_end
2364+ a = start_end [0 ]
2365+ b = start_end [1 ]
23482366 q = query_states [:, :, a :b , :]
23492367 k = key_states [:, :, a :b , :]
23502368 v = value_states [:, :, a :b , :]
2351- return (
2352- attention_interface (
2353- self ,
2354- q ,
2355- k ,
2356- v ,
2357- attention_mask = None ,
2358- scaling = self .scaling ,
2359- dropout = 0.0 if not self .training else self .attention_dropout ,
2360- is_causal = False ,
2361- ** kwargs ,
2362- )[0 ],
2363- )
2369+ return attention_interface (
2370+ self ,
2371+ q ,
2372+ k ,
2373+ v ,
2374+ attention_mask = None ,
2375+ scaling = self .scaling ,
2376+ dropout = 0.0 if not self .training else self .attention_dropout ,
2377+ is_causal = False ,
2378+ ** kwargs ,
2379+ )[0 ]
23642380
23652381 starts = cu_seqlens [:- 1 ]
23662382 ends = cu_seqlens [1 :]
@@ -2369,6 +2385,13 @@ def _iteration(start_end, query_states, key_states, value_states):
23692385 _iteration (start_end , query_states , key_states , value_states )
23702386 for start_end in starts_ends
23712387 ]
2388+ # attn_outputs = torch._higher_order_ops.while_loop(
2389+ # attn_outputs = torch.ops.higher_order.while_loop(
2390+ # (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
2391+ # _iteration,
2392+ # (torch.tensor(0),
2393+ # starts_ends, query_states, key_states, value_states), tuple(),
2394+ # )
23722395 attn_output = torch .cat (attn_outputs , dim = 1 )
23732396 else :
23742397 # Other implementations: Process each chunk separately
0 commit comments