@@ -72,10 +72,9 @@ object Test {
72
72
})
73
73
}
74
74
}
75
- case nested : Nested [a, bt] => ???
76
- // {
77
- // fold_raw[bt](((e: bt) => fold_raw(consumer, nested.nestedf(e))), Linear(nested.producer))
78
- // }
75
+ case nested : Nested [a, bt] => {
76
+ fold_raw[bt](((e : bt) => fold_raw(consumer, nested.nestedf(e))), Linear (nested.producer))
77
+ }
79
78
}
80
79
}
81
80
@@ -108,13 +107,23 @@ object Test {
108
107
109
108
Linear (prod)
110
109
}
111
- case nested : Nested [a, bt] => ???
112
- // {
113
- // Nested(nested.producer, (a: bt) => mapRaw[A, B](f, nested.nestedf(a)))
114
- // }
110
+ case nested : Nested [a, bt] => {
111
+ Nested (nested.producer, (a : bt) => mapRaw[A , B ](f, nested.nestedf(a)))
112
+ }
115
113
}
116
114
}
117
115
116
+ def flatMap [B : Type ](f : (Expr [A ] => Stream [B ])): Stream [B ] = {
117
+ Stream (flatMapRaw[Expr [A ], Expr [B ]]((a => { val Stream (nested) = f(a); nested }), stream))
118
+ }
119
+
120
+ def flatMapRaw [A , B ](f : (A => StagedStream [B ]), stream : StagedStream [A ]): StagedStream [B ] = {
121
+ stream match {
122
+ case Linear (producer) => Nested (producer, f)
123
+ case nested : Nested [a, bt] =>
124
+ Nested (nested.producer, (a : bt) => flatMapRaw[A , B ](f, nested.nestedf(a)))
125
+ }
126
+ }
118
127
}
119
128
120
129
object Stream {
@@ -162,10 +171,17 @@ object Test {
162
171
.map((a : Expr [Int ]) => ' { ~ a * 2 })
163
172
.fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
164
173
174
+ def test3 () = Stream
175
+ .of(' {Array (1 , 2 , 3 )})
176
+ .flatMap((d : Expr [Int ]) => Stream .of(' {Array (1 , 2 , 3 )}).map((dp : Expr [Int ]) => ' { ~ d * ~ dp }))
177
+ .fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
178
+
165
179
def main (args : Array [String ]): Unit = {
166
180
println(test1().run)
167
181
println
168
182
println(test2().run)
183
+ println
184
+ println(test3().run)
169
185
}
170
186
}
171
187
0 commit comments