@@ -133,7 +133,6 @@ struct Flash_fwd_params {
133133
134134 // Local window size
135135 int window_size_left, window_size_right;
136- int attention_chunk;
137136
138137 // Pointer to the RNG seed (idx 0) and offset (idx 1).
139138 uint64_t * rng_state;
@@ -541,14 +540,13 @@ std::vector<at::Tensor> mha_fwd(
541540 std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
542541 std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
543542 std::optional<const at::Tensor>& seqlens_rotary_, // b
544- // std::optional<at::Tensor> & q_descale_, // (b, h_k), not (b, h)
545- // std::optional<at::Tensor> & k_descale_, // (b, h_k)
546- // std::optional<at::Tensor> & v_descale_, // (b, h_k)
547- std::optional< double > softmax_scale_,
543+ std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
544+ std::optional<at::Tensor>& k_descale_, // (b, h_k)
545+ std::optional<at::Tensor>& v_descale_, // (b, h_k)
546+ const float softmax_scale_,
548547 bool is_causal,
549548 int window_size_left,
550549 int window_size_right,
551- int attention_chunk,
552550 float const softcap,
553551 bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
554552 std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
@@ -619,10 +617,8 @@ std::vector<at::Tensor> mha_fwd(
619617 int const total_k = !is_varlen_k ? batch_size * k.size (1 ) : k.size (0 );
620618 int const num_heads_k = k.size (-2 );
621619 int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size (0 ) : cu_seqlens_k.size (0 ) - 1 ) : page_table.size (0 );
622- double softmax_scale = 1.0 / sqrt (double (head_size));
623- if (softmax_scale_.has_value ()) {
624- softmax_scale = softmax_scale_.value ();
625- }
620+ float softmax_scale = softmax_scale_;
621+
626622 if (!kv_batch_idx_.has_value ()) {
627623 TORCH_CHECK (batch_size == batch_size_k, " batch_size must be equal to batch_size_k" );
628624 }
@@ -791,8 +787,8 @@ std::vector<at::Tensor> mha_fwd(
791787
792788 // Causal is the special case where window_size_right == 0 and window_size_left < 0.
793789 // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
794- params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0 ;
795- params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1 ) && !params.is_causal ;
790+ params.is_causal = window_size_left < 0 && window_size_right == 0 ;
791+ params.is_local = (window_size_left >= 0 || window_size_right >= 0 ) && !params.is_causal ;
796792
797793 // TODO: check this
798794 if (window_size_left < 0 ) {
@@ -801,13 +797,8 @@ std::vector<at::Tensor> mha_fwd(
801797 if (window_size_right < 0 ) {
802798 window_size_right = seqlen_q - 1 ;
803799 }
804- if (attention_chunk > 0 ) {
805- window_size_left = std::min (window_size_left, attention_chunk - 1 );
806- window_size_right = std::min (window_size_right, attention_chunk - 1 );
807- }
808800 params.window_size_left = window_size_left;
809801 params.window_size_right = window_size_right;
810- params.attention_chunk = attention_chunk;
811802
812803 params.total_q = total_q;
813804 params.total_k = total_k;
0 commit comments