@@ -86,42 +86,62 @@ LogicalResult getListFromTensor(Value value, SmallVector<OpFoldResult> &vals) {
8686 getAsOpFoldResult (full.getFillValue ()));
8787 return success ();
8888 }
89- // TODO: Add a case for unsqueeze of a primnumtotensorscalarop?
89+
90+ if (auto unsqueeze = value.getDefiningOp <Torch::AtenUnsqueezeOp>()) {
91+ Value usqSelf = unsqueeze.getSelf ();
92+ if (auto numToTensor =
93+ usqSelf.getDefiningOp <Torch::PrimNumToTensorScalarOp>()) {
94+ vals.push_back (getAsOpFoldResult (numToTensor.getA ()));
95+ return success ();
96+ }
97+ }
98+
99+ // A common rank 0 tensor producer
100+ if (auto numToTensor =
101+ value.getDefiningOp <Torch::PrimNumToTensorScalarOp>()) {
102+ vals.push_back (getAsOpFoldResult (numToTensor.getA ()));
103+ return success ();
104+ }
90105
91106 // Last supported case: ValueTensorLiteralOp
92107 auto literalOp = value.getDefiningOp <Torch::ValueTensorLiteralOp>();
93108 if (!literalOp)
94109 return failure ();
95110
96- // Check the type. We make sure the type is not unsigned here before trying to
97- // materialize
111+ // Check the type.
98112 auto ty = cast<ValueTensorType>(literalOp.getType ());
99113 if (!ty.hasSizes () || ty.getSizes ().size () > 1 )
100114 return failure ();
101- int64_t listSize = ty. getSizes (). size () == 1 ? ty. getSizes (). front () : 1 ;
115+ // make sure the type is not unsigned here before trying to materialize
102116 auto intTy = dyn_cast_or_null<IntegerType>(ty.getDtype ());
103117 if (!intTy || intTy.isUnsigned ())
104118 return failure ();
105119
120+ // if we have a rank 0 literal, we will be adding one element to the list
121+ int64_t listSize = ty.getSizes ().size () == 1 ? ty.getSizes ().front () : 1 ;
122+
123+ if (listSize > kMaxFold )
124+ return failure ();
125+
126+ // check for a splat or dense attr
106127 auto splattr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue ());
107128 auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(literalOp.getValue ());
108129
109130 if (!splattr && !denseAttr)
110131 return failure ();
111132
133+ // These are not mutually exclusive, so try splat first.
112134 if (splattr) {
113135 auto attr = splattr.getSplatValue <Attribute>();
114136 vals.resize ((int64_t )vals.size () + listSize, attr);
137+ return success ();
115138 }
116139
117- if (denseAttr && !splattr) {
118- for (auto e : denseAttr.getValues <Attribute>())
119- vals.push_back (e);
120- }
121-
122- if ((int64_t )vals.size () != listSize)
140+ // remaining case: denseAttr
141+ if ((int64_t )denseAttr.getValues <Attribute>().size () != listSize)
123142 return failure ();
124-
143+ for (auto e : denseAttr.getValues <Attribute>())
144+ vals.push_back (e);
125145 return success ();
126146}
127147
@@ -143,6 +163,45 @@ Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy,
143163// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to
144164// getListFromTensor(A), and further propagate scalarization.
145165
166+ namespace {
167+ class PropagateAtenBroadcastToPattern
168+ : public OpRewritePattern<AtenBroadcastToOp> {
169+ public:
170+ using OpRewritePattern<AtenBroadcastToOp>::OpRewritePattern;
171+ LogicalResult matchAndRewrite (AtenBroadcastToOp op,
172+ PatternRewriter &rewriter) const override {
173+ constexpr int64_t kMaxFold = 16 ;
174+ // for tensor<si64>, or tensor<1xsi64>, broadcasted to tensor<nxsi64>, grab
175+ // the element and convert to a full op.
176+ auto ty = cast<ValueTensorType>(op.getType ());
177+ if (!ty.areAllSizesKnown () || ty.getSizes ().size () != 1 )
178+ return failure ();
179+
180+ if (ty.getSizes ()[0 ] > kMaxFold )
181+ return failure ();
182+
183+ SmallVector<OpFoldResult> fillFold;
184+ if (failed (getListFromTensor (op.getSelf (), fillFold)) ||
185+ fillFold.size () != 1 )
186+ return failure ();
187+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
188+ SmallVector<Value, 1 > fillVals;
189+ if (failed (materializeFolds (b, fillFold, fillVals)))
190+ return failure ();
191+
192+ Value size = b.create <Torch::ConstantIntOp>(ty.getSizes ().front ());
193+ Value sizeList = b.create <Torch::PrimListConstructOp>(
194+ rewriter.getType <Torch::ListType>(rewriter.getType <Torch::IntType>()),
195+ size);
196+ Value none = b.create <Torch::ConstantNoneOp>();
197+ Value cstFalse = b.create <Torch::ConstantBoolOp>(false );
198+ rewriter.replaceOpWithNewOp <AtenFullOp>(op, ty, sizeList, fillVals.front (),
199+ none, none, none, cstFalse);
200+ return success ();
201+ }
202+ };
203+ } // namespace
204+
146205namespace {
147206class PropagateAtenShapeToTensorPattern
148207 : public OpRewritePattern<Aten_ShapeAsTensorOp> {
@@ -541,9 +600,128 @@ class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
541600};
542601} // namespace
543602
603+ namespace {
604+
605+ template <typename OpTy> struct ArithmeticHelper {
606+ static LogicalResult getAlphaAndVerify (OpTy &op, int64_t &alpha) {
607+ alpha = 1 ;
608+ return success ();
609+ }
610+ };
611+
612+ template <> struct ArithmeticHelper <AtenAddTensorOp> {
613+ static LogicalResult getAlphaAndVerify (AtenAddTensorOp &op, int64_t &alpha) {
614+ if (!matchPattern (op.getAlpha (), m_TorchConstantInt (&alpha)) || alpha != 1 )
615+ return failure ();
616+ return success ();
617+ }
618+ };
619+
620+ template <> struct ArithmeticHelper <AtenSubTensorOp> {
621+ static LogicalResult getAlphaAndVerify (AtenSubTensorOp &op, int64_t &alpha) {
622+ if (!matchPattern (op.getAlpha (), m_TorchConstantInt (&alpha)) || alpha != 1 )
623+ return failure ();
624+ return success ();
625+ }
626+ };
627+
628+ template <typename OpTy, typename ScalarOpTy>
629+ class PropagateAtenArithmeticPattern : public OpRewritePattern <OpTy> {
630+ public:
631+ using OpRewritePattern<OpTy>::OpRewritePattern;
632+ LogicalResult matchAndRewrite (OpTy op,
633+ PatternRewriter &rewriter) const override {
634+ // Check type
635+ auto resultTy = cast<ValueTensorType>(op.getType ());
636+ if (resultTy.getSizes ().size () > 1 )
637+ return rewriter.notifyMatchFailure (op, " unsupported: rank > 1" );
638+ if (!resultTy.hasDtype () || !isa<mlir::IntegerType>(resultTy.getDtype ()))
639+ return rewriter.notifyMatchFailure (op, " not an int type" );
640+
641+ int64_t alpha;
642+ if (failed (ArithmeticHelper<OpTy>::getAlphaAndVerify (op, alpha)))
643+ return rewriter.notifyMatchFailure (op, " alpha must be 1" );
644+
645+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
646+ SmallVector<OpFoldResult> selfFold, otherFold;
647+ if (failed (getListFromTensor (op.getSelf (), selfFold)) ||
648+ failed (getListFromTensor (op.getOther (), otherFold)) ||
649+ selfFold.size () != otherFold.size ())
650+ return failure ();
651+ SmallVector<Value> selfVals, otherVals;
652+ if (failed (materializeFolds (b, selfFold, selfVals)) ||
653+ failed (materializeFolds (b, otherFold, otherVals)))
654+ return failure ();
655+ SmallVector<OpFoldResult> resultFolds;
656+ for (uint64_t i = 0 ; i < selfVals.size (); i++) {
657+ resultFolds.push_back (b.createOrFold <ScalarOpTy>(
658+ selfVals[i].getType (), selfVals[i], otherVals[i]));
659+ }
660+ SmallVector<Value> resultVals;
661+ if (failed (materializeFolds (b, resultFolds, resultVals)))
662+ return failure ();
663+
664+ if (resultTy.getSizes ().size () == 0 ) {
665+ rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
666+ op, resultTy, resultVals.front ());
667+ return success ();
668+ }
669+
670+ Value result = constructAtenTensorOpFromList (b, resultTy, resultVals);
671+ rewriter.replaceOp (op, result);
672+ return success ();
673+ }
674+ };
675+ } // namespace
676+
544677// / ------ Fold Patterns ------ ///
545678// These are shape-specific folding patterns
546679
680+ namespace {
681+ class FoldAtenEqIntPattern : public OpRewritePattern <AtenEqIntOp> {
682+ public:
683+ using OpRewritePattern<AtenEqIntOp>::OpRewritePattern;
684+ LogicalResult matchAndRewrite (AtenEqIntOp op,
685+ PatternRewriter &rewriter) const override {
686+ // replaces (size.int == 0) with false and adds an assert
687+ // these comparisons are getting generated because onnx.Reshape considers 0
688+ // to mean "don't change this dim". However, if the size we are passing to
689+ // onnx.Reshape is a tensor dim, this is definitely never supposed to be
690+ // interpreted as "don't change this dim".
691+ int64_t otherInt;
692+ if (!matchPattern (op.getB (), m_TorchConstantInt (&otherInt)) ||
693+ otherInt != 0 )
694+ return failure ();
695+
696+ // in case the shape is a product of two ints, check each
697+ if (auto mulOp = op.getA ().getDefiningOp <AtenMulIntOp>()) {
698+ Value self = mulOp.getA ();
699+ Value other = mulOp.getB ();
700+ Value selfEq = rewriter.create <AtenEqIntOp>(op.getLoc (), self, op.getB ());
701+ Value otherEq =
702+ rewriter.create <AtenEqIntOp>(op.getLoc (), other, op.getB ());
703+ rewriter.replaceOpWithNewOp <Aten__Or__BoolOp>(op, selfEq, otherEq);
704+ return success ();
705+ }
706+
707+ // if lhs is size.int op, assert size > 0 and replace with false.
708+ if (auto sizeOp = op.getA ().getDefiningOp <AtenSizeIntOp>()) {
709+ Value selfGtOther = rewriter.create <AtenGtIntOp>(
710+ op.getLoc (), op.getType (), op.getA (), op.getB ());
711+ rewriter.create <Torch::RuntimeAssertOp>(
712+ op.getLoc (), selfGtOther,
713+ rewriter.getStringAttr (" Expected dim size > 0." ));
714+ Value cstFalse =
715+ rewriter.create <Torch::ConstantBoolOp>(op.getLoc (), false );
716+ rewriter.replaceOp (op, cstFalse);
717+ return success ();
718+ }
719+
720+ return failure ();
721+ }
722+ };
723+ } // namespace
724+
547725namespace {
548726class FoldAtenTensorSplatPattern : public OpRewritePattern <AtenTensorOp> {
549727public:
@@ -594,16 +772,24 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
594772} // namespace
595773
596774namespace {
597- class FoldAtenSqueezePattern : public OpRewritePattern <AtenSqueezeOp> {
775+ template <typename SqueezeOp>
776+ class FoldAtenSqueezePattern : public OpRewritePattern <SqueezeOp> {
598777public:
599- using OpRewritePattern<AtenSqueezeOp >::OpRewritePattern;
600- LogicalResult matchAndRewrite (AtenSqueezeOp op,
778+ using OpRewritePattern<SqueezeOp >::OpRewritePattern;
779+ LogicalResult matchAndRewrite (SqueezeOp op,
601780 PatternRewriter &rewriter) const override {
602781 auto resultTy = cast<ValueTensorType>(op.getType ());
603782 if (!resultTy.hasSizes () || !resultTy.areAllSizesKnown ())
604783 return rewriter.notifyMatchFailure (op, " Unknown result shape" );
605784
606- if (auto atenFull = op.getSelf ().getDefiningOp <AtenFullOp>()) {
785+ Value self = op.getSelf ();
786+ if (auto atenFull = self.getDefiningOp <AtenFullOp>()) {
787+ // in the rank 0 case, just return the rank 0 scalar
788+ if (resultTy.getSizes ().size () == 0 ) {
789+ rewriter.replaceOpWithNewOp <Torch::PrimNumToTensorScalarOp>(
790+ op, resultTy, atenFull.getFillValue ());
791+ return success ();
792+ }
607793 SmallVector<Value> sizes;
608794 for (int i = 0 , s = resultTy.getSizes ().size (); i < s; ++i)
609795 sizes.push_back (rewriter.create <Torch::ConstantIntOp>(
@@ -874,9 +1060,16 @@ bool isPrimListOfInts(Operation *op) {
8741060 return llvm::isa<Torch::IntType>(listType.getContainedType ());
8751061}
8761062
1063+ bool isAnchorOp (Operation *op) {
1064+ return isa<Torch::RuntimeAssertOp>(op) || isa<AtenArangeStartStepOp>(op) ||
1065+ isPrimListOfInts (op);
1066+ }
1067+
8771068void populateScalarizationFoldPatterns (RewritePatternSet &patterns) {
878- patterns.insert <FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
879- FoldAtenWhereSelf, FoldAtenTensorSplatPattern>(
1069+ patterns.insert <FoldAtenSqueezePattern<AtenSqueezeOp>,
1070+ FoldAtenSqueezePattern<AtenSqueezeDimOp>,
1071+ FoldAtenUnsqueezePattern, FoldAtenWhereSelf,
1072+ FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>(
8801073 patterns.getContext ());
8811074}
8821075
@@ -885,10 +1078,21 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
8851078}
8861079
8871080void populateScalarizationPropagationPatterns (RewritePatternSet &patterns) {
888- patterns.insert <PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
889- PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
890- PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
891- PropagateAtenWhereSelfPattern>(patterns.getContext ());
1081+ // A note on division: onnx.Div from int, int -> int types rounds towards
1082+ // zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
1083+ // but this was artificially plummbed through. Unfortunately, there is no
1084+ // scalar trunc div op in torch; however, we can safely assume all operands
1085+ // are positive so floor divide should be a sufficient scalar replacement.
1086+ patterns.insert <
1087+ PropagateAtenCatPattern, PropagateAtenIndexSelectPattern,
1088+ PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
1089+ PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
1090+ PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
1091+ PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
1092+ PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
1093+ PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
1094+ PropagateAtenArithmeticPattern<AtenDivTensorOp, AtenFloordivIntOp>>(
1095+ patterns.getContext ());
8921096}
8931097
8941098void populateScalarizationRemovePatterns (RewritePatternSet &patterns) {
@@ -940,7 +1144,7 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
9401144 [&](Operation *op) {
9411145 // Walking bottom-up, start adding ops when we reach an anchor point
9421146 // (a prim list of ints)
943- if (isPrimListOfInts (op)) {
1147+ if (isAnchorOp (op)) {
9441148 shapeCalculationOps.insert (op);
9451149 return ;
9461150 }
0 commit comments