|
| 1 | +// RUN: iree-opt %s --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-split-reduction-sizes))" --split-input-file > %t |
| 2 | +// RUN: FileCheck %s < %t |
| 3 | + |
| 4 | +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5, d2 + d6, d3)> |
| 5 | +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)> |
| 6 | +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> |
| 7 | +util.func public @conv_2d_chwn_chwf(%arg0: tensor<16x227x227x16xf32>, %arg1: tensor<16x225x225x64xf32>, %arg2: tensor<64x3x3x16xf32>) -> tensor<64x3x3x16xf32> { |
| 8 | + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x227x227x16xf32>, tensor<16x225x225x64xf32>) outs(%arg2 : tensor<64x3x3x16xf32>) { |
| 9 | + ^bb0(%in: f32, %in_3: f32, %out: f32): |
| 10 | + %12 = arith.mulf %in, %in_3 : f32 |
| 11 | + %13 = arith.addf %out, %12 : f32 |
| 12 | + linalg.yield %13 : f32 |
| 13 | + } -> tensor<64x3x3x16xf32> |
| 14 | + util.return %0 : tensor<64x3x3x16xf32> |
| 15 | +} |
| 16 | + |
| 17 | +// CHECK-LABEL: @conv_2d_chwn_chwf |
| 18 | +// CHECK: iree_linalg_ext.split_reduction = [1 : index, 225 : index, 225 : index] |
| 19 | + |
| 20 | +// ----- |
| 21 | + |
| 22 | +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> |
| 23 | +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> |
| 24 | +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> |
| 25 | +util.func public @no_split_conv_2d_nhwc_fhwc(%arg0: tensor<16x227x227x16xf32>, %arg1: tensor<64x3x3x16xf32>, %arg2: tensor<16x225x225x64xf32>) -> tensor<16x225x225x64xf32> { |
| 26 | + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x227x227x16xf32>, tensor<64x3x3x16xf32>) outs(%arg2 : tensor<16x225x225x64xf32>) { |
| 27 | + ^bb0(%in: f32, %in_0: f32, %out: f32): |
| 28 | + %3 = arith.mulf %in, %in_0 : f32 |
| 29 | + %4 = arith.addf %out, %3 : f32 |
| 30 | + linalg.yield %4 : f32 |
| 31 | + } -> tensor<16x225x225x64xf32> |
| 32 | + util.return %0 : tensor<16x225x225x64xf32> |
| 33 | +} |
| 34 | + |
| 35 | +// CHECK-LABEL: @no_split_conv_2d_nhwc_fhwc |
| 36 | +// CHECK-NOT: iree_linalg_ext.split_reduction |
| 37 | + |
| 38 | +// ----- |
| 39 | + |
| 40 | +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5, d2 + d6, d3)> |
| 41 | +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)> |
| 42 | +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> |
| 43 | +util.func public @no_split_large_N_F_sizes(%arg0: tensor<16x98x50x1024xf32>, %arg1: tensor<16x96x48x1024xf32>, %arg2: tensor<1024x3x3x1024xf32>) -> tensor<1024x3x3x1024xf32> { |
| 44 | + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x98x50x1024xf32>, tensor<16x96x48x1024xf32>) outs(%arg2 : tensor<1024x3x3x1024xf32>) { |
| 45 | + ^bb0(%in: f32, %in_3: f32, %out: f32): |
| 46 | + %12 = arith.mulf %in, %in_3 : f32 |
| 47 | + %13 = arith.addf %out, %12 : f32 |
| 48 | + linalg.yield %13 : f32 |
| 49 | + } -> tensor<1024x3x3x1024xf32> |
| 50 | + util.return %0 : tensor<1024x3x3x1024xf32> |
| 51 | +} |
| 52 | + |
| 53 | +// CHECK-LABEL: @no_split_large_N_F_sizes |
| 54 | +// CHECK-NOT: iree_linalg_ext.split_reduction |
| 55 | + |
| 56 | +// ----- |
| 57 | + |
| 58 | +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5, d2 + d6, d3)> |
| 59 | +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)> |
| 60 | +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> |
| 61 | +util.func public @no_split_small_H_W_sizes(%arg0: tensor<16x26x18x288xf32>, %arg1: tensor<16x24x16x288xf32>, %arg2: tensor<288x3x3x288xf32>) -> tensor<288x3x3x288xf32> { |
| 62 | + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<16x26x18x288xf32>, tensor<16x24x16x288xf32>) outs(%arg2 : tensor<288x3x3x288xf32>) { |
| 63 | + ^bb0(%in: f32, %in_3: f32, %out: f32): |
| 64 | + %12 = arith.mulf %in, %in_3 : f32 |
| 65 | + %13 = arith.addf %out, %12 : f32 |
| 66 | + linalg.yield %13 : f32 |
| 67 | + } -> tensor<288x3x3x288xf32> |
| 68 | + util.return %0 : tensor<288x3x3x288xf32> |
| 69 | +} |
| 70 | + |
| 71 | +// CHECK-LABEL: @no_split_small_H_W_sizes |
| 72 | +// CHECK-NOT: iree_linalg_ext.split_reduction |
0 commit comments