Skip to content

Error when increasing the sequence length #53

@ZetangForward

Description

@ZetangForward

Hi, when I increase the seqlen from 1024 * 8 to 1024 * 64 here:

Then, I run the code with

torchrun benchmark/benchmark_varlen_qkvpacked_func.py

The program starts to error, and the error log is as follows:

# flash_attn_varlen_qkvpacked_func
329.0089328816957 iter/s, 0.303943115234375 sec
# ring_flash_attn_varlen_qkvpacked_func
125.49088812377029 iter/s, 0.79687060546875 sec
# zigzag_ring_flash_attn_varlen_qkvpacked_func
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/zecheng/lcm_stack/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py", line 99, in <module>
[rank0]:     benchmark(f, forward_only=forward_only, log=False)
[rank0]:   File "/data/zecheng/lcm_stack/ring-flash-attention/benchmark/benchmark_varlen_qkvpacked_func.py", line 64, in benchmark
[rank0]:     out = f(
[rank0]:   File "/data/zecheng/lcm_stack/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 413, in zigzag_ring_flash_attn_varlen_qkvpacked_func
[rank0]:     return ZigZagRingFlashAttnVarlenFunc.apply(
[rank0]:   File "/data/anaconda3/envs/new_zecheng/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/data/zecheng/lcm_stack/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 331, in forward
[rank0]:     out, softmax_lse = zigzag_ring_flash_attn_varlen_forward(
[rank0]:   File "/data/zecheng/lcm_stack/ring-flash-attention/ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 94, in zigzag_ring_flash_attn_varlen_forward
[rank0]:     q1 = q[half_index1]
[rank0]: IndexError: The shape of the mask [8192] at index 0 does not match the shape of the indexed tensor [65536, 5, 128] at index 0
E1003 16:05:54.677000 139778620753728 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 140219) of binary: /data/anaconda3/envs/new_zecheng/bin/python3.10

How to fix this problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions