@@ -273,7 +273,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
273273 arith::ConstantOp constOp, OpAdaptor adaptor,
274274 ConversionPatternRewriter &rewriter) const {
275275 auto srcType = constOp.getType ().dyn_cast <ShapedType>();
276- if (!srcType)
276+ if (!srcType || srcType. getNumElements () == 1 )
277277 return failure ();
278278
279279 // arith.constant should only have vector or tenor types.
@@ -358,16 +358,25 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
358358 arith::ConstantOp constOp, OpAdaptor adaptor,
359359 ConversionPatternRewriter &rewriter) const {
360360 Type srcType = constOp.getType ();
361+ if (auto shapedType = srcType.dyn_cast <ShapedType>()) {
362+ if (shapedType.getNumElements () != 1 )
363+ return failure ();
364+ srcType = shapedType.getElementType ();
365+ }
361366 if (!srcType.isIntOrIndexOrFloat ())
362367 return failure ();
363368
369+ Attribute cstAttr = constOp.getValue ();
370+ if (cstAttr.getType ().isa <ShapedType>())
371+ cstAttr = cstAttr.cast <DenseElementsAttr>().getSplatValue <Attribute>();
372+
364373 Type dstType = getTypeConverter ()->convertType (srcType);
365374 if (!dstType)
366375 return failure ();
367376
368377 // Floating-point types.
369378 if (srcType.isa <FloatType>()) {
370- auto srcAttr = constOp. getValue () .cast <FloatAttr>();
379+ auto srcAttr = cstAttr .cast <FloatAttr>();
371380 auto dstAttr = srcAttr;
372381
373382 // Floating-point types not supported in the target environment are all
@@ -386,7 +395,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386395 if (srcType.isInteger (1 )) {
387396 // arith.constant can use 0/1 instead of true/false for i1 values. We need
388397 // to handle that here.
389- auto dstAttr = convertBoolAttr (constOp. getValue () , rewriter);
398+ auto dstAttr = convertBoolAttr (cstAttr , rewriter);
390399 if (!dstAttr)
391400 return failure ();
392401 rewriter.replaceOpWithNewOp <spirv::ConstantOp>(constOp, dstType, dstAttr);
@@ -395,7 +404,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
395404
396405 // IndexType or IntegerType. Index values are converted to 32-bit integer
397406 // values when converting to SPIR-V.
398- auto srcAttr = constOp. getValue () .cast <IntegerAttr>();
407+ auto srcAttr = cstAttr .cast <IntegerAttr>();
399408 auto dstAttr =
400409 convertIntegerAttr (srcAttr, dstType.cast <IntegerType>(), rewriter);
401410 if (!dstAttr)
0 commit comments