File tree Expand file tree Collapse file tree 1 file changed +0
-14
lines changed
Expand file tree Collapse file tree 1 file changed +0
-14
lines changed Original file line number Diff line number Diff line change 1010import torch
1111import torch ._inductor .config
1212import torch .nn as nn
13- from torch .backends .cuda import SDPBackend
1413from torch .distributed .device_mesh import DeviceMesh
1514from torch .distributed .tensor import Replicate , Shard
1615from torch .distributed .tensor .parallel import (
@@ -135,19 +134,6 @@ def parallelize_qwen3(
135134 attn_backend ,
136135 )
137136
138- if parallel_dims .tp_enabled and parallel_dims .cp_enabled :
139- # Workaround: cuDNN SDPA backward has a stride mismatch bug with CP.
140- # Exclude cuDNN until PyTorch fix lands. See https://github.com/pytorch/pytorch/issues/176915.
141- if attn_backend == "sdpa" :
142- # pyrefly: ignore [missing-attribute, not-callable]
143- for block in model .layers .values ():
144- block .attention .inner_attention .sdpa_backends = (
145- [ # pyrefly: ignore [missing-attribute]
146- SDPBackend .FLASH_ATTENTION ,
147- SDPBackend .MATH ,
148- ]
149- )
150-
151137 if ac_config .mode != "none" :
152138 apply_ac (
153139 model ,
You can’t perform that action at this time.
0 commit comments