Skip to content

Commit 663e750

Browse files
Extension copy (#262)
Co-authored-by: Kacper Korban <[email protected]>
1 parent 7cf831e commit 663e750

File tree

5 files changed

+247
-64
lines changed

5 files changed

+247
-64
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject)
1414
val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
1515
organization := "com.softwaremill.quicklens",
1616
updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))),
17-
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all"
17+
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros"
1818
ideSkipProject := (scalaVersion.value != scalaIdeaVersion)
1919
)
2020

quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala

Lines changed: 133 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ object QuicklensMacros {
5858
def noSuchMember(tpeStr: String, name: String) =
5959
s"$tpeStr has no member named $name"
6060

61+
def noSuitableMember(tpeStr: String, name: String, argNames: Iterable[String]) =
62+
s"$tpeStr has no member $name with parameters ${argNames.mkString("(", ", ", ")")}"
63+
6164
def multipleMatchingMethods(tpeStr: String, name: String, syms: Seq[Symbol]) =
6265
val symsStr = syms.map(s => s" - $s: ${s.termRef.dealias.widen.show}").mkString("\n", "\n", "")
6366
s"Multiple methods named $name found in $tpeStr: $symsStr"
@@ -109,11 +112,14 @@ object QuicklensMacros {
109112
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))
110113

111114
enum PathSymbol:
112-
case Field(name: String)
113-
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
115+
case Field(override val name: String)
116+
case Extension(term: Term, override val name: String)
117+
case FunctionDelegate(override val name: String, givn: Term, typeTree: TypeTree, args: List[Term])
118+
def name: String
114119

115120
def equiv(other: Any): Boolean = (this, other) match
116121
case (Field(name1), Field(name2)) => name1 == name2
122+
case (Extension(term1, name1), Extension(term2, name2)) => term1 == term2 && name1 == name2
117123
case (FunctionDelegate(name1, _, typeTree1, args1), FunctionDelegate(name2, _, typeTree2, args2)) =>
118124
name1 == name2 && typeTree1.tpe == typeTree2.tpe && args1 == args2
119125
case _ => false
@@ -133,6 +139,9 @@ object QuicklensMacros {
133139
/** Method call with one type parameter and using clause */
134140
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
135141
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
142+
/** Extension method, which is called e.g. as x(_$1) */
143+
case Apply(obj@Select(term, member), Seq(deep)) if obj.symbol.flags.is(Flags.ExtensionMethod) =>
144+
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
136145
/** Field access */
137146
case Apply(deep, idents) =>
138147
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
@@ -157,43 +166,104 @@ object QuicklensMacros {
157166
def matchingTypeSymbol: Symbol = tpe.widenAll match {
158167
case AndType(l, r) =>
159168
val lSym = l.matchingTypeSymbol
160-
if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol
161-
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
162-
tpe.typeSymbol
163-
case tpe if isProductLike(tpe.typeSymbol) =>
169+
if lSym != Symbol.noSymbol then lSym else r.matchingTypeSymbol
170+
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) =>
164171
tpe.typeSymbol
165172
case _ =>
166173
Symbol.noSymbol
167174
}
168175

169-
def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
170-
val mem = sym.fieldMember(name)
171-
if mem != Symbol.noSymbol then mem
172-
else methodSymbolByNameOrError(sym, name)
176+
extension (term: Term)
177+
def appliedToIfNeeded(args: List[Term]): Term =
178+
if args.isEmpty then term else term.appliedToArgs(args)
179+
180+
def symbolAccessorByNameOrError(obj: Term, name: String): Term = {
181+
val objTpe = obj.tpe.widenAll
182+
val objSymbol = objTpe.matchingTypeSymbol
183+
// opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143)
184+
val fieldMemberSym = objSymbol.fieldMember(name)
185+
if !objSymbol.flags.is(Flags.Deferred) && fieldMemberSym.exists then
186+
Select(obj, fieldMemberSym)
187+
else
188+
objSymbol.methodMember(name) match
189+
case List(m) =>
190+
Select(obj, m)
191+
case lst =>
192+
report.errorAndAbort(reportMethodError(objSymbol, name, lst))
193+
}
194+
195+
def reportMethodError(sym: Symbol, name: String, lst: List[Symbol], maybeArgNames: Option[Iterable[String]] = None): String = {
196+
(lst, maybeArgNames) match
197+
case (Nil, _) => noSuchMember(sym.name, name)
198+
case (lst, None) => multipleMatchingMethods(sym.name, name, lst)
199+
case (lst, Some(argNames)) => noSuitableMember(sym.name, name, argNames)
173200
}
174201

175202
def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = {
176203
sym.methodMember(name) match
177204
case List(m) => m
178-
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
179-
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
205+
case lst => report.errorAndAbort(reportMethodError(sym, name, lst))
180206
}
181207

182-
def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = {
208+
def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
183209
val argNames = argsMap.keys
184-
sym.methodMember(name).filter{ msym =>
210+
allMethods.filter { msym =>
185211
// for copy, we filter out the methods that don't have the desired parameter names
186212
val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
187213
argNames.forall(paramNames.contains)
188214
} match
189-
case List(m) => m
190-
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
191-
case lst @ (m :: _) =>
215+
case List(m) => Some(m)
216+
case Nil => None
217+
case lst@(m :: _) =>
192218
// if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
193219
val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic))
194220
syntheticCopies match
195-
case List(mSynth) => mSynth
196-
case _ => m
221+
case List(mSynth) => Some(mSynth)
222+
case _ => Some(m)
223+
}
224+
225+
def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Either[String, Symbol] = {
226+
if !sym.flags.is(Flags.Deferred) then
227+
val memberMethods = sym.methodMember(name)
228+
filterMethodsByNameAndArgs(memberMethods, argsMap)
229+
.toRight(reportMethodError(sym, name, memberMethods, Some(argsMap.keys)))
230+
else Left(s"Deferred type ${sym.name}")
231+
}
232+
233+
/**
234+
* @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value
235+
* on which the extension is called
236+
* */
237+
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = {
238+
require(argsMap.size == 1 || argsMap.size == 2, s"argsMap.size should be either 1 or 2, got: ${argsMap.size} ($argsMap)")
239+
val objTpe = obj.tpe.widenAll
240+
val objSymbol = objTpe.matchingTypeSymbol
241+
242+
val typeParams = objTpe.typeArgs
243+
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
244+
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
245+
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
246+
.flatten.toList
247+
248+
val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
249+
val i = _i + 1
250+
def defaultMethod: Term =
251+
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
252+
// default values in extension methods take the extension receiver as the first parameter
253+
val defaultMethodArgs = argsMap.dropRight(1).flatMap(_.values)
254+
obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs)
255+
n -> v.getOrElse(defaultMethod)
256+
}.toMap
257+
258+
val argLists: List[List[Term]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))
259+
260+
if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
261+
report.errorAndAbort(
262+
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
263+
)
264+
265+
val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams)
266+
argLists.foldLeft(withTypeParamsApplied)(Apply(_, _))
197267
}
198268

199269
def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
@@ -210,15 +280,32 @@ object QuicklensMacros {
210280
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
211281
}
212282

283+
def findCompanionLikeObject(objSymbol: Symbol): Symbol = {
284+
if objSymbol.companionModule.exists then
285+
objSymbol.companionModule
286+
else
287+
val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name)
288+
if namedFromOwnerScope.flags.is(Flags.Module) then namedFromOwnerScope
289+
else Symbol.noSymbol
290+
}
291+
292+
def hasExtensionNamed(sym: Symbol, methodName: String): List[Symbol] = {
293+
val companionSymbol = findCompanionLikeObject(sym)
294+
if companionSymbol.exists then
295+
companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags.ExtensionMethod))
296+
else
297+
Nil
298+
}
299+
213300
def isProductLike(sym: Symbol): Boolean = {
214-
sym.methodMember("copy").size >= 1
301+
sym.methodMember("copy").nonEmpty || hasExtensionNamed(sym, "copy").nonEmpty
215302
}
216303

217304
def caseClassCopy(
218305
owner: Symbol,
219306
mod: Expr[A => A],
220307
obj: Term,
221-
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
308+
fields: Seq[(PathSymbol.Field | PathSymbol.Extension, Seq[PathTree])]
222309
): Term = {
223310
val objTpe = obj.tpe.widenAll
224311
val objSymbol = objTpe.matchingTypeSymbol
@@ -248,50 +335,39 @@ object QuicklensMacros {
248335
}
249336

250337
val elseThrow = '{ throw new IllegalStateException() }.asTerm
338+
251339
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
252340
If(ifCond, ifThen, ifElse)
253341
}
254342
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
255343
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
256-
val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name)
257-
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
344+
val fieldMethod = field match {
345+
case PathSymbol.Field(name) =>
346+
symbolAccessorByNameOrError(obj, name)
347+
case PathSymbol.Extension(term, name) =>
348+
val extensionMethod = symbolAccessorByNameOrError(term, name)
349+
Apply(extensionMethod, List(obj))
350+
}
351+
val resTerm: Term = trees.foldLeft[Term](fieldMethod) { (term, tree) =>
258352
mapToCopy(owner, mod, term, tree)
259353
}
260354
val namedArg = NamedArg(field.name, resTerm)
261355
field.name -> namedArg
262356
}.toMap
263-
val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap)
264-
265-
val typeParams = objTpe match {
266-
case AppliedType(_, typeParams) => Some(typeParams)
267-
case _ => None
268-
}
269-
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
270-
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)
271-
272-
val args = copyParamNames.zipWithIndex.map { (n, _i) =>
273-
val i = _i + 1
274-
val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString))
275-
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
276-
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
277-
argsMap.getOrElse(
278-
n,
279-
defaultMethod
280-
)
281-
}.toList
282-
283-
if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
284-
report.errorAndAbort(
285-
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
286-
)
287-
288-
typeParams match {
289-
// if the object's type is parametrised, we need to call .copy with the same type parameters
290-
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
291-
case _ => Apply(Select(obj, copy), args)
292-
}
357+
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
358+
case Right(copy) =>
359+
callMethod(obj, copy, List(argsMap))
360+
case Left(error) =>
361+
val objCompanion = findCompanionLikeObject(objSymbol)
362+
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap).toOption match
363+
case Some(copy) =>
364+
// now try to call the extension as a method, assume the object is its first parameter
365+
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
366+
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
367+
callMethod(Ref(objCompanion), copy, argsWithObj)
368+
case None => report.errorAndAbort(error)
293369
} else
294-
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
370+
report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
295371
}
296372

297373
def applyFunctionDelegate(
@@ -331,9 +407,9 @@ object QuicklensMacros {
331407
case Nil =>
332408
objTerm
333409

334-
case (_: PathSymbol.Field, _) :: _ =>
335-
val (fs, funs) = pathSymbols.span(_._1.isInstanceOf[PathSymbol.Field])
336-
val fields = fs.collect { case (p: PathSymbol.Field, trees) => p -> trees }
410+
case (_: (PathSymbol.Field | PathSymbol.Extension), _) :: _ =>
411+
val (fs, funs) = pathSymbols.span((ps, _) => ps.isInstanceOf[PathSymbol.Field] || ps.isInstanceOf[PathSymbol.Extension])
412+
val fields = fs.collect { case (p: (PathSymbol.Field | PathSymbol.Extension), trees) => p -> trees }
337413
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
338414
accumulateToCopy(owner, mod, withCopiedFields, funs)
339415

quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ package object quicklens {
154154
def map[A](fa: M[A], f: A => A): M[A] = {
155155
val mapped = fa.view.mapValues(f)
156156
(fa match {
157-
case sfa: SortedMap[K, A] => sfa.sortedMapFactory.from(mapped)(using sfa.ordering)
157+
case sfa: SortedMap[K, A]@unchecked => sfa.sortedMapFactory.from(mapped)(using sfa.ordering)
158158
case _ => mapped.to(fa.mapFactory)
159159
}).asInstanceOf[M[A]]
160160
}

quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
package com.softwaremill.quicklens
2+
package test
23

34
import org.scalatest.flatspec.AnyFlatSpec
45
import org.scalatest.matchers.should.Matchers
@@ -33,7 +34,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
3334
def paths(paths: Paths): Docs = copy(paths = paths)
3435
}
3536
val docs = Docs()
36-
docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
37+
val r = docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
38+
r.paths.pathItems should contain ("a" -> PathItem())
3739
}
3840

3941
it should "modify a case class with an additional explicit copy" in {
@@ -42,7 +44,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
4244
}
4345

4446
val f = Frozen("A", 0)
45-
f.modify(_.state).setTo("B")
47+
val r = f.modify(_.state).setTo("B")
48+
r.state shouldEqual "B"
4649
}
4750

4851
it should "modify a case class with an ambiguous additional explicit copy" in {
@@ -51,7 +54,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
5154
}
5255

5356
val f = Frozen("A", 0)
54-
f.modify(_.state).setTo("B")
57+
val r = f.modify(_.state).setTo("B")
58+
r.state shouldEqual "B"
5559
}
5660

5761
it should "modify a class with two explicit copy methods" in {
@@ -61,7 +65,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
6165
}
6266

6367
val f = new Frozen("A", 0)
64-
f.modify(_.state).setTo("B")
68+
val r = f.modify(_.state).setTo("B")
69+
r.state shouldEqual "B"
6570
}
6671

6772
it should "modify a case class with an ambiguous additional explicit copy and pick the synthetic one first" in {
@@ -77,6 +82,19 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
7782
accessed shouldEqual 0
7883
}
7984

85+
it should "not compile when modifying a field which is not present as a copy parameter" in {
86+
"""
87+
case class Content(x: String)
88+
89+
class A(val c: Content) {
90+
def copy(x: String = c.x): A = new A(Content(x))
91+
}
92+
93+
val a = new A(Content("A"))
94+
val am = a.modify(_.c).setTo(Content("B"))
95+
""" shouldNot compile
96+
}
97+
8098
// TODO: Would be nice to be able to handle this case. Based on the types, it
8199
// is obvious, that the explicit copy should be picked, but I'm not sure if we
82100
// can get that information
@@ -90,5 +108,4 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
90108
// val f = Frozen("A", 0)
91109
// f.modify(_.state).setTo('B')
92110
// }
93-
94111
}

0 commit comments

Comments
 (0)