|
| 1 | +package x |
| 2 | + |
| 3 | +import scala.quoted.* |
| 4 | + |
| 5 | +def fun(x:Int): Int = ??? |
| 6 | + |
| 7 | +transparent inline def in1[T](inline expr: Int => Int): Int => Int = |
| 8 | + ${ |
| 9 | + M.transformLambdaImpl('expr) |
| 10 | + } |
| 11 | + |
| 12 | +object M: |
| 13 | + |
| 14 | + def transformLambdaImpl(cexpr: Expr[Int => Int])(using Quotes): Expr[Int => Int] = |
| 15 | + import quotes.reflect.* |
| 16 | + |
| 17 | + def extractLambda(f:Term): (ValDef, Term, Term => Term ) = |
| 18 | + f match |
| 19 | + case Inlined(call, bindings, body) => |
| 20 | + val inner = extractLambda(body) |
| 21 | + (inner._1, inner._2, t => Inlined(call, bindings, t) ) |
| 22 | + case Lambda(params,body) => |
| 23 | + params match |
| 24 | + case List(vd) => (vd, body, identity) |
| 25 | + case _ => report.throwError(s"lambda with one argument expected, we have ${params}",cexpr) |
| 26 | + case Block(Nil,nested@Lambda(params,body)) => extractLambda(nested) |
| 27 | + case _ => |
| 28 | + report.throwError(s"lambda expected, have: ${f}", cexpr) |
| 29 | + |
| 30 | + val (oldValDef, body, inlineBack) = extractLambda(cexpr.asTerm) |
| 31 | + val mt = MethodType(List(oldValDef.name))( _ => List(oldValDef.tpt.tpe), _ => TypeRepr.of[Int]) |
| 32 | + val nLambda = Lambda(Symbol.spliceOwner, mt, (owner, params) => { |
| 33 | + val argTransformer = new TreeMap() { |
| 34 | + override def transformTerm(tree: Term)(owner: Symbol): Term = |
| 35 | + tree match |
| 36 | + case Ident(name) if (tree.symbol == oldValDef.symbol) => Ref(params.head.symbol) |
| 37 | + case _ => super.transformTerm(tree)(owner) |
| 38 | + } |
| 39 | + argTransformer.transformTerm('{ fun(${body.asExprOf[Int]}) }.asTerm )(owner) |
| 40 | + }) |
| 41 | + inlineBack(nLambda).asExprOf[Int => Int] |
0 commit comments