Skip to content

Commit 814ae65

Browse files
committed
fix: show correctly typed hole on applyDynamic
1 parent 4ec6751 commit 814ae65

File tree

5 files changed

+111
-31
lines changed

5 files changed

+111
-31
lines changed

presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import scala.annotation.tailrec
2020
import dotty.tools.dotc.core.Denotations.SingleDenotation
2121
import dotty.tools.dotc.core.Denotations.MultiDenotation
2222
import dotty.tools.dotc.util.Spans.Span
23+
import dotty.tools.dotc.core.Symbols
2324

2425
object ApplyExtractor:
2526
def unapply(path: List[Tree])(using Context): Option[Apply] =
@@ -44,8 +45,10 @@ object ApplyExtractor:
4445

4546

4647
object ApplyArgsExtractor:
48+
// normally symbol but for refinment types method type
49+
type Method = Symbol | Type
4750
def getArgsAndParams(
48-
optIndexedContext: Option[IndexedContext],
51+
indexedContext: IndexedContext,
4952
apply: Apply,
5053
span: Span
5154
)(using Context): List[(List[Tree], List[ParamSymbol])] =
@@ -78,47 +81,56 @@ object ApplyArgsExtractor:
7881

7982
// fallback for when multiple overloaded methods match the supplied args
8083
def fallbackFindMatchingMethods() =
81-
def matchingMethodsSymbols(
82-
indexedContext: IndexedContext,
83-
method: Tree
84-
): List[Symbol] =
84+
def matchingMethodsSymbols(method: Tree): List[Method] =
8585
method match
8686
case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil)
87-
case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil)
87+
case Select(t @ This(_), name) =>
88+
val res = indexedContext.findSymbol(name).getOrElse(Nil).filter(_.exists)
89+
res ++ findRefinments(t.symbol.info, name)
8890
case sel @ Select(from, name) =>
8991
val symbol = from.symbol
9092
val ownerSymbol =
9193
if symbol.is(Method) && symbol.owner.isClass then
9294
Some(symbol.owner)
9395
else Try(symbol.info.classSymbol).toOption
94-
ownerSymbol.map(sym => sym.info.member(name)).collect{
96+
val res = ownerSymbol.map(sym => sym.info.member(name)).collect{
9597
case single: SingleDenotation => List(single.symbol)
9698
case multi: MultiDenotation => multi.allSymbols
9799
}.getOrElse(Nil)
98-
case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun)
100+
res ++ findRefinments(symbol.info, name)
101+
case Apply(fun, _) => matchingMethodsSymbols(fun)
102+
case TypeApply(fun, args) =>
103+
matchingMethodsSymbols(fun).map {
104+
case t: PolyType => t.appliedTo(args.map(_.tpe))
105+
case s => s
106+
}
99107
case _ => Nil
100108
val matchingMethods =
101109
for
102-
indexedContext <- optIndexedContext.toList
103-
potentialMatch <- matchingMethodsSymbols(indexedContext, method)
104-
if potentialMatch.is(Flags.Method) &&
105-
potentialMatch.vparamss.length >= argss.length &&
106-
Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption
107-
.getOrElse(false) &&
108-
potentialMatch.vparamss
110+
potentialMatch <- matchingMethodsSymbols(method)
111+
if potentialMatch match
112+
case s: Symbol => s.is(Flags.Method) && Try(s.isAccessibleFrom(apply.symbol.info)).toOption.getOrElse(false)
113+
case _ => true
114+
if potentialMatch.vparamss.length >= argss.length &&
115+
(potentialMatch match {
116+
case s: Symbol =>
117+
s.symVparamss
109118
.zip(argss)
110119
.reverse
111120
.zipWithIndex
112121
.forall { case (pair, index) =>
113-
FuzzyArgMatcher(potentialMatch.tparams)
122+
FuzzyArgMatcher(s.symTparams)
114123
.doMatch(allArgsProvided = index != 0, span)
115124
.tupled(pair)
116125
}
126+
case _ => true
127+
})
128+
117129
yield potentialMatch
118130
matchingMethods
119131
end fallbackFindMatchingMethods
120132

121-
val matchingMethods: List[Symbol] =
133+
val matchingMethods: List[Method] =
122134
if method.symbol.paramSymss.nonEmpty then
123135
val allArgsAreSupplied =
124136
val vparamss = method.symbol.vparamss
@@ -157,11 +169,10 @@ object ApplyArgsExtractor:
157169
// def curry(x: Int)(apple: String, banana: String) = ???
158170
// curry(1)(apple = "test", b@@)
159171
// ```
160-
val (baseParams0, baseArgs) =
172+
val (defaultBaseParams, baseArgs) =
161173
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))
162174

163175
val baseParams: List[ParamSymbol] =
164-
def defaultBaseParams = baseParams0.map(JustSymbol(_))
165176
@tailrec
166177
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
167178
if level > 0 then
@@ -176,8 +187,8 @@ object ApplyArgsExtractor:
176187
else
177188
refinedType match
178189
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
179-
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
180-
RefinedSymbol(sym, name, arg)
190+
defaultBaseParams.zip(ri).zip(args).map { case ((sym, name), arg) =>
191+
RefinedSymbol(sym.symbol, name, arg)
181192
}
182193
case _ => defaultBaseParams
183194
// finds param refinements for lambda expressions
@@ -198,11 +209,35 @@ object ApplyArgsExtractor:
198209
(baseArgs, baseParams)
199210
}
200211

201-
extension (method: Symbol)
202-
def vparamss(using Context) = method.filteredParamss(_.isTerm)
203-
def tparams(using Context) = method.filteredParamss(_.isType).flatten
204-
def filteredParamss(f: Symbol => Boolean)(using Context) =
205-
method.paramSymss.filter(params => params.forall(f))
212+
@tailrec
213+
private def findRefinments(tpe: Type, name: Name, acc: List[Method] = Nil): List[Method] =
214+
tpe match
215+
case RefinedType(parent, `name`, refinedInfo) =>
216+
findRefinments(parent, name, refinedInfo :: acc)
217+
case RefinedType(parent, _, s) => findRefinments(parent, name, acc)
218+
case _ => acc.reverse
219+
220+
221+
extension (method: Method)
222+
def vparamss(using Context): List[List[ParamSymbol]] =
223+
method match
224+
case s: Symbol => s.symVparamss.map(_.map(JustSymbol(_)))
225+
case m: MethodType =>
226+
m.paramInfoss.zipWithIndex.map {
227+
case (params, idx) =>
228+
params.zip(m.paramNamess.get(idx).getOrElse(Nil)).map{
229+
case (tpe, name) => RefinedSymbol(Symbols.NoSymbol, name, tpe)
230+
}
231+
}
232+
case _ => Nil
233+
234+
extension (sym: Symbol)
235+
def symVparamss(using Context): List[List[Symbol]] = filteredParamss(sym, _.isTerm)
236+
237+
def symTparams(using Context): List[Symbol] = filteredParamss(sym, _.isType).flatten
238+
239+
private def filteredParamss(s: Symbol, f: Symbol => Boolean)(using Context): List[List[Symbol]] =
240+
s.paramSymss.filter(params => params.forall(f))
206241
sealed trait ParamSymbol:
207242
def name: Name
208243
def info: Type

presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ class InferExpectedType(
5050
val indexedCtx = IndexedContext(pos)(using locatedCtx)
5151
val printer =
5252
ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx)
53-
InterCompletionType.inferType(path)(using newctx).map{
53+
InferCompletionType.inferType(path)(using newctx).map{
5454
tpe => printer.tpe(tpe)
5555
}
5656
case None => None
5757

58-
object InterCompletionType:
58+
object InferCompletionType:
5959
def inferType(path: List[Tree])(using Context): Option[Type] =
6060
path match
6161
case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(s: Select)) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span)
@@ -94,7 +94,7 @@ object InterCompletionType:
9494
else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind))
9595
// f(@@)
9696
case ApplyExtractor(app) =>
97-
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption
97+
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(IndexedContext.Empty, app, span).headOption
9898
argsAndParams.flatMap:
9999
case (args, params) =>
100100
val idx = args.indexWhere(_.span.contains(span))

presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ class Completions(
520520
config.isCompletionSnippetsEnabled()
521521
)
522522
(args, false)
523-
val singletonCompletions = InterCompletionType.inferType(path).map(
523+
val singletonCompletions = InferCompletionType.inferType(path).map(
524524
SingletonCompletions.contribute(path, _, completionPos)
525525
).getOrElse(Nil)
526526
(singletonCompletions ++ advanced, exclusive)

presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ object NamedArgCompletions:
7474
case _ => false
7575

7676
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(
77-
Some(indexedContext),
77+
indexedContext,
7878
apply,
7979
ident.span
8080
)

presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,48 @@ class InferExpectedTypeSuite extends BasePCSuite:
335335
"""|String
336336
|""".stripMargin
337337
)
338+
339+
@Test def `apply-dynamic` =
340+
check(
341+
"""|object TypedHoleApplyDynamic {
342+
| val obj: reflect.Selectable {
343+
| def method(x: Int): Unit
344+
| } = new reflect.Selectable {
345+
| def method(x: Int): Unit = ()
346+
| }
347+
|
348+
| obj.method(@@)
349+
|}
350+
|""".stripMargin,
351+
"Int"
352+
)
353+
354+
@Test def `apply-dynamic-2` =
355+
check(
356+
"""|object TypedHoleApplyDynamic {
357+
| val obj: reflect.Selectable {
358+
| def method[T](x: Int, y: T): Unit
359+
| } = new reflect.Selectable {
360+
| def method[T](x: Int, y: T): Unit = ()
361+
| }
362+
|
363+
| obj.method[Int](1, @@)
364+
|}
365+
|""".stripMargin,
366+
"Int"
367+
)
368+
369+
@Test def `apply-dynamic-3` =
370+
check(
371+
"""|object TypedHoleApplyDynamic {
372+
| val obj: reflect.Selectable {
373+
| def method[T](a: Int)(x: Int, y: T): Unit
374+
| } = new reflect.Selectable {
375+
| def method[T](a: Int)(x: Int, y: T): Unit = ()
376+
| }
377+
|
378+
| obj.method[String](1)(1, @@)
379+
|}
380+
|""".stripMargin,
381+
"String"
382+
)

0 commit comments

Comments
 (0)