@@ -151,6 +151,8 @@ def _get_or_create_block_mask(
151151 kv_len : int ,
152152 device : torch .device ,
153153 block_ids : Optional [torch .Tensor ] = None ,
154+ orig_positions : Optional [torch .Tensor ] = None ,
155+ token_anchor_pos : Optional [torch .Tensor ] = None ,
154156 ) -> "BlockMask" :
155157 """Get cached BlockMask or create a new one."""
156158 if block_ids is None :
@@ -165,17 +167,21 @@ def _get_or_create_block_mask(
165167
166168 if block_ids is not None :
167169 _block_ids = block_ids
170+ _orig_pos = orig_positions
171+ _anchor_pos = token_anchor_pos
168172
169173 def dflash_mask_fn (b , h , q_idx , kv_idx ):
170174 L = q_len
171175 is_ctx = kv_idx < L
172176 q_b = _block_ids [b , q_idx ]
173- k_ctx = _block_ids [b , kv_idx .clamp (max = L - 1 )]
177+ k_ctx_id = _block_ids [b , kv_idx .clamp (max = L - 1 )]
174178 k_noise = _block_ids [b , (kv_idx - L ).clamp (min = 0 , max = L - 1 )]
175179 q_valid = q_b >= 0
176- k_ctx_valid = k_ctx >= 0
180+ k_ctx_valid = k_ctx_id >= 0
177181 k_noise_valid = k_noise >= 0
178- ctx_visible = is_ctx & q_valid & k_ctx_valid & (k_ctx < q_b )
182+ kv_orig = _orig_pos [b , kv_idx .clamp (max = L - 1 )]
183+ q_anchor = _anchor_pos [b , q_idx ]
184+ ctx_visible = is_ctx & q_valid & k_ctx_valid & (kv_orig < q_anchor )
179185 noise_visible = (~ is_ctx ) & q_valid & k_noise_valid & (k_noise == q_b )
180186 return ctx_visible | noise_visible
181187
@@ -213,6 +219,8 @@ def _create_parallel_attention_mask(
213219 seq_len : int ,
214220 device : torch .device ,
215221 block_ids : Optional [torch .Tensor ] = None ,
222+ orig_positions : Optional [torch .Tensor ] = None ,
223+ token_anchor_pos : Optional [torch .Tensor ] = None ,
216224 ) -> torch .Tensor :
217225 """Create [bsz, L, 2L] attention mask for parallel training."""
218226 if block_ids is None :
@@ -226,11 +234,13 @@ def _create_parallel_attention_mask(
226234 full_mask .masked_fill_ (~ full_mask_bool , torch .finfo (torch .float32 ).min )
227235 return full_mask .unsqueeze (0 ).expand (bsz , - 1 , - 1 )
228236
237+ q_anchor = token_anchor_pos .unsqueeze (2 )
238+ k_orig = orig_positions .unsqueeze (1 )
229239 q_ids = block_ids .unsqueeze (2 )
230240 k_ids = block_ids .unsqueeze (1 )
231241 q_valid = q_ids >= 0
232242 k_valid = k_ids >= 0
233- ctx_mask = q_valid & k_valid & (k_ids < q_ids )
243+ ctx_mask = q_valid & k_valid & (k_orig < q_anchor )
234244 noise_mask = q_valid & k_valid & (k_ids == q_ids )
235245 full_mask_bool = torch .cat ([ctx_mask , noise_mask ], dim = 2 )
236246 full_mask = torch .zeros_like (full_mask_bool , dtype = torch .float32 )
@@ -248,6 +258,8 @@ def forward(
248258 bsz , seq_len = input_ids .shape
249259 device = input_ids .device
250260 block_ids = None
261+ orig_positions = None
262+ token_anchor_pos = None
251263
252264 if self .random_anchor and self .training :
253265 anchor_positions , block_keep_mask = self ._sample_anchor_positions (
@@ -264,6 +276,10 @@ def forward(
264276 )
265277 effective_len = input_ids .shape [1 ]
266278 base_positions = block_positions
279+ orig_positions = block_positions
280+ token_anchor_pos = anchor_positions .repeat_interleave (
281+ self .block_size , dim = 1
282+ )
267283 else :
268284 n_blocks = seq_len // self .block_size
269285 effective_len = n_blocks * self .block_size
@@ -291,10 +307,12 @@ def forward(
291307 kv_len = effective_len * 2 ,
292308 device = device ,
293309 block_ids = block_ids ,
310+ orig_positions = orig_positions ,
311+ token_anchor_pos = token_anchor_pos ,
294312 )
295313 else :
296314 dflash_attn_mask = self ._create_parallel_attention_mask (
297- bsz , effective_len , device , block_ids
315+ bsz , effective_len , device , block_ids , orig_positions , token_anchor_pos
298316 )
299317 dflash_attn_mask = dflash_attn_mask .to (dtype = hidden_states .dtype )
300318 dflash_attn_mask = dflash_attn_mask .unsqueeze (1 )
0 commit comments