Skip to content

Commit 4211e98

Browse files
committed
Merge branch 'fix-nogil-autotune' of https://github.com/qubitium/triton into fix-nogil-autotune
2 parents 76089e6 + c41aae5 commit 4211e98

File tree

13 files changed

+1086
-437
lines changed

13 files changed

+1086
-437
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,37 @@ LogicalResult Fp4ToFpOp::verifyFp4ToFp(mlir::Operation *op,
442442
<< ", dst=" << resShape[i] << ", axis=" << axis << ")";
443443
}
444444
}
445+
if (bool(resTy.getEncoding()) != bool(srcTy.getEncoding()))
446+
return op->emitError()
447+
<< "source and result must both have an encoding, or neither";
448+
if (!resTy.getEncoding()) {
449+
return success();
450+
}
451+
auto srcLl = toLinearLayout(srcTy);
452+
auto resLl = toLinearLayout(resTy);
453+
auto *ctx = srcTy.getContext();
454+
auto regDim = StringAttr::get(ctx, "register");
455+
auto outDims = standardOutDimNames(ctx, rank);
456+
457+
// We use backward inference here as it is striclty more general
458+
Attribute inferSrc;
459+
auto dialect =
460+
resTy.getEncoding()
461+
.getDialect()
462+
.getRegisteredInterface<triton::DialectInferLayoutInterface>();
463+
assert(dialect);
464+
if (failed(dialect->inferFp4ToFpOpEncoding(
465+
resTy.getShape(), axis, resTy.getEncoding(), inferSrc,
466+
/*fwdInference*/ false, std::nullopt))) {
467+
return op->emitError() << "failed to infer encoding";
468+
}
469+
if (!areLayoutsEquivalent(srcTy.getShape(),
470+
cast<LayoutEncodingTrait>(inferSrc),
471+
cast<LayoutEncodingTrait>(srcTy.getEncoding())))
472+
return op->emitError()
473+
<< "Src and Dst encodings are not compatible:\n"
474+
<< toLinearLayout(srcTy.getShape(), inferSrc).toString() << "\n"
475+
<< srcLl.toString();
445476
return success();
446477
}
447478

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ static Attribute inferDstEncoding(triton::gpu::Fp4ToFpOp op, Attribute srcEnc) {
431431

432432
static Attribute inferSrcEncoding(triton::gpu::Fp4ToFpOp op, Attribute dstEnc) {
433433
Attribute srcEnc;
434-
auto shape = op.getSrc().getType().getShape();
434+
auto shape = op.getType().getShape();
435435
if (succeeded(
436436
dstEnc.getDialect()
437437
.getRegisteredInterface<triton::DialectInferLayoutInterface>()

python/src/llvm.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ std::unique_ptr<TargetMachine>
4747
createTargetMachine(llvm::Module *module, std::string proc,
4848
bool enable_fp_fusion, const std::string &features) {
4949
std::string error;
50-
auto target = llvm::TargetRegistry::lookupTarget(
51-
module->getTargetTriple().str(), error);
50+
auto target =
51+
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
5252
llvm::TargetOptions opt;
5353
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
5454
if (enable_fp_fusion)
@@ -278,15 +278,16 @@ void init_triton_llvm(py::module &&m) {
278278
const std::string proc,
279279
const std::string features) {
280280
std::string error;
281-
auto target = llvm::TargetRegistry::lookupTarget(triple, error);
281+
llvm::Triple targetTriple(triple);
282+
auto target = llvm::TargetRegistry::lookupTarget(targetTriple, error);
282283
if (!target) {
283284
throw std::runtime_error("target lookup error: " + error);
284285
}
285286
llvm::TargetOptions opt;
286287
// Target machine is only used to create the data layout.
287288
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
288-
llvm::Triple(triple), proc, features, opt, llvm::Reloc::PIC_,
289-
std::nullopt, llvm::CodeGenOptLevel::None)};
289+
targetTriple, proc, features, opt, llvm::Reloc::PIC_, std::nullopt,
290+
llvm::CodeGenOptLevel::None)};
290291
// set data layout
291292
mod->setDataLayout(machine->createDataLayout());
292293
});

test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
3636
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
3737
// COMMON: buffer_load %arg0[%[[offset]]]
3838
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
39-
// COMMON: buffer_load %arg1[%[[offset]]]
39+
// Note: offset = pid * 256 + arange(0, 256); byte-ofst="offset * sizeof(i32)" may not fall into range of 2G.
40+
// COMMON-NOT: buffer_load %arg1[%[[offset]]]
4041
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
4142
// COMMON: %[[data:.*]] = arith.addf
4243
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
4344
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
4445
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
45-
// COMMON: buffer_store %[[data]], %arg2[%[[offset]]]
46+
// Note: see the explanation above
47+
// COMMON-NOT: buffer_store %[[data]], %arg2[%[[offset]]]
4648
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
4749
tt.return
4850
}
@@ -70,7 +72,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
7072
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
7173
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
7274
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
73-
// COMMON: buffer_load %[[scalar_ptr]][%[[offset]]]
75+
// Note: the base "scalar_ptr" points to arg0 which is a large-tensor.
76+
// the offset="%sub + arange(0,1024)" where "%sub=pid*1024-128",
77+
// We can prove "offset > 0", but cannot prove byte-offset < 2G.
78+
// COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]]
7479
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
7580
tt.return %10 : tensor<1024xf32, #blocked>
7681
}
@@ -122,7 +127,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
122127
// COMMON: %[[offset_32_bit:.*]] = arith.trunci
123128
%narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked>
124129
%9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
125-
// COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
130+
// Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024))
131+
// offset is in [0, i32-max].
132+
// COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
126133
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
127134
tt.return %10 : tensor<1024xf32, #blocked>
128135
}
@@ -555,7 +562,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
555562
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
556563
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
557564
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
558-
// COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
565+
// Note: the large tensor is accessed, offset is in the range of [0, smax].
566+
// without tl.assume the range would be [-128, smax]
567+
// COMMON-NOT: amdgpu.buffer_atomic_rmw
559568
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
560569
tt.return %8 : tensor<1024xf32, #blocked>
561570
}

0 commit comments

Comments
 (0)