Skip to content

Ordered transformations #83

@cscherrer

Description

@cscherrer

Hi, have you looked into ordered transformations, as in Stan's approach?

I was experimenting with this, and so far this seems to work ok:

const TV = TransformVariables

struct Ordered{T <: TV.AbstractTransform} <: TV.VectorTransform
    transformation::T
    dim::Int
end

TV.dimension(t::Ordered) = t.dim

addlogjac(ℓ, Δℓ) =+ Δℓ
addlogjac(::TV.NoLogJac, Δℓ) = TV.NoLogJac()


bounds(t::TV.ShiftedExp{true}) = (t.shift, TV.∞)
bounds(t::TV.ShiftedExp{false}) = (-TV.∞, t.shift)
bounds(t::TV.ScaledShiftedLogistic) = (t.shift, t.scale + t.shift)
bounds(::TV.Identity) = (-TV.∞, TV.∞)

# See https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html
function TV.transform_with(flag::TV.LogJacFlag, t::Ordered, x, index::T) where {T}
    transformation, len = (t.transformation, t.dim)
    @assert dimension(transformation) == 1
    y = similar(x, len)
        
    (lo,hi) = bounds(transformation)

    @inbounds (y[1], ℓ, _) = TV.transform_with(flag, as(Real, lo, hi), x, index)

    @inbounds for i in 2:len
        (y[i], Δℓ, _) =  TV.transform_with(flag, as(Real, y[i-1], hi), x, index)
        ℓ = addlogjac(ℓ, Δℓ)
        index += 1
    end

    return (y, ℓ, index)
end

TV.inverse_eltype(t::Ordered, y::AbstractVector) = TV.extended_eltype(y)

For example,

julia> transform(Ordered(as𝕀, 10), randn(10))
10-element Vector{Float64}:
 0.43442911666628803
 0.6801295759251248
 0.8652960112555127
 0.9378879818399162
 0.97186447061553
 0.992691488683781
 0.9954155804092922
 0.9973011716146748
 0.9984116151813937
 0.9991915044548042

julia> transform(Ordered(asℝ₊, 10), randn(10))
10-element Vector{Float64}:
 0.5238744065026507
 1.0477488130053014
 1.250352594752778
 1.6981546986067726
 2.4349325599881855
 4.314050541989461
 4.56554977859454
 5.699772340085799
 7.099209192367974
 7.7740732481190165

It does sometimes run into trouble, for example

julia> transform(Ordered(as𝕀, 100), randn(100))
ERROR: ArgumentError: the interval (1.0, 1.0) is empty
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/ArgCheck/5xEDR/src/checks.jl:243 [inlined]
 [2] as(#unused#::Type{Real}, left::Float64, right::Float64)
   @ TransformVariables ~/.julia/packages/TransformVariables/DDsiH/src/scalar.jl:173
 [3] transform_with(flag::TransformVariables.NoLogJac, t::Ordered{TransformVariables.ScaledShiftedLogistic{Float64}}, x::Vector{Float64}, index::Int64)
   @ Main ./REPL[163]:10
 [4] transform(t::Ordered{TransformVariables.ScaledShiftedLogistic{Float64}}, x::Vector{Float64})
   @ TransformVariables ~/.julia/packages/TransformVariables/DDsiH/src/generic.jl:261
 [5] top-level scope
   @ REPL[173]:1

With some more floating point manipulation, I think this could be made to map to non-decreasing sequences, instead of requiring strict monotonicity.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions