-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[AMD] AsyncCopyGlobalToLocal lowering to global.load.lds #5729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
ba17a9e
Basic lowering AsyncCommitGroup and AsyncWait
AlexAUT 2327587
WIP lowering of AsyncCopy
AlexAUT 1d4edf6
Added layout checks for asynccopy lowering
AlexAUT ead4915
Support direct to lds
AlexAUT 3141ba4
Enable non working masking
AlexAUT 644aa1e
Add support to enable disable direct to lds with env var AMDGCN_USE_D…
AlexAUT 7c9bab1
Fix masking and others for direct to lds
AlexAUT cb823d0
Fix when AsycCopy is lowered without a mask
AlexAUT c097616
Use ROCDL instead of intrinsics
AlexAUT 1a9f1e0
Cleanup and simplify AsyncCopy lowering
AlexAUT a20b686
CacheModifiers for AsyncCopy
AlexAUT 97d677d
Add lit test for AsyncCopy
AlexAUT 30352ad
Split AsyncCopy Lit for gfx950
AlexAUT fe8619d
Add const to getCtrlBitsForCacheModifierOnTarget
AlexAUT 7941a30
Cleanup StreamPipeliner changes
AlexAUT def9313
Revert stream pipeline related changes
AlexAUT 318caa2
Add missing CDNA1 to AsyncCopy support list
AlexAUT 6600138
Cleanup
AlexAUT ea02c3c
Replace macros for llvm ops with TritonLLVMOpBuilder
AlexAUT 13419bb
Fix wrong value in supported bit width for global.to.lds
AlexAUT ca8b441
Addressing review comments
AlexAUT 6aa3554
Unified async ops lit tests
AlexAUT 04fad93
Emit correct wmcnt wait instead of waiting on all cnts
AlexAUT f6cbe22
Add tests for AsyncWait/AsyncCommitGroup
AlexAUT 3d30f43
Limit AsyncWait conversion to gfx9
AlexAUT 0c382db
Add AsyncOpy lowering lit test with masking and other values
AlexAUT f560aeb
Added async copy lit tests with cache modifiers
AlexAUT d6b0d02
Merge branch 'main' into global_to_lds_lowering
AlexAUT d90ffbe
Adjust to shared encoding changes
AlexAUT 5356802
Fix a few small issues
antiagainst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-builtin-func-to-llvm | FileCheck %s | ||
|
|
||
| #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> | ||
| #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> | ||
| #smem = #ttg.shared_memory | ||
| module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { | ||
| // CHECK-LABEL: async_copy_vectorized | ||
| tt.func public @async_copy_vectorized(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, | ||
| %arg1: i32 {tt.divisibility = 16 : i32}, | ||
| %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { | ||
| // We need the index calculation so AxisAnalysis sees that we can vectorize the load | ||
| %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | ||
| %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> | ||
| %3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> | ||
| %4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked> | ||
| %5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked> | ||
|
|
||
| // Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds | ||
| // CHECK: rocdl.global.load.lds | ||
| // CHECK-NOT: rocdl.global.load.lds | ||
| %6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable> | ||
| tt.return | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.