@@ -76,10 +76,10 @@ object BetaReduce:
7676    val  bindingsBuf  =  new  ListBuffer [DefTree ]
7777    def  recur (fn : Tree , argss : List [List [Tree ]]):  Option [Tree ] =  fn match 
7878      case  Block ((ddef  : DefDef ) ::  Nil , closure : Closure ) if  ddef.symbol ==  closure.meth.symbol => 
79-         Some ( reduceApplication(ddef, argss, bindingsBuf) )
79+         reduceApplication(ddef, argss, bindingsBuf)
8080      case  Block ((TypeDef (_, template : Template )) ::  Nil , Typed (Apply (Select (New (_), _), _), _)) if  template.constr.rhs.isEmpty => 
8181        template.body match 
82-           case  (ddef : DefDef ) ::  Nil  =>  Some ( reduceApplication(ddef, argss, bindingsBuf) )
82+           case  (ddef : DefDef ) ::  Nil  =>  reduceApplication(ddef, argss, bindingsBuf)
8383          case  _ =>  None 
8484      case  Block (stats, expr) if  stats.forall(isPureBinding) => 
8585        recur(expr, argss).map(cpy.Block (fn)(stats, _))
@@ -106,12 +106,22 @@ object BetaReduce:
106106      case  _ => 
107107        tree
108108
109-   /**  Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */  
110-   def  reduceApplication (ddef : DefDef , argss : List [List [Tree ]], bindings : ListBuffer [DefTree ])(using  Context ):  Tree  = 
109+   /**  Beta-reduces a call to `ddef` with arguments `args` and registers new bindings. 
110+    *  @return  optionally, the expanded call, or none if the actual argument 
111+    *          lists do not match in shape the formal parameters 
112+    */  
113+   def  reduceApplication (ddef : DefDef , argss : List [List [Tree ]], bindings : ListBuffer [DefTree ])
114+       (using  Context ):  Option [Tree ] = 
111115    val  (targs, args) =  argss.flatten.partition(_.isType)
112116    val  tparams  =  ddef.leadingTypeParams
113117    val  vparams  =  ddef.termParamss.flatten
114118
119+     def  shapeMatch (paramss : List [ParamClause ], argss : List [List [Tree ]]):  Boolean  =  (paramss, argss) match 
120+       case  (params ::  paramss1, args ::  argss1) if  params.length ==  args.length => 
121+         shapeMatch(paramss1, argss1)
122+       case  (Nil , Nil ) =>  true 
123+       case  _ =>  false 
124+ 
115125    val  targSyms  = 
116126      for  (targ, tparam) <-  targs.zip(tparams) yield 
117127        targ.tpe.dealias match 
@@ -143,19 +153,26 @@ object BetaReduce:
143153              bindings +=  binding.withSpan(arg.span)
144154            bindingSymbol
145155
146-     val  expansion  =  TreeTypeMap (
147-       oldOwners =  ddef.symbol ::  Nil ,
148-       newOwners =  ctx.owner ::  Nil ,
149-       substFrom =  (tparams :::  vparams).map(_.symbol),
150-       substTo =  targSyms :::  argSyms
151-     ).transform(ddef.rhs)
152- 
153-     val  expansion1  =  new  TreeMap  {
154-       override  def  transform (tree : Tree )(using  Context ) =  tree.tpe.widenTermRefExpr match 
155-         case  ConstantType (const) if  isPureExpr(tree) =>  cpy.Literal (tree)(const)
156-         case  tpe : TypeRef  if  tree.isTerm &&  tpe.derivesFrom(defn.UnitClass ) &&  isPureExpr(tree) => 
157-           cpy.Literal (tree)(Constant (()))
158-         case  _ =>  super .transform(tree)
159-     }.transform(expansion)
160- 
161-     expansion1
156+     if  shapeMatch(ddef.paramss, argss) then 
157+       //  We can't assume arguments always match. It's possible to construct a
158+       //  function with wrong apply method by hand which causes `shapeMatch` to fail.
159+       //  See neg/i21952.scala
160+       val  expansion  =  TreeTypeMap (
161+         oldOwners =  ddef.symbol ::  Nil ,
162+         newOwners =  ctx.owner ::  Nil ,
163+         substFrom =  (tparams :::  vparams).map(_.symbol),
164+         substTo =  targSyms :::  argSyms
165+       ).transform(ddef.rhs)
166+ 
167+       val  expansion1  =  new  TreeMap  {
168+         override  def  transform (tree : Tree )(using  Context ) =  tree.tpe.widenTermRefExpr match 
169+           case  ConstantType (const) if  isPureExpr(tree) =>  cpy.Literal (tree)(const)
170+           case  tpe : TypeRef  if  tree.isTerm &&  tpe.derivesFrom(defn.UnitClass ) &&  isPureExpr(tree) => 
171+             cpy.Literal (tree)(Constant (()))
172+           case  _ =>  super .transform(tree)
173+       }.transform(expansion)
174+ 
175+       Some (expansion1)
176+     else  None 
177+   end  reduceApplication 
178+ end  BetaReduce 
0 commit comments