@@ -921,65 +921,64 @@ def test_flash_attn_kvcache(
921921
922922 # # Check that FlashAttention's numerical error is at most twice the numerical error
923923 # # of a Pytorch implementation.
924- # if new_kv:
925- # if page_size is None:
926- # k_cache_select = (
927- # k_cache.to(dtype_ref)
928- # if not has_batch_idx
929- # else k_cache.to(dtype_ref)[cache_batch_idx]
930- # )
931- # v_cache_select = (
932- # v_cache.to(dtype_ref)
933- # if not has_batch_idx
934- # else v_cache.to(dtype_ref)[cache_batch_idx]
935- # )
936- # else:
937- # k_cache_select = rearrange(
938- # k_cache_paged.to(dtype_ref)[
939- # (
940- # page_table
941- # if not has_batch_idx
942- # else page_table[cache_batch_idx]
943- # ).flatten()
944- # ],
945- # "(b nblocks) block_size ... -> b (nblocks block_size) ...",
946- # b=batch_size,
947- # )[:, :seqlen_k].to(dtype_ref)
948- # v_cache_select = rearrange(
949- # v_cache_paged.to(dtype_ref)[
950- # (
951- # page_table
952- # if not has_batch_idx
953- # else page_table[cache_batch_idx]
954- # ).flatten()
955- # ],
956- # "(b nblocks) block_size ... -> b (nblocks block_size) ...",
957- # b=batch_size,
958- # )[:, :seqlen_k].to(dtype_ref)
959- # k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
960- # v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
961- # # if dtype is not torch.float8_e4m3fn:
962- # # import pdb; pdb.set_trace()
963- # # assert torch.equal(v_cache_select, v_cache_ref)
964- # # else:
965- # # assert torch.allclose(
966- # # v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3
967- # # )
968- # # breakpoint()
969- # # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
970- # # if rotary_dim == 0:
971- # # assert torch.equal(k_cache_select, k_cache_ref)
972- # # else:
973- # # # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
974- # # # breakpoint()
975- # # if dtype is not torch.float8_e4m3fn:
976- # # assert torch.allclose(
977- # # k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3
978- # # )
979- # # else:
980- # # assert torch.allclose(
981- # # k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1
982- # # )
924+ if new_kv :
925+ if page_size is None :
926+ k_cache_select = (
927+ k_cache .to (dtype_ref )
928+ if not has_batch_idx
929+ else k_cache .to (dtype_ref )[cache_batch_idx ]
930+ )
931+ v_cache_select = (
932+ v_cache .to (dtype_ref )
933+ if not has_batch_idx
934+ else v_cache .to (dtype_ref )[cache_batch_idx ]
935+ )
936+ else :
937+ k_cache_select = rearrange (
938+ k_cache_paged .to (dtype_ref )[
939+ (
940+ page_table
941+ if not has_batch_idx
942+ else page_table [cache_batch_idx ]
943+ ).flatten ()
944+ ],
945+ "(b nblocks) block_size ... -> b (nblocks block_size) ..." ,
946+ b = batch_size ,
947+ )[:, :seqlen_k ].to (dtype_ref )
948+ v_cache_select = rearrange (
949+ v_cache_paged .to (dtype_ref )[
950+ (
951+ page_table
952+ if not has_batch_idx
953+ else page_table [cache_batch_idx ]
954+ ).flatten ()
955+ ],
956+ "(b nblocks) block_size ... -> b (nblocks block_size) ..." ,
957+ b = batch_size ,
958+ )[:, :seqlen_k ].to (dtype_ref )
959+ k_cache_ref = k_cache_ref .to (dtype ).to (dtype_ref )
960+ v_cache_ref = v_cache_ref .to (dtype ).to (dtype_ref )
961+ if dtype is not torch .float8_e4m3fn :
962+ import pdb ; pdb .set_trace ()
963+ assert torch .equal (v_cache_select , v_cache_ref )
964+ else :
965+ assert torch .allclose (
966+ v_cache_select , v_cache_ref , rtol = 1e-3 , atol = 1e-3
967+ )
968+ breakpoint ()
969+ if rotary_dim == 0 and dtype is not torch .float8_e4m3fn :
970+ assert torch .equal (k_cache_select , k_cache_ref )
971+ else :
972+ # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
973+ # breakpoint()
974+ if dtype is not torch .float8_e4m3fn :
975+ assert torch .allclose (
976+ k_cache_select , k_cache_ref , rtol = 1e-3 , atol = 1e-3
977+ )
978+ else :
979+ assert torch .allclose (
980+ k_cache_select , k_cache_ref , rtol = 1e-1 , atol = 1e-1
981+ )
983982 mult = 4 if dtype == torch .float8_e4m3fn else 2
984983 assert (out - out_ref ).abs ().max ().item () <= mult * (
985984 out_pt - out_ref
0 commit comments