Skip to content

Commit ca4c5c5

Browse files
IanWood1weidel-p
authored andcommitted
[Codegen] Handle multiple dyn dims in tensor load pattern (iree-org#22328)
Fix compile error when `FoldExpandShapeIntoInterfaceTensorLoad` tries to fold an expand shape with multiple dyn dims into a `iree_tensor_ext.dispatch.tensor.load` op. This change tries to use the output shape SSA values from the expand shape when the output shape cannot be inferred. Fixes iree-org#22324 --------- Signed-off-by: Ian Wood <[email protected]> Signed-off-by: Philipp <[email protected]>
1 parent 8abb0cd commit ca4c5c5

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1616
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1717
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include "mlir/Transforms/RegionUtils.h"
1819

1920
#define DEBUG_TYPE "iree-codegen-reshape-patterns"
2021

@@ -207,12 +208,22 @@ struct FoldExpandShapeIntoInterfaceTensorLoad
207208
auto currStaticDims = loadOp.getType().getShape();
208209
auto currOfrDynamicDims =
209210
mlir::getMixedValues(currStaticDims, currDynamicDims, rewriter);
211+
212+
// Try to infer the expanded shape. This only works if each reassociation
213+
// has <=1 dyn dim.
210214
std::optional<SmallVector<OpFoldResult>> expandedDims =
211215
mlir::inferExpandShapeOutputShape(
212216
rewriter, subspanOp.getLoc(), reshapeOp.getType(),
213217
reshapeOp.getReassociationIndices(), currOfrDynamicDims);
214218
if (!expandedDims) {
215-
return reshapeOp.emitOpError("failure in expanded shape");
219+
// If inference fails, try to use the reshape's SSA values.
220+
if (failed(mlir::moveValueDefinitions(
221+
rewriter, reshapeOp.getOutputShape(), subspanOp))) {
222+
return rewriter.notifyMatchFailure(reshapeOp,
223+
"could not infer output shape or "
224+
"move SSA values before subspan op");
225+
}
226+
expandedDims = reshapeOp.getMixedOutputShape();
216227
}
217228

218229
auto tensorAccess =

compiler/src/iree/compiler/Codegen/Common/test/fold_reshape_into_interface_tensor.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,54 @@ func.func @fold_expand_into_loads_dynamic() -> tensor<2x?x16x32xf32> {
4747
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [2, %[[SHAPE]], 16, 32], strides = [1, 1, 1, 1]
4848
// CHECK-SAME: !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x?x16x32xf32>>{%[[SHAPE]]}
4949

50+
// -----
51+
52+
#pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
53+
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>
54+
func.func @fold_expand_into_loads_fully_dynamic() -> tensor<?x?xf32> {
55+
%c0 = arith.constant 0 : index
56+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
57+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
58+
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
59+
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
60+
flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?xf32>>{%0}
61+
%4 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0], sizes = [%0], strides = [1]
62+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?xf32>>{%0} -> tensor<?xf32>
63+
%5 = tensor.expand_shape %4 [[0, 1]] output_shape [%1, %2] : tensor<?xf32> into tensor<?x?xf32>
64+
return %5 : tensor<?x?xf32>
65+
}
66+
// CHECK-LABEL: func @fold_expand_into_loads_fully_dynamic()
67+
// CHECK-DAG: %[[CONST0:.+]] = hal.interface.constant.load {{.*}} ordinal(1)
68+
// CHECK-DAG: %[[CONST1:.+]] = hal.interface.constant.load {{.*}} ordinal(2)
69+
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan
70+
// CHECK-SAME: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32>>{%[[CONST0]], %[[CONST1]]}
71+
// CHECK: %[[LOAD:.+]] = iree_tensor_ext.dispatch.tensor.load %[[SUBSPAN]]
72+
// CHECK-SAME: offsets = [0, 0], sizes = [%[[CONST0]], %[[CONST1]]]
73+
// CHECK-SAME: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32>>{%[[CONST0]], %[[CONST1]]}
74+
75+
// -----
76+
77+
#pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [
78+
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>
79+
func.func @no_fold_expand_into_loads_fully_dynamic() -> tensor<?x?xindex> {
80+
%c0 = arith.constant 0 : index
81+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
82+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
83+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
84+
flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?xindex>>{%0}
85+
%3 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0], sizes = [%0], strides = [1]
86+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?xindex>>{%0} -> tensor<?xindex>
87+
%4 = tensor.extract %3[%c0] : tensor<?xindex>
88+
%5 = tensor.expand_shape %3 [[0, 1]] output_shape [%1, %4] : tensor<?xindex> into tensor<?x?xindex>
89+
return %5 : tensor<?x?xindex>
90+
}
91+
// This case cannot be folded because expanded sizes depend on the tensor itself.
92+
// So, the size cannot be known before the load.
93+
94+
// CHECK-LABEL: func @no_fold_expand_into_loads_fully_dynamic()
95+
// CHECK: tensor.expand_shape
96+
97+
5098
// -----
5199

52100
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [

0 commit comments

Comments
 (0)