Skip to content

Commit b69fd34

Browse files
[Offload] Add oneInterationPerThread param to loop device RTL (llvm#151959)
Currently, Flang can generate no-loop kernels for all OpenMP target kernels in the program if the flags -fopenmp-assume-teams-oversubscription or -fopenmp-assume-threads-oversubscription are set. If we add an additional parameter, we can choose in the future which OpenMP kernels should be generated as no-loop kernels. This PR doesn't modify current behavior of oversubscription flags. RFC for no-loop kernels: https://discourse.llvm.org/t/rfc-no-loop-mode-for-openmp-gpu-kernels/87517
1 parent 0977a6d commit b69fd34

File tree

6 files changed

+81
-44
lines changed

6 files changed

+81
-44
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -470,18 +470,18 @@ __OMP_RTL(__kmpc_target_deinit, false, Void,)
470470
__OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
471471
__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
472472
VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)
473-
__OMP_RTL(__kmpc_for_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32)
474-
__OMP_RTL(__kmpc_for_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32)
475-
__OMP_RTL(__kmpc_for_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64)
476-
__OMP_RTL(__kmpc_for_static_loop_8u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64)
477-
__OMP_RTL(__kmpc_distribute_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32)
478-
__OMP_RTL(__kmpc_distribute_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32)
479-
__OMP_RTL(__kmpc_distribute_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64)
480-
__OMP_RTL(__kmpc_distribute_static_loop_8u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64)
481-
__OMP_RTL(__kmpc_distribute_for_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int32)
482-
__OMP_RTL(__kmpc_distribute_for_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int32)
483-
__OMP_RTL(__kmpc_distribute_for_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int64)
484-
__OMP_RTL(__kmpc_distribute_for_static_loop_8u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int64)
473+
__OMP_RTL(__kmpc_for_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int8)
474+
__OMP_RTL(__kmpc_for_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int8)
475+
__OMP_RTL(__kmpc_for_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int8)
476+
__OMP_RTL(__kmpc_for_static_loop_8u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int8)
477+
__OMP_RTL(__kmpc_distribute_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int8)
478+
__OMP_RTL(__kmpc_distribute_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int8)
479+
__OMP_RTL(__kmpc_distribute_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int8)
480+
__OMP_RTL(__kmpc_distribute_static_loop_8u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int8)
481+
__OMP_RTL(__kmpc_distribute_for_static_loop_4, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int32, Int8)
482+
__OMP_RTL(__kmpc_distribute_for_static_loop_4u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int32, Int32, Int32, Int32, Int8)
483+
__OMP_RTL(__kmpc_distribute_for_static_loop_8, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int64, Int8)
484+
__OMP_RTL(__kmpc_distribute_for_static_loop_8u, false, Void, IdentPtr, VoidPtr, VoidPtr, Int64, Int64, Int64, Int64, Int8)
485485
__OMP_RTL(__kmpc_kernel_parallel, false, Int1, VoidPtrPtr)
486486
__OMP_RTL(__kmpc_kernel_end_parallel, false, Void, )
487487
__OMP_RTL(__kmpc_serialized_parallel, false, Void, IdentPtr, Int32)
@@ -674,22 +674,22 @@ __OMP_RTL_ATTRS(__kmpc_cancel_barrier, BarrierAttrs, SExt,
674674
ParamAttrs(ReadOnlyPtrAttrs, SExt))
675675
__OMP_RTL_ATTRS(__kmpc_distribute_for_static_loop_4, AlwaysInlineAttrs, AttributeSet(),
676676
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet(),
677-
SExt, SExt, SExt, SExt))
677+
SExt, SExt, SExt, SExt, ZExt))
678678
__OMP_RTL_ATTRS(__kmpc_distribute_for_static_loop_4u, AlwaysInlineAttrs, AttributeSet(),
679679
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet(),
680-
ZExt, ZExt, ZExt, ZExt))
680+
ZExt, ZExt, ZExt, ZExt, ZExt))
681681
__OMP_RTL_ATTRS(__kmpc_distribute_static_loop_4, AlwaysInlineAttrs, AttributeSet(),
682682
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet(),
683-
SExt, SExt))
683+
SExt, SExt, ZExt))
684684
__OMP_RTL_ATTRS(__kmpc_distribute_static_loop_4u, AlwaysInlineAttrs, AttributeSet(),
685685
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet(),
686-
ZExt, ZExt))
686+
ZExt, ZExt, ZExt))
687687
__OMP_RTL_ATTRS(__kmpc_for_static_loop_4, AlwaysInlineAttrs, AttributeSet(),
688688
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet(),
689-
SExt, SExt, SExt))
689+
SExt, SExt, SExt, ZExt))
690690
__OMP_RTL_ATTRS(__kmpc_for_static_loop_4u, AlwaysInlineAttrs, AttributeSet(),
691691
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), AttributeSet(),
692-
ZExt, ZExt, ZExt))
692+
ZExt, ZExt, ZExt, ZExt))
693693
__OMP_RTL_ATTRS(__kmpc_error, AttributeSet(), AttributeSet(),
694694
ParamAttrs(AttributeSet(), SExt))
695695
__OMP_RTL_ATTRS(__kmpc_flush, BarrierAttrs, AttributeSet(),

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4969,6 +4969,7 @@ static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
49694969
RealArgs.push_back(TripCount);
49704970
if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
49714971
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4972+
RealArgs.push_back(ConstantInt::get(Builder.getInt8Ty(), 0));
49724973
Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
49734974
Builder.CreateCall(RTLFn, RealArgs);
49744975
return;
@@ -4984,6 +4985,7 @@ static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
49844985
if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
49854986
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
49864987
}
4988+
RealArgs.push_back(ConstantInt::get(Builder.getInt8Ty(), 0));
49874989

49884990
Builder.CreateCall(RTLFn, RealArgs);
49894991
}

mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
3737
// CHECK-SAME: #[[ATTRS1:[0-9]+]]
3838
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB]] to ptr),
3939
// CHECK-SAME: ptr @[[LOOP_BODY_FUNC:.*]], ptr %[[LOO_BODY_FUNC_ARG:.*]], i32 10,
40-
// CHECK-SAME: i32 %[[THREAD_NUM:.*]], i32 0)
40+
// CHECK-SAME: i32 %[[THREAD_NUM:.*]], i8 0)
4141

4242
// CHECK: define internal void @[[LOOP_BODY_FUNC]](i32 %[[CNT:.*]], ptr %[[LOOP_BODY_ARG_PTR:.*]]) #[[ATTRS2:[0-9]+]] {
4343

mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
2525
// CHECK: define void @[[FUNC_COLLAPSED_WSLOOP:.*]](ptr %[[ARG0:.*]])
2626
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB2:[0-9]+]] to ptr),
2727
// CHECK-SAME: ptr @[[COLLAPSED_WSLOOP_BODY_FN:.*]], ptr %[[STRUCT_ARG:.*]], i32 10000,
28-
// CHECK-SAME: i32 %[[NUM_THREADS:.*]], i32 0)
28+
// CHECK-SAME: i32 %[[NUM_THREADS:.*]], i8 0)
2929

3030
// CHECK: define internal void @[[COLLAPSED_WSLOOP_BODY_FN]](i32 %[[LOOP_CNT:.*]], ptr %[[LOOP_BODY_ARG:.*]])
3131
// CHECK: %[[TMP0:.*]] = urem i32 %[[LOOP_CNT]], 100

mlir/test/Target/LLVMIR/omptarget-wsloop.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
3737
// CHECK: %[[GEP:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0
3838
// CHECK: store ptr %[[ARG0]], ptr addrspace(5) %[[GEP]], align 8
3939
// CHECK: %[[NUM_THREADS:.*]] = call i32 @omp_get_num_threads()
40-
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr), ptr @[[LOOP_BODY_FN:.*]], ptr %[[STRUCTARG_ASCAST]], i32 10, i32 %[[NUM_THREADS]], i32 0)
40+
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr), ptr @[[LOOP_BODY_FN:.*]], ptr %[[STRUCTARG_ASCAST]], i32 10, i32 %[[NUM_THREADS]], i32 0, i8 0)
4141

4242
// CHECK: define internal void @[[LOOP_BODY_FN]](i32 %[[LOOP_CNT:.*]], ptr %[[LOOP_BODY_ARG:.*]])
4343
// CHECK: %[[GEP2:.*]] = getelementptr { ptr }, ptr %[[LOOP_BODY_ARG]], i32 0, i32 0
@@ -46,6 +46,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
4646
// CHECK: store i32 %[[VAL0:.*]], ptr %[[GEP3]], align 4
4747

4848
// CHECK: define void @[[FUNC_EMPTY_WSLOOP:.*]]()
49-
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB2:[0-9]+]] to ptr), ptr @[[LOOP_EMPTY_BODY_FN:.*]], ptr null, i32 10, i32 %[[NUM_THREADS:.*]], i32 0)
49+
// CHECK: call void @__kmpc_for_static_loop_4u(ptr addrspacecast (ptr addrspace(1) @[[GLOB2:[0-9]+]] to ptr), ptr @[[LOOP_EMPTY_BODY_FN:.*]], ptr null, i32 10, i32 %[[NUM_THREADS:.*]], i32 0, i8 0)
5050

5151
// CHECK: define internal void @[[LOOP_EMPTY_BODY_FN]](i32 %[[LOOP_CNT:.*]])

offload/DeviceRTL/src/Workshare.cpp

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ template <typename Ty> class StaticLoopChunker {
698698
static void NormalizedLoopNestNoChunk(void (*LoopBody)(Ty, void *), void *Arg,
699699
Ty NumBlocks, Ty BId, Ty NumThreads,
700700
Ty TId, Ty NumIters,
701-
bool OneIterationPerThread) {
701+
uint8_t OneIterationPerThread) {
702702
Ty KernelIteration = NumBlocks * NumThreads;
703703

704704
// Start index in the normalized space.
@@ -729,7 +729,7 @@ template <typename Ty> class StaticLoopChunker {
729729
Ty BlockChunk, Ty NumBlocks, Ty BId,
730730
Ty ThreadChunk, Ty NumThreads, Ty TId,
731731
Ty NumIters,
732-
bool OneIterationPerThread) {
732+
uint8_t OneIterationPerThread) {
733733
Ty KernelIteration = NumBlocks * BlockChunk;
734734

735735
// Start index in the chunked space.
@@ -767,8 +767,18 @@ template <typename Ty> class StaticLoopChunker {
767767

768768
public:
769769
/// Worksharing `for`-loop.
770+
/// \param[in] Loc Description of source location
771+
/// \param[in] LoopBody Function which corresponds to loop body
772+
/// \param[in] Arg Pointer to struct which contains loop body args
773+
/// \param[in] NumIters Number of loop iterations
774+
/// \param[in] NumThreads Number of GPU threads
775+
/// \param[in] ThreadChunk Size of thread chunk
776+
/// \param[in] OneIterationPerThread If true/nonzero, each thread executes
777+
/// only one loop iteration or one thread chunk. This avoids an outer loop
778+
/// over all loop iterations/chunks.
770779
static void For(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
771-
Ty NumIters, Ty NumThreads, Ty ThreadChunk) {
780+
Ty NumIters, Ty NumThreads, Ty ThreadChunk,
781+
uint8_t OneIterationPerThread) {
772782
ASSERT(NumIters >= 0, "Bad iteration count");
773783
ASSERT(ThreadChunk >= 0, "Bad thread count");
774784

@@ -790,12 +800,13 @@ template <typename Ty> class StaticLoopChunker {
790800

791801
// If we know we have more threads than iterations we can indicate that to
792802
// avoid an outer loop.
793-
bool OneIterationPerThread = false;
794803
if (config::getAssumeThreadsOversubscription()) {
795-
ASSERT(NumThreads >= NumIters, "Broken assumption");
796804
OneIterationPerThread = true;
797805
}
798806

807+
if (OneIterationPerThread)
808+
ASSERT(NumThreads >= NumIters, "Broken assumption");
809+
799810
if (ThreadChunk != 1)
800811
NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
801812
ThreadChunk, NumThreads, TId, NumIters,
@@ -806,8 +817,17 @@ template <typename Ty> class StaticLoopChunker {
806817
}
807818

808819
/// Worksharing `distribute`-loop.
820+
/// \param[in] Loc Description of source location
821+
/// \param[in] LoopBody Function which corresponds to loop body
822+
/// \param[in] Arg Pointer to struct which contains loop body args
823+
/// \param[in] NumIters Number of loop iterations
824+
/// \param[in] BlockChunk Size of block chunk
825+
/// \param[in] OneIterationPerThread If true/nonzero, each thread executes
826+
/// only one loop iteration or one thread chunk. This avoids an outer loop
827+
/// over all loop iterations/chunks.
809828
static void Distribute(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
810-
Ty NumIters, Ty BlockChunk) {
829+
Ty NumIters, Ty BlockChunk,
830+
uint8_t OneIterationPerThread) {
811831
ASSERT(icv::Level == 0, "Bad distribute");
812832
ASSERT(icv::ActiveLevel == 0, "Bad distribute");
813833
ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
@@ -831,12 +851,13 @@ template <typename Ty> class StaticLoopChunker {
831851

832852
// If we know we have more blocks than iterations we can indicate that to
833853
// avoid an outer loop.
834-
bool OneIterationPerThread = false;
835854
if (config::getAssumeTeamsOversubscription()) {
836-
ASSERT(NumBlocks >= NumIters, "Broken assumption");
837855
OneIterationPerThread = true;
838856
}
839857

858+
if (OneIterationPerThread)
859+
ASSERT(NumBlocks >= NumIters, "Broken assumption");
860+
840861
if (BlockChunk != NumThreads)
841862
NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
842863
ThreadChunk, NumThreads, TId, NumIters,
@@ -852,9 +873,20 @@ template <typename Ty> class StaticLoopChunker {
852873
}
853874

854875
/// Worksharing `distribute parallel for`-loop.
876+
/// \param[in] Loc Description of source location
877+
/// \param[in] LoopBody Function which corresponds to loop body
878+
/// \param[in] Arg Pointer to struct which contains loop body args
879+
/// \param[in] NumIters Number of loop iterations
880+
/// \param[in] NumThreads Number of GPU threads
881+
/// \param[in] BlockChunk Size of block chunk
882+
/// \param[in] ThreadChunk Size of thread chunk
883+
/// \param[in] OneIterationPerThread If true/nonzero, each thread executes
884+
/// only one loop iteration or one thread chunk. This avoids an outer loop
885+
/// over all loop iterations/chunks.
855886
static void DistributeFor(IdentTy *Loc, void (*LoopBody)(Ty, void *),
856887
void *Arg, Ty NumIters, Ty NumThreads,
857-
Ty BlockChunk, Ty ThreadChunk) {
888+
Ty BlockChunk, Ty ThreadChunk,
889+
uint8_t OneIterationPerThread) {
858890
ASSERT(icv::Level == 1, "Bad distribute");
859891
ASSERT(icv::ActiveLevel == 1, "Bad distribute");
860892
ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
@@ -882,13 +914,14 @@ template <typename Ty> class StaticLoopChunker {
882914

883915
// If we know we have more threads (across all blocks) than iterations we
884916
// can indicate that to avoid an outer loop.
885-
bool OneIterationPerThread = false;
886917
if (config::getAssumeTeamsOversubscription() &
887918
config::getAssumeThreadsOversubscription()) {
888919
OneIterationPerThread = true;
889-
ASSERT(NumBlocks * NumThreads >= NumIters, "Broken assumption");
890920
}
891921

922+
if (OneIterationPerThread)
923+
ASSERT(NumBlocks * NumThreads >= NumIters, "Broken assumption");
924+
892925
if (BlockChunk != NumThreads || ThreadChunk != 1)
893926
NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
894927
ThreadChunk, NumThreads, TId, NumIters,
@@ -907,24 +940,26 @@ template <typename Ty> class StaticLoopChunker {
907940

908941
#define OMP_LOOP_ENTRY(BW, TY) \
909942
[[gnu::flatten, clang::always_inline]] void \
910-
__kmpc_distribute_for_static_loop##BW( \
911-
IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
912-
TY num_threads, TY block_chunk, TY thread_chunk) { \
943+
__kmpc_distribute_for_static_loop##BW( \
944+
IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
945+
TY num_threads, TY block_chunk, TY thread_chunk, \
946+
uint8_t one_iteration_per_thread) { \
913947
ompx::StaticLoopChunker<TY>::DistributeFor( \
914-
loc, fn, arg, num_iters, num_threads, block_chunk, thread_chunk); \
948+
loc, fn, arg, num_iters, num_threads, block_chunk, thread_chunk, \
949+
one_iteration_per_thread); \
915950
} \
916951
[[gnu::flatten, clang::always_inline]] void \
917-
__kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
918-
void *arg, TY num_iters, \
919-
TY block_chunk) { \
920-
ompx::StaticLoopChunker<TY>::Distribute(loc, fn, arg, num_iters, \
921-
block_chunk); \
952+
__kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
953+
void *arg, TY num_iters, TY block_chunk, \
954+
uint8_t one_iteration_per_thread) { \
955+
ompx::StaticLoopChunker<TY>::Distribute( \
956+
loc, fn, arg, num_iters, block_chunk, one_iteration_per_thread); \
922957
} \
923958
[[gnu::flatten, clang::always_inline]] void __kmpc_for_static_loop##BW( \
924959
IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
925-
TY num_threads, TY thread_chunk) { \
960+
TY num_threads, TY thread_chunk, uint8_t one_iteration_per_thread) { \
926961
ompx::StaticLoopChunker<TY>::For(loc, fn, arg, num_iters, num_threads, \
927-
thread_chunk); \
962+
thread_chunk, one_iteration_per_thread); \
928963
}
929964

930965
extern "C" {

0 commit comments

Comments
 (0)