@@ -20,6 +20,7 @@ import scala.annotation.tailrec
2020import dotty .tools .dotc .core .Denotations .SingleDenotation
2121import dotty .tools .dotc .core .Denotations .MultiDenotation
2222import dotty .tools .dotc .util .Spans .Span
23+ import dotty .tools .dotc .core .Symbols
2324
2425object ApplyExtractor :
2526 def unapply (path : List [Tree ])(using Context ): Option [Apply ] =
@@ -44,8 +45,10 @@ object ApplyExtractor:
4445
4546
4647object ApplyArgsExtractor :
48+ // normally symbol but for refinment types method type
49+ type Method = Symbol | Type
4750 def getArgsAndParams (
48- optIndexedContext : Option [ IndexedContext ] ,
51+ indexedContext : IndexedContext ,
4952 apply : Apply ,
5053 span : Span
5154 )(using Context ): List [(List [Tree ], List [ParamSymbol ])] =
@@ -78,47 +81,56 @@ object ApplyArgsExtractor:
7881
7982 // fallback for when multiple overloaded methods match the supplied args
8083 def fallbackFindMatchingMethods () =
81- def matchingMethodsSymbols (
82- indexedContext : IndexedContext ,
83- method : Tree
84- ): List [Symbol ] =
84+ def matchingMethodsSymbols (method : Tree ): List [Method ] =
8585 method match
8686 case Ident (name) => indexedContext.findSymbol(name).getOrElse(Nil )
87- case Select (This (_), name) => indexedContext.findSymbol(name).getOrElse(Nil )
87+ case Select (t @ This (_), name) =>
88+ val res = indexedContext.findSymbol(name).getOrElse(Nil ).filter(_.exists)
89+ res ++ findRefinments(t.symbol.info, name)
8890 case sel @ Select (from, name) =>
8991 val symbol = from.symbol
9092 val ownerSymbol =
9193 if symbol.is(Method ) && symbol.owner.isClass then
9294 Some (symbol.owner)
9395 else Try (symbol.info.classSymbol).toOption
94- ownerSymbol.map(sym => sym.info.member(name)).collect{
96+ val res = ownerSymbol.map(sym => sym.info.member(name)).collect{
9597 case single : SingleDenotation => List (single.symbol)
9698 case multi : MultiDenotation => multi.allSymbols
9799 }.getOrElse(Nil )
98- case Apply (fun, _) => matchingMethodsSymbols(indexedContext, fun)
100+ res ++ findRefinments(symbol.info, name)
101+ case Apply (fun, _) => matchingMethodsSymbols(fun)
102+ case TypeApply (fun, args) =>
103+ matchingMethodsSymbols(fun).map {
104+ case t : PolyType => t.appliedTo(args.map(_.tpe))
105+ case s => s
106+ }
99107 case _ => Nil
100108 val matchingMethods =
101109 for
102- indexedContext <- optIndexedContext.toList
103- potentialMatch <- matchingMethodsSymbols(indexedContext, method)
104- if potentialMatch.is(Flags .Method ) &&
105- potentialMatch.vparamss.length >= argss.length &&
106- Try (potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption
107- .getOrElse(false ) &&
108- potentialMatch.vparamss
110+ potentialMatch <- matchingMethodsSymbols(method)
111+ if potentialMatch match
112+ case s : Symbol => s.is(Flags .Method ) && Try (s.isAccessibleFrom(apply.symbol.info)).toOption.getOrElse(false )
113+ case _ => true
114+ if potentialMatch.vparamss.length >= argss.length &&
115+ (potentialMatch match {
116+ case s : Symbol =>
117+ s.symVparamss
109118 .zip(argss)
110119 .reverse
111120 .zipWithIndex
112121 .forall { case (pair, index) =>
113- FuzzyArgMatcher (potentialMatch.tparams )
122+ FuzzyArgMatcher (s.symTparams )
114123 .doMatch(allArgsProvided = index != 0 , span)
115124 .tupled(pair)
116125 }
126+ case _ => true
127+ })
128+
117129 yield potentialMatch
118130 matchingMethods
119131 end fallbackFindMatchingMethods
120132
121- val matchingMethods : List [Symbol ] =
133+ val matchingMethods : List [Method ] =
122134 if method.symbol.paramSymss.nonEmpty then
123135 val allArgsAreSupplied =
124136 val vparamss = method.symbol.vparamss
@@ -157,11 +169,10 @@ object ApplyArgsExtractor:
157169 // def curry(x: Int)(apple: String, banana: String) = ???
158170 // curry(1)(apple = "test", b@@)
159171 // ```
160- val (baseParams0 , baseArgs) =
172+ val (defaultBaseParams , baseArgs) =
161173 vparamss.zip(argss).lastOption.getOrElse((Nil , Nil ))
162174
163175 val baseParams : List [ParamSymbol ] =
164- def defaultBaseParams = baseParams0.map(JustSymbol (_))
165176 @ tailrec
166177 def getRefinedParams (refinedType : Type , level : Int ): List [ParamSymbol ] =
167178 if level > 0 then
@@ -176,8 +187,8 @@ object ApplyArgsExtractor:
176187 else
177188 refinedType match
178189 case RefinedType (AppliedType (_, args), _, MethodType (ri)) =>
179- baseParams0 .zip(ri).zip(args).map { case ((sym, name), arg) =>
180- RefinedSymbol (sym, name, arg)
190+ defaultBaseParams .zip(ri).zip(args).map { case ((sym, name), arg) =>
191+ RefinedSymbol (sym.symbol , name, arg)
181192 }
182193 case _ => defaultBaseParams
183194 // finds param refinements for lambda expressions
@@ -198,11 +209,35 @@ object ApplyArgsExtractor:
198209 (baseArgs, baseParams)
199210 }
200211
201- extension (method : Symbol )
202- def vparamss (using Context ) = method.filteredParamss(_.isTerm)
203- def tparams (using Context ) = method.filteredParamss(_.isType).flatten
204- def filteredParamss (f : Symbol => Boolean )(using Context ) =
205- method.paramSymss.filter(params => params.forall(f))
212+ @ tailrec
213+ private def findRefinments (tpe : Type , name : Name , acc : List [Method ] = Nil ): List [Method ] =
214+ tpe match
215+ case RefinedType (parent, `name`, refinedInfo) =>
216+ findRefinments(parent, name, refinedInfo :: acc)
217+ case RefinedType (parent, _, s) => findRefinments(parent, name, acc)
218+ case _ => acc.reverse
219+
220+
221+ extension (method : Method )
222+ def vparamss (using Context ): List [List [ParamSymbol ]] =
223+ method match
224+ case s : Symbol => s.symVparamss.map(_.map(JustSymbol (_)))
225+ case m : MethodType =>
226+ m.paramInfoss.zipWithIndex.map {
227+ case (params, idx) =>
228+ params.zip(m.paramNamess.get(idx).getOrElse(Nil )).map{
229+ case (tpe, name) => RefinedSymbol (Symbols .NoSymbol , name, tpe)
230+ }
231+ }
232+ case _ => Nil
233+
234+ extension (sym : Symbol )
235+ def symVparamss (using Context ): List [List [Symbol ]] = filteredParamss(sym, _.isTerm)
236+
237+ def symTparams (using Context ): List [Symbol ] = filteredParamss(sym, _.isType).flatten
238+
239+ private def filteredParamss (s : Symbol , f : Symbol => Boolean )(using Context ): List [List [Symbol ]] =
240+ s.paramSymss.filter(params => params.forall(f))
206241sealed trait ParamSymbol :
207242 def name : Name
208243 def info : Type
0 commit comments