Skip to content

Commit d82ccc1

Browse files
authored
faster iteration over a Flatten of heterogenous iterators (JuliaLang#58522)
seems to help in many cases. would fix the precise MWE given in JuliaLang#52552, but does not necessarily fix comprehensively all perf issues of all heterogenous flattens. but, may as well be better when it's possible setup: ``` julia> using BenchmarkTools julia> A = rand(Int, 100000); B = 1:100000; julia> function g(it) s = 0 for i in it s += i end s end ``` before: ``` julia> @Btime g($(Iterators.flatten((A, B)))) 12.461 ms (698979 allocations: 18.29 MiB) julia> @Btime g($(Iterators.flatten(i for i in (A, B)))) 12.393 ms (698979 allocations: 18.29 MiB) julia> @Btime g($(Iterators.flatten([A, B]))) 15.115 ms (999494 allocations: 25.93 MiB) julia> @Btime g($(Iterators.flatten((A, Iterators.flatten((A, B)))))) 82.585 ms (2997964 allocations: 106.78 MiB) ``` after: ``` julia> @Btime g($(Iterators.flatten((A, B)))) 135.958 μs (2 allocations: 64 bytes) julia> @Btime g($(Iterators.flatten(i for i in (A, B)))) 149.500 μs (2 allocations: 64 bytes) julia> @Btime g($(Iterators.flatten([A, B]))) 17.130 ms (999498 allocations: 25.93 MiB) julia> @Btime g($(Iterators.flatten((A, Iterators.flatten((A, B)))))) 13.716 ms (398983 allocations: 10.67 MiB) ```
1 parent 69a2a57 commit d82ccc1

File tree

1 file changed

+40
-12
lines changed

1 file changed

+40
-12
lines changed

base/iterators.jl

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,20 +1239,48 @@ flatten_length(f, T) = throw(ArgumentError(
12391239
length(f::Flatten{I}) where {I} = flatten_length(f, eltype(I))
12401240
length(f::Flatten{Tuple{}}) = 0
12411241

1242-
@propagate_inbounds function iterate(f::Flatten, state=())
1243-
if state !== ()
1244-
y = iterate(tail(state)...)
1245-
y !== nothing && return (y[1], (state[1], state[2], y[2]))
1242+
@propagate_inbounds function iterate(fl::Flatten)
1243+
it_result = iterate(fl.it)
1244+
it_result === nothing && return nothing
1245+
1246+
inner_iterator, next_outer_state = it_result
1247+
inner_it_result = iterate(inner_iterator)
1248+
1249+
while inner_it_result === nothing
1250+
it_result = iterate(fl.it, next_outer_state)
1251+
it_result === nothing && return nothing
1252+
1253+
inner_iterator, next_outer_state = it_result
1254+
inner_it_result = iterate(inner_iterator)
12461255
end
1247-
x = (state === () ? iterate(f.it) : iterate(f.it, state[1]))
1248-
x === nothing && return nothing
1249-
y = iterate(x[1])
1250-
while y === nothing
1251-
x = iterate(f.it, x[2])
1252-
x === nothing && return nothing
1253-
y = iterate(x[1])
1256+
1257+
item, next_inner_state = inner_it_result
1258+
return item, (next_outer_state, inner_iterator, next_inner_state)
1259+
end
1260+
1261+
@propagate_inbounds function iterate(fl::Flatten, state)
1262+
next_outer_state, inner_iterator, next_inner_state = state
1263+
1264+
# try to advance the inner iterator
1265+
inner_it_result = iterate(inner_iterator, next_inner_state)
1266+
if inner_it_result !== nothing
1267+
item, next_inner_state = inner_it_result
1268+
return item, (next_outer_state, inner_iterator, next_inner_state)
1269+
end
1270+
1271+
# advance the outer iterator
1272+
while true
1273+
outer_it_result = iterate(fl.it, next_outer_state)
1274+
outer_it_result === nothing && return nothing
1275+
1276+
inner_iterator, next_outer_state = outer_it_result
1277+
inner_it_result = iterate(inner_iterator)
1278+
1279+
if inner_it_result !== nothing
1280+
item, next_inner_state = inner_it_result
1281+
return item, (next_outer_state, inner_iterator, next_inner_state)
1282+
end
12541283
end
1255-
return y[1], (x[2], x[1], y[2])
12561284
end
12571285

12581286
reverse(f::Flatten) = Flatten(reverse(itr) for itr in reverse(f.it))

0 commit comments

Comments
 (0)