Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ jobs:
- name: Check cpp style
if: ${{ matrix.runner != 'macos-latest' }}
run: |
sudo apt-get install -y clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -not -path "./third-party/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -not -path "./third-party/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)

- name: Flake8
if: ${{ matrix.runner != 'macos-latest' }}
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "deps/dlfcn-win32"]
path = deps/dlfcn-win32
url = https://github.com/dlfcn-win32/dlfcn-win32.git
[submodule "third-party/pybind11"]
path = third-party/pybind11
url = https://github.com/pybind/pybind11.git
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)

# Third-party
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pybind11/include)

if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
mlir::Value falseValue = selectOp.getFalseValue();

auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
if (!loadOp)
return mlir::failure();

Expand All @@ -81,7 +81,7 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {

auto *broadcastOpCandidate = mask.getDefiningOp();
auto broadcastOp =
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
llvm::dyn_cast_or_null<triton::BroadcastOp>(broadcastOpCandidate);
if (!broadcastOp)
return mlir::failure();

Expand All @@ -106,7 +106,8 @@ struct CanonicalizeMaskedLoadPattern
if (!mask)
return mlir::failure();

auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();

Expand Down Expand Up @@ -152,7 +153,8 @@ struct CanonicalizeMaskedStorePattern
if (!mask)
return mlir::failure();

auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();

Expand Down
14 changes: 11 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,17 @@ void LoopPipeliner::emitPrologue() {
}

// If this is a load/async_copy, we need to update the mask
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
: newOp->getOperand(3);
if (Value mask = [&]() {
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
return loadOp.mask();
} else if (auto insertSliceAsyncOp =
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
newOp)) {
return insertSliceAsyncOp.mask();
} else {
return mlir::Value();
}
}()) {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
// TODO: move this out of the loop
Expand Down
15 changes: 15 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]


def check_submodule():
submodule_paths = ["third-party/pybind11/include/pybind11"]
if not all([os.path.exists(p) for p in submodule_paths]):
print("initializing submodules ...")
try:
cwd = os.path.abspath(os.path.dirname(__file__))
subprocess.check_call(["git", "submodule", "update", "--init", "--recursive"], cwd=cwd)
print("submodule initialization succeeded")
except Exception:
print("submodule initialization failed")
print(" Please run:\n\tgit submodule update --init --recursive")
exit(-1)


def get_llvm():
# download if nothing is installed
system = platform.system()
Expand Down Expand Up @@ -81,6 +95,7 @@ def run(self):
self.build_extension(ext)

def build_extension(self, ext):
check_submodule()
llvm_include_dir, llvm_library_dir = get_llvm()
# lit is used by the test suite
lit_dir = shutil.which('lit')
Expand Down
Loading