@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
10221022 }
10231023};
10241024
1025+ struct VectorToElementOpConvert final
1026+ : OpConversionPattern<vector::ToElementsOp> {
1027+ using OpConversionPattern::OpConversionPattern;
1028+
1029+ LogicalResult
1030+ matchAndRewrite (vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1031+ ConversionPatternRewriter &rewriter) const override {
1032+
1033+ SmallVector<Value> results (toElementsOp->getNumResults ());
1034+ Location loc = toElementsOp.getLoc ();
1035+
1036+ // Input vectors of size 1 are converted to scalars by the type converter.
1037+ // We cannot use `spirv::CompositeExtractOp` directly in this case.
1038+ // For a scalar source, the result is just the scalar itself.
1039+ if (isa<spirv::ScalarType>(adaptor.getSource ().getType ())) {
1040+ results[0 ] = adaptor.getSource ();
1041+ rewriter.replaceOp (toElementsOp, results);
1042+ return success ();
1043+ }
1044+
1045+ Type srcElementType = toElementsOp.getElements ().getType ().front ();
1046+ Type elementType = getTypeConverter ()->convertType (srcElementType);
1047+ if (!elementType)
1048+ return rewriter.notifyMatchFailure (
1049+ toElementsOp,
1050+ llvm::formatv (" failed to convert element type '{0}' to SPIR-V" ,
1051+ srcElementType));
1052+
1053+ for (auto [idx, element] : llvm::enumerate (toElementsOp.getElements ())) {
1054+ // Create an CompositeExtract operation only for results that are not
1055+ // dead.
1056+ if (element.use_empty ())
1057+ continue ;
1058+
1059+ Value result = rewriter.create <spirv::CompositeExtractOp>(
1060+ loc, elementType, adaptor.getSource (),
1061+ rewriter.getI32ArrayAttr ({static_cast <int32_t >(idx)}));
1062+ results[idx] = result;
1063+ }
1064+
1065+ rewriter.replaceOp (toElementsOp, results);
1066+ return success ();
1067+ }
1068+ };
1069+
10251070} // namespace
10261071#define CL_INT_MAX_MIN_OPS \
10271072 spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1039,8 +1084,8 @@ void mlir::populateVectorToSPIRVPatterns(
10391084 VectorExtractElementOpConvert, VectorExtractOpConvert,
10401085 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
10411086 VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1042- VectorInsertElementOpConvert, VectorInsertOpConvert ,
1043- VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1087+ VectorToElementOpConvert, VectorInsertElementOpConvert ,
1088+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
10441089 VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
10451090 VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
10461091 VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
0 commit comments