@@ -58,6 +58,9 @@ object QuicklensMacros {
5858 def noSuchMember (tpeStr : String , name : String ) =
5959 s " $tpeStr has no member named $name"
6060
61+ def noSuitableMember (tpeStr : String , name : String , argNames : Iterable [String ]) =
62+ s " $tpeStr has no member $name with parameters ${argNames.mkString(" (" , " , " , " )" )}"
63+
6164 def multipleMatchingMethods (tpeStr : String , name : String , syms : Seq [Symbol ]) =
6265 val symsStr = syms.map(s => s " - $s: ${s.termRef.dealias.widen.show}" ).mkString(" \n " , " \n " , " " )
6366 s " Multiple methods named $name found in $tpeStr: $symsStr"
@@ -109,11 +112,14 @@ object QuicklensMacros {
109112 case (symbol :: tail) => PathTree .Node (Seq (symbol -> Seq (tail.toPathTree)))
110113
111114 enum PathSymbol :
112- case Field (name : String )
113- case FunctionDelegate (name : String , givn : Term , typeTree : TypeTree , args : List [Term ])
115+ case Field (override val name : String )
116+ case Extension (term : Term , override val name : String )
117+ case FunctionDelegate (override val name : String , givn : Term , typeTree : TypeTree , args : List [Term ])
118+ def name : String
114119
115120 def equiv (other : Any ): Boolean = (this , other) match
116121 case (Field (name1), Field (name2)) => name1 == name2
122+ case (Extension (term1, name1), Extension (term2, name2)) => term1 == term2 && name1 == name2
117123 case (FunctionDelegate (name1, _, typeTree1, args1), FunctionDelegate (name2, _, typeTree2, args2)) =>
118124 name1 == name2 && typeTree1.tpe == typeTree2.tpe && args1 == args2
119125 case _ => false
@@ -133,6 +139,9 @@ object QuicklensMacros {
133139 /** Method call with one type parameter and using clause */
134140 case a @ Apply (TypeApply (Apply (TypeApply (Ident (s), _), idents), typeTrees), List (givn)) if methodSupported(s) =>
135141 idents.flatMap(toPath(_, focus)) :+ PathSymbol .FunctionDelegate (s, givn, typeTrees.last, List .empty)
142+ /** Extension method, which is called e.g. as x(_$1) */
143+ case Apply (obj@ Select (term, member), Seq (deep)) if obj.symbol.flags.is(Flags .ExtensionMethod ) =>
144+ toPath(deep, focus) :+ PathSymbol .Extension (term, member)
136145 /** Field access */
137146 case Apply (deep, idents) =>
138147 toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
@@ -157,43 +166,104 @@ object QuicklensMacros {
157166 def matchingTypeSymbol : Symbol = tpe.widenAll match {
158167 case AndType (l, r) =>
159168 val lSym = l.matchingTypeSymbol
160- if l.matchingTypeSymbol != Symbol .noSymbol then lSym else r.matchingTypeSymbol
161- case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
162- tpe.typeSymbol
163- case tpe if isProductLike(tpe.typeSymbol) =>
169+ if lSym != Symbol .noSymbol then lSym else r.matchingTypeSymbol
170+ case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) =>
164171 tpe.typeSymbol
165172 case _ =>
166173 Symbol .noSymbol
167174 }
168175
169- def symbolAccessorByNameOrError (sym : Symbol , name : String ): Symbol = {
170- val mem = sym.fieldMember(name)
171- if mem != Symbol .noSymbol then mem
172- else methodSymbolByNameOrError(sym, name)
176+ extension (term : Term )
177+ def appliedToIfNeeded (args : List [Term ]): Term =
178+ if args.isEmpty then term else term.appliedToArgs(args)
179+
180+ def symbolAccessorByNameOrError (obj : Term , name : String ): Term = {
181+ val objTpe = obj.tpe.widenAll
182+ val objSymbol = objTpe.matchingTypeSymbol
183+ // opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143)
184+ val fieldMemberSym = objSymbol.fieldMember(name)
185+ if ! objSymbol.flags.is(Flags .Deferred ) && fieldMemberSym.exists then
186+ Select (obj, fieldMemberSym)
187+ else
188+ objSymbol.methodMember(name) match
189+ case List (m) =>
190+ Select (obj, m)
191+ case lst =>
192+ report.errorAndAbort(reportMethodError(objSymbol, name, lst))
193+ }
194+
195+ def reportMethodError (sym : Symbol , name : String , lst : List [Symbol ], maybeArgNames : Option [Iterable [String ]] = None ): String = {
196+ (lst, maybeArgNames) match
197+ case (Nil , _) => noSuchMember(sym.name, name)
198+ case (lst, None ) => multipleMatchingMethods(sym.name, name, lst)
199+ case (lst, Some (argNames)) => noSuitableMember(sym.name, name, argNames)
173200 }
174201
175202 def methodSymbolByNameOrError (sym : Symbol , name : String ): Symbol = {
176203 sym.methodMember(name) match
177204 case List (m) => m
178- case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
179- case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
205+ case lst => report.errorAndAbort(reportMethodError(sym, name, lst))
180206 }
181207
182- def methodSymbolByNameAndArgsOrError ( sym : Symbol , name : String , argsMap : Map [String , Term ]): Symbol = {
208+ def filterMethodsByNameAndArgs ( allMethods : List [ Symbol ], argsMap : Map [String , Term ]): Option [ Symbol ] = {
183209 val argNames = argsMap.keys
184- sym.methodMember(name). filter{ msym =>
210+ allMethods. filter { msym =>
185211 // for copy, we filter out the methods that don't have the desired parameter names
186212 val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
187213 argNames.forall(paramNames.contains)
188214 } match
189- case List (m) => m
190- case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
191- case lst @ (m :: _) =>
215+ case List (m) => Some (m)
216+ case Nil => None
217+ case lst@ (m :: _) =>
192218 // if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
193219 val syntheticCopies = lst.filter(_.flags.is(Flags .Synthetic ))
194220 syntheticCopies match
195- case List (mSynth) => mSynth
196- case _ => m
221+ case List (mSynth) => Some (mSynth)
222+ case _ => Some (m)
223+ }
224+
225+ def methodSymbolByNameAndArgs (sym : Symbol , name : String , argsMap : Map [String , Term ]): Either [String , Symbol ] = {
226+ if ! sym.flags.is(Flags .Deferred ) then
227+ val memberMethods = sym.methodMember(name)
228+ filterMethodsByNameAndArgs(memberMethods, argsMap)
229+ .toRight(reportMethodError(sym, name, memberMethods, Some (argsMap.keys)))
230+ else Left (s " Deferred type ${sym.name}" )
231+ }
232+
233+ /**
234+ * @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value
235+ * on which the extension is called
236+ * */
237+ def callMethod (obj : Term , copy : Symbol , argsMap : List [Map [String , Term ]]) = {
238+ require(argsMap.size == 1 || argsMap.size == 2 , s " argsMap.size should be either 1 or 2, got: ${argsMap.size} ( $argsMap) " )
239+ val objTpe = obj.tpe.widenAll
240+ val objSymbol = objTpe.matchingTypeSymbol
241+
242+ val typeParams = objTpe.typeArgs
243+ val copyTree : DefDef = copy.tree.asInstanceOf [DefDef ]
244+ val copyParams : List [(String , Option [Term ])] = copyTree.termParamss.zip(argsMap)
245+ .map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
246+ .flatten.toList
247+
248+ val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
249+ val i = _i + 1
250+ def defaultMethod : Term =
251+ val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + " $default$" + i.toString)
252+ // default values in extension methods take the extension receiver as the first parameter
253+ val defaultMethodArgs = argsMap.dropRight(1 ).flatMap(_.values)
254+ obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs)
255+ n -> v.getOrElse(defaultMethod)
256+ }.toMap
257+
258+ val argLists : List [List [Term ]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))
259+
260+ if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(! _.symbol.flags.is(Flags .Implicit ))) then
261+ report.errorAndAbort(
262+ s " Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1 )}"
263+ )
264+
265+ val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams)
266+ argLists.foldLeft(withTypeParamsApplied)(Apply (_, _))
197267 }
198268
199269 def termMethodByNameUnsafe (term : Term , name : String ): Symbol = {
@@ -210,15 +280,32 @@ object QuicklensMacros {
210280 (sym.flags.is(Flags .Sealed ) && (sym.flags.is(Flags .Trait ) || sym.flags.is(Flags .Abstract )))
211281 }
212282
283+ def findCompanionLikeObject (objSymbol : Symbol ): Symbol = {
284+ if objSymbol.companionModule.exists then
285+ objSymbol.companionModule
286+ else
287+ val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name)
288+ if namedFromOwnerScope.flags.is(Flags .Module ) then namedFromOwnerScope
289+ else Symbol .noSymbol
290+ }
291+
292+ def hasExtensionNamed (sym : Symbol , methodName : String ): List [Symbol ] = {
293+ val companionSymbol = findCompanionLikeObject(sym)
294+ if companionSymbol.exists then
295+ companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags .ExtensionMethod ))
296+ else
297+ Nil
298+ }
299+
213300 def isProductLike (sym : Symbol ): Boolean = {
214- sym.methodMember(" copy" ).size >= 1
301+ sym.methodMember(" copy" ).nonEmpty || hasExtensionNamed(sym, " copy " ).nonEmpty
215302 }
216303
217304 def caseClassCopy (
218305 owner : Symbol ,
219306 mod : Expr [A => A ],
220307 obj : Term ,
221- fields : Seq [(PathSymbol .Field , Seq [PathTree ])]
308+ fields : Seq [(PathSymbol .Field | PathSymbol . Extension , Seq [PathTree ])]
222309 ): Term = {
223310 val objTpe = obj.tpe.widenAll
224311 val objSymbol = objTpe.matchingTypeSymbol
@@ -248,50 +335,39 @@ object QuicklensMacros {
248335 }
249336
250337 val elseThrow = ' { throw new IllegalStateException () }.asTerm
338+
251339 ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
252340 If (ifCond, ifThen, ifElse)
253341 }
254342 } else if isProduct(objSymbol) || isProductLike(objSymbol) then {
255343 val argsMap : Map [String , Term ] = fields.map { (field, trees) =>
256- val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name)
257- val resTerm : Term = trees.foldLeft[Term ](Select (obj, fieldMethod)) { (term, tree) =>
344+ val fieldMethod = field match {
345+ case PathSymbol .Field (name) =>
346+ symbolAccessorByNameOrError(obj, name)
347+ case PathSymbol .Extension (term, name) =>
348+ val extensionMethod = symbolAccessorByNameOrError(term, name)
349+ Apply (extensionMethod, List (obj))
350+ }
351+ val resTerm : Term = trees.foldLeft[Term ](fieldMethod) { (term, tree) =>
258352 mapToCopy(owner, mod, term, tree)
259353 }
260354 val namedArg = NamedArg (field.name, resTerm)
261355 field.name -> namedArg
262356 }.toMap
263- val copy = methodSymbolByNameAndArgsOrError(objSymbol, " copy" , argsMap)
264-
265- val typeParams = objTpe match {
266- case AppliedType (_, typeParams) => Some (typeParams)
267- case _ => None
268- }
269- val copyTree : DefDef = copy.tree.asInstanceOf [DefDef ]
270- val copyParamNames : List [String ] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)
271-
272- val args = copyParamNames.zipWithIndex.map { (n, _i) =>
273- val i = _i + 1
274- val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, " copy$default$" + i.toString))
275- // for extension methods, might need sth more like this: (or probably some weird implicit conversion)
276- // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
277- argsMap.getOrElse(
278- n,
279- defaultMethod
280- )
281- }.toList
282-
283- if copyTree.termParamss.drop(1 ).exists(_.params.exists(! _.symbol.flags.is(Flags .Implicit ))) then
284- report.errorAndAbort(
285- s " Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. "
286- )
287-
288- typeParams match {
289- // if the object's type is parametrised, we need to call .copy with the same type parameters
290- case Some (typeParams) => Apply (TypeApply (Select (obj, copy), typeParams.map(Inferred (_))), args)
291- case _ => Apply (Select (obj, copy), args)
292- }
357+ methodSymbolByNameAndArgs(objSymbol, " copy" , argsMap) match
358+ case Right (copy) =>
359+ callMethod(obj, copy, List (argsMap))
360+ case Left (error) =>
361+ val objCompanion = findCompanionLikeObject(objSymbol)
362+ methodSymbolByNameAndArgs(objCompanion, " copy" , argsMap).toOption match
363+ case Some (copy) =>
364+ // now try to call the extension as a method, assume the object is its first parameter
365+ val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
366+ val argsWithObj = List (extensionParameter.map(name => name.name -> obj).toMap, argsMap)
367+ callMethod(Ref (objCompanion), copy, argsWithObj)
368+ case None => report.errorAndAbort(error)
293369 } else
294- report.errorAndAbort(s " Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} ( ${obj.show}) " )
370+ report.errorAndAbort(s " Unsupported source object: must be a case class, sealed trait or class with copy method , but got: $objSymbol of type ${objTpe.show} ( ${obj.show}) " )
295371 }
296372
297373 def applyFunctionDelegate (
@@ -331,9 +407,9 @@ object QuicklensMacros {
331407 case Nil =>
332408 objTerm
333409
334- case (_ : PathSymbol .Field , _) :: _ =>
335- val (fs, funs) = pathSymbols.span(_._1. isInstanceOf [PathSymbol .Field ])
336- val fields = fs.collect { case (p : PathSymbol .Field , trees) => p -> trees }
410+ case (_ : ( PathSymbol .Field | PathSymbol . Extension ) , _) :: _ =>
411+ val (fs, funs) = pathSymbols.span((ps, _) => ps. isInstanceOf [PathSymbol .Field ] || ps. isInstanceOf [ PathSymbol . Extension ])
412+ val fields = fs.collect { case (p : ( PathSymbol .Field | PathSymbol . Extension ) , trees) => p -> trees }
337413 val withCopiedFields : Term = caseClassCopy(owner, mod, objTerm, fields)
338414 accumulateToCopy(owner, mod, withCopiedFields, funs)
339415
0 commit comments