@@ -260,6 +260,31 @@ struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
260260 }
261261};
262262
263+ struct SimplifyVecSplat : public OpRewritePattern <VecSplatOp> {
264+ using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265+ LogicalResult matchAndRewrite (VecSplatOp op,
266+ PatternRewriter &rewriter) const override {
267+ mlir::Value splatValue = op.getValue ();
268+ auto constant =
269+ mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp ());
270+ if (!constant)
271+ return mlir::failure ();
272+
273+ auto value = constant.getValue ();
274+ if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
275+ !mlir::isa_and_nonnull<cir::FPAttr>(value))
276+ return mlir::failure ();
277+
278+ cir::VectorType resultType = op.getResult ().getType ();
279+ SmallVector<mlir::Attribute, 16 > elements (resultType.getSize (), value);
280+ auto constVecAttr = cir::ConstVectorAttr::get (
281+ resultType, mlir::ArrayAttr::get (getContext (), elements));
282+
283+ rewriter.replaceOpWithNewOp <cir::ConstantOp>(op, constVecAttr);
284+ return mlir::success ();
285+ }
286+ };
287+
263288// ===----------------------------------------------------------------------===//
264289// CIRSimplifyPass
265290// ===----------------------------------------------------------------------===//
@@ -275,7 +300,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
275300 patterns.add <
276301 SimplifyTernary,
277302 SimplifySelect,
278- SimplifySwitch
303+ SimplifySwitch,
304+ SimplifyVecSplat
279305 >(patterns.getContext ());
280306 // clang-format on
281307}
@@ -288,7 +314,7 @@ void CIRSimplifyPass::runOnOperation() {
288314 // Collect operations to apply patterns.
289315 llvm::SmallVector<Operation *, 16 > ops;
290316 getOperation ()->walk ([&](Operation *op) {
291- if (isa<TernaryOp, SelectOp, SwitchOp>(op))
317+ if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp >(op))
292318 ops.push_back (op);
293319 });
294320
0 commit comments