Skip to content

Commit 9da4e39

Browse files
committed
move cudnn backend
1 parent 3b1cde9 commit 9da4e39

File tree

1 file changed

+0
-14
lines changed

1 file changed

+0
-14
lines changed

torchtitan/models/qwen3/parallelize.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
import torch._inductor.config
1212
import torch.nn as nn
13-
from torch.backends.cuda import SDPBackend
1413
from torch.distributed.device_mesh import DeviceMesh
1514
from torch.distributed.tensor import Replicate, Shard
1615
from torch.distributed.tensor.parallel import (
@@ -135,19 +134,6 @@ def parallelize_qwen3(
135134
attn_backend,
136135
)
137136

138-
if parallel_dims.tp_enabled and parallel_dims.cp_enabled:
139-
# Workaround: cuDNN SDPA backward has a stride mismatch bug with CP.
140-
# Exclude cuDNN until PyTorch fix lands. See https://github.com/pytorch/pytorch/issues/176915.
141-
if attn_backend == "sdpa":
142-
# pyrefly: ignore [missing-attribute, not-callable]
143-
for block in model.layers.values():
144-
block.attention.inner_attention.sdpa_backends = (
145-
[ # pyrefly: ignore [missing-attribute]
146-
SDPBackend.FLASH_ATTENTION,
147-
SDPBackend.MATH,
148-
]
149-
)
150-
151137
if ac_config.mode != "none":
152138
apply_ac(
153139
model,

0 commit comments

Comments
 (0)