Skip to content

Commit b7758b8

Browse files
committed
Switch llama3 and qwen3 model configs from sdpa/causal to flex/block_causal
With packed datasets, sdpa/causal allows cross-document attention leakage and uses sequential positions across document boundaries. flex/block_causal isolates documents in attention and enables per-document RoPE position IDs. llama4, deepseek_v3, and gpt_oss already used flex/block_causal.
1 parent 5ddb317 commit b7758b8

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

torchtitan/models/llama3/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
),
4545
attention=GQAttention.Config(
4646
n_heads=16,
47-
attn_backend="sdpa",
47+
attn_backend="flex",
48+
attn_mask_type="block_causal",
4849
rope_backend="complex",
4950
),
5051
),
@@ -57,6 +58,7 @@
5758
scaling="llama",
5859
),
5960
),
61+
# TODO: now identical to "debugmodel", can be removed
6062
"debugmodel_flex_attn": Llama3Model.Config(
6163
dim=256,
6264
n_layers=6,
@@ -130,7 +132,8 @@
130132
attention=GQAttention.Config(
131133
n_heads=32,
132134
n_kv_heads=8,
133-
attn_backend="sdpa",
135+
attn_backend="flex",
136+
attn_mask_type="block_causal",
134137
rope_backend="complex",
135138
),
136139
),
@@ -142,6 +145,7 @@
142145
scaling="llama",
143146
),
144147
),
148+
# TODO: now identical to "8B", can be removed
145149
"8B_flex": Llama3Model.Config(
146150
dim=4096,
147151
n_layers=32,
@@ -219,7 +223,8 @@
219223
attention=GQAttention.Config(
220224
n_heads=64,
221225
n_kv_heads=8,
222-
attn_backend="sdpa",
226+
attn_backend="flex",
227+
attn_mask_type="block_causal",
223228
rope_backend="complex",
224229
),
225230
),
@@ -248,7 +253,8 @@
248253
attention=GQAttention.Config(
249254
n_heads=128,
250255
n_kv_heads=8,
251-
attn_backend="sdpa",
256+
attn_backend="flex",
257+
attn_mask_type="block_causal",
252258
rope_backend="complex",
253259
),
254260
),

torchtitan/models/qwen3/__init__.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
head_dim=128,
4747
q_norm=RMSNorm.Config(eps=1e-6),
4848
k_norm=RMSNorm.Config(eps=1e-6),
49-
attn_backend="sdpa",
49+
attn_backend="flex",
50+
attn_mask_type="block_causal",
5051
rope_backend="cos_sin",
5152
),
5253
),
@@ -57,6 +58,7 @@
5758
backend="cos_sin",
5859
),
5960
),
61+
# TODO: now identical to "debugmodel", can be removed
6062
"debugmodel_flex": Qwen3Model.Config(
6163
vocab_size=2048,
6264
dim=256,
@@ -76,6 +78,7 @@
7678
q_norm=RMSNorm.Config(eps=1e-6),
7779
k_norm=RMSNorm.Config(eps=1e-6),
7880
attn_backend="flex",
81+
attn_mask_type="block_causal",
7982
rope_backend="cos_sin",
8083
),
8184
),
@@ -106,7 +109,8 @@
106109
head_dim=128,
107110
q_norm=RMSNorm.Config(eps=1e-6),
108111
k_norm=RMSNorm.Config(eps=1e-6),
109-
attn_backend="sdpa",
112+
attn_backend="flex",
113+
attn_mask_type="block_causal",
110114
rope_backend="cos_sin",
111115
),
112116
),
@@ -137,7 +141,8 @@
137141
head_dim=128,
138142
q_norm=RMSNorm.Config(eps=1e-6),
139143
k_norm=RMSNorm.Config(eps=1e-6),
140-
attn_backend="sdpa",
144+
attn_backend="flex",
145+
attn_mask_type="block_causal",
141146
rope_backend="cos_sin",
142147
),
143148
),
@@ -168,7 +173,8 @@
168173
head_dim=128,
169174
q_norm=RMSNorm.Config(eps=1e-6),
170175
k_norm=RMSNorm.Config(eps=1e-6),
171-
attn_backend="sdpa",
176+
attn_backend="flex",
177+
attn_mask_type="block_causal",
172178
rope_backend="cos_sin",
173179
),
174180
),
@@ -198,7 +204,8 @@
198204
head_dim=128,
199205
q_norm=RMSNorm.Config(eps=1e-6),
200206
k_norm=RMSNorm.Config(eps=1e-6),
201-
attn_backend="sdpa",
207+
attn_backend="flex",
208+
attn_mask_type="block_causal",
202209
rope_backend="cos_sin",
203210
),
204211
),
@@ -228,7 +235,8 @@
228235
head_dim=128,
229236
q_norm=RMSNorm.Config(eps=1e-6),
230237
k_norm=RMSNorm.Config(eps=1e-6),
231-
attn_backend="sdpa",
238+
attn_backend="flex",
239+
attn_mask_type="block_causal",
232240
rope_backend="cos_sin",
233241
),
234242
),
@@ -258,7 +266,8 @@
258266
head_dim=128,
259267
q_norm=RMSNorm.Config(eps=1e-6),
260268
k_norm=RMSNorm.Config(eps=1e-6),
261-
attn_backend="sdpa",
269+
attn_backend="flex",
270+
attn_mask_type="block_causal",
262271
rope_backend="cos_sin",
263272
),
264273
),
@@ -300,7 +309,8 @@
300309
head_dim=128,
301310
q_norm=RMSNorm.Config(eps=1e-6),
302311
k_norm=RMSNorm.Config(eps=1e-6),
303-
attn_backend="sdpa",
312+
attn_backend="flex",
313+
attn_mask_type="block_causal",
304314
rope_backend="cos_sin",
305315
),
306316
),
@@ -341,7 +351,8 @@
341351
head_dim=128,
342352
q_norm=RMSNorm.Config(eps=1e-6),
343353
k_norm=RMSNorm.Config(eps=1e-6),
344-
attn_backend="sdpa",
354+
attn_backend="flex",
355+
attn_mask_type="block_causal",
345356
rope_backend="cos_sin",
346357
),
347358
),
@@ -382,7 +393,8 @@
382393
head_dim=128,
383394
q_norm=RMSNorm.Config(eps=1e-6),
384395
k_norm=RMSNorm.Config(eps=1e-6),
385-
attn_backend="sdpa",
396+
attn_backend="flex",
397+
attn_mask_type="block_causal",
386398
rope_backend="cos_sin",
387399
),
388400
),

0 commit comments

Comments
 (0)