Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,32 @@ on:
pull_request:
branches:
- main
- triton-mlir

jobs:

Runner-Preparation:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-latest"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-latest"]'
fi

Integration-Tests:
needs: Runner-Preparation

runs-on: ${{ matrix.runner }}

runs-on: self-hosted
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}

steps:

Expand All @@ -19,26 +39,29 @@ jobs:

- name: Clear cache
run: |
rm -r ~/.triton/cache/
continue-on-error: true
rm -rf ~/.triton/cache/

- name: Check imports
if: ${{ matrix.runner != 'macos-latest' }}
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )

- name: Check python style
if: ${{ matrix.runner != 'macos-latest' }}
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )

- name: Check cpp style
if: ${{ matrix.runner != 'macos-latest' }}
run: |
sudo apt-get install -y clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)

- name: Flake8
if: ${{ matrix.runner != 'macos-latest' }}
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
Expand All @@ -59,6 +82,7 @@ jobs:
lit -v "$LIT_TEST_DIR"

- name: Run python tests
if: ${{ matrix.runner[0] == 'self-hosted' }}
run: |
cd python/tests
pytest
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ struct PTXBuilder {

Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);

llvm::SmallVector<Operand *> getAllArgs() const;
llvm::SmallVector<Operand *, 4> getAllArgs() const;

llvm::SmallVector<Value, 4> getAllMLIRArgs() const;

Expand Down
32 changes: 19 additions & 13 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";

let arguments = (ins I64Tensor:$from);
let arguments = (ins TT_I64Like:$from);

let results = (outs TT_PtrTensor:$result);
let results = (outs TT_PtrLike:$result);

let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}

def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";

let arguments = (ins TT_PtrTensor:$from);
let arguments = (ins TT_PtrLike:$from);

let results = (outs I64Tensor:$result);
let results = (outs TT_I64Like:$result);

let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}

def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Floating point casting for custom types";

let description = [{
Expand All @@ -54,9 +58,11 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
BF8 <-> F8, FP16, FP32
}];

let arguments = (ins TT_FloatTensor:$from);
let arguments = (ins TT_FloatLike:$from);

let results = (outs TT_FloatLike:$result);

let results = (outs TT_FloatTensor:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";

// TODO: We need a verifier here.
}
Expand Down Expand Up @@ -127,16 +133,16 @@ def TT_StoreOp : TT_Op<"store",
let hasCanonicalizer = 1;
}

def TT_GEPOp : TT_Op<"getelementptr",
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect, SameOperandsAndResultShape,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);

let results = (outs TT_PtrTensor:$result);
let results = (outs TT_PtrLike:$result);

let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
}
Expand Down Expand Up @@ -278,7 +284,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
return $old
}];

let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);

let results = (outs TT_Type:$result);
}
Expand Down Expand Up @@ -318,7 +324,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {

let arguments = (ins I32Attr:$start, I32Attr:$end);

let results = (outs TT_IntegerTensor:$result);
let results = (outs TT_IntTensor:$result);

let assemblyFormat = "attr-dict `:` type($result)";
}
Expand Down
34 changes: 26 additions & 8 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,36 @@ class TritonTypeDef<string name, string _mnemonic>
let mnemonic = _mnemonic;
}

// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;

def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

// IntegerType
// Boolean Type
// TT_Bool -> I1
def TT_BoolTensor : TensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;

// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_IntTensor : TensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;

// I32 Type
// TT_I32 -> I32
// TT_I32Tensor -> I32Tensor
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;

// I64 Type
// TT_I64 -> I64
// TT_I64Tensor -> I64Tensor
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;

// PointerType
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
// Pointer Type
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";

let description = [{
Expand All @@ -43,12 +61,12 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {

let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;

def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;

def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_Pointer, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;

#endif
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
let description = [{
dot(a, b, 0) + c => dot(a, b, c)

gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))

select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)
Expand Down
16 changes: 5 additions & 11 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType

def TT_BoolTensor : TensorOf<[I1]>;

def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;

Expand Down Expand Up @@ -48,8 +42,8 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let description = [{}];

let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntegerLike:$lhs,
TT_IntegerLike:$rhs);
TT_IntLike:$lhs,
TT_IntLike:$rhs);

let results = (outs TT_BoolLike:$result);
}
Expand All @@ -66,7 +60,7 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> {
let results = (outs TT_BoolLike:$result);
}

def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize,
MemoryEffects<[MemRead, MemWrite]>,
TypesMatchWith<"infer mask type from src type",
Expand Down Expand Up @@ -94,7 +88,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.

In the future, we may decompose this operation into a sequence of:

* `async` operation to specify a sequence of asynchronous operations
Expand Down Expand Up @@ -191,7 +185,7 @@ def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
}];

let assemblyFormat = [{attr-dict `:` type($result)}];
let assemblyFormat = [{attr-dict `:` type($result)}];

let results = (outs TT_Tensor:$result);

Expand Down
9 changes: 8 additions & 1 deletion include/triton/driver/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ class dispatch {
initializer();
if (cache == nullptr) {
cache = dlsym(lib_h, name);
if (cache == 0)
if (cache == 0) {
#ifdef __EXCEPTIONS
throw std::runtime_error("dlsym unable to load function");
#else
std::cerr << "Triton: dlsym unable to load function `" << name << "`"
<< std::endl;
std::abort();
#endif
}
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
}
}
// Addition
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ llvm::SmallVector<Value, 4> PTXBuilder::getAllMLIRArgs() const {
return res;
}

SmallVector<PTXBuilder::Operand *> PTXBuilder::getAllArgs() const {
SmallVector<PTXBuilder::Operand *, 4> PTXBuilder::getAllArgs() const {
llvm::SmallVector<Operand *, 4> res;
for (auto &x : argArchive)
if (!x->isList())
Expand Down
11 changes: 6 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ struct StoreOpConversion

const int numVecs = numElems / vec;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;

const int maxWordWidth = std::max<int>(32, valueElemNbits);
Expand Down Expand Up @@ -1173,12 +1173,13 @@ struct GetProgramIdOpConversion
}
};

struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
struct AddPtrOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
Expand Down Expand Up @@ -1298,7 +1299,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,

patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<GEPOpConversion>(typeConverter, benefit);
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
patterns.add< // TODO: view should have custom pattern that views the layout
TritonGenericPattern<triton::ViewOp>,
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
typeConverter, context);
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ static Type getI1SameShape(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type,
tensorType.getEncoding());
return Type();
return i1Type;
}

static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type,
tensorType.getEncoding());
return Type();
return i32Type;
}

static Type getPointerTypeFromTensor(Type type) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
patterns.add<CombineDotAddFRevPattern>(context);
// %}
patterns.add<CombineSelectMaskedLoadPattern>(context);
patterns.add<CombineGEPPattern>(context);
patterns.add<CombineAddPtrPattern>(context);
patterns.add<CombineBroadcastConstantPattern>(context);

if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
Expand Down
Loading