Skip to content

Commit bc9fdce

Browse files
N5N3aviatesk
andauthored
Improve the inferability of some index-related self-recursive functions in Base (JuliaLang#45672)
* Better inferred `to_indices` * Better inferred `CartedianIndex(x::Union{Integer,CartedianIndex}...)` Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent 1f99ee9 commit bc9fdce

File tree

4 files changed

+39
-41
lines changed

4 files changed

+39
-41
lines changed

base/indices.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,15 @@ to_indices(A, I::Tuple{}) = ()
349349
to_indices(A, I::Tuple{Vararg{Int}}) = I
350350
to_indices(A, I::Tuple{Vararg{Integer}}) = (@inline; to_indices(A, (), I))
351351
to_indices(A, inds, ::Tuple{}) = ()
352-
to_indices(A, inds, I::Tuple{Any, Vararg{Any}}) =
353-
(@inline; (to_index(A, I[1]), to_indices(A, _maybetail(inds), tail(I))...))
352+
function to_indices(A, inds, I::Tuple{Any, Vararg{Any}})
353+
@inline
354+
head = _to_indices1(A, inds, I[1])
355+
rest = to_indices(A, _cutdim(inds, I[1]), tail(I))
356+
(head..., rest...)
357+
end
354358

355-
_maybetail(::Tuple{}) = ()
356-
_maybetail(t::Tuple) = tail(t)
359+
_to_indices1(A, inds, I1) = (to_index(A, I1),)
360+
_cutdim(inds, I1) = safe_tail(inds)
357361

358362
"""
359363
Slice(indices)

base/multidimensional.jl

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
module IteratorsMD
55
import .Base: eltype, length, size, first, last, in, getindex, setindex!, IndexStyle,
66
min, max, zero, oneunit, isless, eachindex, ndims, IteratorSize,
7-
convert, show, iterate, promote_rule, to_indices, to_index
7+
convert, show, iterate, promote_rule
88

99
import .Base: +, -, *, (:)
1010
import .Base: simd_outer_range, simd_inner_length, simd_index, setindex
11+
import .Base: to_indices, to_index, _to_indices1, _cutdim
1112
using .Base: IndexLinear, IndexCartesian, AbstractCartesianIndex, fill_to_length, tail,
12-
ReshapedArray, ReshapedArrayLF, OneTo
13+
ReshapedArray, ReshapedArrayLF, OneTo, Fix1
1314
using .Base.Iterators: Reverse, PartitionIterator
1415
using .Base: @propagate_inbounds
1516

@@ -75,13 +76,9 @@ module IteratorsMD
7576
CartesianIndex{N}() where {N} = CartesianIndex{N}(())
7677
# Un-nest passed CartesianIndexes
7778
CartesianIndex(index::Union{Integer, CartesianIndex}...) = CartesianIndex(flatten(index))
78-
flatten(I::Tuple{}) = I
79-
flatten(I::Tuple{Any}) = I
80-
flatten(I::Tuple{<:CartesianIndex}) = I[1].I
81-
@inline flatten(I) = _flatten(I...)
82-
@inline _flatten() = ()
83-
@inline _flatten(i, I...) = (i, _flatten(I...)...)
84-
@inline _flatten(i::CartesianIndex, I...) = (i.I..., _flatten(I...)...)
79+
flatten(::Tuple{}) = ()
80+
flatten(I::Tuple{Any}) = Tuple(I[1])
81+
@inline flatten(I::Tuple) = (Tuple(I[1])..., flatten(tail(I))...)
8582
CartesianIndex(index::Tuple{Vararg{Union{Integer, CartesianIndex}}}) = CartesianIndex(index...)
8683
show(io::IO, i::CartesianIndex) = (print(io, "CartesianIndex"); show(io, i.I))
8784

@@ -457,13 +454,11 @@ module IteratorsMD
457454
last(iter::CartesianIndices) = CartesianIndex(map(last, iter.indices))
458455

459456
# When used as indices themselves, CartesianIndices can simply become its tuple of ranges
460-
@inline function to_indices(A, inds, I::Tuple{CartesianIndices{N}, Vararg{Any}}) where N
461-
_, indstail = split(inds, Val(N))
462-
(map(i -> to_index(A, i), I[1].indices)..., to_indices(A, indstail, tail(I))...)
463-
end
457+
_to_indices1(A, inds, I1::CartesianIndices) = map(Fix1(to_index, A), I1.indices)
458+
_cutdim(inds::Tuple, I1::CartesianIndices) = split(inds, Val(ndims(I1)))[2]
459+
464460
# but preserve CartesianIndices{0} as they consume a dimension.
465-
@inline to_indices(A, inds, I::Tuple{CartesianIndices{0},Vararg{Any}}) =
466-
(first(I), to_indices(A, inds, tail(I))...)
461+
_to_indices1(A, inds, I1::CartesianIndices{0}) = (I1,)
467462

468463
@inline in(i::CartesianIndex, r::CartesianIndices) = false
469464
@inline in(i::CartesianIndex{N}, r::CartesianIndices{N}) where {N} = all(map(in, i.I, r.indices))
@@ -835,33 +830,23 @@ ensure_indexable(I::Tuple{}) = ()
835830
@inline to_indices(A, I::Tuple{Vararg{Union{Integer, CartesianIndex}}}) = to_indices(A, (), I)
836831
# But some index types require more context spanning multiple indices
837832
# CartesianIndex is unfolded outside the inner to_indices for better inference
838-
@inline function to_indices(A, inds, I::Tuple{CartesianIndex{N}, Vararg{Any}}) where N
839-
_, indstail = IteratorsMD.split(inds, Val(N))
840-
(map(i -> to_index(A, i), I[1].I)..., to_indices(A, indstail, tail(I))...)
841-
end
833+
_to_indices1(A, inds, I1::CartesianIndex) = map(Fix1(to_index, A), I1.I)
834+
_cutdim(inds, I1::CartesianIndex) = IteratorsMD.split(inds, Val(length(I1)))[2]
842835
# For arrays of CartesianIndex, we just skip the appropriate number of inds
843-
@inline function to_indices(A, inds, I::Tuple{AbstractArray{CartesianIndex{N}}, Vararg{Any}}) where N
844-
_, indstail = IteratorsMD.split(inds, Val(N))
845-
(to_index(A, I[1]), to_indices(A, indstail, tail(I))...)
846-
end
836+
_cutdim(inds, I1::AbstractArray{CartesianIndex{N}}) where {N} = IteratorsMD.split(inds, Val(N))[2]
847837
# And boolean arrays behave similarly; they also skip their number of dimensions
848-
@inline function to_indices(A, inds, I::Tuple{AbstractArray{Bool, N}, Vararg{Any}}) where N
849-
_, indstail = IteratorsMD.split(inds, Val(N))
850-
(to_index(A, I[1]), to_indices(A, indstail, tail(I))...)
851-
end
838+
_cutdim(inds::Tuple, I1::AbstractArray{Bool}) = IteratorsMD.split(inds, Val(ndims(I1)))[2]
852839
# As an optimization, we allow trailing Array{Bool} and BitArray to be linear over trailing dimensions
853840
@inline to_indices(A, inds, I::Tuple{Union{Array{Bool,N}, BitArray{N}}}) where {N} =
854841
(_maybe_linear_logical_index(IndexStyle(A), A, I[1]),)
855842
_maybe_linear_logical_index(::IndexStyle, A, i) = to_index(A, i)
856843
_maybe_linear_logical_index(::IndexLinear, A, i) = LogicalIndex{Int}(i)
857844

858845
# Colons get converted to slices by `uncolon`
859-
@inline to_indices(A, inds, I::Tuple{Colon, Vararg{Any}}) =
860-
(uncolon(inds, I), to_indices(A, _maybetail(inds), tail(I))...)
846+
_to_indices1(A, inds, I1::Colon) = (uncolon(inds),)
861847

862-
const CI0 = Union{CartesianIndex{0}, AbstractArray{CartesianIndex{0}}}
863-
uncolon(inds::Tuple{}, I::Tuple{Colon, Vararg{Any}}) = Slice(OneTo(1))
864-
uncolon(inds::Tuple, I::Tuple{Colon, Vararg{Any}}) = Slice(inds[1])
848+
uncolon(::Tuple{}) = Slice(OneTo(1))
849+
uncolon(inds::Tuple) = Slice(inds[1])
865850

866851
### From abstractarray.jl: Internal multidimensional indexing definitions ###
867852
getindex(x::Union{Number,AbstractChar}, ::CartesianIndex{0}) = x

test/abstractarray.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,10 +1607,13 @@ end
16071607
end
16081608

16091609
@testset "to_indices inference (issue #42001 #44059)" begin
1610-
@test (@inferred to_indices([], ntuple(Returns(CartesianIndex(1)), 32))) == ntuple(Returns(1), 32)
1611-
@test (@inferred to_indices([], ntuple(Returns(CartesianIndices(1:1)), 32))) == ntuple(Returns(Base.OneTo(1)), 32)
1612-
@test (@inferred to_indices([], (CartesianIndex(),1,CartesianIndex(1,1,1)))) == ntuple(Returns(1), 4)
1613-
A = randn(2,2,2,2,2,2);
1614-
i = CartesianIndex((1,1))
1610+
CIdx = CartesianIndex
1611+
CIdc = CartesianIndices
1612+
@test (@inferred to_indices([], ntuple(Returns(CIdx(1)), 32))) == ntuple(Returns(1), 32)
1613+
@test (@inferred to_indices([], ntuple(Returns(CIdc(1:1)), 32))) == ntuple(Returns(Base.OneTo(1)), 32)
1614+
@test (@inferred to_indices([], (CIdx(), 1, CIdx(1,1,1)))) == ntuple(Returns(1), 4)
1615+
A = randn(2, 2, 2, 2, 2, 2);
1616+
i = CIdx((1, 1))
16151617
@test (@inferred A[i,i,i]) === A[1]
1618+
@test (@inferred to_indices([], (1, CIdx(1, 1), 1, CIdx(1, 1), 1, CIdx(1, 1), 1))) == ntuple(Returns(1), 10)
16161619
end

test/cartesian.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,9 @@ end
530530
@test isassigned(A, CartesianIndex(1, 2), 3)
531531
@test !isassigned(A, CartesianIndex(5, 2), 3)
532532
end
533+
534+
@testset "`CartedianIndex(x::Union{Integer,CartedianIndex}...)`'s stability" begin
535+
CI = CartesianIndex
536+
inds2 = (1, CI(1, 2), 1, CI(1, 2), 1, CI(1, 2), 1)
537+
@test (@inferred CI(inds2)) == CI(1, 1, 2, 1, 1, 2, 1, 1, 2, 1)
538+
end

0 commit comments

Comments
 (0)