Skip to content

Commit f4d18c0

Browse files
authored
[MLIR][Transform][Tune] Introduce transform.tune.alternatives op (llvm#160724)
This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a `selected_region` argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with the op's interpreted semantics corresponding to exactly this. This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs _on_ transforms, going further by making it tunable which (sequences of) transforms are to be applied.
1 parent a33544b commit f4d18c0

File tree

7 files changed

+604
-17
lines changed

7 files changed

+604
-17
lines changed

mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
1010
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
1111

12+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1213
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1314
#include "mlir/IR/BuiltinAttributes.h"
1415
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111

1212
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1313
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Interfaces/ControlFlowInterfaces.td"
1415
include "mlir/Interfaces/SideEffectInterfaces.td"
1516
include "mlir/IR/BuiltinAttributes.td"
1617
include "mlir/IR/CommonAttrConstraints.td"
1718

19+
//===----------------------------------------------------------------------===//
20+
// KnobOp
21+
//===----------------------------------------------------------------------===//
22+
1823
def KnobOp : Op<Transform_Dialect, "tune.knob", [
1924
DeclareOpInterfaceMethods<TransformOpInterface>,
2025
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
5257
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
5358
}
5459

60+
//===----------------------------------------------------------------------===//
61+
// AlternativesOp
62+
//===----------------------------------------------------------------------===//
63+
64+
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
65+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
66+
["getEntrySuccessorOperands", "getSuccessorRegions",
67+
"getRegionInvocationBounds"]>,
68+
DeclareOpInterfaceMethods<TransformOpInterface>,
69+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
70+
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
71+
NoRegionArguments
72+
]> {
73+
let summary = "Represents a choice among its regions, i.e. sub-schedules";
74+
75+
let description = [{
76+
This op represents a choice over which of its regions is to be used.
77+
78+
When `selected_region` is provided, the semantics are that this op is to be
79+
substituted for by the selected region, meaning the region's results become
80+
the results of this op. Without a provided `selected_region`, the semantics
81+
are that this non-deterministic choice is yet to be resolved -- which in
82+
terms of the op's interpreted semantics is a failure.
83+
84+
The `selected_region` argument is either an `IntegerAttr` or a param holding
85+
an `IntegerAttr`, which should provide a valid zero-based index with respect
86+
to the number of alternatives, i.e. regions.
87+
}];
88+
let cppNamespace = [{ mlir::transform::tune }];
89+
90+
let arguments = (ins Builtin_StringAttr:$name,
91+
OptionalAttr<APIntAttr>:$selected_region_attr,
92+
Optional<TransformParamTypeInterface>:$selected_region_param);
93+
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
94+
let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
95+
96+
let assemblyFormat = [{
97+
`<` $name `>`
98+
(`selected_region` `=` custom<AlternativesOpSelectedRegion>(
99+
$selected_region_attr, $selected_region_param)^)?
100+
attr-dict-with-keyword
101+
(`:` type($selected_region_param)^)?
102+
(`->` type($results)^)?
103+
regions
104+
}];
105+
106+
let hasVerifier = 1;
107+
}
108+
55109
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS

mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,24 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
910
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
11+
#include "mlir/IR/OpImplementation.h"
1012
#include "llvm/Support/Debug.h"
1113

1214
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
1315

1416
using namespace mlir;
1517

18+
static ParseResult parseAlternativesOpSelectedRegion(
19+
OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
20+
std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
21+
22+
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
23+
Operation *op,
24+
IntegerAttr selectedRegionAttr,
25+
Value selectedRegionParam);
26+
1627
#define GET_OP_CLASSES
1728
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
1829

@@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() {
5768

5869
return success();
5970
}
71+
72+
//===----------------------------------------------------------------------===//
73+
// AlternativesOp
74+
//===----------------------------------------------------------------------===//
75+
76+
static ParseResult parseAlternativesOpSelectedRegion(
77+
OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
78+
std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
79+
size_t selectedRegionIdx;
80+
OptionalParseResult attrParseRes =
81+
parser.parseOptionalInteger(selectedRegionIdx);
82+
if (attrParseRes.has_value()) {
83+
if (failed(*attrParseRes))
84+
return failure();
85+
86+
selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
87+
return success();
88+
}
89+
90+
OpAsmParser::UnresolvedOperand param;
91+
auto paramParseRes = parser.parseOptionalOperand(param);
92+
if (paramParseRes.has_value()) {
93+
if (failed(*paramParseRes))
94+
return failure();
95+
96+
selectedRegionParam = param;
97+
return success();
98+
}
99+
100+
return parser.emitError(parser.getCurrentLocation())
101+
<< "expected either an integer attribute or a transform.param operand";
102+
}
103+
104+
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
105+
Operation *op,
106+
IntegerAttr selectedRegionAttr,
107+
Value selectedRegionParam) {
108+
if (selectedRegionAttr)
109+
printer << selectedRegionAttr.getValue();
110+
if (selectedRegionParam)
111+
printer << selectedRegionParam;
112+
}
113+
114+
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
115+
RegionBranchPoint point) {
116+
// No operands will be forwarded to the region(s).
117+
return getOperands().slice(0, 0);
118+
}
119+
120+
void transform::tune::AlternativesOp::getSuccessorRegions(
121+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
122+
if (point.isParent())
123+
if (auto selectedRegionIdx = getSelectedRegionAttr())
124+
regions.emplace_back(
125+
&getAlternatives()[selectedRegionIdx->getSExtValue()],
126+
Block::BlockArgListType());
127+
else
128+
for (Region &alternative : getAlternatives())
129+
regions.emplace_back(&alternative, Block::BlockArgListType());
130+
else
131+
regions.emplace_back(getOperation()->getResults());
132+
}
133+
134+
void transform::tune::AlternativesOp::getRegionInvocationBounds(
135+
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
136+
(void)operands;
137+
bounds.reserve(getNumRegions());
138+
139+
if (auto selectedRegionIdx = getSelectedRegionAttr()) {
140+
bounds.resize(getNumRegions(), InvocationBounds(0, 0));
141+
bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
142+
} else {
143+
bounds.resize(getNumRegions(), InvocationBounds(0, 1));
144+
}
145+
}
146+
147+
void transform::tune::AlternativesOp::getEffects(
148+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
149+
onlyReadsHandle(getSelectedRegionParamMutable(), effects);
150+
producesHandle(getOperation()->getOpResults(), effects);
151+
// TODO: should effects from regions be forwarded?
152+
}
153+
154+
DiagnosedSilenceableFailure
155+
transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
156+
transform::TransformResults &results,
157+
transform::TransformState &state) {
158+
std::optional<size_t> selectedRegionIdx;
159+
160+
if (auto selectedRegionAttr = getSelectedRegionAttr())
161+
selectedRegionIdx = selectedRegionAttr->getSExtValue();
162+
163+
if (Value selectedRegionParam = getSelectedRegionParam()) {
164+
ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
165+
IntegerAttr selectedRegionAttr;
166+
if (associatedAttrs.size() != 1 ||
167+
!(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
168+
return emitDefiniteFailure()
169+
<< "param should hold exactly one integer attribute, got: "
170+
<< associatedAttrs[0];
171+
selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
172+
}
173+
174+
if (!selectedRegionIdx)
175+
return emitDefiniteFailure() << "non-deterministic choice " << getName()
176+
<< " is only resolved through providing a "
177+
"`selected_region` attr/param";
178+
179+
if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
180+
return emitDefiniteFailure()
181+
<< "'selected_region' attribute/param specifies region at index "
182+
<< *selectedRegionIdx << " while op has only " << getNumRegions()
183+
<< " regions";
184+
185+
Region &selectedRegion = getRegion(*selectedRegionIdx);
186+
auto scope = state.make_region_scope(selectedRegion);
187+
Block &block = selectedRegion.front();
188+
// Apply the region's ops one by one.
189+
for (Operation &transform : block.without_terminator()) {
190+
DiagnosedSilenceableFailure result =
191+
state.applyTransform(cast<transform::TransformOpInterface>(transform));
192+
if (result.isDefiniteFailure())
193+
return result;
194+
195+
if (result.isSilenceableFailure()) {
196+
for (const auto &res : getResults())
197+
results.set(res, {});
198+
return result;
199+
}
200+
}
201+
// Forward the operation mapping for values yielded from the region to the
202+
// values produced by the alternatives op.
203+
transform::detail::forwardTerminatorOperands(&block, state, results);
204+
return DiagnosedSilenceableFailure::success();
205+
}
206+
207+
LogicalResult transform::tune::AlternativesOp::verify() {
208+
for (auto *region : getRegions()) {
209+
auto yieldTerminator =
210+
llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
211+
if (!yieldTerminator)
212+
return emitOpError() << "expected '"
213+
<< transform::YieldOp::getOperationName()
214+
<< "' as terminator";
215+
216+
if (yieldTerminator->getNumOperands() != getNumResults())
217+
return yieldTerminator.emitOpError()
218+
<< "expected terminator to have as many operands as the parent op "
219+
"has results";
220+
221+
for (auto [i, operandType, resultType] : llvm::zip_equal(
222+
llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
223+
yieldTerminator->getOperands().getType(), getResultTypes())) {
224+
if (operandType == resultType)
225+
continue;
226+
return yieldTerminator.emitOpError()
227+
<< "the type of the terminator operand #" << i
228+
<< " must match the type of the corresponding parent op result ("
229+
<< operandType << " vs " << resultType << ")";
230+
}
231+
}
232+
233+
if (auto selectedRegionAttr = getSelectedRegionAttr()) {
234+
size_t regionIdx = selectedRegionAttr->getSExtValue();
235+
if (regionIdx < 0 || regionIdx >= getNumRegions())
236+
return emitOpError()
237+
<< "'selected_region' attribute specifies region at index "
238+
<< regionIdx << " while op has only " << getNumRegions()
239+
<< " regions";
240+
}
241+
242+
return success();
243+
}

mlir/python/mlir/dialects/transform/tune.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from ...ir import (
88
Type,
9+
Value,
10+
Operation,
11+
OpView,
912
Attribute,
1013
ArrayAttr,
1114
StringAttr,
@@ -19,7 +22,10 @@
1922
from .._transform_tune_extension_ops_gen import _Dialect
2023

2124
try:
22-
from .._ods_common import _cext as _ods_cext
25+
from .._ods_common import (
26+
get_op_result_or_value as _get_op_result_or_value,
27+
_cext as _ods_cext,
28+
)
2329
except ImportError as e:
2430
raise RuntimeError("Error loading imports from extension module") from e
2531

@@ -36,7 +42,7 @@ def __init__(
3642
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
3743
],
3844
*,
39-
selected: Optional[Attribute] = None,
45+
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
4046
loc=None,
4147
ip=None,
4248
):
@@ -75,8 +81,62 @@ def knob(
7581
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
7682
],
7783
*,
78-
selected: Optional[Attribute] = None,
84+
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
7985
loc=None,
8086
ip=None,
8187
):
8288
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
89+
90+
91+
@_ods_cext.register_operation(_Dialect, replace=True)
92+
class AlternativesOp(AlternativesOp):
93+
def __init__(
94+
self,
95+
results: Sequence[Type],
96+
name: Union[StringAttr, str],
97+
num_alternatives: int,
98+
*,
99+
selected_region: Optional[
100+
Union[int, IntegerAttr, Value, Operation, OpView]
101+
] = None,
102+
loc=None,
103+
ip=None,
104+
):
105+
if isinstance(name, str):
106+
name = StringAttr.get(name)
107+
108+
selected_region_attr = selected_region_param = None
109+
if isinstance(selected_region, IntegerAttr):
110+
selected_region_attr = selected_region
111+
elif isinstance(selected_region, int):
112+
selected_region_attr = IntegerAttr.get(
113+
IntegerType.get_signless(32), selected_region
114+
)
115+
elif isinstance(selected_region, (Value, Operation, OpView)):
116+
selected_region_param = _get_op_result_or_value(selected_region)
117+
118+
super().__init__(
119+
results,
120+
name,
121+
num_alternatives,
122+
selected_region_attr=selected_region_attr,
123+
selected_region_param=selected_region_param,
124+
loc=loc,
125+
ip=ip,
126+
)
127+
for region in self.regions:
128+
region.blocks.append()
129+
130+
131+
def alternatives(
132+
results: Sequence[Type],
133+
name: Union[StringAttr, str],
134+
num_alternatives: int,
135+
*,
136+
selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
137+
loc=None,
138+
ip=None,
139+
):
140+
return AlternativesOp(
141+
results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
142+
)

0 commit comments

Comments
 (0)