Skip to content

Conversation

@leloykun
Copy link

Description

This PR implements a minimal backward pass for flash attention.

I got these results on my RTX 2060

=== profiling manual attention (backward pass) ===
...
Self CPU time total: 11.139ms
Self CUDA time total: 1.721ms
=== profiling minimal flash attention (backward pass) === 
...
Self CPU time total: 31.466ms
Self CUDA time total: 629.000us

2x speedup

Tho my GPU can only handle size 16 blocks (vs. size 32 blocks for T4)

@hypertseng
Copy link

@leloykun hello Franz! I have some trouble with the code and flash attention. Firstly, why the attn values sanity check return False when the seq_len is lower than 32. It lead to collapse in inference which seq_len is usually 1, I guess the block size may cause this result? Then, how to choose a appropriate block size? Looking forward to your reply!
image

@leloykun
Copy link
Author

leloykun commented Apr 17, 2024

Hi @hypertseng!

I believe it was because we weren't exiting the loops after going past the seq length. The forward pass should be fixed in my repo here: https://github.com/leloykun/flash-hyperbolic-attention-minimal

@hypertseng
Copy link

@leloykun Recently, I found the flash_attn_bwd implementation in your repo is lower than the manual implementation, this is totally because the implicitly function call of cudaDeviceSynchronize which Increases the CPU time a lot. Do you have any idea to solve this problem?
image
By the way, I found that change the AtomicAdd to normal add will decrease the cudaDeviceSynchronize occupancy, but I don't know why, I am a beginner of cuda hhhhh.

@FumoTime
Copy link

FumoTime commented Jul 8, 2024

@hypertseng Most likely, cudaDeviceSynchronize time includes the kernel execution time. You can Use cuda events to time it instead.

torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
minimal_result = minimal_attn.forward(q, k, v)
end_event.record()
torch.cuda.synchronize()

elapsed_time_ms = start_event.elapsed_time(end_event)
max_vram_MB = torch.cuda.max_memory_allocated() / (1024*1024)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants