Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/methods/aggregate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -457,20 +457,20 @@ end
return _reduce_noskip(sum, block, mv, dst)
end
@propagate_inbounds function _reduce_noskip(::typeof(sum), block, mv, dst)
agg = zero(eltype(block))
agg = Base.add_sum(zero(eltype(block)), zero(eltype(block)))
for x in block
_ismissing(x, mv) && return _missingval_or_missing(dst)
agg += x
agg += Base.add_sum(agg, x)
end
return agg
end
@propagate_inbounds function _reduce_noskip(::typeof(DD.Statistics.mean), block, mv, dst)
agg = zero(eltype(block))
agg = float(zero(eltype(block))) # Force floating point for mean
n = 0
for x in block
_ismissing(x, mv) && return _missingval_or_missing(dst)
n += 1
agg += x
agg += Base.add_sum(agg, x)
end
return agg / n
end
Expand All @@ -488,24 +488,25 @@ end
return _reduce_skip(sum, block, mv, dst)
end
@propagate_inbounds function _reduce_skip(::typeof(sum), block, mv, dst)
agg = zero(eltype(block))
# Use add_sum to get the correct type, e.g. UInt64 from UInt
agg = Base.add_sum(zero(eltype(block)), zero(eltype(block)))
found = false
for x in block
_ismissing(x, mv) && continue
found = true
agg += x
agg += Base.add_sum(agg, x)
end
return found ? agg : _missingval_or_missing(dst)
end
@propagate_inbounds function _reduce_skip(::typeof(DD.Statistics.mean), block, mv, dst)
agg = zero(eltype(block))
agg = float(zero(eltype(block))) # Force floating point for mean
found = false
n = 0
for x in block
_ismissing(x, mv) && continue
found = true
n += 1
agg += x
agg += Base.add_sum(agg, x)
end
return found ? agg / n : _missingval_or_missing(dst)
end
Expand Down
9 changes: 6 additions & 3 deletions src/methods/rasterize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ _reduce_init(reducer, nt::NamedTuple, missingval) = map(x -> _reduce_init(reduce
_reduce_init(f, x, missingval) = _reduce_init(f, typeof(x), missingval)

_reduce_init(::Nothing, x::Type{T}, missingval) where T = zero(T)
_reduce_init(f::Function, ::Type{T}, missingval) where T = zero(f((zero(nonmissingtype(T)), zero(nonmissingtype(T)))))
_reduce_init(::typeof(sum), ::Type{T}, missingval) where T = zero(nonmissingtype(T))
_reduce_init(::typeof(prod), ::Type{T}, missingval) where T = oneunit(nonmissingtype(T))
_reduce_init(f::Function, ::Type{T}, missingval) where T =
zero(f((zero(nonmissingtype(T)), zero(nonmissingtype(T)))))
_reduce_init(::typeof(sum), ::Type{T}, missingval) where T =
Base.add_sum(zero(nonmissingtype(T)), zero(nonmissingtype(T))) # add_sum(zero, zero) for correct type
_reduce_init(::typeof(prod), ::Type{T}, missingval) where T =
Base.mul_prod(oneunit(nonmissingtype(T)), one(nonmissingtype(T))) # mul_prod(oneunit, one) for correct type
_reduce_init(::typeof(minimum), ::Type{T}, missingval) where T = typemax(nonmissingtype(T))
_reduce_init(::typeof(maximum), ::Type{T}, missingval) where T = typemin(nonmissingtype(T))
_reduce_init(::typeof(last), ::Type{T}, missingval) where T = _maybe_to_missing(missingval)
Expand Down
24 changes: 24 additions & 0 deletions test/aggregate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,27 @@ end
@test eager_disag_series == lazy_disag_series
@test all(x -> all(x -> x isa SubArray, parent(x)), lazy_disag_series)
end

@testset "aggregate with integer types does not overflow" begin
uint8_a = Raster(fill(UInt8(200), (X(1:2), Y(1:2))))
int8_a = Raster(fill(Int8(100), (X(1:2), Y(1:2))))

# Summing four UInt8(200) values will overflow UInt8, should return UInt
sum_uint8 = aggregate(sum, uint8_a, 2)
@test eltype(sum_uint8) == UInt
@test sum_uint8[1, 1] == 800

# Summing four Int8(100) values will overflow Int8, should return Int
sum_int8 = aggregate(sum, int8_a, 2)
@test eltype(sum_int8) == Int
@test sum_int8[1, 1] == 400

# Mean should return Float64
mean_uint8 = aggregate(mean, uint8_a, 2)
@test eltype(mean_uint8) == Float64
@test mean_uint8[1, 1] == 200.0

mean_int8 = aggregate(mean, int8_a, 2)
@test eltype(mean_int8) == Float64
@test mean_int8[1, 1] == 100.0
end
29 changes: 28 additions & 1 deletion test/rasterize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,33 @@ end
@test replace_missing(rasterize(last, pointtable; to = A1, fill = 1), 0) == replace_missing(rasterize(last, fancy_table; to = A1, fill = 1), 0) # sanity check
end

@testset "Type promotion handles integer overflows in rasterize" begin
poly1 = GI.Polygon([[[-20.0, 30.0], [-20.0, 28.0], [-18.0, 28.0], [-18.0, 30.0], [-20.0, 30.0]]])
poly2 = GI.Polygon([[[-19.0, 29.0], [-19.0, 27.0], [-17.0, 27.0], [-17.0, 29.0], [-19.0, 29.0]]])
polys = [poly1, poly2]

# Test UInt8 sum: 200 + 200 would overflow UInt8, should promote to UInt
r_sum_uint = rasterize(sum, polys; res=1.0, fill=UInt8(200), missingval=UInt8(0))
@test eltype(r_sum_uint) === UInt
@test r_sum_uint[X=-19, Y=28] === UInt(400)
@test r_sum_uint[X=-20, Y=30] === UInt(200)
@test r_sum_uint[X=-17, Y=27] === UInt(200)

# Test Int8 sum: -100 + -100 would overflow Int8, should promote to Int
r_sum_int = rasterize(sum, polys; res=1.0, fill=Int8(-100), missingval=Int8(0))
@test eltype(r_sum_int) === Int
@test r_sum_int[X=-19, Y=28] === Int(-200)
@test r_sum_int[X=-20, Y=30] === Int(-100)
@test r_sum_int[X=-17, Y=27] === Int(-100)

# Test UInt8 prod: 20 * 20 would overflow UInt8, should promote to UInt
r_prod_uint = rasterize(prod, polys; res=1.0, fill=UInt8(20), missingval=UInt8(0))
@test eltype(r_prod_uint) === UInt
@test r_prod_uint[X=-19, Y=28] === UInt(400)
@test r_prod_uint[X=-20, Y=30] === UInt(20)
@test r_prod_uint[X=-17, Y=27] === UInt(20)
end

@testset "rasterizing strange types" begin
@testset "vector of feature indices, with overlap" begin
polygons = [
Expand Down Expand Up @@ -584,4 +611,4 @@ end
@test count(x -> x == [2], result) == 12
@test count(x -> x == [1, 2], result) == 12
end
end
end
Loading