@@ -33,6 +33,14 @@ class TransformOpInterface;
3333// / expected to populate the `TransformResults` class instance in order to
3434// / update the mapping. The `applyTransform` method takes care of propagating
3535// / the state of `TransformResults` into the instance of this class.
36+ // /
37+ // / When applying transform IR operations with regions, the client is expected
38+ // / to create a RegionScope RAII object to create a new "stack frame" for
39+ // / values defined inside the region. The mappings from and to these values will
40+ // / be automatically dropped when the object goes out of scope, typically at the
41+ // / end of the "apply" function of the parent operation. If a region contains
42+ // / blocks with arguments, the client can map those arguments to payload IR ops
43+ // / using "mapBlockArguments".
3644class TransformState {
3745 // / Mapping between a Value in the transform IR and the corresponding set of
3846 // / operations in the payload IR.
@@ -42,9 +50,19 @@ class TransformState {
4250 // / currently associated with.
4351 using TransformOpReverseMapping = DenseMap<Operation *, Value>;
4452
53+ // / Bidirectional mappings between transform IR values and payload IR
54+ // / operations.
55+ struct Mappings {
56+ TransformOpMapping direct;
57+ TransformOpReverseMapping reverse;
58+ };
59+
4560public:
46- // / Creates a state for the transformation rooted at the given op.
47- explicit TransformState (Operation *root);
61+ // / Creates a state for transform ops living in the given region. The parent
62+ // / operation of the region. The second argument points to the root operation
63+ // / in the payload IR beind transformed, which may or may not contain the
64+ // / region with transform ops.
65+ TransformState (Region ®ion, Operation *root);
4866
4967 // / Returns the op at which the transformation state is rooted. This is
5068 // / typically helpful for transformations that apply globally.
@@ -58,10 +76,96 @@ class TransformState {
5876 // / the state accordingly.
5977 LogicalResult applyTransform (TransformOpInterface transform);
6078
79+ // / Records the mapping between a block argument in the transform IR and a
80+ // / list of operations in the payload IR. The arguments must be defined in
81+ // / blocks of the currently processed transform IR region, typically after a
82+ // / region scope is defined.
83+ LogicalResult mapBlockArguments (BlockArgument argument,
84+ ArrayRef<Operation *> operations) {
85+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
86+ assert (argument.getParentRegion () == regionStack.back () &&
87+ " mapping block arguments from a region other than the active one" );
88+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
89+ return setPayloadOps (argument, operations);
90+ }
91+
92+ // Forward declarations to support limited visibility.
93+ class RegionScope ;
94+
95+ // / Creates a new region scope for the given region. The region is expected to
96+ // / be nested in the currently processed region.
97+ // Implementation note: this method is inline but implemented outside of the
98+ // class body to comply with visibility and full-declaration requirements.
99+ inline RegionScope make_region_scope (Region ®ion);
100+
101+ // / A RAII object maintaining a "stack frame" for a transform IR region. When
102+ // / applying a transform IR operation that contains a region, the caller is
103+ // / expected to create a RegionScope before applying the ops contained in the
104+ // / region. This ensures that the mappings between values defined in the
105+ // / transform IR region and payload IR operations are cleared when the region
106+ // / processing ends; such values cannot be accessed outside the region.
107+ class RegionScope {
108+ public:
109+ // / Forgets the mapping from or to values defined in the associated
110+ // / transform IR region.
111+ ~RegionScope () {
112+ state.mappings .erase (region);
113+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
114+ state.regionStack .pop_back ();
115+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
116+ }
117+
118+ private:
119+ // / Creates a new scope for mappings between values defined in the given
120+ // / transform IR region and payload IR operations.
121+ RegionScope (TransformState &state, Region ®ion)
122+ : state(state), region(®ion) {
123+ auto res = state.mappings .try_emplace (this ->region );
124+ assert (res.second && " the region scope is already present" );
125+ (void )res;
126+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
127+ assert (state.regionStack .back ()->isProperAncestor (®ion) &&
128+ " scope started at a non-nested region" );
129+ state.regionStack .push_back (®ion);
130+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
131+ }
132+
133+ // / Back-reference to the transform state.
134+ TransformState &state;
135+
136+ // / The region this scope is associated with.
137+ Region *region;
138+
139+ friend RegionScope TransformState::make_region_scope (Region &);
140+ };
141+ friend class RegionScope ;
142+
61143private:
62144 // / Identifier for storing top-level value in the `operations` mapping.
63145 static constexpr Value kTopLevelValue = Value();
64146
147+ // / Returns the mappings frame for the reigon in which the value is defined.
148+ const Mappings &getMapping (Value value) const {
149+ return const_cast <TransformState *>(this )->getMapping (value);
150+ }
151+ Mappings &getMapping (Value value) {
152+ auto it = mappings.find (value.getParentRegion ());
153+ assert (it != mappings.end () &&
154+ " trying to find a mapping for a value from an unmapped region" );
155+ return it->second ;
156+ }
157+
158+ // / Returns the mappings frame for the region in which the operation resides.
159+ const Mappings &getMapping (Operation *operation) const {
160+ return const_cast <TransformState *>(this )->getMapping (operation);
161+ }
162+ Mappings &getMapping (Operation *operation) {
163+ auto it = mappings.find (operation->getParentRegion ());
164+ assert (it != mappings.end () &&
165+ " trying to find a mapping for an operation from an unmapped region" );
166+ return it->second ;
167+ }
168+
65169 // / Sets the payload IR ops associated with the given transform IR value.
66170 // / Fails if this would result in multiple transform IR values with uses
67171 // / corresponding to the same payload IR ops. For example, a hypothetical
@@ -88,9 +192,19 @@ class TransformState {
88192 void updatePayloadOps (Value value,
89193 function_ref<Operation *(Operation *)> callback);
90194
91- // / The mapping between payload IR values and transform IR ops.
92- TransformOpMapping operationMapping;
93- TransformOpReverseMapping reverseMapping;
195+ // / The mappings between transform IR values and payload IR ops, aggregated by
196+ // / the region in which the transform IR values are defined.
197+ llvm::SmallDenseMap<Region *, Mappings> mappings;
198+
199+ // / The top-level operation that contains all payload IR, typically a module.
200+ Operation *topLevel;
201+
202+ #if LLVM_ENABLE_ABI_BREAKING_CHECKS
203+ // / A stack of nested regions that are being processed in the transform IR.
204+ // / Each region must be an ancestor of the following regions in this list.
205+ // / These are also the keys for "mappings".
206+ SmallVector<Region *> regionStack;
207+ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
94208};
95209
96210// / Local mapping between values defined by a specific op implementing the
@@ -123,6 +237,10 @@ class TransformResults {
123237 SmallVector<Operation *> operations;
124238};
125239
240+ TransformState::RegionScope TransformState::make_region_scope (Region ®ion) {
241+ return RegionScope (*this , region);
242+ }
243+
126244} // namespace transform
127245} // namespace mlir
128246
0 commit comments