Skip to content

Commit 3915171

Browse files
committed
[mlir][OpenMP] Added assemblyFormat for ParallelOp
This patch adds assemblyFormat for omp.parallel operation. Some existing functions have been altered to fit the custom directive in assemblyFormat. This has led to their callsites to get modified too, but those will be removed in later patches, when other operations get their assemblyFormat. All operations were not changed in one patch for ease of review. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D120157
1 parent 357b18e commit 3915171

File tree

4 files changed

+75
-86
lines changed

4 files changed

+75
-86
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,17 @@ def ParallelOp : OpenMP_Op<"parallel", [
9797
let builders = [
9898
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
9999
];
100-
let hasCustomAssemblyFormat = 1;
100+
let assemblyFormat = [{
101+
oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
102+
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
103+
| `allocate` `(`
104+
custom<AllocateAndAllocator>(
105+
$allocate_vars, type($allocate_vars),
106+
$allocators_vars, type($allocators_vars)
107+
) `)`
108+
| `proc_bind` `(` custom<ProcBindKind>($proc_bind_val) `)`
109+
) $region attr-dict
110+
}];
101111
let hasVerifier = 1;
102112
}
103113

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 49 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -89,35 +89,53 @@ static ParseResult parseAllocateAndAllocator(
8989
SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
9090
SmallVectorImpl<Type> &typesAllocator) {
9191

92-
return parser.parseCommaSeparatedList(
93-
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
94-
OpAsmParser::OperandType operand;
95-
Type type;
96-
if (parser.parseOperand(operand) || parser.parseColonType(type))
97-
return failure();
98-
operandsAllocator.push_back(operand);
99-
typesAllocator.push_back(type);
100-
if (parser.parseArrow())
101-
return failure();
102-
if (parser.parseOperand(operand) || parser.parseColonType(type))
103-
return failure();
92+
return parser.parseCommaSeparatedList([&]() -> ParseResult {
93+
OpAsmParser::OperandType operand;
94+
Type type;
95+
if (parser.parseOperand(operand) || parser.parseColonType(type))
96+
return failure();
97+
operandsAllocator.push_back(operand);
98+
typesAllocator.push_back(type);
99+
if (parser.parseArrow())
100+
return failure();
101+
if (parser.parseOperand(operand) || parser.parseColonType(type))
102+
return failure();
104103

105-
operandsAllocate.push_back(operand);
106-
typesAllocate.push_back(type);
107-
return success();
108-
});
104+
operandsAllocate.push_back(operand);
105+
typesAllocate.push_back(type);
106+
return success();
107+
});
109108
}
110109

111110
/// Print allocate clause
112-
static void printAllocateAndAllocator(OpAsmPrinter &p,
111+
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
113112
OperandRange varsAllocate,
114-
OperandRange varsAllocator) {
115-
p << "allocate(";
113+
TypeRange typesAllocate,
114+
OperandRange varsAllocator,
115+
TypeRange typesAllocator) {
116116
for (unsigned i = 0; i < varsAllocate.size(); ++i) {
117-
std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
118-
p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
119-
p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
117+
std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
118+
p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
119+
p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
120+
}
121+
}
122+
123+
ParseResult parseProcBindKind(OpAsmParser &parser,
124+
omp::ClauseProcBindKindAttr &procBindAttr) {
125+
StringRef procBindStr;
126+
if (parser.parseKeyword(&procBindStr))
127+
return failure();
128+
if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
129+
procBindAttr =
130+
ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
131+
return success();
120132
}
133+
return failure();
134+
}
135+
136+
void printProcBindKind(OpAsmPrinter &p, Operation *op,
137+
omp::ClauseProcBindKindAttr procBindAttr) {
138+
p << stringifyClauseProcBindKind(procBindAttr.getValue());
121139
}
122140

123141
LogicalResult ParallelOp::verify() {
@@ -127,24 +145,6 @@ LogicalResult ParallelOp::verify() {
127145
return success();
128146
}
129147

130-
void ParallelOp::print(OpAsmPrinter &p) {
131-
p << " ";
132-
if (auto ifCond = if_expr_var())
133-
p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
134-
135-
if (auto threads = num_threads_var())
136-
p << "num_threads(" << threads << " : " << threads.getType() << ") ";
137-
138-
if (!allocate_vars().empty())
139-
printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
140-
141-
if (auto bind = proc_bind_val())
142-
p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
143-
144-
p << ' ';
145-
p.printRegion(getRegion());
146-
}
147-
148148
//===----------------------------------------------------------------------===//
149149
// Parser and printer for Linear Clause
150150
//===----------------------------------------------------------------------===//
@@ -626,9 +626,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
626626
return failure();
627627
clauseSegments[pos[threadLimitClause]] = 1;
628628
} else if (clauseKeyword == "allocate") {
629-
if (checkAllowed(allocateClause) ||
629+
if (checkAllowed(allocateClause) || parser.parseLParen() ||
630630
parseAllocateAndAllocator(parser, allocates, allocateTypes,
631-
allocators, allocatorTypes))
631+
allocators, allocatorTypes) ||
632+
parser.parseRParen())
632633
return failure();
633634
clauseSegments[pos[allocateClause]] = allocates.size();
634635
clauseSegments[pos[allocateClause] + 1] = allocators.size();
@@ -803,32 +804,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
803804
return success();
804805
}
805806

806-
/// Parses a parallel operation.
807-
///
808-
/// operation ::= `omp.parallel` clause-list
809-
/// clause-list ::= clause | clause clause-list
810-
/// clause ::= if | num-threads | allocate | proc-bind
811-
///
812-
ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
813-
SmallVector<ClauseType> clauses = {ifClause, numThreadsClause, allocateClause,
814-
procBindClause};
815-
816-
SmallVector<int> segments;
817-
818-
if (failed(parseClauses(parser, result, clauses, segments)))
819-
return failure();
820-
821-
result.addAttribute("operand_segment_sizes",
822-
parser.getBuilder().getI32VectorAttr(segments));
823-
824-
Region *body = result.addRegion();
825-
SmallVector<OpAsmParser::OperandType> regionArgs;
826-
SmallVector<Type> regionArgTypes;
827-
if (parser.parseRegion(*body, regionArgs, regionArgTypes))
828-
return failure();
829-
return success();
830-
}
831-
832807
//===----------------------------------------------------------------------===//
833808
// Parser, printer and verifier for SectionsOp
834809
//===----------------------------------------------------------------------===//
@@ -863,8 +838,12 @@ void SectionsOp::print(OpAsmPrinter &p) {
863838
if (!reduction_vars().empty())
864839
printReductionVarList(p, reductions(), reduction_vars());
865840

866-
if (!allocate_vars().empty())
867-
printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
841+
if (!allocate_vars().empty()) {
842+
printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(),
843+
allocate_vars().getTypes(), allocators_vars(),
844+
allocators_vars().getTypes());
845+
p << ")";
846+
}
868847

869848
if (nowait())
870849
p << "nowait";

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
22

33
func @unknown_clause() {
4-
// expected-error@+1 {{invalid is not a valid clause}}
4+
// expected-error@+1 {{expected '{' to begin a region}}
55
omp.parallel invalid {
66
}
77

@@ -11,7 +11,7 @@ func @unknown_clause() {
1111
// -----
1212

1313
func @if_once(%n : i1) {
14-
// expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}}
14+
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
1515
omp.parallel if(%n : i1) if(%n : i1) {
1616
}
1717

@@ -21,7 +21,7 @@ func @if_once(%n : i1) {
2121
// -----
2222

2323
func @num_threads_once(%n : si32) {
24-
// expected-error@+1 {{at most one num_threads clause can appear on the omp.parallel operation}}
24+
// expected-error@+1 {{`num_threads` clause can appear at most once in the expansion of the oilist directive}}
2525
omp.parallel num_threads(%n : si32) num_threads(%n : si32) {
2626
}
2727

@@ -31,54 +31,54 @@ func @num_threads_once(%n : si32) {
3131
// -----
3232

3333
func @nowait_not_allowed(%n : memref<i32>) {
34-
// expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}}
34+
// expected-error@+1 {{expected '{' to begin a region}}
3535
omp.parallel nowait {}
3636
return
3737
}
3838

3939
// -----
4040

4141
func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
42-
// expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}}
42+
// expected-error@+1 {{expected '{' to begin a region}}
4343
omp.parallel linear(%data_var = %linear_var : memref<i32>) {}
4444
return
4545
}
4646

4747
// -----
4848

4949
func @schedule_not_allowed() {
50-
// expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}}
50+
// expected-error@+1 {{expected '{' to begin a region}}
5151
omp.parallel schedule(static) {}
5252
return
5353
}
5454

5555
// -----
5656

5757
func @collapse_not_allowed() {
58-
// expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}}
58+
// expected-error@+1 {{expected '{' to begin a region}}
5959
omp.parallel collapse(3) {}
6060
return
6161
}
6262

6363
// -----
6464

6565
func @order_not_allowed() {
66-
// expected-error@+1 {{order is not a valid clause for the omp.parallel operation}}
66+
// expected-error@+1 {{expected '{' to begin a region}}
6767
omp.parallel order(concurrent) {}
6868
return
6969
}
7070

7171
// -----
7272

7373
func @ordered_not_allowed() {
74-
// expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}}
74+
// expected-error@+1 {{expected '{' to begin a region}}
7575
omp.parallel ordered(2) {}
7676
}
7777

7878
// -----
7979

8080
func @proc_bind_once() {
81-
// expected-error@+1 {{at most one proc_bind clause can appear on the omp.parallel operation}}
81+
// expected-error@+1 {{`proc_bind` clause can appear at most once in the expansion of the oilist directive}}
8282
omp.parallel proc_bind(close) proc_bind(spread) {
8383
}
8484

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
5959
// CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6060
"omp.parallel"(%num_threads, %data_var, %data_var) ({
6161
omp.terminator
62-
}) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
62+
}) {num_threads, allocate, operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
6363

6464
// CHECK: omp.barrier
6565
omp.barrier
@@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
6868
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
6969
"omp.parallel"(%if_cond, %data_var, %data_var) ({
7070
omp.terminator
71-
}) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
71+
}) {if, allocate, operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
7272

7373
// test without allocate
7474
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
7575
"omp.parallel"(%if_cond, %num_threads) ({
7676
omp.terminator
77-
}) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
77+
}) {if, num_threads, operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
7878

7979
omp.terminator
80-
}) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
80+
}) {if, num_threads, allocate, operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
8181

8282
// test with multiple parameters for single variadic argument
8383
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
8484
"omp.parallel" (%data_var, %data_var) ({
8585
omp.terminator
86-
}) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
86+
}) {allocate, operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
8787

8888
return
8989
}

0 commit comments

Comments
 (0)