Skip to content

Commit 58e1726

Browse files
committed
Implement flatMap
1 parent 99071e0 commit 58e1726

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
6
22

3-
12
3+
12
4+
5+
36

tests/run-with-compiler/staged-streams_1.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,9 @@ object Test {
7272
})
7373
}
7474
}
75-
case nested: Nested[a, bt] => ???
76-
// {
77-
// fold_raw[bt](((e: bt) => fold_raw(consumer, nested.nestedf(e))), Linear(nested.producer))
78-
// }
75+
case nested: Nested[a, bt] => {
76+
fold_raw[bt](((e: bt) => fold_raw(consumer, nested.nestedf(e))), Linear(nested.producer))
77+
}
7978
}
8079
}
8180

@@ -108,13 +107,23 @@ object Test {
108107

109108
Linear(prod)
110109
}
111-
case nested: Nested[a, bt] => ???
112-
// {
113-
// Nested(nested.producer, (a: bt) => mapRaw[A, B](f, nested.nestedf(a)))
114-
// }
110+
case nested: Nested[a, bt] => {
111+
Nested(nested.producer, (a: bt) => mapRaw[A, B](f, nested.nestedf(a)))
112+
}
115113
}
116114
}
117115

116+
def flatMap[B : Type](f: (Expr[A] => Stream[B])): Stream[B] = {
117+
Stream(flatMapRaw[Expr[A], Expr[B]]((a => { val Stream (nested) = f(a); nested }), stream))
118+
}
119+
120+
def flatMapRaw[A, B](f: (A => StagedStream[B]), stream: StagedStream[A]): StagedStream[B] = {
121+
stream match {
122+
case Linear(producer) => Nested(producer, f)
123+
case nested: Nested[a, bt] =>
124+
Nested(nested.producer, (a: bt) => flatMapRaw[A, B](f, nested.nestedf(a)))
125+
}
126+
}
118127
}
119128

120129
object Stream {
@@ -162,10 +171,17 @@ object Test {
162171
.map((a: Expr[Int]) => '{ ~a * 2 })
163172
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
164173

174+
def test3() = Stream
175+
.of('{Array(1, 2, 3)})
176+
.flatMap((d: Expr[Int]) => Stream.of('{Array(1, 2, 3)}).map((dp: Expr[Int]) => '{ ~d * ~dp }))
177+
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
178+
165179
def main(args: Array[String]): Unit = {
166180
println(test1().run)
167181
println
168182
println(test2().run)
183+
println
184+
println(test3().run)
169185
}
170186
}
171187

0 commit comments

Comments
 (0)