Skip to content

Commit d733f2c

Browse files
committed
[OpenMPIRBuilder] Support opaque pointers in reduction handling
Make the reduction handling in OpenMPIRBuilder compatible with opaque pointers by explicitly storing the element type in ReductionInfo, and also passing it to the atomic reduction callback, as at least the ones in the test need the type there. This doesn't make things fully compatible yet, there are other uses of element types in this class. I also left one getPointerElementType() call in mlir, because I'm not familiar with that area. Differential Revison: https://reviews.llvm.org/D115638
1 parent 26f6fbe commit d733f2c

File tree

4 files changed

+48
-46
lines changed

4 files changed

+48
-46
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -539,24 +539,27 @@ class OpenMPIRBuilder {
539539
function_ref<InsertPointTy(InsertPointTy, Value *, Value *, Value *&)>;
540540

541541
/// Functions used to generate atomic reductions. Such functions take two
542-
/// Values representing pointers to LHS and RHS of the reduction. They are
543-
/// expected to atomically update the LHS to the reduced value.
542+
/// Values representing pointers to LHS and RHS of the reduction, as well as
543+
/// the element type of these pointers. They are expected to atomically
544+
/// update the LHS to the reduced value.
544545
using AtomicReductionGenTy =
545-
function_ref<InsertPointTy(InsertPointTy, Value *, Value *)>;
546+
function_ref<InsertPointTy(InsertPointTy, Type *, Value *, Value *)>;
546547

547548
/// Information about an OpenMP reduction.
548549
struct ReductionInfo {
549-
ReductionInfo(Value *Variable, Value *PrivateVariable,
550+
ReductionInfo(Type *ElementType, Value *Variable, Value *PrivateVariable,
550551
ReductionGenTy ReductionGen,
551552
AtomicReductionGenTy AtomicReductionGen)
552-
: Variable(Variable), PrivateVariable(PrivateVariable),
553-
ReductionGen(ReductionGen), AtomicReductionGen(AtomicReductionGen) {}
554-
555-
/// Returns the type of the element being reduced.
556-
Type *getElementType() const {
557-
return Variable->getType()->getPointerElementType();
553+
: ElementType(ElementType), Variable(Variable),
554+
PrivateVariable(PrivateVariable), ReductionGen(ReductionGen),
555+
AtomicReductionGen(AtomicReductionGen) {
556+
assert(cast<PointerType>(Variable->getType())
557+
->isOpaqueOrPointeeTypeMatches(ElementType) && "Invalid elem type");
558558
}
559559

560+
/// Reduction element type, must match pointee type of variable.
561+
Type *ElementType;
562+
560563
/// Reduction variable of pointer type.
561564
Value *Variable;
562565

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
11561156
Builder.SetInsertPoint(NonAtomicRedBlock);
11571157
for (auto En : enumerate(ReductionInfos)) {
11581158
const ReductionInfo &RI = En.value();
1159-
Type *ValueType = RI.getElementType();
1159+
Type *ValueType = RI.ElementType;
11601160
Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
11611161
"red.value." + Twine(En.index()));
11621162
Value *PrivateRedValue =
@@ -1181,8 +1181,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
11811181
Builder.SetInsertPoint(AtomicRedBlock);
11821182
if (CanGenerateAtomic) {
11831183
for (const ReductionInfo &RI : ReductionInfos) {
1184-
Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.Variable,
1185-
RI.PrivateVariable));
1184+
Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
1185+
RI.Variable, RI.PrivateVariable));
11861186
if (!Builder.GetInsertBlock())
11871187
return InsertPointTy();
11881188
}
@@ -1207,13 +1207,13 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
12071207
RedArrayTy, LHSArrayPtr, 0, En.index());
12081208
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
12091209
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
1210-
Value *LHS = Builder.CreateLoad(RI.getElementType(), LHSPtr);
1210+
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
12111211
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
12121212
RedArrayTy, RHSArrayPtr, 0, En.index());
12131213
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
12141214
Value *RHSPtr =
12151215
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
1216-
Value *RHS = Builder.CreateLoad(RI.getElementType(), RHSPtr);
1216+
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
12171217
Value *Reduced;
12181218
Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
12191219
if (!Builder.GetInsertBlock())

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,10 +3028,10 @@ sumReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS,
30283028
}
30293029

30303030
static OpenMPIRBuilder::InsertPointTy
3031-
sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) {
3031+
sumAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS,
3032+
Value *RHS) {
30323033
IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
3033-
Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(),
3034-
RHS, "red.partial");
3034+
Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial");
30353035
Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, LHS, Partial, None,
30363036
AtomicOrdering::Monotonic);
30373037
return Builder.saveIP();
@@ -3046,10 +3046,10 @@ xorReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS,
30463046
}
30473047

30483048
static OpenMPIRBuilder::InsertPointTy
3049-
xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Value *LHS, Value *RHS) {
3049+
xorAtomicReduction(OpenMPIRBuilder::InsertPointTy IP, Type *Ty, Value *LHS,
3050+
Value *RHS) {
30503051
IRBuilder<> Builder(IP.getBlock(), IP.getPoint());
3051-
Value *Partial = Builder.CreateLoad(RHS->getType()->getPointerElementType(),
3052-
RHS, "red.partial");
3052+
Value *Partial = Builder.CreateLoad(Ty, RHS, "red.partial");
30533053
Builder.CreateAtomicRMW(AtomicRMWInst::Xor, LHS, Partial, None,
30543054
AtomicOrdering::Monotonic);
30553055
return Builder.saveIP();
@@ -3081,13 +3081,15 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
30813081
// Create variables to be reduced.
30823082
InsertPointTy OuterAllocaIP(&F->getEntryBlock(),
30833083
F->getEntryBlock().getFirstInsertionPt());
3084+
Type *SumType = Builder.getFloatTy();
3085+
Type *XorType = Builder.getInt32Ty();
30843086
Value *SumReduced;
30853087
Value *XorReduced;
30863088
{
30873089
IRBuilderBase::InsertPointGuard Guard(Builder);
30883090
Builder.restoreIP(OuterAllocaIP);
3089-
SumReduced = Builder.CreateAlloca(Builder.getFloatTy());
3090-
XorReduced = Builder.CreateAlloca(Builder.getInt32Ty());
3091+
SumReduced = Builder.CreateAlloca(SumType);
3092+
XorReduced = Builder.CreateAlloca(XorType);
30913093
}
30923094

30933095
// Store initial values of reductions into global variables.
@@ -3109,12 +3111,8 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
31093111
Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
31103112
Value *SumLocal =
31113113
Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local");
3112-
Value *SumPartial =
3113-
Builder.CreateLoad(SumReduced->getType()->getPointerElementType(),
3114-
SumReduced, "sum.partial");
3115-
Value *XorPartial =
3116-
Builder.CreateLoad(XorReduced->getType()->getPointerElementType(),
3117-
XorReduced, "xor.partial");
3114+
Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial");
3115+
Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial");
31183116
Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum");
31193117
Value *Xor = Builder.CreateXor(XorPartial, TID, "xor");
31203118
Builder.CreateStore(Sum, SumReduced);
@@ -3164,8 +3162,8 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
31643162
Builder.restoreIP(AfterIP);
31653163

31663164
OpenMPIRBuilder::ReductionInfo ReductionInfos[] = {
3167-
{SumReduced, SumPrivatized, sumReduction, sumAtomicReduction},
3168-
{XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}};
3165+
{SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction},
3166+
{XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}};
31693167

31703168
OMPBuilder.createReductions(BodyIP, BodyAllocaIP, ReductionInfos);
31713169

@@ -3319,13 +3317,15 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
33193317
// Create variables to be reduced.
33203318
InsertPointTy OuterAllocaIP(&F->getEntryBlock(),
33213319
F->getEntryBlock().getFirstInsertionPt());
3320+
Type *SumType = Builder.getFloatTy();
3321+
Type *XorType = Builder.getInt32Ty();
33223322
Value *SumReduced;
33233323
Value *XorReduced;
33243324
{
33253325
IRBuilderBase::InsertPointGuard Guard(Builder);
33263326
Builder.restoreIP(OuterAllocaIP);
3327-
SumReduced = Builder.CreateAlloca(Builder.getFloatTy());
3328-
XorReduced = Builder.CreateAlloca(Builder.getInt32Ty());
3327+
SumReduced = Builder.CreateAlloca(SumType);
3328+
XorReduced = Builder.CreateAlloca(XorType);
33293329
}
33303330

33313331
// Store initial values of reductions into global variables.
@@ -3344,9 +3344,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
33443344
Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
33453345
Value *SumLocal =
33463346
Builder.CreateUIToFP(TID, Builder.getFloatTy(), "sum.local");
3347-
Value *SumPartial =
3348-
Builder.CreateLoad(SumReduced->getType()->getPointerElementType(),
3349-
SumReduced, "sum.partial");
3347+
Value *SumPartial = Builder.CreateLoad(SumType, SumReduced, "sum.partial");
33503348
Value *Sum = Builder.CreateFAdd(SumPartial, SumLocal, "sum");
33513349
Builder.CreateStore(Sum, SumReduced);
33523350

@@ -3364,9 +3362,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
33643362
Constant *SrcLocStr = OMPBuilder.getOrCreateSrcLocStr(Loc);
33653363
Value *Ident = OMPBuilder.getOrCreateIdent(SrcLocStr);
33663364
Value *TID = OMPBuilder.getOrCreateThreadID(Ident);
3367-
Value *XorPartial =
3368-
Builder.CreateLoad(XorReduced->getType()->getPointerElementType(),
3369-
XorReduced, "xor.partial");
3365+
Value *XorPartial = Builder.CreateLoad(XorType, XorReduced, "xor.partial");
33703366
Value *Xor = Builder.CreateXor(XorPartial, TID, "xor");
33713367
Builder.CreateStore(Xor, XorReduced);
33723368

@@ -3421,10 +3417,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
34213417

34223418
OMPBuilder.createReductions(
34233419
FirstBodyIP, FirstBodyAllocaIP,
3424-
{{SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}});
3420+
{{SumType, SumReduced, SumPrivatized, sumReduction, sumAtomicReduction}});
34253421
OMPBuilder.createReductions(
34263422
SecondBodyIP, SecondBodyAllocaIP,
3427-
{{XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}});
3423+
{{XorType, XorReduced, XorPrivatized, xorReduction, xorAtomicReduction}});
34283424

34293425
Builder.restoreIP(AfterIP);
34303426
Builder.CreateRetVoid();

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,8 @@ using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
415415
llvm::Value *&)>;
416416
using OwningAtomicReductionGen =
417417
std::function<llvm::OpenMPIRBuilder::InsertPointTy(
418-
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *)>;
418+
llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
419+
llvm::Value *)>;
419420
} // namespace
420421

421422
/// Create an OpenMPIRBuilder-compatible reduction generator for the given
@@ -462,7 +463,7 @@ makeAtomicReductionGen(omp::ReductionDeclareOp decl,
462463
// (which aren't actually mutating it), and we must capture decl by-value to
463464
// avoid the dangling reference after the parent function returns.
464465
OwningAtomicReductionGen atomicGen =
465-
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
466+
[&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
466467
llvm::Value *lhs, llvm::Value *rhs) mutable {
467468
Region &atomicRegion = decl.atomicReductionRegion();
468469
moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
@@ -763,9 +764,11 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
763764
llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
764765
if (owningAtomicReductionGens[i])
765766
atomicGen = owningAtomicReductionGens[i];
766-
reductionInfos.push_back(
767-
{moduleTranslation.lookupValue(loop.reduction_vars()[i]),
768-
privateReductionVariables[i], owningReductionGens[i], atomicGen});
767+
llvm::Value *variable =
768+
moduleTranslation.lookupValue(loop.reduction_vars()[i]);
769+
reductionInfos.push_back({variable->getType()->getPointerElementType(),
770+
variable, privateReductionVariables[i],
771+
owningReductionGens[i], atomicGen});
769772
}
770773

771774
// The call to createReductions below expects the block to have a

0 commit comments

Comments
 (0)