Skip to content

Commit fcbf00f

Browse files
committed
[mlir][OpenMP] Added ReductionClauseInterface
This patch adds the ReductionClauseInterface and also adds reduction support for `omp.parallel` operation. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D122402
1 parent 1f52d02 commit fcbf00f

File tree

5 files changed

+214
-40
lines changed

5 files changed

+214
-40
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
205205
// Create and insert the operation.
206206
auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
207207
currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
208-
ValueRange(), ValueRange(),
208+
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
209+
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
209210
procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
210211
// Handle attribute based clauses.
211212
for (const auto &clause : parallelOpClauseList.v) {

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def OpenMP_PointerLikeType : Type<
6666
def ParallelOp : OpenMP_Op<"parallel", [
6767
AutomaticAllocationScope, AttrSizedOperandSegments,
6868
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
69-
RecursiveSideEffects]> {
69+
RecursiveSideEffects, ReductionClauseInterface]> {
7070
let summary = "parallel construct";
7171
let description = [{
7272
The parallel construct includes a region of code which is to be executed
@@ -83,6 +83,18 @@ def ParallelOp : OpenMP_Op<"parallel", [
8383
The $allocators_vars and $allocate_vars parameters are a variadic list of values
8484
that specify the memory allocator to be used to obtain storage for private values.
8585

86+
Reductions can be performed in a parallel construct by specifying reduction
87+
accumulator variables in `reduction_vars` and symbols referring to reduction
88+
declarations in the `reductions` attribute. Each reduction is identified
89+
by the accumulator it uses and accumulators must not be repeated in the same
90+
reduction. The `omp.reduction` operation accepts the accumulator and a
91+
partial value which is considered to be produced by the thread for the
92+
given reduction. If multiple values are produced for the same accumulator,
93+
i.e. there are multiple `omp.reduction`s, the last value is taken. The
94+
reduction declaration specifies how to combine the values from each thread
95+
into the final value, which is available in the accumulator after all the
96+
threads complete.
97+
8698
The optional $proc_bind_val attribute controls the thread affinity for the execution
8799
of the parallel region.
88100
}];
@@ -91,6 +103,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
91103
Optional<AnyType>:$num_threads_var,
92104
Variadic<AnyType>:$allocate_vars,
93105
Variadic<AnyType>:$allocators_vars,
106+
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
107+
OptionalAttr<SymbolRefArrayAttr>:$reductions,
94108
OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
95109

96110
let regions = (region AnyRegion:$region);
@@ -99,7 +113,11 @@ def ParallelOp : OpenMP_Op<"parallel", [
99113
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
100114
];
101115
let assemblyFormat = [{
102-
oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
116+
oilist( `reduction` `(`
117+
custom<ReductionVarList>(
118+
$reduction_vars, type($reduction_vars), $reductions
119+
) `)`
120+
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
103121
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
104122
| `allocate` `(`
105123
custom<AllocateAndAllocator>(
@@ -110,6 +128,12 @@ def ParallelOp : OpenMP_Op<"parallel", [
110128
) $region attr-dict
111129
}];
112130
let hasVerifier = 1;
131+
let extraClassDeclaration = [{
132+
// TODO: remove this once emitAccessorPrefix is set to
133+
// kEmitAccessorPrefix_Prefixed for the dialect.
134+
/// Returns the reduction variables
135+
operand_range getReductionVars() { return reduction_vars(); }
136+
}];
113137
}
114138

115139
def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> {
@@ -156,7 +180,8 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">]> {
156180
let assemblyFormat = "$region attr-dict";
157181
}
158182

159-
def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
183+
def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments,
184+
ReductionClauseInterface]> {
160185
let summary = "sections construct";
161186
let description = [{
162187
The sections construct is a non-iterative worksharing construct that
@@ -207,6 +232,13 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
207232

208233
let hasVerifier = 1;
209234
let hasRegionVerifier = 1;
235+
236+
let extraClassDeclaration = [{
237+
// TODO: remove this once emitAccessorPrefix is set to
238+
// kEmitAccessorPrefix_Prefixed for the dialect.
239+
/// Returns the reduction variables
240+
operand_range getReductionVars() { return reduction_vars(); }
241+
}];
210242
}
211243

212244
//===----------------------------------------------------------------------===//
@@ -247,7 +279,7 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
247279

248280
def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
249281
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
250-
RecursiveSideEffects]> {
282+
RecursiveSideEffects, ReductionClauseInterface]> {
251283
let summary = "workshare loop construct";
252284
let description = [{
253285
The workshare loop construct specifies that the iterations of the loop(s)
@@ -338,6 +370,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
338370

339371
/// Returns the number of reduction variables.
340372
unsigned getNumReductionVars() { return reduction_vars().size(); }
373+
374+
// TODO: remove this once emitAccessorPrefix is set to
375+
// kEmitAccessorPrefix_Prefixed for the dialect.
376+
/// Returns the reduction variables
377+
operand_range getReductionVars() { return reduction_vars(); }
341378
}];
342379
let hasCustomAssemblyFormat = 1;
343380
let assemblyFormat = [{

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,18 @@ def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
3131
];
3232
}
3333

34+
def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
35+
let description = [{
36+
OpenMP operations that support reduction clause have this interface.
37+
}];
38+
39+
let cppNamespace = "::mlir::omp";
40+
41+
let methods = [
42+
InterfaceMethod<
43+
"Get reduction vars", "::mlir::Operation::operand_range",
44+
"getReductionVars">,
45+
];
46+
}
47+
3448
#endif // OpenMP_OPS_INTERFACES

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
2929
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30+
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
3031
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
3132

3233
using namespace mlir;
@@ -58,19 +59,6 @@ void OpenMPDialect::initialize() {
5859
MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
5960
}
6061

61-
//===----------------------------------------------------------------------===//
62-
// ParallelOp
63-
//===----------------------------------------------------------------------===//
64-
65-
void ParallelOp::build(OpBuilder &builder, OperationState &state,
66-
ArrayRef<NamedAttribute> attributes) {
67-
ParallelOp::build(
68-
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69-
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
70-
/*proc_bind_val=*/nullptr);
71-
state.addAttributes(attributes);
72-
}
73-
7462
//===----------------------------------------------------------------------===//
7563
// Parser and printer for Allocate Clause
7664
//===----------------------------------------------------------------------===//
@@ -142,13 +130,6 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
142130
p << stringifyEnum(attr.getValue());
143131
}
144132

145-
LogicalResult ParallelOp::verify() {
146-
if (allocate_vars().size() != allocators_vars().size())
147-
return emitError(
148-
"expected equal sizes for allocate and allocator variables");
149-
return success();
150-
}
151-
152133
//===----------------------------------------------------------------------===//
153134
// Parser and printer for Linear Clause
154135
//===----------------------------------------------------------------------===//
@@ -469,6 +450,27 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
469450
return success();
470451
}
471452

453+
//===----------------------------------------------------------------------===//
454+
// ParallelOp
455+
//===----------------------------------------------------------------------===//
456+
457+
void ParallelOp::build(OpBuilder &builder, OperationState &state,
458+
ArrayRef<NamedAttribute> attributes) {
459+
ParallelOp::build(
460+
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
461+
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
462+
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
463+
/*proc_bind_val=*/nullptr);
464+
state.addAttributes(attributes);
465+
}
466+
467+
LogicalResult ParallelOp::verify() {
468+
if (allocate_vars().size() != allocators_vars().size())
469+
return emitError(
470+
"expected equal sizes for allocate and allocator variables");
471+
return verifyReductionVarList(*this, reductions(), reduction_vars());
472+
}
473+
472474
//===----------------------------------------------------------------------===//
473475
// Verifier for SectionsOp
474476
//===----------------------------------------------------------------------===//
@@ -709,13 +711,17 @@ LogicalResult ReductionDeclareOp::verifyRegions() {
709711
}
710712

711713
LogicalResult ReductionOp::verify() {
712-
// TODO: generalize this to an op interface when there is more than one op
713-
// that supports reductions.
714-
auto container = (*this)->getParentOfType<WsLoopOp>();
715-
for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
716-
if (container.reduction_vars()[i] == accumulator())
717-
return success();
718-
714+
auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
715+
if (!op)
716+
return emitOpError() << "must be used within an operation supporting "
717+
"reduction clause interface";
718+
while (op) {
719+
for (const auto &var :
720+
cast<ReductionClauseInterface>(op).getReductionVars())
721+
if (var == accumulator())
722+
return success();
723+
op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
724+
}
719725
return emitOpError() << "the accumulator is not used by the parent";
720726
}
721727

0 commit comments

Comments
 (0)