@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
126126 if (isa<tosa::MulOp>(op)) {
127127 auto shiftVal = cast<tosa::MulOp>(op).getShift ();
128128 DenseElementsAttr shiftElem;
129- if (! matchPattern (shiftVal, m_Constant (&shiftElem))) {
130- ( void )rewriter. notifyMatchFailure (op, " shift value of mul not found " ) ;
131- return nullptr ;
132- }
133-
134- int32_t shift = shiftElem. getValues <IntegerAttr>()[ 0 ]. getInt () ;
129+ bool shiftIsConstant = true ;
130+ int32_t shift = 0 ;
131+ if ( matchPattern (shiftVal, m_Constant (&shiftElem)))
132+ shift = shiftElem. getValues <IntegerAttr>()[ 0 ]. getInt ();
133+ else
134+ shiftIsConstant = false ;
135135
136136 if (isa<FloatType>(elementTy)) {
137137 if (shift != 0 ) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
147147 Value a = args[0 ];
148148 Value b = args[1 ];
149149
150- if (shift > 0 ) {
151- auto shiftConst =
152- arith::ConstantIntOp::create (rewriter, loc, shift, /* bitwidth=*/ 8 );
150+ if (shift > 0 || !shiftIsConstant) {
151+ Value shiftConst;
152+ if (shiftIsConstant)
153+ shiftConst =
154+ rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
155+
153156 if (!a.getType ().isInteger (32 ))
154157 a = arith::ExtSIOp::create (rewriter, loc, rewriter.getI32Type (), a);
155158
156159 if (!b.getType ().isInteger (32 ))
157160 b = arith::ExtSIOp::create (rewriter, loc, rewriter.getI32Type (), b);
158161
162+ auto shiftAmount = shiftIsConstant ? shiftConst : args[2 ];
159163 auto result = tosa::ApplyScaleOp::create (
160- rewriter, loc, rewriter.getI32Type (), a, b, shiftConst ,
164+ rewriter, loc, rewriter.getI32Type (), a, b, shiftAmount ,
161165 rewriter.getStringAttr (" SINGLE_ROUND" ));
162166
163- if (elementTy.isInteger (32 ))
164- return result;
165-
166- return arith::TruncIOp::create (rewriter, loc, elementTy, result);
167+ return result;
167168 }
168169
169170 int aWidth = a.getType ().getIntOrFloatBitWidth ();
@@ -918,6 +919,18 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
918919 if (operands.size () == 1 )
919920 return operands;
920921
922+ // No need to broadcast for static shape
923+ bool hasDynamic = false ;
924+ for (auto op : operands) {
925+ const auto tType = dyn_cast<RankedTensorType>(op.getType ());
926+ if (tType && !tType.hasStaticShape ()) {
927+ hasDynamic = true ;
928+ break ;
929+ }
930+ }
931+ if (!hasDynamic)
932+ return operands;
933+
921934 // Broadcast dynamic dimensions operand by operand
922935 return llvm::map_to_vector (operands, [&](Value operand) {
923936 return broadcastDynamicDimensions (rewriter, loc, indexPool, operand,
@@ -990,8 +1003,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
9901003static ValueRange getBroadcastableOperands (Operation *operation,
9911004 ValueRange operands) {
9921005 // Shift cannot broadcast
993- if (isa<tosa::MulOp>(operation))
994- return operands.take_front (2 );
1006+ if (isa<tosa::MulOp>(operation)) {
1007+ DenseElementsAttr shiftElems;
1008+ // Shift cannot broadcast when it is constant
1009+ if (matchPattern (operation->getOperand (2 ), m_Constant (&shiftElems)))
1010+ return operands.take_front (2 );
1011+ else
1012+ return operands.take_front (3 );
1013+ }
9951014 // Input1_zp and output_zp cannot broadcast
9961015 if (isa<tosa::NegateOp>(operation))
9971016 return operands.take_front (1 );
0 commit comments