1010
1111#include " mlir/IR/AffineMap.h"
1212#include " mlir/IR/Builders.h"
13+ #include " mlir/IR/BuiltinTypeInterfaces.h"
14+ #include " llvm/ADT/ArrayRef.h"
15+ #include " llvm/ADT/SmallVector.h"
16+ #include " llvm/Support/LogicalResult.h"
1317
1418#include < numeric>
1519#include < optional>
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
2832 return std::nullopt ;
2933}
3034
31- std::optional<SmallVector<ReassociationIndices>>
32- mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
33- ArrayRef<int64_t > targetShape) {
34- if (sourceShape.size () <= targetShape.size ())
35- return std::nullopt ;
36- unsigned sourceDim = 0 ;
37- SmallVector<ReassociationIndices> reassociationMap;
38- reassociationMap.reserve (targetShape.size ());
35+ namespace {
36+ // / A simple struct to represent ReassociationIndices as an inclusive interval.
37+ // / It's designed to be feasibly minimal, so the call sites should manage the
38+ // / validity of the range manually.
39+ struct ReassociationIndexRange {
40+ // / FIXME: Signed type is used for consistency with ReassociationIndices.
41+ // / We should consider refactoring all reassociation utilities to use unsigned
42+ // / types.
43+ int64_t leftIdx = 0 , rightIdx = 0 ;
44+
45+ // / Util for manual checks of the range's validity
46+ LogicalResult verify () const {
47+ return leftIdx >= 0 && (leftIdx <= rightIdx) ? success () : failure ();
48+ }
49+
50+ // / Checks range's containment within another range. Treats the edges
51+ // / non-exclusively.
52+ bool isInRange (const ReassociationIndexRange &outerRange) const {
53+ return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx ;
54+ }
55+
56+ unsigned size () const {
57+ assert (succeeded (verify ()));
58+ return rightIdx - leftIdx + 1 ;
59+ }
60+ bool containsSingleIndex () const { return size () == 1 ; }
61+
62+ // / Collects indices that do not overlap between this and another range.
63+ ReassociationIndices
64+ getNonOverlappingIndicesWith (ReassociationIndexRange &rhs) const {
65+ if (rightIdx < rhs.leftIdx ) {
66+ // The intervals do not overlap - concatenate the indices from both.
67+ auto jointFullIndices = getFullIndices ();
68+ jointFullIndices.append (rhs.getFullIndices ());
69+ return jointFullIndices;
70+ }
71+ ReassociationIndices result;
72+ // Handle the chunk left of the overlapping range.
73+ int64_t leftStart = std::min (leftIdx, rhs.leftIdx );
74+ int64_t leftEnd = std::max (leftIdx, rhs.leftIdx );
75+ llvm::append_range (result, llvm::seq (leftStart, leftEnd));
76+ // Handle the chunk right of the overlapping range. Symmetrically, we should
77+ // skip the edge of the overlap AND include the rightmost index.
78+ int64_t rightStart = std::min (rightIdx, rhs.rightIdx ) + 1 ;
79+ int64_t rightEnd = std::max (rightIdx, rhs.rightIdx );
80+ if (rightStart < rightEnd)
81+ llvm::append_range (result, llvm::seq_inclusive (rightStart, rightEnd));
82+ return result;
83+ }
84+
85+ // / Converts the range into ReassociationIndices.
86+ ReassociationIndices getFullIndices () const {
87+ ReassociationIndices result;
88+ for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89+ result.push_back (idx);
90+ }
91+ return result;
92+ }
93+ };
94+ } // namespace
95+
96+ // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
97+ // / sequence that can be collapsed into a dynamic dimension (at least one must
98+ // / be present in the source).
99+ // / By default, lazily returns once the first dynamic dimension has been found.
100+ // / Setting `matchGreedily` as `true` will also mark all subsequent
101+ // / source dimensions for collapsing into the target.
102+ static FailureOr<ReassociationIndexRange>
103+ findReassociationRangeForDynamicDim (ArrayRef<int64_t > sourceShape,
104+ int64_t sourceStartIdx,
105+ bool matchGreedily = false ) {
106+ const unsigned numSourceDims = sourceShape.size ();
107+ ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
108+ std::optional<ReassociationIndexRange> resultRange = std::nullopt ;
109+
110+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111+ for (; iterationRange.isInRange (sourceShapeAsRange);
112+ iterationRange.rightIdx ++) {
113+ int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
114+ if (sourceSize == ShapedType::kDynamic ) {
115+ resultRange = iterationRange;
116+ break ;
117+ }
118+ }
119+ if (!resultRange)
120+ return failure ();
121+ if (matchGreedily)
122+ resultRange->rightIdx = sourceShapeAsRange.rightIdx ;
123+ return *resultRange;
124+ }
39125
40- ReassociationIndices currIndices;
126+ // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
127+ // / sequence of static dimensions such that their product matches `targetSize`.
128+ // / By default, lazily returns once the product matches the target size. Setting
129+ // / `matchGreedily` as `true` will append all neighboring unit dimensions
130+ // / (dimensions of 1) to the match.
131+ static FailureOr<ReassociationIndexRange>
132+ findReassociationRangeForSize (ArrayRef<int64_t > sourceShape,
133+ int64_t sourceStartIdx, int64_t targetSize,
134+ bool matchGreedily = false ) {
135+ const unsigned numSourceDims = sourceShape.size ();
136+ ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
137+ std::optional<ReassociationIndexRange> resultRange = std::nullopt ;
138+
139+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
41140 int64_t prodOfCollapsedDims = 1 ;
42- while (sourceDim < sourceShape.size ()) {
43- unsigned targetDim = reassociationMap.size ();
44- // If we have mapped all the target dimensions stop and handle the remaining
45- // tail of size-1 dimensions explicitly.
46- if (targetDim == targetShape.size ())
141+ while (iterationRange.isInRange (sourceShapeAsRange)) {
142+ int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
143+ if (sourceSize == ShapedType::kDynamic ) {
144+ // Reassociation for a static dim cannot include a dynamic dim. Reset
145+ // induction variables to essentially restart the loop from the next
146+ // source dimension.
147+ prodOfCollapsedDims = 1 ;
148+ iterationRange = {iterationRange.rightIdx + 1 ,
149+ iterationRange.rightIdx + 1 };
150+ continue ;
151+ }
152+ prodOfCollapsedDims *= sourceSize;
153+ // If the target size has been exceeded without matching, we need to shift
154+ // the range start right. From the start of the range, roll back the
155+ // multiplication until the target size exceeds the product again.
156+ while (prodOfCollapsedDims > targetSize &&
157+ !iterationRange.containsSingleIndex ()) {
158+ int64_t frontSourceSize = sourceShape[iterationRange.leftIdx ];
159+ prodOfCollapsedDims /= frontSourceSize;
160+ // Shrink the range rightwards
161+ iterationRange.leftIdx ++;
162+ }
163+ // We could've reached the target size with the current dimension,
164+ // also as a result of the above shift to right.
165+ if (prodOfCollapsedDims == targetSize) {
166+ resultRange = iterationRange;
47167 break ;
168+ }
169+ // Increment the iteration range
170+ iterationRange.rightIdx ++;
171+ }
172+ if (!resultRange)
173+ return failure ();
174+ if (matchGreedily) {
175+ // We now want to collect all unit dimensions directly after the target
176+ // product match. Advance the iterator to avoid OOB when the product match
177+ // happens at the last element.
178+ iterationRange.rightIdx ++;
179+ while (iterationRange.isInRange (sourceShapeAsRange) &&
180+ sourceShape[iterationRange.rightIdx ] == 1 ) {
181+ resultRange = iterationRange;
182+ iterationRange.rightIdx ++;
183+ }
184+ }
185+ return *resultRange;
186+ }
48187
49- int64_t currTargetShape = targetShape[targetDim];
50- while (sourceDim < (sourceShape.size () - 1 ) &&
51- sourceShape[sourceDim] != ShapedType::kDynamic &&
52- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53- prodOfCollapsedDims *= sourceShape[sourceDim];
54- currIndices.push_back (sourceDim++);
188+ // / Attempts to find a valid collapsing reassociation of `sourceShape` into
189+ // / `targetShape` through a simple traversal. If successful, an array of source
190+ // / index ranges is returned, correspondingly to each dimension in the target
191+ // / shape. The resulting indices shall fully cover the `sourceShape` without
192+ // / overlaps.
193+ // /
194+ // / The algorithm is essentially a lazy one, searching for non-greedy matches -
195+ // / it will only yield a greedy match for the last target dimension.
196+ // / FIXME: The algorithm can only backtrack when it needs to append an offset
197+ // / for a static target dimension to the preceding dynamic one (this retains the
198+ // / linear complexity). As feasible, consider adding further backtracking
199+ // / routines to enable more reassociations, e.g.:
200+ // / - ?x2x?x2 into ?x2
201+ static FailureOr<SmallVector<ReassociationIndexRange>>
202+ findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
203+ ArrayRef<int64_t > targetShape) {
204+ unsigned numSourceDims = sourceShape.size (),
205+ numTargetDims = targetShape.size ();
206+ assert (numSourceDims > numTargetDims);
207+ ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
208+
209+ SmallVector<ReassociationIndexRange> reassocRanges;
210+ reassocRanges.reserve (numTargetDims);
211+ // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212+ // cases, e.g.:
213+ // - ?x2x3x5 into ?x15
214+ std::optional<int64_t > prevTargetSize = std::nullopt ;
215+ for (unsigned targetDimIdx = 0 , sourceDimIdx = 0 ;
216+ targetDimIdx < numTargetDims; ++targetDimIdx) {
217+ int64_t targetSize = targetShape[targetDimIdx];
218+ // Simply check if there are any subsequent target dimensions left - if not,
219+ // the match must be made greedily.
220+ bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1 ;
221+ FailureOr<ReassociationIndexRange> sourceRange;
222+ if (targetSize == ShapedType::kDynamic ) {
223+ sourceRange = findReassociationRangeForDynamicDim (
224+ sourceShape, sourceDimIdx, shouldMatchGreedily);
225+ } else {
226+ sourceRange = findReassociationRangeForSize (
227+ sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
55228 }
56229
57- // If the current expanded dimension is dynamic, then the collapsed
58- // dimensions should also be dynamic and product of all previous unprocessed
59- // dimensions of the expanded shape should be 1.
60- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1 ))
62- return std::nullopt ;
63-
64- // If the collapsed dim is dynamic, the current expanded dim should also
65- // be dynamic.
66- if (currTargetShape == ShapedType::kDynamic &&
67- sourceShape[sourceDim] != ShapedType::kDynamic )
68- return std::nullopt ;
69-
70- // For static shapes, if the product of dimensions of the expanded shape
71- // should match the collapsed dimension shape.
72- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73- return std::nullopt ;
74-
75- currIndices.push_back (sourceDim++);
76- reassociationMap.emplace_back (ReassociationIndices{});
77- std::swap (reassociationMap.back (), currIndices);
78- prodOfCollapsedDims = 1 ;
230+ // Run sanity checks on the returned index range.
231+ if (failed (sourceRange) || failed (sourceRange->verify ()) ||
232+ !sourceRange->isInRange (sourceShapeAsRange))
233+ return failure ();
234+ if (sourceRange->leftIdx > sourceDimIdx) {
235+ // If some source dimensions had to be skipped in order to find a match,
236+ // they must be collapsed into the directly preceding dynamic dimension.
237+ if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic )
238+ return failure ();
239+ reassocRanges.back ().rightIdx = sourceRange->leftIdx - 1 ;
240+ }
241+
242+ // Store the gathered information as required for the next iteration.
243+ prevTargetSize = targetSize;
244+ sourceDimIdx = sourceRange->rightIdx + 1 ;
245+ reassocRanges.push_back (*sourceRange);
79246 }
80- // All the dimensions in the target must have been processed.
81- if (reassociationMap.size () != targetShape.size ())
247+ // Fail if the source shape wasn't a full match for the target shape. We only
248+ // need to check the last recorded index - any other gaps should have been
249+ // mended by the main loop.
250+ if (reassocRanges.back ().rightIdx < sourceShapeAsRange.rightIdx )
251+ return failure ();
252+ return reassocRanges;
253+ }
254+
255+ // / A variant of `findReassociationRangesForCollapse(...)` that can also scan
256+ // / the shapes right-to-left.
257+ static FailureOr<SmallVector<ReassociationIndexRange>>
258+ findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
259+ ArrayRef<int64_t > targetShape,
260+ bool iterateRightToLeft) {
261+ if (!iterateRightToLeft)
262+ return findReassociationRangesForCollapse (sourceShape, targetShape);
263+ // NB: To iterate right-to-left, we currently reverse the shapes and then
264+ // reverse the result back. The reversed shapes must not be temporary, as
265+ // we're passing through an ArrayRef.
266+ // FIXME: It would be preferable to avoid the expensive copies. At the moment,
267+ // this approach is chosen for readability of the main implementation.
268+ std::vector<int64_t > sourceToReverse = sourceShape.vec (),
269+ targetToReverse = targetShape.vec ();
270+ std::reverse (sourceToReverse.begin (), sourceToReverse.end ());
271+ std::reverse (targetToReverse.begin (), targetToReverse.end ());
272+ auto invertedRanges =
273+ findReassociationRangesForCollapse (sourceToReverse, targetToReverse);
274+ if (failed (invertedRanges))
275+ return failure ();
276+ SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277+ unsigned numSourceDims = sourceShape.size ();
278+ // We have received the ranges for inverted shapes. Now we have to invert
279+ // the ranges back to correspond with the original source shape.
280+ for (auto &range : rangesToInvert) {
281+ int64_t invLeftIdx = range.leftIdx , invRightIdx = range.rightIdx ;
282+ range.leftIdx = numSourceDims - 1 - invRightIdx;
283+ range.rightIdx = numSourceDims - 1 - invLeftIdx;
284+ }
285+ // Also invert the ordering of the ranges to correspond with the original
286+ // target shape.
287+ std::reverse (rangesToInvert.begin (), rangesToInvert.end ());
288+ return rangesToInvert;
289+ }
290+
291+ std::optional<SmallVector<ReassociationIndices>>
292+ mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
293+ ArrayRef<int64_t > targetShape) {
294+ unsigned numSourceDims = sourceShape.size (),
295+ numTargetDims = targetShape.size ();
296+ // We're supposed to search for a collapsing reassociation. If the sizes
297+ // match, there's no actual collapsing taking place - it's either a no-op or a
298+ // `tensor.reshape`-style reassociation (that would be beyond the scope of
299+ // this utility).
300+ if (numSourceDims <= numTargetDims)
301+ return std::nullopt ;
302+ // Early handling for scalar target types.
303+ if (numTargetDims == 0 ) {
304+ ReassociationIndices allSourceIndices;
305+ allSourceIndices.reserve (numSourceDims);
306+ for (unsigned sourceDimIdx = 0 ; sourceDimIdx < numSourceDims;
307+ ++sourceDimIdx) {
308+ int64_t sourceSize = sourceShape[sourceDimIdx];
309+ // All source dimensions must be unit or dynamic.
310+ if (sourceSize != 1 && sourceSize != ShapedType::kDynamic )
311+ return std::nullopt ;
312+ allSourceIndices.push_back (sourceDimIdx);
313+ }
314+ return SmallVector<ReassociationIndices>{allSourceIndices};
315+ }
316+
317+ // Collect source ranges by iterating over the target shape left-to-right.
318+ FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
319+ findReassociationRangesForCollapse (sourceShape, targetShape);
320+ if (failed (maybeForwardRanges))
321+ return std::nullopt ;
322+ auto &ranges = *maybeForwardRanges;
323+ // Now do the same in reverse. We need to get another valid reassociation
324+ // through some other strategy, and then compare the results in order to
325+ // disambiguate mixed subshapes, such as:
326+ // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
327+ // This leads us to lose some of the reassociation opportunities that can only
328+ // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
329+ // backtracking, the algorithm will fail right-to-left. However, this is the
330+ // best way to preserve correctness.
331+ FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332+ findReassociationRangesForCollapse (sourceShape, targetShape,
333+ /* iterateRightToLeft=*/ true );
334+ if (failed (maybeReverseRanges))
335+ return std::nullopt ;
336+ auto &reverseRanges = *maybeReverseRanges;
337+
338+ if (ranges.size () != numTargetDims || reverseRanges.size () != numTargetDims)
82339 return std::nullopt ;
83- // Process any remaining entries in the source shape. They all need to be
84- // 1 or dynamic.
85- for (; sourceDim < sourceShape.size (); sourceDim++) {
86- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87- sourceShape[sourceDim] != 1 )
88- return std::nullopt ;
89- // The map is empty when the target type is a scalar.
90- if (!reassociationMap.empty ())
91- reassociationMap.back ().push_back (sourceDim);
340+ // Now we can check for ambiguity of each target dimension's reassociation. If
341+ // successful, we put the full indices into our result map for the target
342+ // shape.
343+ SmallVector<ReassociationIndices> reassociationMap (numTargetDims);
344+ for (unsigned targetDimIdx = 0 ; targetDimIdx < numTargetDims;
345+ ++targetDimIdx) {
346+ ReassociationIndexRange &range = ranges[targetDimIdx];
347+ ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348+ // Get non-overlapping indices between the ranges
349+ ReassociationIndices nonMatchingIndices =
350+ range.getNonOverlappingIndicesWith (reverseRange);
351+ // Unit dimensions can be collapsed wherever - this is the only ambiguity
352+ // that we allow.
353+ for (int64_t sourceDimIdx : nonMatchingIndices) {
354+ if (sourceShape[sourceDimIdx] != 1 )
355+ return std::nullopt ;
356+ }
357+ reassociationMap[targetDimIdx] = range.getFullIndices ();
92358 }
93359 return reassociationMap;
94360}
0 commit comments