Skip to content

Commit 658b5b2

Browse files
authored
[AMD] Add scheduling hint for attention optimizations (#6290)
This PR cleaned up iglp scheduling variants and introduced a new scheduling variant 'attention' to group attention-related optimizations together to improve usability, including: - `sink-insts-to-avoid-spills` LLVM commandline option to avoid register spills - `ROCDL::IglpOpt` intrinsic and `ROCDL::SchedBarrier` around it to reschedule instructions, specifically, we use iglp 2 to interleave mfma and exp instructions. This is experimental feature for now and may change in the future.
1 parent 6fa33ef commit 658b5b2

File tree

8 files changed

+138
-72
lines changed

8 files changed

+138
-72
lines changed

test/TritonGPU/amd/amd-instruction-sched.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
// RUN: triton-opt %s -triton-amdgpu-insert-instruction-sched-hints="variant=llvm_iglp_0" -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
2-
// RUN: triton-opt %s -triton-amdgpu-insert-instruction-sched-hints="variant=llvm_iglp_1" -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
31
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1" -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=1" -triton-amdgpu-insert-instruction-sched-hints="variant=local_prefetch" -tritongpu-reduce-data-duplication -optimize-amd-lds-usage="target-arch=gfx942" -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm="arch=gfx942" -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
42
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1" -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=2" -triton-amdgpu-insert-instruction-sched-hints="variant=local_prefetch" -tritongpu-reduce-data-duplication -optimize-amd-lds-usage="target-arch=gfx942" -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm="arch=gfx942" -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
53
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1" -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=2" -triton-amdgpu-insert-instruction-sched-hints="variant=local_prefetch" -tritongpu-reduce-data-duplication -optimize-amd-lds-usage="target-arch=gfx942" -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm="arch=gfx942" -triton-amdgpu-lower-insert-instruction-sched-hints="arch=gfx942 num_stages=2" -debug-only="lower-insert-instruction-sched-hints" -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
64
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=1" | FileCheck %s -check-prefix=LABELING_PS_1
75
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=2" | FileCheck %s -check-prefix=LABELING_PS_2
86

97
module {
10-
// INSERT_IGLP0-LABEL: @test_dot_op
11-
// INSERT_IGLP1-LABEL: @test_dot_op
128
// INSTR_COUNT_NS1-LABEL: @test_dot_op
139
// INSTR_COUNT_NS2-LABEL: @test_dot_op
1410
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: @test_dot_op
@@ -44,9 +40,6 @@ module {
4440
%a = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>>
4541
%b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>>
4642

47-
// INSERT_IGLP0: rocdl.iglp.opt 0
48-
// INSERT_IGLP1: rocdl.iglp.opt 1
49-
5043
// INSTR_COUNT_NS1: amdgpu.instruction_sched_hint
5144
// INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false
5245
// INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints="variant=attention" | FileCheck %s -check-prefix=INSTR_HINT
2+
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints="variant=attention" -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=LOWER_HINT
3+
4+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
5+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>
6+
#dot_op_a = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>
7+
#dot_op_b = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>
8+
// INSTR_HINT-LABEL: @insert_schedule_hint
9+
// LOWER_HINT-LABEL: @insert_schedule_hint
10+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
11+
tt.func public @insert_schedule_hint(
12+
%lb : index, %ub : index, %step : index,
13+
%arg0: tensor<128x128xf32, #dot_op_a>,
14+
%arg1: tensor<128x128xf32, #dot_op_b>,
15+
%arg2: tensor<128x128x!tt.ptr<f32>, #blocked>
16+
) {
17+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
18+
// INSTR_HINT: scf.for
19+
// INSTR_HINT-NEXT: amdgpu.instruction_sched_hint
20+
// INSTR_HINT-SAME: variant = #amdgpu.SchedHintVariant<attention>
21+
22+
// LOWER_HINT: scf.for
23+
// LOWER_HINT-NEXT: rocdl.sched.barrier 0
24+
// LOWER_HINT-COUNT-2: tt.dot
25+
// LOWER_HINT: rocdl.iglp.opt 2
26+
// LOWER_HINT-NEXT: rocdl.sched.barrier 0
27+
// LOWER_HINT-NEXT: scf.yield
28+
%loop = scf.for %iv = %lb to %ub step %step iter_args(%c = %cst) -> (tensor<128x128xf32, #mma>) {
29+
%4 = tt.dot %arg0, %arg1, %c : tensor<128x128xf32, #dot_op_a> * tensor<128x128xf32, #dot_op_b> -> tensor<128x128xf32, #mma>
30+
%5 = math.exp2 %4 : tensor<128x128xf32, #mma>
31+
%6 = ttg.convert_layout %5 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #dot_op_a>
32+
%7 = tt.dot %6, %arg1, %c : tensor<128x128xf32, #dot_op_a> * tensor<128x128xf32, #dot_op_b> -> tensor<128x128xf32, #mma>
33+
scf.yield %7 : tensor<128x128xf32, #mma>
34+
}
35+
%8 = ttg.convert_layout %loop : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
36+
tt.store %arg2, %8 : tensor<128x128x!tt.ptr<f32>, #blocked>
37+
tt.return
38+
}
39+
}

third_party/amd/backend/compiler.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,14 @@ class HIPOptions:
5353
#
5454
# Current experimental scheduling variants:
5555
#
56-
# llvm-iglp-0: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `0` to the GEMM's
57-
# k-loop; i.e., "interleave DS and MFMA instructions for small GEMM kernels".
58-
# llvm-iglp-1: injects `llvm.amdgcn.iglp_opt` intrinsic call with value `1` to the GEMM's
59-
# k-loop; i.e., "interleave DS and MFMA instructions for single wave small
60-
# GEMM kernels.".
6156
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
6257
# Kernel library. Note, this variant requires the use of buffer load/store ops
6358
# and a special software pipelining style - i.e., 1x LDS and 1x register
6459
# prefetch buffers for each GEMM tile.
65-
instruction_sched_variant: str = 'none'
60+
# attention: enables a bunch of optimizations for attention kernels, including:
61+
# - iglp 2 and sched.barrier around it
62+
# - sink-insts-to-avoid-spills flag to avoid register spills
63+
schedule_hint: str = 'none'
6664

6765
def __post_init__(self):
6866
default_libdir = Path(__file__).parent / 'lib'
@@ -242,7 +240,7 @@ def make_ttgir(mod, metadata, options):
242240
use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1
243241

244242
# The `local-prefetch` scheduling variant requires turning on buffer ops.
245-
if options.instruction_sched_variant == "local-prefetch":
243+
if options.schedule_hint == "local-prefetch":
246244
global_prefetch = local_prefetch = 1
247245

248246
if amd.has_matrix_core_feature(options.arch):
@@ -256,8 +254,8 @@ def make_ttgir(mod, metadata, options):
256254
if use_async_copy:
257255
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
258256
passes.common.add_canonicalizer(pm)
259-
if options.instruction_sched_variant.lower() != "none":
260-
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
257+
if options.schedule_hint.lower() != "none":
258+
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.schedule_hint)
261259
passes.ttgpuir.add_optimize_dot_operands(pm, True)
262260
passes.ttgpuir.add_remove_layout_conversions(pm)
263261
passes.ttgpuir.add_reduce_data_duplication(pm)
@@ -314,7 +312,7 @@ def make_llir(src, metadata, options):
314312
passes.common.add_canonicalizer(pm)
315313
passes.common.add_cse(pm)
316314
passes.common.add_symbol_dce(pm)
317-
if options.instruction_sched_variant.lower() != "none":
315+
if options.schedule_hint.lower() != "none":
318316
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
319317
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
320318
passes.llvmir.add_di_scope(pm)
@@ -396,7 +394,14 @@ def make_amdgcn(src, metadata, options):
396394
assert len(names) == 1
397395
metadata["name"] = names[0]
398396
# llvm -> hsaco
399-
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
397+
flags = []
398+
# The sink-insts-to-avoid-spills flag asks LLVM backend to sink instructions
399+
# into loops to avoid register spills in the MachineSinking pass, while it
400+
# can also lead to regression in some cases. But from current observation,
401+
# the regression is not significant. It would be better to have some heuristics.
402+
if options.schedule_hint == 'attention':
403+
flags.append('sink-insts-to-avoid-spills')
404+
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', flags, options.enable_fp_fusion, False)
400405
if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
401406
print("// -----// AMDGCN Dump //----- //")
402407
print(amdgcn)

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,14 @@ class TritonAMDGPU_I32EnumAttr<string mnemonic, TritonAMDGPU_I32Enum enumInfo> :
7373
}
7474

7575
def SchedHintCaseNone : I32EnumAttrCase<"none", 0>;
76-
def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>;
77-
def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>;
78-
def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>;
76+
def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 1>;
77+
def SchedHintCaseAttention : I32EnumAttrCase<"attention", 2>;
7978

8079
def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum<
8180
"SchedHint", "Instruction Scheduling Hints for AMD GPUs", [
8281
SchedHintCaseNone,
83-
SchedHintCaseLLVMIglp0,
84-
SchedHintCaseLLVMIglp1,
8582
SchedHintCaseLocalPrefetch,
83+
SchedHintCaseAttention,
8684
]>;
8785

8886
def TritonAMDGPU_SchedHintVariantAttr :

third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "SchedInstructions.h"
22
#include "TritonAMDGPUToLLVM/Passes.h"
33
#include "TritonAMDGPUToLLVM/TargetUtils.h"
4+
#include "Utility.h"
45
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
56
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
67
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -19,6 +20,7 @@ namespace mlir::triton {
1920
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2021

2122
using namespace mlir;
23+
using ::mlir::LLVM::AMD::isChainDotHead;
2224

2325
// TODO: The following passes/algorithms are applicable only for a single
2426
// `tt.dot` op in a `scf.for` block -i.e., a single schedule hint op per block.
@@ -425,7 +427,8 @@ struct InstructionSchedHintsRewriter
425427
// not supposed to be used together with IGLP OPT according to the AMDGPU
426428
// backend documentation.
427429
const bool limitSchedulingRange =
428-
schedVariant == mlir::triton::amdgpu::SchedHint::local_prefetch;
430+
schedVariant == mlir::triton::amdgpu::SchedHint::local_prefetch ||
431+
schedVariant == mlir::triton::amdgpu::SchedHint::attention;
429432
;
430433
Location loc = instructionSchedHint->getLoc();
431434
Block *block = instructionSchedHint->getBlock();
@@ -438,13 +441,12 @@ struct InstructionSchedHintsRewriter
438441
rewriter.setInsertionPoint(block, std::prev(block->end()));
439442

440443
switch (schedVariant) {
441-
case mlir::triton::amdgpu::SchedHint::llvm_iglp_0:
442-
case mlir::triton::amdgpu::SchedHint::llvm_iglp_1:
443-
createIglpOpt(rewriter, loc, static_cast<int>(schedVariant) - 1);
444-
break;
445444
case mlir::triton::amdgpu::SchedHint::local_prefetch:
446445
createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint);
447446
break;
447+
case mlir::triton::amdgpu::SchedHint::attention:
448+
createIglpOpt(rewriter, loc, 2);
449+
break;
448450
case mlir::triton::amdgpu::SchedHint::none:
449451
default:
450452
break;
@@ -520,7 +522,8 @@ struct TritonAMDGPUInsertInstructionSchedHints
520522
return;
521523
}
522524

523-
if (schedHint != mlir::triton::amdgpu::SchedHint::none) {
525+
switch (schedHint) {
526+
case mlir::triton::amdgpu::SchedHint::local_prefetch:
524527
mod.walk([&](scf::ForOp forOp) {
525528
// Note, instruction schedule barriers are inserted only in the case of
526529
// a single `tt.dot` op in a `scf::ForOp` scope in the current
@@ -532,6 +535,28 @@ struct TritonAMDGPUInsertInstructionSchedHints
532535
schedHint);
533536
}
534537
});
538+
break;
539+
case mlir::triton::amdgpu::SchedHint::attention:
540+
mod.walk([&](scf::ForOp forOp) {
541+
// The attention schedule hint is inserted to the beginning of a
542+
// for-loop with chained dots.
543+
auto result = forOp->walk([](triton::DotOp op) {
544+
if (isChainDotHead(op))
545+
return WalkResult::interrupt();
546+
return WalkResult::advance();
547+
});
548+
549+
if (result.wasInterrupted()) {
550+
OpBuilder rewriter(ctx);
551+
rewriter.setInsertionPointToStart(forOp.getBody());
552+
rewriter.create<triton::amdgpu::InstructionSchedHint>(forOp->getLoc(),
553+
schedHint);
554+
}
555+
});
556+
break;
557+
case mlir::triton::amdgpu::SchedHint::none:
558+
default:
559+
break;
535560
}
536561
}
537562
};

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1010

11+
namespace tt = mlir::triton;
1112
using mlir::triton::ModuleAxisInfoAnalysis;
1213
using mlir::triton::AMD::DppCtrl;
1314
using mlir::triton::AMD::ISAFamily;
@@ -641,4 +642,43 @@ bool isUsedByDotScaledOp(Operation *op) {
641642
});
642643
}
643644

645+
bool isChainDotHead(tt::DotOpInterface dotOp) {
646+
auto isInSameRegion = [&dotOp](Operation *op) {
647+
return op->getParentRegion() == dotOp->getParentRegion();
648+
};
649+
ForwardSliceOptions fwdOpt;
650+
fwdOpt.filter = isInSameRegion;
651+
SetVector<mlir::Operation *> fwdSlices;
652+
getForwardSlice(dotOp, &fwdSlices, fwdOpt);
653+
for (Operation *op : fwdSlices) {
654+
if (auto dOp = dyn_cast<tt::DotOpInterface>(op)) {
655+
assert(dOp != dotOp);
656+
auto opA = dOp.getA().getDefiningOp();
657+
if (opA && fwdSlices.contains(opA)) {
658+
return true;
659+
}
660+
}
661+
}
662+
return false;
663+
}
664+
665+
bool isChainDotTail(tt::DotOpInterface dotOp) {
666+
auto isInSameRegion = [&dotOp](Operation *op) {
667+
return op->getParentRegion() == dotOp->getParentRegion();
668+
};
669+
BackwardSliceOptions bwdOpt;
670+
bwdOpt.omitBlockArguments = true;
671+
bwdOpt.filter = isInSameRegion;
672+
SetVector<Operation *> bwdSlices;
673+
Operation *opA = dotOp.getA().getDefiningOp();
674+
if (!opA)
675+
return false;
676+
getBackwardSlice(opA, &bwdSlices, bwdOpt);
677+
if (llvm::find_if(bwdSlices, [](Operation *op) {
678+
return isa<tt::DotOpInterface>(op);
679+
}) != bwdSlices.end())
680+
return true;
681+
return false;
682+
}
683+
644684
} // namespace mlir::LLVM::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter,
9595
// Return true if op is used by DotScaledOp or UpcastMXFPOp ops.
9696
bool isUsedByDotScaledOp(Operation *op);
9797

98+
// Check if the result of this tl.dot is used as opA of another tl.dot
99+
// in the same region
100+
bool isChainDotHead(mlir::triton::DotOpInterface dotOp);
101+
102+
// Check if the opA of this tl.dot is the result of another tl.dot
103+
// in the same region
104+
bool isChainDotTail(mlir::triton::DotOpInterface dotOp);
98105
} // namespace mlir::LLVM::AMD
99106

100107
#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
using namespace mlir;
1616
namespace tt = mlir::triton;
1717
namespace ttg = mlir::triton::gpu;
18+
using ::mlir::LLVM::AMD::isChainDotHead;
19+
using ::mlir::LLVM::AMD::isChainDotTail;
1820
using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType;
1921
using mlir::triton::gpu::chooseScaledMfmaScaleLayout;
2022

@@ -55,49 +57,6 @@ FailureOr<ScaleDotElemType> mlirTypeToScaledElemType(Type type) {
5557
.Default([](Type) { return failure(); });
5658
}
5759

58-
// Check if the result of this tl.dot is used as opA of another tl.dot
59-
// in the same region
60-
bool isChainDotHead(tt::DotOpInterface dotOp) {
61-
auto isInSameRegion = [&dotOp](Operation *op) {
62-
return op->getParentRegion() == dotOp->getParentRegion();
63-
};
64-
ForwardSliceOptions fwdOpt;
65-
fwdOpt.filter = isInSameRegion;
66-
SetVector<mlir::Operation *> fwdSlices;
67-
getForwardSlice(dotOp, &fwdSlices, fwdOpt);
68-
for (Operation *op : fwdSlices) {
69-
if (auto dOp = dyn_cast<tt::DotOpInterface>(op)) {
70-
assert(dOp != dotOp);
71-
auto opA = dOp.getA().getDefiningOp();
72-
if (opA && fwdSlices.contains(opA)) {
73-
return true;
74-
}
75-
}
76-
}
77-
return false;
78-
}
79-
80-
// Check if the opA of this tl.dot is the result of another tl.dot
81-
// in the same region
82-
bool isChainDotTail(tt::DotOpInterface dotOp) {
83-
auto isInSameRegion = [&dotOp](Operation *op) {
84-
return op->getParentRegion() == dotOp->getParentRegion();
85-
};
86-
BackwardSliceOptions bwdOpt;
87-
bwdOpt.omitBlockArguments = true;
88-
bwdOpt.filter = isInSameRegion;
89-
SetVector<Operation *> bwdSlices;
90-
Operation *opA = dotOp.getA().getDefiningOp();
91-
if (!opA)
92-
return false;
93-
getBackwardSlice(opA, &bwdSlices, bwdOpt);
94-
if (llvm::find_if(bwdSlices, [](Operation *op) {
95-
return isa<tt::DotOpInterface>(op);
96-
}) != bwdSlices.end())
97-
return true;
98-
return false;
99-
}
100-
10160
SmallVector<unsigned, 3>
10261
warpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
10362
std::pair<int64_t, int64_t> shapePerWarp) {

0 commit comments

Comments
 (0)