@@ -50,6 +50,7 @@ def _fwd_kernel(
50
50
BLOCK_DMODEL : tl .constexpr , # head size
51
51
BLOCK_DMODEL_PADDED : tl .constexpr , # head size padded to a power of 2
52
52
BLOCK_N : tl .constexpr ,
53
+ SLIDING_WINDOW : tl .constexpr ,
53
54
):
54
55
cur_batch = tl .program_id (0 )
55
56
cur_head = tl .program_id (1 )
@@ -62,42 +63,53 @@ def _fwd_kernel(
62
63
cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
63
64
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
64
65
66
+ # start position inside of the query
67
+ # generally, N goes over kv, while M goes over query_len
65
68
block_start_loc = BLOCK_M * start_m
66
69
67
70
# initialize offsets
71
+ # [N]; starts at 0
68
72
offs_n = tl .arange (0 , BLOCK_N )
73
+ # [D]; starts at 0
69
74
offs_d = tl .arange (0 , BLOCK_DMODEL_PADDED )
75
+ # [M]; starts at current position in query
70
76
offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
77
+ # [M,D]
71
78
off_q = (
72
79
(cur_batch_in_all_start_index + offs_m [:, None ]) * stride_qbs +
73
80
cur_head * stride_qh + offs_d [None , :] * stride_qd )
74
81
75
82
dim_mask = tl .where (
76
- tl .arange (0 , BLOCK_DMODEL_PADDED ) < BLOCK_DMODEL , 1 , 0 ).to (tl .int1 )
83
+ tl .arange (0 , BLOCK_DMODEL_PADDED ) < BLOCK_DMODEL , 1 ,
84
+ 0 ).to (tl .int1 ) # [D]
77
85
78
86
q = tl .load (Q + off_q ,
79
87
mask = dim_mask [None , :] &
80
88
(offs_m [:, None ] < cur_batch_query_len ),
81
- other = 0.0 )
89
+ other = 0.0 ) # [M,D]
82
90
83
- # # initialize pointer to m and l
84
- m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
85
- l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
86
- acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL_PADDED ], dtype = tl .float32 )
91
+ # initialize pointer to m and l
92
+ m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" ) # [M]
93
+ l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) # [M]
94
+ acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL_PADDED ],
95
+ dtype = tl .float32 ) # [M,D]
87
96
97
+ # compute query against context (no causal mask here)
88
98
for start_n in range (0 , cur_batch_ctx_len , BLOCK_N ):
89
99
start_n = tl .multiple_of (start_n , BLOCK_N )
90
100
# -- compute qk ----
91
101
bn = tl .load (B_Loc + cur_batch * stride_b_loc_b +
92
102
((start_n + offs_n ) // block_size ) * stride_b_loc_s ,
93
103
mask = (start_n + offs_n ) < cur_batch_ctx_len ,
94
- other = 0 )
104
+ other = 0 ) # [N]
105
+ # [D,N]
95
106
off_k = (bn [None , :] * stride_k_cache_bs +
96
107
cur_kv_head * stride_k_cache_h +
97
108
(offs_d [:, None ] // x ) * stride_k_cache_d +
98
109
((start_n + offs_n [None , :]) % block_size ) *
99
110
stride_k_cache_bl +
100
111
(offs_d [:, None ] % x ) * stride_k_cache_x )
112
+ # [N,D]
101
113
off_v = (
102
114
bn [:, None ] * stride_v_cache_bs +
103
115
cur_kv_head * stride_v_cache_h +
@@ -106,23 +118,39 @@ def _fwd_kernel(
106
118
k = tl .load (K_cache + off_k ,
107
119
mask = dim_mask [:, None ] &
108
120
((start_n + offs_n [None , :]) < cur_batch_ctx_len ),
109
- other = 0.0 )
121
+ other = 0.0 ) # [D,N]
110
122
111
- qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
123
+ qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 ) # [M,N]
112
124
qk += tl .dot (q , k )
113
125
qk = tl .where ((start_n + offs_n [None , :]) < cur_batch_ctx_len , qk ,
114
126
float ("-inf" ))
115
127
qk *= sm_scale
128
+ if SLIDING_WINDOW > 0 :
129
+ # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
130
+ # Q entries in sequence
131
+ # (start_n + offs_n[None, :]) are the positions of
132
+ # KV entries in sequence
133
+ # So the condition makes sure each entry in Q only attends
134
+ # to KV entries not more than SLIDING_WINDOW away.
135
+ #
136
+ # We can't use -inf here, because the
137
+ # sliding window may lead to the entire row being masked.
138
+ # This then makes m_ij contain -inf, which causes NaNs in
139
+ # exp().
140
+ qk = tl .where ((cur_batch_ctx_len + offs_m [:, None ]) -
141
+ (start_n + offs_n [None , :]) < SLIDING_WINDOW , qk ,
142
+ - 10000 )
116
143
117
144
# -- compute m_ij, p, l_ij
118
- m_ij = tl .max (qk , 1 )
119
- p = tl .exp (qk - m_ij [:, None ])
120
- l_ij = tl .sum (p , 1 )
145
+ m_ij = tl .max (qk , 1 ) # [M]
146
+ p = tl .exp (qk - m_ij [:, None ]) # [M,N]
147
+ l_ij = tl .sum (p , 1 ) # [M]
121
148
# -- update m_i and l_i
122
- m_i_new = tl .maximum (m_i , m_ij )
123
- alpha = tl .exp (m_i - m_i_new )
124
- beta = tl .exp (m_ij - m_i_new )
125
- l_i_new = alpha * l_i + beta * l_ij
149
+ m_i_new = tl .maximum (m_i , m_ij ) # [M]
150
+ alpha = tl .exp (m_i - m_i_new ) # [M]
151
+ beta = tl .exp (m_ij - m_i_new ) # [M]
152
+ l_i_new = alpha * l_i + beta * l_ij # [M]
153
+
126
154
# -- update output accumulator --
127
155
# scale p
128
156
p_scale = beta / l_i_new
@@ -134,7 +162,7 @@ def _fwd_kernel(
134
162
v = tl .load (V_cache + off_v ,
135
163
mask = dim_mask [None , :] &
136
164
((start_n + offs_n [:, None ]) < cur_batch_ctx_len ),
137
- other = 0.0 )
165
+ other = 0.0 ) # [N,D]
138
166
139
167
p = p .to (v .dtype )
140
168
acc += tl .dot (p , v )
@@ -149,8 +177,10 @@ def _fwd_kernel(
149
177
k_ptrs = K + off_k
150
178
v_ptrs = V + off_v
151
179
180
+ # block_mask is 0 when we're already past the current query length
152
181
block_mask = tl .where (block_start_loc < cur_batch_query_len , 1 , 0 )
153
182
183
+ # compute query against itself (with causal mask)
154
184
for start_n in range (0 , block_mask * (start_m + 1 ) * BLOCK_M , BLOCK_N ):
155
185
start_n = tl .multiple_of (start_n , BLOCK_N )
156
186
# -- compute qk ----
@@ -163,8 +193,13 @@ def _fwd_kernel(
163
193
qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
164
194
qk += tl .dot (q , k )
165
195
qk *= sm_scale
196
+ # apply causal mask
166
197
qk = tl .where (offs_m [:, None ] >= (start_n + offs_n [None , :]), qk ,
167
198
float ("-inf" ))
199
+ if SLIDING_WINDOW > 0 :
200
+ qk = tl .where (
201
+ offs_m [:, None ] -
202
+ (start_n + offs_n [None , :]) < SLIDING_WINDOW , qk , - 10000 )
168
203
169
204
# -- compute m_ij, p, l_ij
170
205
m_ij = tl .max (qk , 1 )
@@ -636,15 +671,16 @@ def context_attention_fwd(q,
636
671
b_seq_len ,
637
672
b_ctx_len ,
638
673
max_input_len ,
639
- alibi_slopes = None ):
674
+ alibi_slopes = None ,
675
+ sliding_window = None ):
640
676
641
677
cap = torch .cuda .get_device_capability ()
642
678
BLOCK = 128 if cap [0 ] >= 8 else 64
643
679
# shape constraints
644
680
Lq , Lk , Lv = q .shape [- 1 ], k .shape [- 1 ], v .shape [- 1 ]
645
681
assert Lq == Lk and Lk == Lv
646
682
# round up Lk to a power of 2 - this is required for Triton block size
647
- Lk_padded = 2 ** (( Lk - 1 ). bit_length () )
683
+ Lk_padded = triton . next_power_of_2 ( Lk )
648
684
649
685
sm_scale = 1.0 / (Lq ** 0.5 )
650
686
batch , head = b_seq_len .shape [0 ], q .shape [1 ]
@@ -749,6 +785,7 @@ def context_attention_fwd(q,
749
785
BLOCK_DMODEL = Lk ,
750
786
BLOCK_DMODEL_PADDED = Lk_padded ,
751
787
BLOCK_N = BLOCK ,
788
+ SLIDING_WINDOW = sliding_window if sliding_window is not None else 0 ,
752
789
num_warps = num_warps ,
753
790
num_stages = 1 ,
754
791
)
0 commit comments