Skip to content

Commit 4f77d53

Browse files
authored
Skip padding tokens (#792)
* Skip padding tokens * Write out 0s explicitly on update, add unit test
1 parent 3d3f2d5 commit 4f77d53

File tree

3 files changed

+89
-4
lines changed

3 files changed

+89
-4
lines changed

mamba_ssm/ops/triton/selective_state_update.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,17 @@ def _selective_scan_update_kernel(
5151
pid_b = tl.program_id(axis=1)
5252
pid_h = tl.program_id(axis=2)
5353

54+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
55+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
56+
out_ptrs = out_ptr + offs_m * stride_out_dim
57+
5458
if HAS_STATE_BATCH_INDICES:
5559
state_batch_indices_ptr += pid_b
5660
state_batch_idx = tl.load(state_batch_indices_ptr)
61+
# Skip padding tokens
62+
if state_batch_idx < 0:
63+
tl.store(out_ptrs, 0.0, mask=offs_m < dim)
64+
return
5765
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
5866
else:
5967
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
@@ -67,9 +75,7 @@ def _selective_scan_update_kernel(
6775
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
6876
if HAS_Z:
6977
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
70-
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
7178

72-
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
7379
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
7480
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
7581
x_ptrs = x_ptr + offs_m * stride_x_dim
@@ -85,7 +91,6 @@ def _selective_scan_update_kernel(
8591
D_ptrs = D_ptr + offs_m * stride_D_dim
8692
if HAS_Z:
8793
z_ptrs = z_ptr + offs_m * stride_z_dim
88-
out_ptrs = out_ptr + offs_m * stride_out_dim
8994

9095
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
9196
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

mamba_ssm/ops/triton/ssd_chunk_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _chunk_state_fwd_kernel(
233233
scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
234234
else:
235235
# scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
236-
scale = tl.where(seq_idx_k == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
236+
scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
237237
b *= scale[:, None]
238238
b = b.to(x_ptr.dtype.element_ty)
239239
acc += tl.dot(x, b)

tests/test_generation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,83 @@ def test_generation_varlen():
111111
out_varlen = torch.cat(scores, dim=1)
112112
print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
113113
assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()
114+
115+
def test_generation_varlen_with_padding():
116+
seqlens = [170, 65, 100]
117+
non_padded_seqlen = sum(seqlens)
118+
padded_seqlen = 512
119+
seqlens.append(padded_seqlen - non_padded_seqlen)
120+
genlen = 20
121+
total_seqlen = sum(seqlens)
122+
assert total_seqlen == padded_seqlen
123+
device = "cuda"
124+
dtype = torch.float16
125+
126+
config = MambaConfig(
127+
d_model=1024,
128+
n_layer=4,
129+
vocab_size=50277,
130+
ssm_cfg=dict(layer="Mamba2"),
131+
rms_norm=True,
132+
residual_in_fp32=True,
133+
fused_add_norm=True,
134+
pad_vocab_size_multiple=16,
135+
)
136+
torch.manual_seed(2357)
137+
model = MambaLMHeadModel(config, device=device, dtype=dtype)
138+
xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]
139+
140+
# Reference 1: Forward pass with seq_idx
141+
x = torch.cat(xs[:-1], dim=1)
142+
seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)
143+
for i, ids in enumerate(xs[:-1])], dim=0).unsqueeze(0)
144+
cu_seqlens = F.pad(torch.tensor(seqlens[:-1], device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
145+
146+
out_ref = model(x, seq_idx=seq_idx).logits
147+
# Only take the last @genlen logits of each sequence
148+
out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
149+
for i in range(len(seqlens) - 1)], dim=0)
150+
151+
# Reference 2: Generate the last @genlen tokens of each sequence in a for loop
152+
out_loop = []
153+
for input_ids in xs[:-1]:
154+
out = model.generate(
155+
input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,
156+
return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,
157+
).scores
158+
out_loop.append(torch.stack(out, dim=1))
159+
out_loop = torch.cat(out_loop, dim=0)
160+
print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")
161+
162+
# Varlen generation
163+
input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)
164+
prompt_seqlens = [seqlen - genlen for seqlen in seqlens]
165+
cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
166+
seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)
167+
for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)
168+
inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))
169+
170+
# Account for padding
171+
offset = genlen * len(seqlens)
172+
seq_idx[non_padded_seqlen - offset : padded_seqlen - offset] = -1
173+
cu_seqlens[-1] = cu_seqlens[-2]
174+
175+
scores, sequences = [], []
176+
# Both seq_idx and cu_seqlens must be passed in for varlen generation
177+
logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits
178+
logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d")
179+
scores.append(logits)
180+
# In practice we should sample. In this case we take from the teacher_output for testing
181+
sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1")
182+
sequences.append(sampled_tokens)
183+
for i in range(1, genlen):
184+
inference_params.seqlen_offset += 1
185+
logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits
186+
scores.append(logits)
187+
# In practice we should sample. In this case we take from the teacher_output for testing
188+
sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1")
189+
sequences.append(sampled_tokens)
190+
out_varlen = torch.cat(scores, dim=1)
191+
192+
print(f"Max diff: {(out_varlen[:-1] - out_ref).abs().max()}")
193+
assert (out_varlen[:-1] - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()

0 commit comments

Comments
 (0)