Skip to content

Commit b68730d

Browse files
committed
set viewShapeInts[dim] = size; in DecomposeComplexOps
1 parent 5d374ba commit b68730d

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12855,8 +12855,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
1285512855
Value index = rewriter.create<Torch::AtenArangeOp>(
1285612856
loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);
1285712857

12858-
// Set the current dimension to -1 for broadcasting
12859-
viewShapeInts[dim] = -1;
12858+
viewShapeInts[dim] = size;
1286012859
viewShapeListElems[dim] = cstMinusOne;
1286112860

1286212861
Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(

projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,46 @@ def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
12811281
module.forward(tu.rand(5, 12, 3))
12821282

12831283

1284+
class UnflattenIntDynamicModule(torch.nn.Module):
1285+
def __init__(self):
1286+
super().__init__()
1287+
1288+
@export
1289+
@annotate_args(
1290+
[
1291+
None,
1292+
([-1, 12], torch.float32, True),
1293+
]
1294+
)
1295+
def forward(self, inputs):
1296+
return torch.ops.aten.unflatten(inputs, 1, [3, 4])
1297+
1298+
1299+
@register_test_case(module_factory=lambda: UnflattenIntDynamicModule())
1300+
def UnflattenIntDynamicModule_basic(module, tu: TestUtils):
1301+
module.forward(tu.rand(2, 12))
1302+
1303+
1304+
class UnflattenIntDynamicWithInferredSizeModule(torch.nn.Module):
1305+
def __init__(self):
1306+
super().__init__()
1307+
1308+
@export
1309+
@annotate_args(
1310+
[
1311+
None,
1312+
([-1, 20], torch.float32, True),
1313+
]
1314+
)
1315+
def forward(self, inputs):
1316+
return torch.ops.aten.unflatten(inputs, 1, [4, -1])
1317+
1318+
1319+
@register_test_case(module_factory=lambda: UnflattenIntDynamicWithInferredSizeModule())
1320+
def UnflattenIntDynamicWithInferredSizeModule_basic(module, tu: TestUtils):
1321+
module.forward(tu.rand(3, 20))
1322+
1323+
12841324
# ==============================================================================
12851325

12861326

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @torch.aten.unflatten.int$static
4+
// CHECK: torch_c.to_builtin_tensor
5+
// CHECK: tensor.expand_shape
6+
// CHECK: torch_c.from_builtin_tensor
7+
func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> {
8+
%int1 = torch.constant.int 1
9+
%int2 = torch.constant.int 2
10+
%int3 = torch.constant.int 3
11+
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
12+
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,2,3,4],f32>
13+
return %1 : !torch.vtensor<[2,2,3,4],f32>
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim
19+
// CHECK: torch_c.to_builtin_tensor
20+
// CHECK: tensor.expand_shape
21+
// CHECK: torch_c.from_builtin_tensor
22+
func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> {
23+
%int-2 = torch.constant.int -2
24+
%int2 = torch.constant.int 2
25+
%int3 = torch.constant.int 3
26+
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
27+
%1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,2,3,4],f32>
28+
return %1 : !torch.vtensor<[2,2,3,4],f32>
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size
34+
// CHECK: torch_c.to_builtin_tensor
35+
// CHECK: tensor.expand_shape
36+
// CHECK: torch_c.from_builtin_tensor
37+
func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> {
38+
%int1 = torch.constant.int 1
39+
%int2 = torch.constant.int 2
40+
%int-1 = torch.constant.int -1
41+
%0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
42+
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[3,2,6],f32>
43+
return %1 : !torch.vtensor<[3,2,6],f32>
44+
}
45+
46+
// -----
47+
48+
// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input
49+
// CHECK: torch_c.to_builtin_tensor
50+
// CHECK: tensor.expand_shape
51+
// CHECK: torch_c.from_builtin_tensor
52+
func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> {
53+
%int1 = torch.constant.int 1
54+
%int2 = torch.constant.int 2
55+
%int3 = torch.constant.int 3
56+
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
57+
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,2,3],f32>
58+
return %1 : !torch.vtensor<[?,2,3],f32>
59+
}
60+
61+
// -----
62+
63+
// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims
64+
// CHECK: torch_c.to_builtin_tensor
65+
// CHECK: tensor.from_elements
66+
// CHECK: tensor.reshape
67+
// CHECK: torch_c.from_builtin_tensor
68+
func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> {
69+
%int1 = torch.constant.int 1
70+
%2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int
71+
%0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list<int>
72+
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
73+
return %1 : !torch.vtensor<[?,?,?],f32>
74+
}

0 commit comments

Comments
 (0)