@@ -3047,13 +3047,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30473047 // instruction cost.
30483048 return 0 ;
30493049 case Instruction::Call: {
3050- if (!isSingleScalar ()) {
3051- // TODO: Handle remaining call costs here as well.
3052- if (VF.isScalable ())
3053- return InstructionCost::getInvalid ();
3054- break ;
3055- }
3056-
30573050 auto *CalledFn =
30583051 cast<Function>(getOperand (getNumOperands () - 1 )->getLiveInIRValue ());
30593052 if (CalledFn->isIntrinsic ())
@@ -3063,7 +3056,42 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30633056 for (VPValue *ArgOp : drop_end (operands ()))
30643057 Tys.push_back (Ctx.Types .inferScalarType (ArgOp));
30653058 Type *ResultTy = Ctx.Types .inferScalarType (this );
3066- return Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3059+ InstructionCost ScalarCallCost =
3060+ Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3061+ if (isSingleScalar ())
3062+ return ScalarCallCost;
3063+
3064+ if (VF.isScalable ())
3065+ return InstructionCost::getInvalid ();
3066+
3067+ // Compute the cost of scalarizing the result and operands if needed.
3068+ InstructionCost ScalarizationCost = 0 ;
3069+ if (VF.isVector ()) {
3070+ if (!ResultTy->isVoidTy ()) {
3071+ for (Type *VectorTy :
3072+ to_vector (getContainedTypes (toVectorizedTy (ResultTy, VF)))) {
3073+ ScalarizationCost += Ctx.TTI .getScalarizationOverhead (
3074+ cast<VectorType>(VectorTy), APInt::getAllOnes (VF.getFixedValue ()),
3075+ /* Insert=*/ true ,
3076+ /* Extract=*/ false , Ctx.CostKind );
3077+ }
3078+ }
3079+ // Skip operands that do not require extraction/scalarization and do not
3080+ // incur any overhead.
3081+ SmallPtrSet<const VPValue *, 4 > UniqueOperands;
3082+ Tys.clear ();
3083+ for (auto *Op : drop_end (operands ())) {
3084+ if (Op->isLiveIn () || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3085+ !UniqueOperands.insert (Op).second )
3086+ continue ;
3087+ Tys.push_back (toVectorizedTy (Ctx.Types .inferScalarType (Op), VF));
3088+ }
3089+ ScalarizationCost +=
3090+ Ctx.TTI .getOperandsScalarizationOverhead (Tys, Ctx.CostKind );
3091+ }
3092+
3093+ return ScalarCallCost * (isSingleScalar () ? 1 : VF.getFixedValue ()) +
3094+ ScalarizationCost;
30673095 }
30683096 case Instruction::Add:
30693097 case Instruction::Sub:
0 commit comments