Skip to content

Commit 6f88f76

Browse files
authored
Fix OOM in attention kernel test (#1223)
1 parent 202351d commit 6f88f76

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/kernels/test_attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,11 @@ def test_multi_query_kv_attention(
247247
torch.random.manual_seed(seed)
248248
torch.cuda.manual_seed(seed)
249249

250-
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
251-
seq_lens[-1] = MAX_SEQ_LEN
250+
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
251+
# As the xformers library is already tested with its own tests, we can use
252+
# a smaller MAX_SEQ_LEN here.
253+
max_len = min(MAX_SEQ_LEN, 4096)
254+
seq_lens = random.sample(range(1, max_len), num_seqs)
252255
num_tokens = sum(seq_lens)
253256

254257
scale = float(1.0 / (head_size**0.5))

0 commit comments

Comments
 (0)