We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 202351d commit 6f88f76Copy full SHA for 6f88f76
tests/kernels/test_attention.py
@@ -247,8 +247,11 @@ def test_multi_query_kv_attention(
247
torch.random.manual_seed(seed)
248
torch.cuda.manual_seed(seed)
249
250
- seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
251
- seq_lens[-1] = MAX_SEQ_LEN
+ # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
+ # As the xformers library is already tested with its own tests, we can use
252
+ # a smaller MAX_SEQ_LEN here.
253
+ max_len = min(MAX_SEQ_LEN, 4096)
254
+ seq_lens = random.sample(range(1, max_len), num_seqs)
255
num_tokens = sum(seq_lens)
256
257
scale = float(1.0 / (head_size**0.5))
0 commit comments