Skip to content

Commit e5091a6

Browse files
committed
Add InferredMethodProvider
1 parent 13fd5af commit e5091a6

File tree

3 files changed

+869
-0
lines changed

3 files changed

+869
-0
lines changed
Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
package dotty.tools.pc
2+
3+
import java.nio.file.Paths
4+
5+
import scala.annotation.tailrec
6+
7+
import scala.meta.pc.OffsetParams
8+
import scala.meta.pc.PresentationCompilerConfig
9+
import scala.meta.pc.SymbolSearch
10+
import scala.meta.pc.reports.ReportContext
11+
12+
import dotty.tools.dotc.ast.tpd.*
13+
import dotty.tools.dotc.core.Contexts.*
14+
import dotty.tools.dotc.core.Names.Name
15+
import dotty.tools.dotc.core.Symbols.*
16+
import dotty.tools.dotc.core.Symbols.defn
17+
import dotty.tools.dotc.core.Types.*
18+
import dotty.tools.dotc.interactive.Interactive
19+
import dotty.tools.dotc.interactive.InteractiveDriver
20+
import dotty.tools.dotc.util.SourceFile
21+
import dotty.tools.dotc.util.SourcePosition
22+
import dotty.tools.pc.printer.ShortenedTypePrinter
23+
import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam
24+
import dotty.tools.pc.utils.InteractiveEnrichments.*
25+
26+
import org.eclipse.lsp4j.TextEdit
27+
import org.eclipse.lsp4j as l
28+
29+
/**
30+
* Tries to calculate edits needed to create a method that will fix missing symbol
31+
* in all the places that it is possible such as:
32+
* - apply inside method invocation `method(.., nonExistent(param), ...)` and `method(.., nonExistent, ...)`
33+
* - method in val definition `val value: DefinedType = nonExistent(param)` and `val value: DefinedType = nonExistent`
34+
* - simple method call `nonExistent(param)` and `nonExistent`
35+
* - method call inside a container `container.nonExistent(param)` and `container.nonExistent`
36+
*
37+
* @param params position and actual source
38+
* @param driver Scala 3 interactive compiler driver
39+
* @param config presentation compiler configuration
40+
* @param symbolSearch symbol search
41+
*/
42+
final class InferredMethodProvider(
43+
params: OffsetParams,
44+
driver: InteractiveDriver,
45+
config: PresentationCompilerConfig,
46+
symbolSearch: SymbolSearch
47+
)(using ReportContext):
48+
49+
case class AdjustTypeOpts(
50+
text: String,
51+
adjustedEndPos: l.Position
52+
)
53+
54+
def inferredMethodEdits(
55+
adjustOpt: Option[AdjustTypeOpts] = None
56+
): List[TextEdit] =
57+
val uri = params.uri().nn
58+
val filePath = Paths.get(uri).nn
59+
60+
val sourceText = adjustOpt.map(_.text).getOrElse(params.text().nn)
61+
val source =
62+
SourceFile.virtual(filePath.toString(), sourceText)
63+
driver.run(uri, source)
64+
val unit = driver.currentCtx.run.nn.units.head
65+
val pos = driver.sourcePosition(params)
66+
val path =
67+
Interactive.pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx)
68+
69+
given locatedCtx: Context = driver.localContext(params)
70+
val indexedCtx = IndexedContext(pos)(using locatedCtx)
71+
72+
val autoImportsGen = AutoImports.generator(
73+
pos,
74+
sourceText,
75+
unit.tpdTree,
76+
unit.comments,
77+
indexedCtx,
78+
config
79+
)
80+
81+
val printer = ShortenedTypePrinter(
82+
symbolSearch,
83+
includeDefaultParam = IncludeDefaultParam.ResolveLater,
84+
isTextEdit = true
85+
)(using indexedCtx)
86+
87+
def imports: List[TextEdit] =
88+
printer.imports(autoImportsGen)
89+
90+
def printType(tpe: Type): String =
91+
printer.tpe(tpe)
92+
93+
def printName(name: Name): String =
94+
printer.nameString(name)
95+
96+
def printParams(params: List[Type]): String =
97+
params.zipWithIndex
98+
.map { case (p, index) =>
99+
s"arg$index: ${printType(p)}"
100+
}
101+
.mkString(", ")
102+
103+
def printSignature(
104+
methodName: Name,
105+
params: Option[List[Type]],
106+
retTypeOpt: Option[Type]
107+
): String =
108+
val retTypeString = retTypeOpt match
109+
case Some(retType) =>
110+
val printRetType = printType(retType)
111+
if printRetType contains "Any" then ""
112+
else s": $printRetType"
113+
case _ => ""
114+
115+
val paramsString = params match
116+
case Some(params) => s"(${printParams(params)})"
117+
case None => ""
118+
119+
s"def ${printName(methodName)}${paramsString}$retTypeString = ???"
120+
121+
@tailrec
122+
def countIndent(text: String, index: Int, acc: Int): Int =
123+
if text(index) != '\n' then countIndent(text, index - 1, acc + 1)
124+
else acc
125+
126+
def indentation(text: String, pos: Int): String =
127+
if pos > 0 then
128+
val isSpace = text(pos) == ' '
129+
val isTab = text(pos) == '\t'
130+
val indent = countIndent(params.text(), pos, 0)
131+
132+
if isSpace then " " * indent else if isTab then "\t" * indent else ""
133+
else ""
134+
135+
def insertPosition() =
136+
val blockOrTemplateIndex =
137+
path.tail.indexWhere {
138+
case _: Block | _: Template => true
139+
case _ => false
140+
}
141+
path(blockOrTemplateIndex).sourcePos
142+
143+
/**
144+
* Returns the position to insert the method signature for a container.
145+
* If the container has an empty body, the position is the end of the container.
146+
* If the container has a non-empty body, the position is the end of the last element in the body.
147+
*
148+
* @param container the container to insert the method signature for
149+
* @return the position to insert the method signature for the container and a boolean indicating if the container has an empty body
150+
*/
151+
def insertPositionFor(container: Tree): Option[(SourcePosition, Boolean)] =
152+
val typeSymbol = container.tpe.widenDealias.typeSymbol
153+
if typeSymbol.exists then
154+
val trees = driver.openedTrees(params.uri().nn)
155+
val include = Interactive.Include.definitions | Interactive.Include.local
156+
Interactive.findTreesMatching(trees, include, typeSymbol).headOption match
157+
case Some(srcTree) =>
158+
srcTree.tree match
159+
case classDef: TypeDef if classDef.rhs.isInstanceOf[Template] =>
160+
val template = classDef.rhs.asInstanceOf[Template]
161+
val (pos, hasEmptyBody) = template.body.lastOption match
162+
case Some(last) => (last.sourcePos, false)
163+
case None => (classDef.sourcePos, true)
164+
Some((pos, hasEmptyBody))
165+
case _ => None
166+
case None => None
167+
else None
168+
169+
/**
170+
* Extracts type information for a specific parameter in a method signature.
171+
* If the parameter is a function type, extracts both the function's argument types
172+
* and return type. Otherwise, extracts just the parameter type.
173+
*
174+
* @param methodType the method type to analyze
175+
* @param argIndex the index of the parameter to extract information for
176+
* @return a tuple of (argument types, return type) where:
177+
* - argument types: Some(List[Type]) if parameter is a function, None otherwise
178+
* - return type: Some(Type) representing either the function's return type or the parameter type itself
179+
*/
180+
def extractParameterTypeInfo(methodType: Type, argIndex: Int): (Option[List[Type]], Option[Type]) =
181+
methodType match
182+
case m @ MethodType(param) =>
183+
val expectedFunctionType = m.paramInfos(argIndex)
184+
if defn.isFunctionType(expectedFunctionType) then
185+
expectedFunctionType match
186+
case defn.FunctionOf(argTypes, retType, _) =>
187+
(Some(argTypes), Some(retType))
188+
case _ =>
189+
(None, Some(expectedFunctionType))
190+
else
191+
(None, Some(m.paramInfos(argIndex)))
192+
case _ => (None, None)
193+
194+
def signatureEdits(signature: String): List[TextEdit] =
195+
val pos = insertPosition()
196+
val indent = indentation(params.text(), pos.start - 1)
197+
val lspPos = pos.toLsp
198+
lspPos.setEnd(lspPos.getStart())
199+
200+
List(
201+
TextEdit(
202+
lspPos,
203+
s"$signature\n$indent",
204+
)
205+
) ::: imports
206+
207+
def signatureEditsForContainer(signature: String, container: Tree): List[TextEdit] =
208+
insertPositionFor(container) match
209+
case Some((pos, hasEmptyBody)) =>
210+
val lspPos = pos.toLsp
211+
lspPos.setStart(lspPos.getEnd())
212+
213+
if hasEmptyBody then
214+
List(
215+
TextEdit(
216+
lspPos,
217+
s":\n $signature",
218+
)
219+
) ::: imports
220+
else
221+
val indent = indentation(params.text(), pos.start - 1)
222+
List(
223+
TextEdit(
224+
lspPos,
225+
s"\n$indent$signature",
226+
)
227+
) ::: imports
228+
case None => Nil
229+
230+
path match
231+
/**
232+
* outerArgs
233+
* ---------------------------
234+
* method(..., errorMethod(args), ...)
235+
*
236+
*/
237+
case (id @ Ident(errorMethod)) ::
238+
(apply @ Apply(func, args)) ::
239+
Apply(method, outerArgs) ::
240+
_ if id.symbol == NoSymbol && func == id && method != apply =>
241+
242+
val argIndex = outerArgs.indexOf(apply)
243+
val (argTypes, retTypeOpt) = extractParameterTypeInfo(method.tpe.widenDealias, argIndex)
244+
245+
val allArgTypes = args.map(_.typeOpt.widenDealias) ::: argTypes.getOrElse(Nil)
246+
val signature = printSignature(errorMethod, Some(allArgTypes), retTypeOpt)
247+
248+
signatureEdits(signature)
249+
250+
/**
251+
* outerArgs
252+
* ---------------------
253+
* method(..., errorMethod, ...)
254+
*
255+
*/
256+
case (id @ Ident(errorMethod)) ::
257+
Apply(method, outerArgs) ::
258+
_ if id.symbol == NoSymbol && method != id =>
259+
260+
val argIndex = outerArgs.indexOf(id)
261+
262+
val (argTypes, retTypeOpt) = extractParameterTypeInfo(method.tpe.widenDealias, argIndex)
263+
val signature = printSignature(errorMethod, argTypes, retTypeOpt)
264+
265+
signatureEdits(signature)
266+
267+
/**
268+
* tpt body
269+
* ----------- ----------------
270+
* val value: DefinedType = errorMethod(args)
271+
*
272+
*/
273+
case (id @ Ident(errorMethod)) ::
274+
(apply @ Apply(func, args)) ::
275+
ValDef(_, tpt, body) ::
276+
_ if id.symbol == NoSymbol && func == id && apply == body =>
277+
278+
val retType = tpt.tpe.widenDealias
279+
val argTypes = args.map(_.typeOpt.widenDealias)
280+
281+
val signature = printSignature(errorMethod, Some(argTypes), Some(retType))
282+
signatureEdits(signature)
283+
284+
/**
285+
* tpt body
286+
* ----------- -----------
287+
* val value: DefinedType = errorMethod
288+
*
289+
*/
290+
case (id @ Ident(errorMethod)) ::
291+
ValDef(_, tpt, body) ::
292+
_ if id.symbol == NoSymbol && id == body =>
293+
294+
val retType = tpt.tpe.widenDealias
295+
296+
val signature = printSignature(errorMethod, None, Some(retType))
297+
signatureEdits(signature)
298+
299+
/**
300+
*
301+
* errorMethod(args)
302+
*
303+
*/
304+
case (id @ Ident(errorMethod)) ::
305+
(apply @ Apply(func, args)) ::
306+
_ if id.symbol == NoSymbol && func == id =>
307+
308+
val argTypes = args.map(_.typeOpt.widenDealias)
309+
310+
val signature = printSignature(errorMethod, Some(argTypes), None)
311+
signatureEdits(signature)
312+
313+
/**
314+
*
315+
* errorMethod
316+
*
317+
*/
318+
case (id @ Ident(errorMethod)) ::
319+
_ if id.symbol == NoSymbol =>
320+
321+
val signature = printSignature(errorMethod, None, None)
322+
signatureEdits(signature)
323+
324+
/**
325+
*
326+
* container.errorMethod(args)
327+
*
328+
*/
329+
case (select @ Select(container, errorMethod)) ::
330+
(apply @ Apply(func, args)) ::
331+
_ if select.symbol == NoSymbol && func == select =>
332+
333+
val argTypes = args.map(_.typeOpt.widenDealias)
334+
val signature = printSignature(errorMethod, Some(argTypes), None)
335+
signatureEditsForContainer(signature, container)
336+
337+
/**
338+
*
339+
* container.errorMethod
340+
*
341+
*/
342+
case (select @ Select(container, errorMethod)) ::
343+
_ if select.symbol == NoSymbol =>
344+
345+
val signature = printSignature(errorMethod, None, None)
346+
signatureEditsForContainer(signature, container)
347+
348+
case _ => Nil
349+
350+
end inferredMethodEdits
351+
end InferredMethodProvider

presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ case class ScalaPresentationCompiler(
6464
CodeActionId.ExtractMethod,
6565
CodeActionId.InlineValue,
6666
CodeActionId.InsertInferredType,
67+
CodeActionId.InsertInferredMethod,
6768
PcConvertToNamedLambdaParameters.codeActionId
6869
).asJava
6970

@@ -92,6 +93,8 @@ case class ScalaPresentationCompiler(
9293
implementAbstractMembers(params)
9394
case (CodeActionId.InsertInferredType, _) =>
9495
insertInferredType(params)
96+
case (CodeActionId.InsertInferredMethod, _) =>
97+
insertInferredMethod(params)
9598
case (CodeActionId.InlineValue, _) =>
9699
inlineValue(params)
97100
case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) =>
@@ -352,6 +355,19 @@ case class ScalaPresentationCompiler(
352355
.asJava
353356
}(params.toQueryContext)
354357

358+
def insertInferredMethod(
359+
params: OffsetParams
360+
): CompletableFuture[ju.List[l.TextEdit]] =
361+
val empty: ju.List[l.TextEdit] = new ju.ArrayList[l.TextEdit]()
362+
compilerAccess.withNonInterruptableCompiler(
363+
empty,
364+
params.token()
365+
) { pc =>
366+
new InferredMethodProvider(params, pc.compiler(), config, search)
367+
.inferredMethodEdits()
368+
.asJava
369+
}(params.toQueryContext)
370+
355371
override def inlineValue(
356372
params: OffsetParams
357373
): CompletableFuture[ju.List[l.TextEdit]] =

0 commit comments

Comments
 (0)