Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def TT_AddPtrOp : TT_Op<"addptr",
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize,
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
Expand Down
15 changes: 4 additions & 11 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,10 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getNumOperands() == 2) { // ptr & mask
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.getOperands()[1], adaptor.other(), adaptor.cache(),
adaptor.evict(), adaptor.isVolatile());
} else {
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
adaptor.isVolatile());
}
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
adaptor.isVolatile());
return success();
}
};
Expand Down
22 changes: 19 additions & 3 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,34 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {

SmallVector<Type> operandTypes;
operandTypes.push_back(getPointerTypeFromTensor(resultTypes[0])); // ptr
if (allOperands.size() >= 2)
int hasMask = 0, hasOther = 0;
if (allOperands.size() >= 2) {
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
if (allOperands.size() >= 3)
hasMask = 1;
}
if (allOperands.size() >= 3) {
operandTypes.push_back(resultTypes[0]); // other
hasOther = 1;
}

if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
auto operand_segment_sizesAttrName =
LoadOp::operand_segment_sizesAttrName(result.name);
result.addAttribute(
operand_segment_sizesAttrName,
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
return success();
}

void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
printer << " ";
printer << loadOp.getOperation()->getOperands();
printer.printOptionalAttrDict(loadOp->getAttrs(), /*elidedAttrs=*/{});
// "operand_segment_sizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(loadOp->getAttrs(),
{loadOp.operand_segment_sizesAttrName()});
printer << " : ";
printer.printStrippedAttrOrType(loadOp.result().getType());
}
Expand Down Expand Up @@ -148,6 +161,9 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
state.addOperands(other);
}
}
state.addAttribute(
operand_segment_sizesAttrName(state.name),
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
state.addAttribute(
cacheAttrName(state.name),
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
Expand Down
19 changes: 18 additions & 1 deletion test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
// RUN: triton-opt %s -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s

func @ops() {
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
%c = arith.constant dense<3.00e+00> : tensor<128x128xf32>
%0 = tt.dot %a, %b, %c {allowTF32 = true} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32>
return
}

func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if LoadOp is lowered properly (see #771)
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%mask = arith.constant dense<true> : tensor<128xi1>
%other = arith.constant dense<0.0e+0> : tensor<128xf32>
// CHECK: %{{.*}} = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
%a = tt.load %ptrs {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
// CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
%b = tt.load %ptrs, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
// CHECK: %{{.*}} = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : {{.*}}
%c = tt.load %ptrs, %mask, %other {cache = 1 : i32, evict = 1 : i32, isVolatile = true} : tensor<128xf32>
tt.store %ptrs, %a : tensor<128xf32>
tt.store %ptrs, %b : tensor<128xf32>
tt.store %ptrs, %c : tensor<128xf32>
return
}