diff --git a/src/read.jl b/src/read.jl index 8fcc60d..cd4a8da 100644 --- a/src/read.jl +++ b/src/read.jl @@ -93,13 +93,13 @@ end const FLOAT_INT_BOUND = 2.0^53 -function read!(buf, pos, len, b, tape, tapeidx, ::Type{Any}, checkint=true; allow_inf::Bool=false) +function read!(buf, pos, len, b, tape, tapeidx, ::Type{Any}, checkint=true; inf_mapping::Union{Function,Nothing}=nothing, allow_inf::Bool=(inf_mapping !== nothing)) if b == UInt8('{') - return read!(buf, pos, len, b, tape, tapeidx, Object, checkint; allow_inf=allow_inf) + return read!(buf, pos, len, b, tape, tapeidx, Object, checkint; allow_inf=allow_inf, inf_mapping=inf_mapping) elseif b == UInt8('[') - return read!(buf, pos, len, b, tape, tapeidx, Array, checkint; allow_inf=allow_inf) + return read!(buf, pos, len, b, tape, tapeidx, Array, checkint; allow_inf=allow_inf, inf_mapping=inf_mapping) elseif b == UInt8('"') - return read!(buf, pos, len, b, tape, tapeidx, String) + return read!(buf, pos, len, b, tape, tapeidx, String; inf_mapping=inf_mapping) elseif b == UInt8('n') return read!(buf, pos, len, b, tape, tapeidx, Nothing) elseif b == UInt8('t') @@ -148,7 +148,7 @@ function read!(buf, pos, len, b, tape, tapeidx, ::Type{Any}, checkint=true; allo invalid(InvalidChar, buf, pos, Any) end -function read!(buf, pos, len, b, tape, tapeidx, ::Type{String}) +function read!(buf, pos, len, b, tape, tapeidx, ::Type{String}; inf_mapping::Union{Function,Nothing}=nothing) pos += 1 @eof strpos = pos @@ -171,6 +171,23 @@ function read!(buf, pos, len, b, tape, tapeidx, ::Type{String}) b = getbyte(buf, pos) end @check + if inf_mapping !== nothing + val = view(buf, strpos-1:pos) + float = if val == codeunits(inf_mapping(Inf)) + Inf + elseif val == codeunits(inf_mapping(-Inf)) + -Inf + elseif val == codeunits(inf_mapping(NaN)) + NaN + else + 0.0 + end + if float != 0.0 + @inbounds tape[tapeidx] = FLOAT + @inbounds tape[tapeidx+1] = Core.bitcast(UInt64, float) + return pos + 1, tapeidx + 2 + end + end @inbounds tape[tapeidx] = string(strlen) @inbounds tape[tapeidx+1] = ifelse(escaped, ESCAPE_BIT | strpos, strpos) return pos + 1, tapeidx + 2 diff --git a/src/write.jl b/src/write.jl index bfff8a1..cc07cfa 100644 --- a/src/write.jl +++ b/src/write.jl @@ -25,6 +25,11 @@ Write JSON. ## Keyword Args * `allow_inf`: Allow writing of `Inf` and `NaN` values (not part of the JSON standard). [default `false`] +* `inf_mapping`: A function to map `Inf`, `-Inf` and `NaN` values to a custom representation. [default `nothing`] + + if `inf_mapping` is `nothing` the mapping is equivalent to + `inf_mapping = x -> x == Inf ? "Infinity" : x == -Inf ? "-Infinity" : "NaN"`. + Specifying `inf_mapping` will automatically set the default value of `allow_inf` to `true`. * `dateformat`: A [`DateFormat`](https://docs.julialang.org/en/v1/stdlib/Dates/#Dates.DateFormat) describing how to format `Date`s in the object. [default `Dates.default_format(T)`] """ function write(io::IO, obj::T; kw...) where {T} @@ -279,19 +284,26 @@ function write(::NumberType, buf, pos, len, x::AbstractFloat; allow_inf::Bool=fa return buf, pos, len end -@inline function write(::NumberType, buf, pos, len, x::T; allow_inf::Bool=false, kw...) where {T <: Base.IEEEFloat} - isfinite(x) || allow_inf || error("$x not allowed to be written in JSON spec") - if isinf(x) +@inline function write(::NumberType, buf, pos, len, x::T; inf_mapping::Union{Function, Nothing} = nothing, allow_inf::Bool = inf_mapping !== nothing, kw...) where {T <: Base.IEEEFloat} + if isfinite(x) || (allow_inf && inf_mapping === nothing && isnan(x)) + @check Ryu.neededdigits(T) + pos = Ryu.writeshortest(buf, pos, x) + else + allow_inf || error("$x not allowed to be written in JSON spec") # Although this is non-standard JSON, "Infinity" is commonly used. # See https://docs.python.org/3/library/json.html#infinite-and-nan-number-values. - if sign(x) == -1 - @writechar '-' + if inf_mapping === nothing + sign(x) == -1 && @writechar '-' + @writechar 'I' 'n' 'f' 'i' 'n' 'i' 't' 'y' + else + bytes = codeunits(inf_mapping(x)) + @check length(bytes) + for b in bytes + @inbounds buf[pos] = b + pos += 1 + end end - @writechar 'I' 'n' 'f' 'i' 'n' 'i' 't' 'y' - return buf, pos, len end - @check Ryu.neededdigits(T) - pos = Ryu.writeshortest(buf, pos, x) return buf, pos, len end diff --git a/test/json.jl b/test/json.jl index 98b1e2c..40f8ecd 100644 --- a/test/json.jl +++ b/test/json.jl @@ -46,6 +46,11 @@ end @test JSON3.read("Inf"; allow_inf=true) === Inf @test JSON3.read("Infinity"; allow_inf=true) === Inf @test JSON3.read("-Infinity"; allow_inf=true) === -Inf + + quoted_inf_mapping(x) = x == Inf ? "\"Infinity\"" : x == -Inf ? "\"-Infinity\"" : "\"NaN\"" + @test JSON3.write(NaN, inf_mapping = quoted_inf_mapping) == "\"NaN\"" + @test JSON3.write(Inf, inf_mapping = quoted_inf_mapping) == "\"Infinity\"" + @test JSON3.write(-Inf, inf_mapping = quoted_inf_mapping) == "\"-Infinity\"" end @testset "Char" begin