Skip to content

Commit e8dd3a7

Browse files
committed
implememted fold for Inpit
1 parent c5ad64c commit e8dd3a7

File tree

2 files changed

+91
-21
lines changed

2 files changed

+91
-21
lines changed

src/main/scala/gopher/channels/Input.scala

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import scala.reflect.macros.blackbox.Context
88
import scala.reflect.api._
99
import scala.util._
1010
import java.util.concurrent.ConcurrentLinkedQueue
11+
1112
import gopher._
13+
import gopher.util._
14+
1215

1316
import java.util.concurrent.atomic._
1417

@@ -316,13 +319,49 @@ trait Input[A]
316319
}
317320

318321

319-
/*
320-
def fold[S,B](s0:S)(f:(S,A)=>(S,Option[B])) = new Input[A] {
322+
def fold[S,B](s0:S)(f:(S,A)=>S): S = macro InputMacro.foldImpl[A,S]
321323

322-
def cbread[C](f: ContRead[A,C] => Option[ContRead.In[A]=>Future[Continuated[C]]], ft: FlowTermination[C] ): Unit =
324+
def afold[S,B](s0:S)(f:(S,A)=>S): Future[S] = macro InputMacro.afoldImpl[A,S]
325+
326+
327+
def afoldSync[S,B](s0:S)(f:(S,A)=>S): Future[S] =
328+
{
329+
val ft = PromiseFlowTermination[S]
330+
var s = s0
331+
def applyF(cont:ContRead[A,S]):Option[ContRead.In[A]=>Future[Continuated[S]]] =
332+
{
333+
val contFold = ContRead(applyF,this,ft)
334+
Some{
335+
case ContRead.ChannelClosed => Future successful Done(s,ft)
336+
case ContRead.Value(a) => s = f(s,a)
337+
Future successful contFold
338+
case ContRead.Skip => Future successful contFold
339+
case ContRead.Failure(ex) => Future failed ex
340+
}
341+
}
342+
cbread(applyF,ft)
343+
ft.future
344+
}
323345

346+
def afoldAsync[S,B](s0:S)(f:(S,A)=>Future[S])(implicit ec:ExecutionContext): Future[S] =
347+
{
348+
val ft = PromiseFlowTermination[S]
349+
var s = s0
350+
def applyF(cont:ContRead[A,S]):Option[ContRead.In[A]=>Future[Continuated[S]]] =
351+
{
352+
Some{
353+
case ContRead.ChannelClosed => Future successful Done(s,ft)
354+
case ContRead.Value(a) => f(s,a) map { x =>
355+
s = x
356+
ContRead(applyF,this,ft)
357+
}
358+
case ContRead.Skip => Future successful ContRead(applyF,this,ft)
359+
case ContRead.Failure(ex) => Future failed ex
360+
}
361+
}
362+
cbread(applyF,ft)
363+
ft.future
324364
}
325-
*/
326365

327366
}
328367

@@ -391,25 +430,10 @@ object InputMacro
391430
def aforeachImpl[A](c:Context)(f:c.Expr[A=>Unit]): c.Expr[Future[Unit]] =
392431
{
393432
import c.universe._
394-
val findAwait = new Traverser {
395-
var found = false
396-
override def traverse(tree:Tree):Unit =
397-
{
398-
if (!found) {
399-
tree match {
400-
case Apply(TypeApply(Select(obj,TermName("await")),objType), args) =>
401-
if (obj.tpe =:= typeOf[scala.async.Async.type]) {
402-
found=true
403-
} else super.traverse(tree)
404-
case _ => super.traverse(tree)
405-
}
406-
}
407-
}
408-
}
409433
f.tree match {
410434
case Function(valdefs,body) =>
411-
findAwait.traverse(body)
412-
if (findAwait.found) {
435+
if (MacroUtil.hasAwait(c)(body)) {
436+
// TODO: add support for flow-termination (?)
413437
val nbody = q"scala.async.Async.async(${body})"
414438
val nfunction = atPos(f.tree.pos)(Function(valdefs,nbody))
415439
val ntree = q"${c.prefix}.foreachAsync(${nfunction})"
@@ -421,5 +445,27 @@ object InputMacro
421445
}
422446
}
423447

448+
def foldImpl[A,S](c:Context)(s0:c.Expr[S])(f:c.Expr[(S,A)=>S]): c.Expr[S] =
449+
{
450+
import c.universe._
451+
c.Expr[S](q"scala.async.Async.await(${afoldImpl(c)(s0)(f)})")
452+
}
453+
454+
def afoldImpl[A,S](c:Context)(s0:c.Expr[S])(f:c.Expr[(S,A)=>S]): c.Expr[Future[S]] =
455+
{
456+
import c.universe._
457+
f.tree match {
458+
case Function(valdefs,body) =>
459+
if (MacroUtil.hasAwait(c)(body)) {
460+
val nbody = atPos(body.pos)(q"scala.async.Async.async(${body})")
461+
val nfunction = atPos(f.tree.pos)(Function(valdefs,nbody))
462+
val ntree = q"${c.prefix}.afoldAsync(${s0.tree})(${nfunction})"
463+
c.Expr[Future[S]](c.untypecheck(ntree))
464+
} else {
465+
c.Expr[Future[S]](q"${c.prefix}.afoldSync(${s0.tree})(${f.tree})")
466+
}
467+
}
468+
}
469+
424470

425471
}

src/main/scala/gopher/util/MacroUtil.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gopher.util
22

33
import scala.reflect.macros.blackbox.Context
44
import scala.reflect.api._
5+
import scala.language.reflectiveCalls
56

67

78
object MacroUtil
@@ -30,5 +31,28 @@ object MacroUtil
3031
}
3132
}
3233

34+
def hasAwait(c:Context)(x: c.Tree):Boolean =
35+
{
36+
import c.universe._
37+
val findAwait = new Traverser {
38+
var found = false
39+
override def traverse(tree:Tree):Unit =
40+
{
41+
if (!found) {
42+
tree match {
43+
case Apply(TypeApply(Select(obj,TermName("await")),objType), args) =>
44+
if (obj.tpe =:= typeOf[scala.async.Async.type]) {
45+
found=true
46+
} else super.traverse(tree)
47+
case _ => super.traverse(tree)
48+
}
49+
}
50+
}
51+
}
52+
findAwait.traverse(x)
53+
findAwait.found
54+
}
55+
56+
3357
final val SHORT_LEN = 80
3458
}

0 commit comments

Comments
 (0)