@@ -649,6 +649,33 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
649649
650650// -----
651651
652+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
653+ #sparse = #sparse_tensor.encoding <{ map = (d0 , d1 ) -> (d0 : dense , d1 : compressed) }>
654+ // CHECK-DAG: #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
655+ // CHECK-LABEL: func @static_shape_inference_with_encoding(
656+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
657+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
658+ func.func @static_shape_inference_with_encoding (%arg0: tensor <?x?xf32 , #sparse >, %arg1: tensor <?x?xf32 >) -> tensor <3 x4 xf32 > {
659+ %0 = tensor.empty () : tensor <3 x4 xf32 >
660+ %1 = linalg.generic {
661+ indexing_maps = [#map , #map , #map ],
662+ iterator_types = [" parallel" , " parallel" ]
663+ } ins (%arg0 , %arg1 : tensor <?x?xf32 , #sparse >, tensor <?x?xf32 >)
664+ outs (%0 : tensor <3 x4 xf32 >) {
665+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
666+ %2 = arith.addf %in , %in_0 : f32
667+ linalg.yield %2 : f32
668+ } -> tensor <3 x4 xf32 >
669+ return %1 : tensor <3 x4 xf32 >
670+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
671+ // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
672+ // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
673+ // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
674+ // CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
675+ }
676+
677+ // -----
678+
652679// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
653680// CHECK-LABEL: func @insert_pad_into_fill
654681// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)
0 commit comments