11# mypy: allow-untyped-defs
22""" Triton Implementation of the flex_attention Kernel"""
33
4+ import copy
45import logging
56import math
67from collections .abc import Sequence
1415from torch ._inductor .virtualized import V
1516from torch .utils ._ordered_set import OrderedSet
1617from torch .utils ._pytree import tree_map
18+ from torch .utils ._sympy .numbers import int_oo
19+ from torch .utils ._sympy .value_ranges import ValueRanges
1720
1821from .. import config
1922from ..ir import (
@@ -100,10 +103,21 @@ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
100103
101104
102105def create_placeholder (
103- name : str , dtype : torch .dtype , device : torch .device
106+ name : str ,
107+ dtype : torch .dtype ,
108+ device : torch .device ,
109+ size : Optional [list [int ]] = None ,
104110) -> TensorBox :
105111 """Creates a placeholder input buffers for producing subgraph_output."""
106- input_buffer = InputBuffer (name = name , layout = FixedLayout (device , dtype , [], []))
112+ input_buffer = InputBuffer (
113+ name = name ,
114+ layout = FixedLayout (
115+ device ,
116+ dtype ,
117+ size if size else [],
118+ FlexibleLayout .contiguous_strides (size ) if size else [],
119+ ),
120+ )
107121 return TensorBox .create (input_buffer )
108122
109123
@@ -173,7 +187,9 @@ def zeros_and_scatter_lowering(shape: list[int], indices, values):
173187SubgraphResults = Union [list [Optional [ComputedBuffer ]], Optional [ComputedBuffer ]]
174188
175189
176- def build_subgraph_buffer (args : list [TensorBox ], subgraph : Subgraph ) -> SubgraphResults :
190+ def build_subgraph_module_buffer (
191+ args : list [TensorBox ], graph_module : torch .fx .GraphModule
192+ ) -> SubgraphResults :
177193 """This function's goal is to take in the required args and produce the subgraph buffer
178194 The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
179195
@@ -184,7 +200,7 @@ def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> Subgraph
184200 from ..subgraph_lowering import PointwiseSubgraphLowering
185201
186202 pw_subgraph = PointwiseSubgraphLowering (
187- subgraph . graph_module ,
203+ graph_module ,
188204 root_graph_lowering = V .graph ,
189205 allowed_mutations = OrderedSet ([torch .ops .flex_lib .zeros_and_scatter .default ]),
190206 additional_lowerings = {
@@ -228,6 +244,10 @@ def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
228244 return tree_map (convert_output_node_to_buffer , pw_subgraph .graph_outputs )
229245
230246
247+ def build_subgraph_buffer (args : list [TensorBox ], subgraph : Subgraph ) -> SubgraphResults :
248+ return build_subgraph_module_buffer (args , subgraph .graph_module )
249+
250+
231251# Inner Triton functions shared by flex_attention & split-k decoding kernels.
232252compute_next_offset_func = r"""
233253@triton.jit
@@ -921,14 +941,31 @@ def lower_cpu(
921941 )
922942
923943 fake_buffers : list [Buffer ] = [] # noqa: F821
944+
945+ # [Note] Handle the case where the split sizes are not statically known.
946+ # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime.
947+ # We use symbols to represent them during the compilation here.
948+ # They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in
949+ # the modification function of the CppFlexAttentionTemplate class.
950+ cur_qSplitSize = V .graph .sizevars .shape_env .create_unbacked_symint ().node .expr
951+ cur_kvSplitSize = V .graph .sizevars .shape_env .create_unbacked_symint ().node .expr
952+ shape_env = V .graph .sizevars .shape_env
953+
954+ # We don't know the concret value of cur_qSplitSize and cur_kvSplitSize during the compilation.
955+ # Mark symbols > 1 to ensure broadcasting is always applied.
956+ # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`.
957+ shape_env .var_to_range [cur_qSplitSize ] = ValueRanges (2 , int_oo )
958+ shape_env .var_to_range [cur_kvSplitSize ] = ValueRanges (2 , int_oo )
959+
960+ score_dtype = torch .float
924961 placeholder_inps = [
925- create_placeholder (name , dtype , query .get_device ())
926- for name , dtype in [
927- ("score" , torch . float ),
928- ("b" , torch .int64 ),
929- ("h" , torch .int64 ),
930- ("q_idx" , torch .int64 ),
931- ("kv_idx" , torch .int64 ),
962+ create_placeholder (name , dtype , query .get_device (), size )
963+ for name , dtype , size in [
964+ ("score" , score_dtype , [ cur_qSplitSize , cur_kvSplitSize ] ),
965+ ("b" , torch .int64 , [] ),
966+ ("h" , torch .int64 , [] ),
967+ ("q_idx" , torch .int64 , [ cur_qSplitSize , 1 ] ),
968+ ("kv_idx" , torch .int64 , [ 1 , cur_kvSplitSize ] ),
932969 ]
933970 ]
934971 subgraph_buffer = build_subgraph_buffer (
@@ -942,18 +979,83 @@ def lower_cpu(
942979 else :
943980 subgraph_buffer .freeze_layout ()
944981 mask_graph_placeholder_inps = [
945- create_placeholder (name , dtype , query .get_device ())
946- for name , dtype in [
947- ("b" , torch .int64 ),
948- ("h" , torch .int64 ),
949- ("q_idx" , torch .int64 ),
950- ("kv_idx" , torch .int64 ),
982+ create_placeholder (name , dtype , query .get_device (), size )
983+ for name , dtype , size in [
984+ ("score" , score_dtype , [cur_qSplitSize , cur_kvSplitSize ]),
985+ ("b" , torch .int64 , []),
986+ ("h" , torch .int64 , []),
987+ ("q_idx" , torch .int64 , [cur_qSplitSize , 1 ]),
988+ ("kv_idx" , torch .int64 , [1 , cur_kvSplitSize ]),
951989 ]
952990 ]
953- mask_graph_buffer = build_subgraph_buffer (
954- mask_graph_placeholder_inps + list (mask_mod_other_buffers ), mask_graph
991+
992+ # The original mask_graph works on a scalar and only includes
993+ # the logic of calculating the mask value.
994+ # We need to add the logic of applying the mark to the qk_data tensor
995+ # into the graph for the later codegen of this part.
996+ # Example:
997+ # mask_graph:
998+ # def mask_fn(b, h, q_idx, kv_idx):
999+ # mask = q_idx >= kv_idx
1000+ # return mask
1001+ # The converted_mask_graph should be:
1002+ # def converted_mask_fn(qk_data, b, h, q_idx, kv_idx):
1003+ # mask = q_idx >= kv_idx
1004+ # qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf")))
1005+ # return qk_data
1006+ def convert_mask_graph_module (mask_graph ):
1007+ gm = copy .deepcopy (mask_graph .graph_module )
1008+ graph = gm .graph
1009+ # Add qk_data as the first input
1010+ with graph .inserting_before (next (iter (graph .nodes ))):
1011+ qk_data_node = graph .placeholder ("qk_data" )
1012+
1013+ # Find the node that returns the mask
1014+ output_node = None
1015+ for node in graph .nodes :
1016+ if node .op == "output" :
1017+ output_node = node
1018+ break
1019+
1020+ # Get the mask node
1021+ assert output_node is not None
1022+ mask_node = output_node .args [0 ]
1023+
1024+ size_node = [cur_qSplitSize , cur_kvSplitSize ]
1025+ # Create a new node for torch.full
1026+ with graph .inserting_after (mask_node ):
1027+ full_node = graph .call_function (
1028+ torch .full ,
1029+ args = (size_node , - float ("inf" )),
1030+ kwargs = {"dtype" : score_dtype },
1031+ )
1032+
1033+ # Create a new node for torch.where
1034+ with graph .inserting_after (full_node ):
1035+ where_node = graph .call_function (
1036+ torch .ops .aten .where , args = (mask_node , qk_data_node , full_node )
1037+ )
1038+
1039+ # Update the output node to return the result of torch.where
1040+ output_node .args = (where_node ,)
1041+
1042+ graph .lint ()
1043+ converted = torch .fx .GraphModule (gm , graph )
1044+ return converted
1045+
1046+ converted_mask_graph_module = convert_mask_graph_module (mask_graph )
1047+
1048+ mask_graph_buffer = build_subgraph_module_buffer (
1049+ mask_graph_placeholder_inps + list (mask_mod_other_buffers ),
1050+ converted_mask_graph_module ,
9551051 )
9561052
1053+ # Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel.
1054+ pending = V .graph .sizevars .shape_env .pending_fresh_unbacked_symbols
1055+ V .graph .sizevars .shape_env .pending_fresh_unbacked_symbols = [
1056+ x for x in pending if x not in (cur_qSplitSize , cur_kvSplitSize )
1057+ ]
1058+
9571059 buffer_list = (
9581060 placeholder_inps
9591061 + list (score_mod_other_buffers )
@@ -1066,6 +1168,7 @@ def lower_cpu(
10661168 len_score_other = len (score_mod_other_buffers ),
10671169 len_mask_other = len (mask_mod_other_buffers ),
10681170 kernel_input_name_to_buffer = kernel_input_name_to_buffer ,
1171+ block_vars = (cur_qSplitSize , cur_kvSplitSize ),
10691172 )
10701173 inputs_for_autotuning = [
10711174 query ,
0 commit comments