Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
59e29d8
Implement basic version of desugaring context bounds for poly functions
KacperFKorban Sep 24, 2024
d97ddd6
Handle named context bounds in poly function context bound desugaring
KacperFKorban Sep 24, 2024
34827b1
Correctly-ish desugar poly function context bounds in function types
KacperFKorban Sep 24, 2024
c8399ea
Fix pickling issue
KacperFKorban Sep 24, 2024
76455c0
Hide context bounds expansion for poly functions under modularity fea…
KacperFKorban Sep 24, 2024
7da2270
Small cleanup
KacperFKorban Sep 25, 2024
a2c6c4b
Add more test cases
KacperFKorban Sep 25, 2024
930d420
Change the implementation of context bound expansion for poly functio…
KacperFKorban Oct 2, 2024
019c6cf
Add support for some type aliases, when expanding context bounds for …
KacperFKorban Oct 3, 2024
eae738e
Make the expandion of context bounds for poly types slightly more ele…
KacperFKorban Oct 4, 2024
a679f11
Add more aliases tests for context bounds with poly functions
KacperFKorban Oct 8, 2024
341f643
Bring back the restriction for requiring value parameters in poly fun…
KacperFKorban Oct 14, 2024
8178cf4
Cleanup dead code
KacperFKorban Oct 18, 2024
38b7785
Reuse addEvidenceParams logic, but no aliases
KacperFKorban Nov 14, 2024
682691c
Cleanup context bounds for poly functions implementation, make the im…
KacperFKorban Nov 14, 2024
96e07b2
More cleanup of poly context bound desugaring
KacperFKorban Nov 14, 2024
44577f6
Short circuit adding evidence params to poly functions, when there ar…
KacperFKorban Nov 14, 2024
1de5b3e
Add a run test for poly context bounds; cleanup typer changes
KacperFKorban Nov 15, 2024
386d83d
Cleanup context bounds for poly functions implementation after review
KacperFKorban Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 139 additions & 63 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ object desugar {
*/
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()

/** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -242,7 +246,7 @@ object desugar {
* def f$default$2[T](x: Int) = x + "m"
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])

/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
* get added to a buffer.
Expand Down Expand Up @@ -304,10 +308,8 @@ object desugar {
tdef1
end desugarContextBounds

private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
val DefDef(_, paramss, tpt, rhs) = meth
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
val evidenceParamBuf = mutable.ListBuffer[ValDef]()

var seenContextBounds: Int = 0
def freshName(unused: Tree) =
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
Expand All @@ -317,7 +319,7 @@ object desugar {
// parameters of the method since shadowing does not affect
// implicit resolution in Scala 3.

val paramssNoContextBounds =
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
val iflag = paramss.lastOption.flatMap(_.headOption) match
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
param.mods.flags & GivenOrImplicit
Expand All @@ -329,15 +331,32 @@ object desugar {
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
}(identity)

rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = paramssNoContextBounds),
evidenceParamBuf.toList)
meth match
case meth @ DefDef(_, paramss, tpt, rhs) =>
val newParamss = paramssNoContextBounds(paramss)
rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = newParamss
),
evidenceParamBuf.toList
)
case meth @ PolyFunction(tparams, fun) =>
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
val params = evidenceParamBuf.toList
if params.isEmpty then
meth
else
val boundNames = getBoundNames(params, newParamss)
val recur = fitEvidenceParams(params, nme.apply, boundNames)
val (paramsFst, paramsSnd) = recur(newParamss)
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
end elimContextBounds

def addDefaultGetters(meth: DefDef)(using Context): Tree =
Expand Down Expand Up @@ -465,6 +484,74 @@ object desugar {
case _ =>
(Nil, tree)

private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => names.contains(name)
case _ => false

/** Fit evidence `params` into the `mparamss` parameter lists, making sure
* that all parameters referencing `params` are after them.
* - for methods the final parameter lists are := result._1 ++ result._2
* - for poly functions, each element of the pair contains at most one term
* parameter list
*
* @param params the evidence parameters list that should fit into `mparamss`
* @param methName the name of the method that `mparamss` belongs to
* @param boundNames the names of the evidence parameters
* @param mparamss the original parameter lists of the method
* @return a pair of parameter lists containing all parameter lists in a
* reference-correct order; make sure that `params` is always at the
* intersection of the pair elements; this is relevant, for poly functions
* where `mparamss` is guaranteed to have exectly one term parameter list,
* then each pair element will have at most one term parameter list
*/
private def fitEvidenceParams(
params: List[ValDef],
methName: Name,
boundNames: Set[TermName]
)(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
(params :: Nil) -> mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
else params
((normParams ++ mparams) :: Nil) -> Nil
case mparams :: mparamss1 =>
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
(mparams :: fst) -> snd
case Nil =>
Nil -> (params :: Nil)

/** Create a chain of possibly contextual functions from the parameter lists */
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = head.map(_.tpt)
val paramNames = head.map(_.name)
val paramsErased = head.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case ValDefs(head) :: rest =>
Function(head, functionsOf(rest, rhs))
case TypeDefs(head) :: rest =>
PolyFunction(head, functionsOf(rest, rhs))
case _ =>
assert(false, i"unexpected paramss $paramss")
EmptyTree

private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
boundNames

/** Add all evidence parameters in `params` as implicit parameters to `meth`.
* The position of the added parameters is determined as follows:
*
Expand All @@ -479,36 +566,23 @@ object desugar {
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
if params.isEmpty then return meth

var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- meth.paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
val boundNames = getBoundNames(params, meth.paramss)

def referencesBoundName(vdef: ValDef): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => boundNames.contains(name)
case _ => false
val fitParams = fitEvidenceParams(params, meth.name, boundNames)

def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
params :: mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
else params
(normParams ++ mparams) :: Nil
case mparams :: mparamss1 =>
mparams :: recur(mparamss1)
case Nil =>
params :: Nil

cpy.DefDef(meth)(paramss = recur(meth.paramss))
if meth.removeAttachment(PolyFunctionApply).isDefined then
// for PolyFunctions we are limited to a single term param list, so we
// reuse the fitEvidenceParams logic to compute the new parameter lists
// and then we add the other parameter lists as function types to the
// return type
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
else
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
else
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
end addEvidenceParams

/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
Expand Down Expand Up @@ -1224,27 +1298,29 @@ object desugar {
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
)).withSpan(tree.span)
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, ())
)).withSpan(tree.span)
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3460,7 +3460,7 @@ object Parsers {
*
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
* TypTypeParam ::= {Annotation}
* (id | ‘_’) [HkTypeParamClause] TypeBounds
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
*
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]
Expand Down Expand Up @@ -3491,7 +3491,9 @@ object Parsers {
else ident().toTypeName
val hkparams = typeParamClauseOpt(ParamOwner.Hk)
val bounds =
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds()
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name)
else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name)
else typeBounds()
TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods)
}
}
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Expand Down Expand Up @@ -2471,7 +2471,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
val refinements1 = impl.body
val seen = mutable.Set[Symbol]()
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
typr.println(s"adding refinement $refinement")
checkRefinementNonCyclic(refinement, refineCls, seen)
val rsym = refinement.symbol
Expand All @@ -2485,7 +2485,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val member = refineCls.info.member(rsym.name)
if (member.isOverloaded)
report.error(OverloadInRefinement(rsym), refinement.srcPos)
}
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
}

Expand Down
Loading
Loading