@@ -20,6 +20,9 @@ import dotty.tools.pc.utils.InteractiveEnrichments.*
2020import scala .meta .pc .reports .ReportContext
2121import scala .meta .pc .OffsetParams
2222import scala .meta .pc .SymbolSearch
23+ import dotty .tools .dotc .util .Signatures
24+ import dotty .tools .dotc .util .Signatures .MethodParam
25+ import dotty .tools .dotc .util .Signatures .TypeParam
2326
2427class InferExpectedType (
2528 search : SymbolSearch ,
@@ -50,32 +53,32 @@ class InferExpectedType(
5053 val indexedCtx = IndexedContext (pos)(using locatedCtx)
5154 val printer =
5255 ShortenedTypePrinter (search, IncludeDefaultParam .ResolveLater )(using indexedCtx)
53- InterCompletionType .inferType(path)(using newctx).map{
56+ InferCompletionType .inferType(path)(using newctx).map{
5457 tpe => printer.tpe(tpe)
5558 }
5659 case None => None
5760
58- object InterCompletionType :
61+ object InferCompletionType :
5962 def inferType (path : List [Tree ])(using Context ): Option [Type ] =
6063 path match
61- case (lit : Literal ) :: Select (Literal (_), _) :: Apply (Select (Literal (_), _), List (s : Select )) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span)
62- case ident :: rest => inferType(rest, ident.span)
64+ case (lit : Literal ) :: Select (Literal (_), _) :: Apply (Select (Literal (_), _), List (s : Select )) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span, path )
65+ case ident :: rest => inferType(rest, ident.span, path )
6366 case _ => None
6467
65- def inferType (path : List [Tree ], span : Span )(using Context ): Option [Type ] =
68+ def inferType (path : List [Tree ], span : Span , fullPath : List [ Tree ] )(using Context ): Option [Type ] =
6669 path match
6770 case Typed (expr, tpt) :: _ if expr.span.contains(span) && ! tpt.tpe.isErroneous => Some (tpt.tpe)
6871 case Block (_, expr) :: rest if expr.span.contains(span) =>
69- inferType(rest, span)
70- case Bind (_, body) :: rest if body.span.contains(span) => inferType(rest, span)
71- case Alternative (_) :: rest => inferType(rest, span)
72- case Try (block, _, _) :: rest if block.span.contains(span) => inferType(rest, span)
73- case CaseDef (_, _, body) :: Try (_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span)
74- case If (cond, _, _) :: rest if ! cond.span.contains(span) => inferType(rest, span)
72+ inferType(rest, span, fullPath )
73+ case Bind (_, body) :: rest if body.span.contains(span) => inferType(rest, span, fullPath )
74+ case Alternative (_) :: rest => inferType(rest, span, fullPath )
75+ case Try (block, _, _) :: rest if block.span.contains(span) => inferType(rest, span, fullPath )
76+ case CaseDef (_, _, body) :: Try (_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span, fullPath )
77+ case If (cond, _, _) :: rest if ! cond.span.contains(span) => inferType(rest, span, fullPath )
7578 case If (cond, _, _) :: rest if cond.span.contains(span) => Some (defn.BooleanType )
7679 case CaseDef (_, _, body) :: Match (_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
77- inferType(rest, span)
78- case NamedArg (_, arg) :: rest if arg.span.contains(span) => inferType(rest, span)
80+ inferType(rest, span, fullPath )
81+ case NamedArg (_, arg) :: rest if arg.span.contains(span) => inferType(rest, span, fullPath )
7982 // x match
8083 // case @@
8184 case CaseDef (pat, _, _) :: Match (sel, cases) :: rest if pat.span.contains(span) && cases.exists(_.span.contains(span)) && ! sel.tpe.isErroneous =>
@@ -94,37 +97,34 @@ object InterCompletionType:
9497 else Some (UnapplyArgs (fun.tpe.finalResultType, fun, pats, NoSourcePosition ).argTypes(ind))
9598 // f(@@)
9699 case ApplyExtractor (app) =>
97- val argsAndParams = ApplyArgsExtractor .getArgsAndParams(None , app, span).headOption
98- argsAndParams.flatMap:
99- case (args, params) =>
100- val idx = args.indexWhere(_.span.contains(span))
101- val param =
102- if idx >= 0 && params.length > idx then Some (params(idx).info)
103- else None
104- param match
105- // def f[T](a: T): T = ???
106- // f[Int](@@)
107- // val _: Int = f(@@)
108- case Some (t : TypeRef ) if t.symbol.is(Flags .TypeParam ) =>
109- for
110- (typeParams, args) <-
111- app match
112- case Apply (TypeApply (fun, args), _) =>
113- val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
114- typeParams.map((_, args.map(_.tpe)))
115- // val f: (j: "a") => Int
116- // f(@@)
117- case Apply (Select (v, StdNames .nme.apply), _) =>
118- v.symbol.info match
119- case AppliedType (des, args) =>
120- Some ((des.typeSymbol.typeParams, args))
121- case _ => None
122- case _ => None
123- ind = typeParams.indexOf(t.symbol)
124- tpe <- args.get(ind)
125- if ! tpe.isErroneous
126- yield tpe
127- case Some (tpe) => Some (tpe)
128- case _ => None
100+ val (idx, _, signatures) = Signatures .signatureHelp(fullPath, span)
101+
102+ val types : List [Type ] = signatures.flatMap { s =>
103+ s.paramss.flatten.get(idx) match {
104+ case Some (mp : MethodParam ) =>
105+ mp.tpe match
106+ case t : TypeParamRef =>
107+ for
108+ args <-
109+ app match
110+ case Apply (TypeApply (fun, args), _) =>
111+ Some (args.map(_.tpe))
112+ // val f: (j: "a") => Int
113+ // f(@@)
114+ case Apply (Select (v, StdNames .nme.apply), _) =>
115+ v.symbol.info match
116+ case AppliedType (des, args) =>
117+ Some (args)
118+ case _ => None
119+ case _ => None
120+ tpe <- args.get(t.paramNum)
121+ if ! tpe.isErroneous
122+ yield tpe
123+ case tpe => Some (tpe)
124+ case _ => None
125+ }
126+ }
127+ if (types.isEmpty) None
128+ else Some (types.reduce(_ | _))
129129 case _ => None
130130
0 commit comments