Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 97e69c2

Browse files
River707tensorflower-gardener
authored andcommitted
Refactor the way that pass options are specified.
This change refactors pass options to be more similar to how statistics are modeled. More specifically, the options are specified directly on the pass instead of in a separate options class. (Note that the behavior and specification for pass pipelines remains the same.) This brings about several benefits: * The specification of options is much simpler * The round-trip format of a pass can be generated automatically * This gives a somewhat deeper integration with "configuring" a pass, which we could potentially expose to users in the future. PiperOrigin-RevId: 286953824
1 parent 0130826 commit 97e69c2

File tree

10 files changed

+369
-177
lines changed

10 files changed

+369
-177
lines changed

g3doc/WritingAPass.md

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ options ::= '{' (key ('=' value)?)+ '}'
421421
pass pipeline, e.g. `cse` or `canonicalize`.
422422
* `options`
423423
* Options are pass specific key value pairs that are handled as described
424-
in the instance specific pass options section.
424+
in the [instance specific pass options](#instance-specific-pass-options)
425+
section.
425426

426427
For example, the following pipeline:
427428

@@ -443,30 +444,47 @@ options in the format described above.
443444
### Instance Specific Pass Options
444445

445446
Options may be specified for a parametric pass. Individual options are defined
446-
using `llvm::cl::opt` flag definition rules. These options will then be parsed
447-
at pass construction time independently for each instance of the pass. The
448-
`PassRegistration` and `PassPipelineRegistration` templates take an additional
449-
optional template parameter that is the Option struct definition to be used for
450-
that pass. To use pass specific options, create a class that inherits from
451-
`mlir::PassOptions` and then add a new constructor that takes `const
452-
MyPassOptions&` and constructs the pass. When using `PassPipelineRegistration`,
453-
the constructor now takes a function with the signature `void (OpPassManager
454-
&pm, const MyPassOptions&)` which should construct the passes from the options
455-
and pass them to the pm. The user code will look like the following:
447+
using the [LLVM command line](https://llvm.org/docs/CommandLine.html) flag
448+
definition rules. These options will then be parsed at pass construction time
449+
independently for each instance of the pass. To provide options for passes, the
450+
`Option<>` and `OptionList<>` classes may be used:
456451

457452
```c++
458-
class MyPass ... {
459-
public:
460-
MyPass(const MyPassOptions& options) ...
453+
struct MyPass ... {
454+
/// Make sure that we have a valid default constructor and copy constructor to
455+
/// make sure that the options are initialized properly.
456+
MyPass() = default;
457+
MyPass(const MyPass& pass) {}
458+
459+
// These just forward onto llvm::cl::list and llvm::cl::opt respectively.
460+
Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
461+
ListOption<int> exampleListOption{*this, "list-flag-name",
462+
llvm::cl::desc("...")};
461463
};
464+
```
462465
463-
struct MyPassOptions : public PassOptions<MyPassOptions> {
466+
For pass pipelines, the `PassPipelineRegistration` templates take an additional
467+
optional template parameter that is the Option struct definition to be used for
468+
that pipeline. To use pipeline specific options, create a class that inherits
469+
from `mlir::PassPipelineOptions` that contains the desired options. When using
470+
`PassPipelineRegistration`, the constructor now takes a function with the
471+
signature `void (OpPassManager &pm, const MyPipelineOptions&)` which should
472+
construct the passes from the options and pass them to the pm:
473+
474+
```c++
475+
struct MyPipelineOptions : public PassPipelineOptions {
464476
// These just forward onto llvm::cl::list and llvm::cl::opt respectively.
465477
Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
466-
List<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
478+
ListOption<int> exampleListOption{*this, "list-flag-name",
479+
llvm::cl::desc("...")};
467480
};
468481
469-
static PassRegistration<MyPass, MyPassOptions> pass("my-pass", "description");
482+
483+
static mlir::PassPipelineRegistration<MyPipelineOptions> pipeline(
484+
"example-pipeline", "Run an example pipeline.",
485+
[](OpPassManager &pm, const MyPipelineOptions &pipelineOptions) {
486+
// Initialize the pass manager.
487+
});
470488
```
471489

472490
## Pass Statistics

include/mlir/Pass/Pass.h

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,40 @@ class Pass {
6161
/// this is a generic OperationPass.
6262
Optional<StringRef> getOpName() const { return opName; }
6363

64+
//===--------------------------------------------------------------------===//
65+
// Options
66+
//===--------------------------------------------------------------------===//
67+
68+
/// This class represents a specific pass option, with a provided data type.
69+
template <typename DataType>
70+
struct Option : public detail::PassOptions::Option<DataType> {
71+
template <typename... Args>
72+
Option(Pass &parent, StringRef arg, Args &&... args)
73+
: detail::PassOptions::Option<DataType>(parent.passOptions, arg,
74+
std::forward<Args>(args)...) {}
75+
using detail::PassOptions::Option<DataType>::operator=;
76+
};
77+
/// This class represents a specific pass option that contains a list of
78+
/// values of the provided data type.
79+
template <typename DataType>
80+
struct ListOption : public detail::PassOptions::ListOption<DataType> {
81+
template <typename... Args>
82+
ListOption(Pass &parent, StringRef arg, Args &&... args)
83+
: detail::PassOptions::ListOption<DataType>(
84+
parent.passOptions, arg, std::forward<Args>(args)...) {}
85+
using detail::PassOptions::ListOption<DataType>::operator=;
86+
};
87+
88+
/// Attempt to initialize the options of this pass from the given string.
89+
LogicalResult initializeOptions(StringRef options);
90+
6491
/// Prints out the pass in the textual representation of pipelines. If this is
6592
/// an adaptor pass, print with the op_name(sub_pass,...) format.
66-
/// Note: The default implementation uses the class name and does not respect
67-
/// options used to construct the pass. Override this method to allow for your
68-
/// pass to be to be round-trippable to the textual format.
69-
virtual void printAsTextualPipeline(raw_ostream &os);
93+
void printAsTextualPipeline(raw_ostream &os);
94+
95+
//===--------------------------------------------------------------------===//
96+
// Statistics
97+
//===--------------------------------------------------------------------===//
7098

7199
/// This class represents a single pass statistic. This statistic functions
72100
/// similarly to an unsigned integer value, and may be updated and incremented
@@ -119,6 +147,10 @@ class Pass {
119147
return getPassState().analysisManager;
120148
}
121149

150+
/// Copy the option values from 'other', which is another instance of this
151+
/// pass.
152+
void copyOptionValuesFrom(const Pass *other);
153+
122154
private:
123155
/// Forwarding function to execute this pass on the given operation.
124156
LLVM_NODISCARD
@@ -141,6 +173,9 @@ class Pass {
141173
/// The set of statistics held by this pass.
142174
std::vector<Statistic *> statistics;
143175

176+
/// The pass options registered to this pass instance.
177+
detail::PassOptions passOptions;
178+
144179
/// Allow access to 'clone' and 'run'.
145180
friend class OpPassManager;
146181
};
@@ -204,7 +239,9 @@ class PassModel : public BasePassT {
204239

205240
/// A clone method to create a copy of this pass.
206241
std::unique_ptr<Pass> clone() const override {
207-
return std::make_unique<PassT>(*static_cast<const PassT *>(this));
242+
auto newInst = std::make_unique<PassT>(*static_cast<const PassT *>(this));
243+
newInst->copyOptionValuesFrom(this);
244+
return newInst;
208245
}
209246

210247
/// Returns the analysis for the parent operation if it exists.

include/mlir/Pass/PassOptions.h

Lines changed: 169 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,50 +24,202 @@
2424

2525
namespace mlir {
2626
namespace detail {
27-
/// Base class for PassOptions<T> that holds all of the non-CRTP features.
28-
class PassOptionsBase : protected llvm::cl::SubCommand {
27+
/// Base container class and manager for all pass options.
28+
class PassOptions : protected llvm::cl::SubCommand {
29+
private:
30+
/// This is the type-erased option base class. This provides some additional
31+
/// hooks into the options that are not available via llvm::cl::Option.
32+
class OptionBase {
33+
public:
34+
virtual ~OptionBase() = default;
35+
36+
/// Out of line virtual function to provide home for the class.
37+
virtual void anchor();
38+
39+
/// Print the name and value of this option to the given stream.
40+
virtual void print(raw_ostream &os) = 0;
41+
42+
/// Return the argument string of this option.
43+
StringRef getArgStr() const { return getOption()->ArgStr; }
44+
45+
protected:
46+
/// Return the main option instance.
47+
virtual const llvm::cl::Option *getOption() const = 0;
48+
49+
/// Copy the value from the given option into this one.
50+
virtual void copyValueFrom(const OptionBase &other) = 0;
51+
52+
/// Allow access to private methods.
53+
friend PassOptions;
54+
};
55+
56+
/// This is the parser that is used by pass options that use literal options.
57+
/// This is a thin wrapper around the llvm::cl::parser, that exposes some
58+
/// additional methods.
59+
template <typename DataType>
60+
struct GenericOptionParser : public llvm::cl::parser<DataType> {
61+
using llvm::cl::parser<DataType>::parser;
62+
63+
/// Returns an argument name that maps to the specified value.
64+
Optional<StringRef> findArgStrForValue(const DataType &value) {
65+
for (auto &it : this->Values)
66+
if (it.V.compare(value))
67+
return it.Name;
68+
return llvm::None;
69+
}
70+
};
71+
72+
/// The specific parser to use depending on llvm::cl parser used. This is only
73+
/// necessary because we need to provide additional methods for certain data
74+
/// type parsers.
75+
/// TODO(riverriddle) We should upstream the methods in GenericOptionParser to
76+
/// avoid the need to do this.
77+
template <typename DataType>
78+
using OptionParser =
79+
std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
80+
llvm::cl::parser<DataType>>::value,
81+
GenericOptionParser<DataType>,
82+
llvm::cl::parser<DataType>>;
83+
84+
/// Utility methods for printing option values.
85+
template <typename DataT>
86+
static void printOptionValue(raw_ostream &os,
87+
GenericOptionParser<DataT> &parser,
88+
const DataT &value) {
89+
if (Optional<StringRef> argStr = parser.findArgStrForValue(value))
90+
os << argStr;
91+
else
92+
llvm_unreachable("unknown data value for option");
93+
}
94+
template <typename DataT, typename ParserT>
95+
static void printOptionValue(raw_ostream &os, ParserT &parser,
96+
const DataT &value) {
97+
os << value;
98+
}
99+
template <typename ParserT>
100+
static void printOptionValue(raw_ostream &os, ParserT &parser,
101+
const bool &value) {
102+
os << (value ? StringRef("true") : StringRef("false"));
103+
}
104+
29105
public:
30106
/// This class represents a specific pass option, with a provided data type.
31-
template <typename DataType> struct Option : public llvm::cl::opt<DataType> {
107+
template <typename DataType>
108+
class Option : public llvm::cl::opt<DataType, /*ExternalStorage=*/false,
109+
OptionParser<DataType>>,
110+
public OptionBase {
111+
public:
32112
template <typename... Args>
33-
Option(PassOptionsBase &parent, StringRef arg, Args &&... args)
34-
: llvm::cl::opt<DataType>(arg, llvm::cl::sub(parent),
35-
std::forward<Args>(args)...) {
113+
Option(PassOptions &parent, StringRef arg, Args &&... args)
114+
: llvm::cl::opt<DataType, /*ExternalStorage=*/false,
115+
OptionParser<DataType>>(arg, llvm::cl::sub(parent),
116+
std::forward<Args>(args)...) {
36117
assert(!this->isPositional() && !this->isSink() &&
37118
"sink and positional options are not supported");
119+
parent.options.push_back(this);
120+
}
121+
using llvm::cl::opt<DataType, /*ExternalStorage=*/false,
122+
OptionParser<DataType>>::operator=;
123+
~Option() override = default;
124+
125+
private:
126+
/// Return the main option instance.
127+
const llvm::cl::Option *getOption() const final { return this; }
128+
129+
/// Print the name and value of this option to the given stream.
130+
void print(raw_ostream &os) final {
131+
os << this->ArgStr << '=';
132+
printOptionValue(os, this->getParser(), this->getValue());
133+
}
134+
135+
/// Copy the value from the given option into this one.
136+
void copyValueFrom(const OptionBase &other) final {
137+
this->setValue(static_cast<const Option<DataType> &>(other).getValue());
38138
}
39139
};
40140

41141
/// This class represents a specific pass option that contains a list of
42142
/// values of the provided data type.
43-
template <typename DataType> struct List : public llvm::cl::list<DataType> {
143+
template <typename DataType>
144+
class ListOption : public llvm::cl::list<DataType, /*StorageClass=*/bool,
145+
OptionParser<DataType>>,
146+
public OptionBase {
147+
public:
44148
template <typename... Args>
45-
List(PassOptionsBase &parent, StringRef arg, Args &&... args)
46-
: llvm::cl::list<DataType>(arg, llvm::cl::sub(parent),
47-
std::forward<Args>(args)...) {
149+
ListOption(PassOptions &parent, StringRef arg, Args &&... args)
150+
: llvm::cl::list<DataType, /*StorageClass=*/bool,
151+
OptionParser<DataType>>(arg, llvm::cl::sub(parent),
152+
std::forward<Args>(args)...) {
48153
assert(!this->isPositional() && !this->isSink() &&
49154
"sink and positional options are not supported");
155+
parent.options.push_back(this);
156+
}
157+
~ListOption() override = default;
158+
159+
/// Allow assigning from an ArrayRef.
160+
ListOption<DataType> &operator=(ArrayRef<DataType> values) {
161+
(*this)->assign(values.begin(), values.end());
162+
return *this;
163+
}
164+
165+
std::vector<DataType> *operator->() { return &*this; }
166+
167+
private:
168+
/// Return the main option instance.
169+
const llvm::cl::Option *getOption() const final { return this; }
170+
171+
/// Print the name and value of this option to the given stream.
172+
void print(raw_ostream &os) final {
173+
os << this->ArgStr << '=';
174+
auto printElementFn = [&](const DataType &value) {
175+
printOptionValue(os, this->getParser(), value);
176+
};
177+
interleave(*this, os, printElementFn, ",");
178+
}
179+
180+
/// Copy the value from the given option into this one.
181+
void copyValueFrom(const OptionBase &other) final {
182+
(*this) = ArrayRef<DataType>((ListOption<DataType> &)other);
50183
}
51184
};
52185

186+
PassOptions() = default;
187+
188+
/// Copy the option values from 'other' into 'this', where 'other' has the
189+
/// same options as 'this'.
190+
void copyOptionValuesFrom(const PassOptions &other);
191+
53192
/// Parse options out as key=value pairs that can then be handed off to the
54193
/// `llvm::cl` command line passing infrastructure. Everything is space
55194
/// separated.
56195
LogicalResult parseFromString(StringRef options);
196+
197+
/// Print the options held by this struct in a form that can be parsed via
198+
/// 'parseFromString'.
199+
void print(raw_ostream &os);
200+
201+
private:
202+
/// A list of all of the opaque options.
203+
std::vector<OptionBase *> options;
57204
};
58205
} // end namespace detail
59206

60-
/// Subclasses of PassOptions provide a set of options that can be used to
61-
/// initialize a pass instance. See PassRegistration for usage details.
207+
//===----------------------------------------------------------------------===//
208+
// PassPipelineOptions
209+
//===----------------------------------------------------------------------===//
210+
211+
/// Subclasses of PassPipelineOptions provide a set of options that can be used
212+
/// to initialize a pass pipeline. See PassPipelineRegistration for usage
213+
/// details.
62214
///
63215
/// Usage:
64216
///
65-
/// struct MyPassOptions : PassOptions<MyPassOptions> {
66-
/// List<int> someListFlag{
217+
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
218+
/// ListOption<int> someListFlag{
67219
/// *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
68220
/// llvm::cl::desc("...")};
69221
/// };
70-
template <typename T> class PassOptions : public detail::PassOptionsBase {
222+
template <typename T> class PassPipelineOptions : public detail::PassOptions {
71223
public:
72224
/// Factory that parses the provided options and returns a unique_ptr to the
73225
/// struct.
@@ -81,7 +233,8 @@ template <typename T> class PassOptions : public detail::PassOptionsBase {
81233

82234
/// A default empty option struct to be used for passes that do not need to take
83235
/// any options.
84-
struct EmptyPassOptions : public PassOptions<EmptyPassOptions> {};
236+
struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
237+
};
85238

86239
} // end namespace mlir
87240

0 commit comments

Comments
 (0)