@@ -782,30 +782,42 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
782782
783783// -----
784784
785- // CHECK-LABEL: fold_extract_shapecast_negative
786- // CHECK: %[[V :.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32 >
787- // CHECK: %[[R:.*]] = vector.extract %[[V ]][1 ] : vector<4x2xf32> from vector<2x4x2xf32 >
788- // CHECK: return %[[R]] : vector<4x2xf32>
789- func.func @fold_extract_shapecast_negative (%arg0 : vector <16 x f32 >) -> vector < 4 x 2 x f32 > {
790- %0 = vector.shape_cast %arg0 : vector <16 x f32 > to vector <2 x 4 x 2 x f32 >
791- %r = vector.extract %0 [1 ] : vector < 4 x 2 x f32 > from vector <2 x 4 x 2 x f32 >
792- return %r : vector < 4 x 2 x f32 >
785+ // CHECK-LABEL: fold_extract_shapecast_0d_result
786+ // CHECK-SAME: %[[IN :.*]]: vector<1x1x1xf32 >
787+ // CHECK: %[[R:.*]] = vector.extract %[[IN ]][0, 0, 0 ] : f32 from vector<1x1x1xf32 >
788+ // CHECK: return %[[R]] : f32
789+ func.func @fold_extract_shapecast_0d_result (%arg0 : vector <1 x 1 x 1 x f32 >) -> f32 {
790+ %0 = vector.shape_cast %arg0 : vector <1 x 1 x 1 x f32 > to vector <f32 >
791+ %r = vector.extract %0 [] : f32 from vector <f32 >
792+ return %r : f32
793793}
794794
795795// -----
796796
797- // CHECK-LABEL: dont_fold_0d_extract_shapecast
798- // CHECK: %[[V :.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32 >
799- // CHECK: %[[R:.*]] = vector.extract %[[V ]][0 ] : f32 from vector<1xf32 >
797+ // CHECK-LABEL: fold_extract_shapecast_0d_source
798+ // CHECK-SAME: %[[IN :.*]]: vector<f32>
799+ // CHECK: %[[R:.*]] = vector.extract %[[IN ]][] : f32 from vector<f32 >
800800// CHECK: return %[[R]] : f32
801- func.func @dont_fold_0d_extract_shapecast (%arg0 : vector <f32 >) -> f32 {
801+ func.func @fold_extract_shapecast_0d_source (%arg0 : vector <f32 >) -> f32 {
802802 %0 = vector.shape_cast %arg0 : vector <f32 > to vector <1 xf32 >
803803 %r = vector.extract %0 [0 ] : f32 from vector <1 xf32 >
804804 return %r : f32
805805}
806806
807807// -----
808808
809+ // CHECK-LABEL: fold_extract_shapecast_negative
810+ // CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
811+ // CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
812+ // CHECK: return %[[R]] : vector<4x2xf32>
813+ func.func @fold_extract_shapecast_negative (%arg0 : vector <16 xf32 >) -> vector <4 x2 xf32 > {
814+ %0 = vector.shape_cast %arg0 : vector <16 xf32 > to vector <2 x4 x2 xf32 >
815+ %r = vector.extract %0 [1 ] : vector <4 x2 xf32 > from vector <2 x4 x2 xf32 >
816+ return %r : vector <4 x2 xf32 >
817+ }
818+
819+ // -----
820+
809821// CHECK-LABEL: fold_extract_shapecast_to_shapecast
810822// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
811823// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
0 commit comments