Skip to content

Commit 331d5cf

Browse files
chunyuan-wpytorchmergebot
authored andcommitted
[inductor] [cpp] Support vectorization for score and mask in FlexAttention CPU (pytorch#143638)
## Description We generate vectorized kernel for score and mask in FlexAttention with this PR. ## Modification The main change include: - For the input and output buffer to the mask and score function, instead of passing scalars, we pass tensors to it. - For the mask function, the original function which works on a scalar only includes the logic of calculating the mask value. The PR added the logic of applying the mark to the qk_data tensor into the graph and then leverage the CPP backend to generate vectorized kernels. The original mask graph: ```python def mask_fn(b, h, q_idx, kv_idx): mask = q_idx >= kv_idx return mask ``` The converted_mask_graph should be: ```python def converted_mask_fn(qk_data, b, h, q_idx, kv_idx): mask = q_idx >= kv_idx qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf"))) return qk_data ``` ## Benchmark For q, k, v of shape: `[1, 32, 1024, 128]`, using 40 CPU cores, we observe over 20x speedup compared with the non vectorized version for both `is_causal` = `False` and `True`. ## Test plan The existing FlexAttention UTs (`test/inductor/test_flex_attention.py`, `test/inductor/test_flex_decoding.py`) can cover the change in this PR. ## Output code **Code before this PR is in scalar version:** ```cpp // apply score mod function for (int64_t row = 0; row < cur_qSplitSize; ++row) { for (int64_t col = 0; col < cur_kvSplitSize; col++) { std::vector<int64_t> b_idx = {i}; std::vector<int64_t> h_idx = {j}; std::vector<int64_t> q_idx = {m+row}; int64_t phisical_kv_idx = n+col; if (use_kv_indice) { phisical_kv_idx= *kv_logical_data * kvBlockSize + col; } std::vector<int64_t> kv_idx = {phisical_kv_idx}; accum_t* in_ptr0 = qk_data + row * cur_kvSplitSize + col; auto in_ptr1 = b_idx.data(); auto in_ptr2 = h_idx.data(); auto in_ptr3 = q_idx.data(); auto in_ptr4 = kv_idx.data(); accum_t* out_ptr0 = in_ptr0; { { { auto tmp0 = in_ptr0[static_cast<int64_t>(0L)]; out_ptr0[static_cast<int64_t>(0L)] = tmp0; } } } } } // Apply block mask, fill unused with -inf for (int64_t row = 0; row < cur_qSplitSize; ++row) { for (int64_t col = 0; col < cur_kvSplitSize; col++) { std::vector<int64_t> b_idx = {i}; std::vector<int64_t> h_idx = {j}; std::vector<int64_t> q_idx = {m+row}; int64_t phisical_kv_idx = n+col; if (use_kv_indice) { phisical_kv_idx= *kv_logical_data * kvBlockSize + col; } std::vector<int64_t> kv_idx = {phisical_kv_idx}; accum_t* qk_block = qk_data + row * cur_kvSplitSize + col; auto in_ptr1 = b_idx.data(); auto in_ptr2 = h_idx.data(); auto in_ptr3 = q_idx.data(); auto in_ptr4 = kv_idx.data(); std::vector<int64_t> temp = {0}; int64_t* out_ptr1 = temp.data(); { { { auto tmp0 = static_cast<bool>(true); out_ptr1[static_cast<int64_t>(0L)] = tmp0; } } } *qk_block = *out_ptr1 != 0 ? *qk_block : -std::numeric_limits<accum_t>::infinity(); } } ``` **Code after this PR will be vectorized:** ```cpp accum_t* in_ptr0 = qk_data; auto in_ptr1 = b_idx.data(); auto in_ptr2 = h_idx.data(); auto in_ptr3 = q_idx.data(); auto in_ptr4 = kv_idx.data(); // apply score mod function { accum_t* out_ptr0 = in_ptr0; { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(cur_qSplitSize); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(cur_kvSplitSize); x1+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L)))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + cur_kvSplitSize*x0), static_cast<int64_t>(16)); tmp0.store(out_ptr0 + static_cast<int64_t>(x1 + cur_kvSplitSize*x0)); } if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(cur_kvSplitSize))) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + cur_kvSplitSize*x0), static_cast<int64_t>(cur_kvSplitSize + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L)))))); tmp0.store(out_ptr0 + static_cast<int64_t>(x1 + cur_kvSplitSize*x0), static_cast<int64_t>(cur_kvSplitSize + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L)))))); } } } } } } // Apply block mask, fill unused with -inf { accum_t* out_ptr1 = in_ptr0; { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(cur_qSplitSize); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(cur_kvSplitSize); x1+=static_cast<int64_t>(16L)) { { if(C10_LIKELY(x1 >= static_cast<int64_t>(0) && x1 < static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L)))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x1 + cur_kvSplitSize*x0), static_cast<int64_t>(16)); auto tmp1 = static_cast<bool>(true); auto tmp2 = -std::numeric_limits<float>::infinity(); auto tmp3 = at::vec::VecMask<float,1>::from(tmp1); auto tmp4 = at::vec::Vectorized<float>(tmp2); auto tmp5 = decltype(tmp0)::blendv(tmp4, tmp0, tmp3.template cast<float,1>()); tmp5.store(out_ptr1 + static_cast<int64_t>(x1 + cur_kvSplitSize*x0)); } if(C10_UNLIKELY(x1 >= static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L)))) && x1 < static_cast<int64_t>(cur_kvSplitSize))) { for (int64_t x1_tail = static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>(cur_kvSplitSize), static_cast<int64_t>(16L))));x1_tail < static_cast<int64_t>(cur_kvSplitSize); x1_tail++) { auto tmp0 = in_ptr0[static_cast<int64_t>(x1_tail + cur_kvSplitSize*x0)]; auto tmp1 = static_cast<bool>(true); auto tmp2 = -std::numeric_limits<float>::infinity(); auto tmp3 = tmp1 ? tmp0 : tmp2; out_ptr1[static_cast<int64_t>(x1_tail + cur_kvSplitSize*x0)] = tmp3; } } } } } } } ``` Pull Request resolved: pytorch#143638 Approved by: https://github.com/jgong5, https://github.com/drisspg, https://github.com/leslie-fang-intel
1 parent ce38bfd commit 331d5cf

File tree

3 files changed

+183
-62
lines changed

3 files changed

+183
-62
lines changed

torch/_inductor/codegen/cpp_flex_attention_template.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -621,53 +621,45 @@
621621
{{kernel.kernel_name}}_mul_scale_kernel<accum_t>(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize);
622622
623623
{%- if score_mod and mask_mod %}
624-
// TODO: vectorization optimization for below score and mask codegen functions
625-
// apply score mod function
626-
for (int64_t row = 0; row < cur_qSplitSize; ++row) {
627-
for (int64_t col = 0; col < cur_kvSplitSize; col++) {
628-
std::vector<int64_t> b_idx = {i};
629-
std::vector<int64_t> h_idx = {j};
630-
std::vector<int64_t> q_idx = {m+row};
631-
int64_t phisical_kv_idx = n+col;
624+
// TODO: reduce the number of calls of q_idx and kv_idx initialization
625+
std::vector<int64_t> q_idx(cur_qSplitSize);
626+
for (int64_t i = 0; i < cur_qSplitSize; ++i) {
627+
q_idx[i] = m + i;
628+
}
629+
630+
std::vector<int64_t> kv_idx(cur_kvSplitSize);
631+
for (int64_t i = 0; i < cur_kvSplitSize; ++i) {
632632
if (use_kv_indice) {
633-
phisical_kv_idx= *kv_logical_data * kvBlockSize + col;
633+
kv_idx[i] = *kv_logical_data * kvBlockSize + i;
634+
} else {
635+
kv_idx[i] = n + i;
634636
}
635-
std::vector<int64_t> kv_idx = {phisical_kv_idx};
636-
accum_t* in_ptr0 = qk_data + row * cur_kvSplitSize + col;
637-
auto in_ptr1 = b_idx.data();
638-
auto in_ptr2 = h_idx.data();
639-
auto in_ptr3 = q_idx.data();
640-
auto in_ptr4 = kv_idx.data();
637+
}
638+
639+
std::vector<int64_t> b_idx = {i};
640+
std::vector<int64_t> h_idx = {j};
641+
642+
accum_t* in_ptr0 = qk_data;
643+
644+
auto in_ptr1 = b_idx.data();
645+
auto in_ptr2 = h_idx.data();
646+
auto in_ptr3 = q_idx.data();
647+
auto in_ptr4 = kv_idx.data();
648+
649+
// apply score mod function
650+
{
641651
{{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }}
642652
accum_t* out_ptr{{score_buf_idx}} = in_ptr0;
643-
{{ template.modification(score_mod, score_buf_name, score_buf_idx) }}
644-
}
653+
{{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }}
645654
}
655+
646656
// Apply block mask, fill unused with -inf
647-
for (int64_t row = 0; row < cur_qSplitSize; ++row) {
648-
for (int64_t col = 0; col < cur_kvSplitSize; col++) {
649-
std::vector<int64_t> b_idx = {i};
650-
std::vector<int64_t> h_idx = {j};
651-
std::vector<int64_t> q_idx = {m+row};
652-
int64_t phisical_kv_idx = n+col;
653-
if (use_kv_indice) {
654-
phisical_kv_idx= *kv_logical_data * kvBlockSize + col;
655-
}
656-
std::vector<int64_t> kv_idx = {phisical_kv_idx};
657-
accum_t* qk_block = qk_data + row * cur_kvSplitSize + col;
658-
auto in_ptr1 = b_idx.data();
659-
auto in_ptr2 = h_idx.data();
660-
auto in_ptr3 = q_idx.data();
661-
auto in_ptr4 = kv_idx.data();
657+
{
662658
{{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }}
663-
std::vector<int64_t> temp = {0};
664-
int64_t* out_ptr{{mask_buf_idx}} = temp.data();
665-
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx) }}
666-
*qk_block = *out_ptr{{mask_buf_idx}} != 0
667-
? *qk_block
668-
: -std::numeric_limits<accum_t>::infinity();
669-
}
659+
accum_t* out_ptr{{mask_buf_idx}} = in_ptr0;
660+
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
670661
}
662+
671663
{%- endif %}
672664
// Update coefficients with Softmax
673665
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
@@ -792,6 +784,7 @@ def __init__(
792784
len_score_other,
793785
len_mask_other,
794786
kernel_input_name_to_buffer,
787+
block_vars,
795788
) -> None:
796789
assert layout.dtype in [torch.float, torch.bfloat16]
797790
super().__init__("flex_attention", input_nodes, layout, parallel_num_threads())
@@ -824,6 +817,7 @@ def get_idx(buf_name):
824817
self.len_score_other = len_score_other
825818
self.len_mask_other = len_mask_other
826819
self.kernel_input_name_to_buffer = kernel_input_name_to_buffer
820+
self.block_vars = block_vars
827821
self.extra_sizevars = list(
828822
OrderedSet(
829823
val
@@ -935,14 +929,15 @@ def modification(self, subgraph_buffer, output_name, output_idx):
935929
cpp_kernel_proxy = CppKernelProxy(kernel_group)
936930
bodies = []
937931
var_sizes_list = []
938-
939-
var_sizes = tuple([]) # type: ignore[var-annotated] # noqa: C409
940-
output_index = 0
932+
var_sizes = tuple(subgraph_buffer.get_size())
941933
var_ranges = {
942934
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
943935
for i, sz in enumerate(var_sizes)
944936
}
945937

938+
dst_layout = subgraph_buffer.get_layout()
939+
output_index = dst_layout.make_indexer()([*var_ranges.keys()])
940+
946941
def fn(*args):
947942
V.ops.store(
948943
output_name,
@@ -970,7 +965,24 @@ def fn(*args):
970965

971966
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
972967
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
973-
return kernel_group.loops_code.getvalue()
968+
output_code = kernel_group.loops_code.getvalue()
969+
970+
var_q_symbol, var_kv_symbol = self.block_vars
971+
# See [Note] Handle the case where the split sizes are not statically known.
972+
# We don't know the value of qBlockSize and rkvBlockSize during compilation time
973+
# thus we've represented them by symbols.
974+
# We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize"
975+
# in the generated code thus they'll be filled with the real value during runtime.
976+
if var_q_symbol in kernel_group.args.sizevars:
977+
output_code = output_code.replace(
978+
kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize"
979+
)
980+
if var_kv_symbol in kernel_group.args.sizevars:
981+
output_code = output_code.replace(
982+
kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize"
983+
)
984+
985+
return output_code
974986

975987
@staticmethod
976988
def add_choices(
@@ -987,6 +999,7 @@ def add_choices(
987999
len_score_other,
9881000
len_mask_other,
9891001
kernel_input_name_to_buffer,
1002+
block_vars,
9901003
):
9911004
def preprocessor(input_nodes, layout):
9921005
return input_nodes, layout
@@ -1010,6 +1023,7 @@ def postprocessor(output):
10101023
len_score_other=len_score_other,
10111024
len_mask_other=len_mask_other,
10121025
kernel_input_name_to_buffer=kernel_input_name_to_buffer,
1026+
block_vars=block_vars,
10131027
)
10141028
template.maybe_append_choice(choices)
10151029
return template

torch/_inductor/codegen/cpp_template_kernel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,10 @@ def store_outputs(
502502
)
503503
return ""
504504

505+
def check_bounds(self, expr, size, lower, upper):
506+
# CppTemplateKernel does not need codegen related operations
507+
return
508+
505509

506510
class CppTemplateCaller(ir.ChoiceCaller):
507511
"""

torch/_inductor/kernel/flex_attention.py

Lines changed: 122 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: allow-untyped-defs
22
""" Triton Implementation of the flex_attention Kernel"""
33

4+
import copy
45
import logging
56
import math
67
from collections.abc import Sequence
@@ -14,6 +15,8 @@
1415
from torch._inductor.virtualized import V
1516
from torch.utils._ordered_set import OrderedSet
1617
from torch.utils._pytree import tree_map
18+
from torch.utils._sympy.numbers import int_oo
19+
from torch.utils._sympy.value_ranges import ValueRanges
1720

1821
from .. import config
1922
from ..ir import (
@@ -100,10 +103,21 @@ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
100103

101104

102105
def create_placeholder(
103-
name: str, dtype: torch.dtype, device: torch.device
106+
name: str,
107+
dtype: torch.dtype,
108+
device: torch.device,
109+
size: Optional[list[int]] = None,
104110
) -> TensorBox:
105111
"""Creates a placeholder input buffers for producing subgraph_output."""
106-
input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], []))
112+
input_buffer = InputBuffer(
113+
name=name,
114+
layout=FixedLayout(
115+
device,
116+
dtype,
117+
size if size else [],
118+
FlexibleLayout.contiguous_strides(size) if size else [],
119+
),
120+
)
107121
return TensorBox.create(input_buffer)
108122

109123

@@ -173,7 +187,9 @@ def zeros_and_scatter_lowering(shape: list[int], indices, values):
173187
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
174188

175189

176-
def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults:
190+
def build_subgraph_module_buffer(
191+
args: list[TensorBox], graph_module: torch.fx.GraphModule
192+
) -> SubgraphResults:
177193
"""This function's goal is to take in the required args and produce the subgraph buffer
178194
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
179195
@@ -184,7 +200,7 @@ def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> Subgraph
184200
from ..subgraph_lowering import PointwiseSubgraphLowering
185201

186202
pw_subgraph = PointwiseSubgraphLowering(
187-
subgraph.graph_module,
203+
graph_module,
188204
root_graph_lowering=V.graph,
189205
allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]),
190206
additional_lowerings={
@@ -228,6 +244,10 @@ def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]:
228244
return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
229245

230246

247+
def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults:
248+
return build_subgraph_module_buffer(args, subgraph.graph_module)
249+
250+
231251
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
232252
compute_next_offset_func = r"""
233253
@triton.jit
@@ -921,14 +941,31 @@ def lower_cpu(
921941
)
922942

923943
fake_buffers: list[Buffer] = [] # noqa: F821
944+
945+
# [Note] Handle the case where the split sizes are not statically known.
946+
# The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime.
947+
# We use symbols to represent them during the compilation here.
948+
# They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in
949+
# the modification function of the CppFlexAttentionTemplate class.
950+
cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr
951+
cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr
952+
shape_env = V.graph.sizevars.shape_env
953+
954+
# We don't know the concret value of cur_qSplitSize and cur_kvSplitSize during the compilation.
955+
# Mark symbols > 1 to ensure broadcasting is always applied.
956+
# This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`.
957+
shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo)
958+
shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo)
959+
960+
score_dtype = torch.float
924961
placeholder_inps = [
925-
create_placeholder(name, dtype, query.get_device())
926-
for name, dtype in [
927-
("score", torch.float),
928-
("b", torch.int64),
929-
("h", torch.int64),
930-
("q_idx", torch.int64),
931-
("kv_idx", torch.int64),
962+
create_placeholder(name, dtype, query.get_device(), size)
963+
for name, dtype, size in [
964+
("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]),
965+
("b", torch.int64, []),
966+
("h", torch.int64, []),
967+
("q_idx", torch.int64, [cur_qSplitSize, 1]),
968+
("kv_idx", torch.int64, [1, cur_kvSplitSize]),
932969
]
933970
]
934971
subgraph_buffer = build_subgraph_buffer(
@@ -942,18 +979,83 @@ def lower_cpu(
942979
else:
943980
subgraph_buffer.freeze_layout()
944981
mask_graph_placeholder_inps = [
945-
create_placeholder(name, dtype, query.get_device())
946-
for name, dtype in [
947-
("b", torch.int64),
948-
("h", torch.int64),
949-
("q_idx", torch.int64),
950-
("kv_idx", torch.int64),
982+
create_placeholder(name, dtype, query.get_device(), size)
983+
for name, dtype, size in [
984+
("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]),
985+
("b", torch.int64, []),
986+
("h", torch.int64, []),
987+
("q_idx", torch.int64, [cur_qSplitSize, 1]),
988+
("kv_idx", torch.int64, [1, cur_kvSplitSize]),
951989
]
952990
]
953-
mask_graph_buffer = build_subgraph_buffer(
954-
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
991+
992+
# The original mask_graph works on a scalar and only includes
993+
# the logic of calculating the mask value.
994+
# We need to add the logic of applying the mark to the qk_data tensor
995+
# into the graph for the later codegen of this part.
996+
# Example:
997+
# mask_graph:
998+
# def mask_fn(b, h, q_idx, kv_idx):
999+
# mask = q_idx >= kv_idx
1000+
# return mask
1001+
# The converted_mask_graph should be:
1002+
# def converted_mask_fn(qk_data, b, h, q_idx, kv_idx):
1003+
# mask = q_idx >= kv_idx
1004+
# qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf")))
1005+
# return qk_data
1006+
def convert_mask_graph_module(mask_graph):
1007+
gm = copy.deepcopy(mask_graph.graph_module)
1008+
graph = gm.graph
1009+
# Add qk_data as the first input
1010+
with graph.inserting_before(next(iter(graph.nodes))):
1011+
qk_data_node = graph.placeholder("qk_data")
1012+
1013+
# Find the node that returns the mask
1014+
output_node = None
1015+
for node in graph.nodes:
1016+
if node.op == "output":
1017+
output_node = node
1018+
break
1019+
1020+
# Get the mask node
1021+
assert output_node is not None
1022+
mask_node = output_node.args[0]
1023+
1024+
size_node = [cur_qSplitSize, cur_kvSplitSize]
1025+
# Create a new node for torch.full
1026+
with graph.inserting_after(mask_node):
1027+
full_node = graph.call_function(
1028+
torch.full,
1029+
args=(size_node, -float("inf")),
1030+
kwargs={"dtype": score_dtype},
1031+
)
1032+
1033+
# Create a new node for torch.where
1034+
with graph.inserting_after(full_node):
1035+
where_node = graph.call_function(
1036+
torch.ops.aten.where, args=(mask_node, qk_data_node, full_node)
1037+
)
1038+
1039+
# Update the output node to return the result of torch.where
1040+
output_node.args = (where_node,)
1041+
1042+
graph.lint()
1043+
converted = torch.fx.GraphModule(gm, graph)
1044+
return converted
1045+
1046+
converted_mask_graph_module = convert_mask_graph_module(mask_graph)
1047+
1048+
mask_graph_buffer = build_subgraph_module_buffer(
1049+
mask_graph_placeholder_inps + list(mask_mod_other_buffers),
1050+
converted_mask_graph_module,
9551051
)
9561052

1053+
# Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel.
1054+
pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols
1055+
V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [
1056+
x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize)
1057+
]
1058+
9571059
buffer_list = (
9581060
placeholder_inps
9591061
+ list(score_mod_other_buffers)
@@ -1066,6 +1168,7 @@ def lower_cpu(
10661168
len_score_other=len(score_mod_other_buffers),
10671169
len_mask_other=len(mask_mod_other_buffers),
10681170
kernel_input_name_to_buffer=kernel_input_name_to_buffer,
1171+
block_vars=(cur_qSplitSize, cur_kvSplitSize),
10691172
)
10701173
inputs_for_autotuning = [
10711174
query,

0 commit comments

Comments
 (0)