@@ -340,11 +340,13 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
340340 * @param fn the function
341341 * @param parts the function prefix followed by the flattened argument list
342342 * @param polyArg the clashing argument to a polymorphic formal
343- * @param clashing the argument with which it clashes
343+ * @param clashing the argument, function prefix, or entire function application result with
344+ * which it clashes,
345+ *
344346 */
345347 def sepApplyError (fn : Tree , parts : List [Tree ], polyArg : Tree , clashing : Tree )(using Context ): Unit =
346348 val polyArgIdx = parts.indexOf(polyArg).ensuring(_ >= 0 ) - 1
347- val clashIdx = parts.indexOf(clashing).ensuring(_ >= 0 )
349+ val clashIdx = parts.indexOf(clashing) // -1 means entire function application
348350 def paramName (mt : Type , idx : Int ): Option [Name ] = mt match
349351 case mt @ MethodType (pnames) =>
350352 if idx < pnames.length then Some (pnames(idx)) else paramName(mt.resType, idx - pnames.length)
@@ -363,11 +365,12 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
363365 if isShowableMethod then i " ${fn.symbol}: ${fn.symbol.info}"
364366 else i " a function of type ${funType.widen}"
365367 def clashArgStr = clashIdx match
366- case 0 => " function prefix"
367- case 1 => " first argument "
368- case 2 => " second argument"
369- case 3 => " third argument "
370- case n => s " ${n}th argument "
368+ case - 1 => " function result"
369+ case 0 => " function prefix"
370+ case 1 => " first argument "
371+ case 2 => " second argument"
372+ case 3 => " third argument "
373+ case n => s " ${n}th argument "
371374 def clashTypeStr =
372375 if clashIdx == 0 && ! isShowableMethod then " " // we already mentioned the type in `funStr`
373376 else i " with type ${clashing.nuType}"
@@ -455,11 +458,12 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
455458 *
456459 * @param fn the applied function
457460 * @param args the flattened argument lists
461+ * @param app the entire application tree
458462 * @param deps cross argument dependencies: maps argument trees to
459463 * those other arguments that where mentioned by coorresponding
460464 * formal parameters.
461465 */
462- private def checkApply (fn : Tree , args : List [Tree ], deps : collection.Map [Tree , List [Tree ]])(using Context ): Unit =
466+ private def checkApply (fn : Tree , args : List [Tree ], app : Tree , deps : collection.Map [Tree , List [Tree ]])(using Context ): Unit =
463467 val (qual, fnCaptures) = methPart(fn) match
464468 case Select (qual, _) => (qual, qual.nuType.captureSet)
465469 case _ => (fn, CaptureSet .empty)
@@ -511,6 +515,29 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
511515 currentPeaks = PeaksPair (
512516 currentPeaks.actual ++ argPeaks.actual,
513517 currentPeaks.hidden ++ argPeaks.hidden)
518+ end for
519+
520+ def collectRefs (args : List [Type ], res : Type ) =
521+ args.foldLeft(argCaptures(res)): (refs, arg) =>
522+ refs ++ arg.deepCaptureSet.elems
523+
524+ /** The deep capture sets of all parameters of this type (if it is a function type) */
525+ def argCaptures (tpe : Type ): Refs = tpe match
526+ case defn.FunctionOf (args, resultType, isContextual) =>
527+ collectRefs(args, resultType)
528+ case defn.RefinedFunctionOf (mt) =>
529+ collectRefs(mt.paramInfos, mt.resType)
530+ case CapturingType (parent, _) =>
531+ argCaptures(parent)
532+ case _ =>
533+ emptyRefs
534+
535+ if ! deps(app).isEmpty then
536+ lazy val appPeaks = argCaptures(app.nuType).peaks
537+ lazy val partPeaks = partsWithPeaks.toMap
538+ for arg <- deps(app) do
539+ if arg.needsSepCheck && ! partPeaks(arg).hidden.sharedWith(appPeaks).isEmpty then
540+ sepApplyError(fn, parts, arg, app)
514541 end checkApply
515542
516543 /** 1. Check that the capabilities used at `tree` don't overlap with
@@ -782,44 +809,55 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
782809 *
783810 * f(x: A, y: B^{cap, x}, z: C^{x, y}): D
784811 *
785- * then the dependencies of an application `f(a, b)` is a map that takes
786- * `b` to `List(a)` and `c` to `List(a, b)`.
812+ * then the dependencies of an application `f(a, b, c)` of type C^{y} is the map
813+ *
814+ * [ b -> [a]
815+ * , c -> [a, b]
816+ * , f(a, b, c) -> [b]]
787817 */
788- private def dependencies (fn : Tree , argss : List [List [Tree ]])(using Context ): collection.Map [Tree , List [Tree ]] =
818+ private def dependencies (fn : Tree , argss : List [List [Tree ]], app : Tree )(using Context ): collection.Map [Tree , List [Tree ]] =
819+ def isFunApply (sym : Symbol ) =
820+ sym.name == nme.apply && defn.isFunctionClass(sym.owner)
789821 val mtpe =
790- if fn.symbol.exists then fn.symbol.info
791- else fn.tpe .widen // happens for PolyFunction applies
822+ if fn.symbol.exists && ! isFunApply(fn.symbol) then fn.symbol.info
823+ else fn.nuType .widen
792824 val mtps = collectMethodTypes(mtpe)
793825 assert(mtps.hasSameLengthAs(argss), i " diff for $fn: ${fn.symbol} /// $mtps /// $argss" )
794826 val mtpsWithArgs = mtps.zip(argss)
795827 val argMap = mtpsWithArgs.toMap
796828 val deps = mutable.HashMap [Tree , List [Tree ]]().withDefaultValue(Nil )
797- for
798- (mt, args) <- mtpsWithArgs
799- (formal, arg) <- mt.paramInfos.zip(args)
800- dep <- formal.captureSet.elems.toList
801- do
802- val referred = dep.stripReach match
803- case dep : TermParamRef =>
804- argMap(dep.binder)(dep.paramNum) :: Nil
805- case dep : ThisType if dep.cls == fn.symbol.owner =>
806- val Select (qual, _) = fn : @ unchecked // TODO can we use fn instead?
807- qual :: Nil
808- case _ =>
809- Nil
810- deps(arg) ++= referred
829+
830+ def recordDeps (formal : Type , actual : Tree ) =
831+ for dep <- formal.captureSet.elems.toList do
832+ val referred = dep.stripReach match
833+ case dep : TermParamRef =>
834+ argMap(dep.binder)(dep.paramNum) :: Nil
835+ case dep : ThisType if dep.cls == fn.symbol.owner =>
836+ val Select (qual, _) = fn : @ unchecked // TODO can we use fn instead?
837+ qual :: Nil
838+ case _ =>
839+ Nil
840+ deps(actual) ++= referred
841+
842+ for (mt, args) <- mtpsWithArgs; (formal, arg) <- mt.paramInfos.zip(args) do
843+ recordDeps(formal, arg)
844+ recordDeps(mtpe.finalResultType, app)
845+ capt.println(i " deps for $app = ${deps.toList}" )
811846 deps
812847
848+
813849 /** Decompose an application into a function prefix and a list of argument lists.
814850 * If some of the arguments need a separation check because they are capture polymorphic,
815851 * perform a separation check with `checkApply`
816852 */
817- private def traverseApply (tree : Tree , argss : List [List [Tree ]])(using Context ): Unit = tree match
818- case Apply (fn, args) => traverseApply(fn, args :: argss)
819- case TypeApply (fn, args) => traverseApply(fn, argss) // skip type arguments
820- case _ =>
821- if argss.nestedExists(_.needsSepCheck) then
822- checkApply(tree, argss.flatten, dependencies(tree, argss))
853+ private def traverseApply (app : Tree )(using Context ): Unit =
854+ def recur (tree : Tree , argss : List [List [Tree ]]): Unit = tree match
855+ case Apply (fn, args) => recur(fn, args :: argss)
856+ case TypeApply (fn, args) => recur(fn, argss) // skip type arguments
857+ case _ =>
858+ if argss.nestedExists(_.needsSepCheck) then
859+ checkApply(tree, argss.flatten, app, dependencies(tree, argss, app))
860+ recur(app, Nil )
823861
824862 /** Is `tree` an application of `caps.unsafe.unsafeAssumeSeparate`? */
825863 def isUnsafeAssumeSeparate (tree : Tree )(using Context ): Boolean = tree match
@@ -866,7 +904,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
866904 traverseChildren(tree)
867905 tree.tpe match
868906 case _ : MethodOrPoly =>
869- case _ => traverseApply(tree, Nil )
907+ case _ => traverseApply(tree)
870908 case _ : Block | _ : Template =>
871909 traverseSection(tree)
872910 case tree : ValDef =>
0 commit comments