diff --git a/src/Extents.jl b/src/Extents.jl index 8febd59..6b0d177 100644 --- a/src/Extents.jl +++ b/src/Extents.jl @@ -128,9 +128,9 @@ function union(ext1::Extent, ext2::Extent; strict=false) else values = map(keys) do k k = _unwrap(k) - k_exts = (ext1[k], ext2[k]) - a = min(map(first, k_exts)...) - b = max(map(last, k_exts)...) + b1, b2 = ext1[k], ext2[k] + a = _nanfree(min, b1[1], b2[1]) + b = _nanfree(max, b1[2], b2[2]) (a, b) end return Extent{map(_unwrap, keys)}(values) @@ -142,6 +142,13 @@ union(::Nothing, ::Nothing; kw...) = nothing union(a, b; kw...) = union(extent(a), extent(b)) union(a, b, c, args...; kw...) = union(union(a, b), c, args...) +_nanfree(f, a, b) = f(a, b) +function _nanfree(f, a::F, b::F) where F<:AbstractFloat + isnan(a) && return b + isnan(b) && return a + f(a, b) +end + """ intersection(ext1::Extent, ext2::Extent; strict=false) @@ -160,9 +167,9 @@ function intersection(a::Extent, b::Extent; strict=false) # Get a symbol from `Val{:k}` k = _unwrap(k) # Acces the k symbol of `a` and `b` - k_exts = (a[k], b[k]) - maxs = max(map(first, k_exts)...) - mins = min(map(last, k_exts)...) + ba, bb = a[k], b[k] + maxs = _nanfree(max, ba[1], bb[1]) + mins = _nanfree(min, ba[2], bb[2]) (maxs, mins) end return Extent{map(_unwrap, keys)}(values) diff --git a/test/runtests.jl b/test/runtests.jl index c2b7e9f..205dd70 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,9 +57,11 @@ end a = E(X=(0.1, 0.5), Y=(1.0, 2.0)) b = E(X=(2.1, 2.5), Y=(3.0, 4.0), Z=(0.0, 1.0)) c = E(Z=(0.2, 2.0)) + n = E(X=(NaN, 0.9), Y=(1.0, NaN)) @test Extents.union(a, b) == Extents.union(a, b, a) == E(X=(0.1, 2.5), Y=(1.0, 4.0)) @test Extents.union(a, b; strict=true) === nothing @test Extents.union(a, c) === nothing + @test Extents.union(a, n) === E(X=(0.1, 0.9), Y=(1.0, 2.0)) # If either argument is nothing, return the other @test Extents.union(a, nothing) === a @@ -279,6 +281,7 @@ end c = E(X=(0.4, 2.5), Y=(1.5, 4.0), Z=(0.0, 1.0)) d = E(X=(0.2, 0.45)) e = E(A=(0.0, 1.0)) + n = E(X=(NaN, 0.9), Y=(1.0, NaN)) # a and b don't intersect @test Extents.intersects(a, b) == false @test Extents.intersects(b, a) == false @@ -306,6 +309,7 @@ end # Objects that have extents can be used @test Extents.intersects(HasExtent(), Extents.extent(HasExtent())) == true @test Extents.intersects(Extents.extent(HasExtent()), HasExtent()) == true + @test Extents.intersects(a, n) == false # a and b are disjoint @test Extents.disjoint(a, b) == true @@ -334,6 +338,7 @@ end # Objects that have extents can be used @test Extents.disjoint(HasExtent(), Extent(X=(2, 3), Y=(4, 5))) == true @test Extents.disjoint(Extent(X=(2, 3), Y=(4, 5)), HasExtent()) == true + @test Extents.disjoint(a, n) == true # a and b do not intersect @test Extents.intersection(a, b) === nothing @@ -348,6 +353,8 @@ end # Unless strict is true @test Extents.intersection(a, c; strict=true) === nothing @test Extents.intersection(c, a; strict=true) === nothing + # NaN returns nothing + @test Extents.intersection(a, n) == nothing # a and d intersect on X @test Extents.intersection(a, d) == Extents.intersection(d, a) == Extent(X=(0.2, 0.45))