Skip to content

Commit b7c6cd1

Browse files
mcabbottToucheSir
authored andcommitted
Print the state of Dropout etc. (FluxML#2222)
* print the state of Dropout etc. * add tests * doc improvements * simpler scheme for testmode/trainmode * simplify active keyword a bit * a bug * fix tests * Update test/layers/normalisation.jl Co-authored-by: Brian Chen <[email protected]> * Update src/functor.jl * extend docstrings & warnings --------- Co-authored-by: Brian Chen <[email protected]>
1 parent b1f6ceb commit b7c6cd1

File tree

6 files changed

+145
-54
lines changed

6 files changed

+145
-54
lines changed

docs/src/models/layers.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ Several normalisation layers behave differently under training and inference (te
146146
The functions `Flux.trainmode!` and `Flux.testmode!` let you manually specify which behaviour you want. When called on a model, they will place all layers within the model into the specified mode.
147147

148148
```@docs
149-
Flux.testmode!
149+
testmode!(::Any)
150+
testmode!(::Any, ::Any)
150151
trainmode!
151152
```

src/deprecations.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,23 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
187187
""")
188188
end
189189

190+
"""
191+
trainmode!(m, active)
192+
193+
!!! warning
194+
This two-argument method is deprecated.
195+
196+
Possible values of `active` are:
197+
- `true` for training, or
198+
- `false` for testing, same as [`testmode!`](@ref)`(m)`
199+
- `:auto` or `nothing` for Flux to detect training automatically.
200+
"""
201+
function trainmode!(m, active::Bool)
202+
Base.depwarn("trainmode!(m, active::Bool) is deprecated", :trainmode)
203+
testmode!(m, !active)
204+
end
205+
206+
190207
# v0.14 deprecations
191208

192209
# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:

src/functor.jl

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,75 @@ import Functors: Functors, @functor, functor, fmap, isleaf
55
using SparseArrays: AbstractSparseArray
66

77
"""
8-
testmode!(m, mode = true)
8+
testmode!(model, [mode]) -> model
99
10-
Set a layer or model's test mode (see below).
11-
Using `:auto` mode will treat any gradient computation as training.
10+
Set a layer, or all layers in a model, to test mode.
11+
This disables the effect of [`Dropout`](@ref) and
12+
some other regularisation layers.
1213
13-
_Note_: if you manually set a model into test mode, you need to manually place
14-
it back into train mode during training phase.
14+
If you manually set a model into test mode, you need to manually place
15+
it back into train mode during training phase, using [`trainmode!`](@ref).
1516
16-
Possible values include:
17-
- `false` for training
18-
- `true` for testing
19-
- `:auto` or `nothing` for Flux to detect the mode automatically
17+
There is an optional second argument, which takes a symbol `:auto` to
18+
reset all layers back to the default automatic mode.
19+
20+
# Example
21+
22+
```jldoctest
23+
julia> d = Dropout(0.3)
24+
Dropout(0.3)
25+
26+
julia> testmode!(d) # dropout is now always disabled
27+
Dropout(0.3, active=false)
28+
29+
julia> trainmode!(d) # dropout is now always enabled
30+
Dropout(0.3, active=true)
31+
32+
julia> testmode!(d, :auto) # back to default
33+
Dropout(0.3)
34+
```
2035
"""
21-
testmode!(m, mode = true) = (foreach(x -> testmode!(x, mode), trainable(m)); m)
36+
testmode!(m) = testmode!(m, true)
2237

2338
"""
24-
trainmode!(m, mode = true)
39+
trainmode!(model) -> model
2540
26-
Set a layer of model's train mode (see below).
27-
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)`).
41+
Set a layer, or all layers in a model, to training mode.
42+
Opposite to [`testmode!`](@ref), see further details there.
43+
"""
44+
trainmode!(m) = testmode!(m, false)
45+
trainmode!(m, mode::Symbol) = testmode!(m, mode)
46+
trainmode!(m, ::Nothing) = testmode!(m, nothing) # why do we have so much API?
47+
48+
"""
49+
testmode!(model, inactive)
50+
51+
This two-argument method is largely internal. It recurses into the `model`,
52+
and until a method like `testmode!(d::Dropout, inactive)` alters the activity of a layer.
53+
Custom layers can support manual `testmode!` / `trainmode!` switching
54+
by defining such a method.
2855
29-
_Note_: if you manually set a model into train mode, you need to manually place
30-
it into test mode during testing phase.
56+
Possible values of `inactive` are:
57+
- `true` for testing, i.e. `active=false`
58+
- `false` for training, same as [`trainmode!`](@ref)`(m)`
59+
- `:auto` or `nothing` for Flux to detect training automatically.
3160
32-
Possible values include:
33-
- `true` for training
34-
- `false` for testing
35-
- `:auto` or `nothing` for Flux to detect the mode automatically
61+
!!! compat
62+
This method may be removed in a future breaking change, to separate
63+
the user-facing `testmode!` from the internal recursion.
3664
"""
37-
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
65+
function testmode!(m, mode)
66+
inactive = if mode isa Symbol
67+
mode === :auto || throw(ArgumentError("testmode! accepts only the symbol :auto, got :$mode"))
68+
nothing
69+
elseif mode isa Union{Bool,Nothing}
70+
mode
71+
else
72+
throw(ArgumentError("testmode! does not accept $(repr(mode)) as the 2nd argument"))
73+
end
74+
foreach(x -> testmode!(x, inactive), trainable(m))
75+
m
76+
end
3877

3978
function params!(p::Params, x, seen = IdSet())
4079
if x isa AbstractArray{<:Number} && Functors.isleaf(x)

src/layers/normalise.jl

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# Internal function, used only for layers defined in this file.
22
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active
33

4+
# Internal function, used only in this file.
5+
_tidy_active(mode::Bool) = mode
6+
_tidy_active(::Nothing) = nothing
7+
_tidy_active(mode) = mode === :auto ? nothing : throw(ArgumentError("active = $(repr(mode)) is not accepted, must be true/false/nothing or :auto"))
8+
49
"""
5-
Dropout(p; [dims, rng])
10+
Dropout(p; [dims, rng, active])
611
712
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
813
This is used as a regularisation, i.e. to reduce overfitting.
@@ -12,7 +17,8 @@ or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
1217
While testing, it has no effect.
1318
1419
By default the mode will switch automatically, but it can also
15-
be controlled manually via [`Flux.testmode!`](@ref).
20+
be controlled manually via [`Flux.testmode!`](@ref),
21+
or by passing keyword `active=true` for training mode.
1622
1723
By default every input is treated independently. With the `dims` keyword,
1824
instead it takes a random choice only along that dimension.
@@ -36,7 +42,11 @@ julia> m(ones(2, 7)) # test mode, no effect
3642
2.0 2.0 2.0 2.0 2.0 2.0 2.0
3743
2.0 2.0 2.0 2.0 2.0 2.0 2.0
3844
39-
julia> Flux.trainmode!(m); # equivalent to use within gradient
45+
julia> Flux.trainmode!(m) # equivalent to use within gradient
46+
Chain(
47+
Dense(2 => 3), # 9 parameters
48+
Dropout(0.4, active=true),
49+
)
4050
4151
julia> m(ones(2, 7))
4252
3×7 Matrix{Float64}:
@@ -63,9 +73,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
6373
end
6474
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
6575

66-
function Dropout(p::Real; dims=:, rng = default_rng_value())
76+
function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = default_rng_value())
6777
0 p 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p"))
68-
Dropout(p, dims, nothing, rng)
78+
Dropout(p, dims, active, rng)
6979
end
7080

7181
@functor Dropout
@@ -74,16 +84,17 @@ trainable(a::Dropout) = (;)
7484
(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims)
7585

7686
testmode!(m::Dropout, mode=true) =
77-
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
87+
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)
7888

7989
function Base.show(io::IO, d::Dropout)
8090
print(io, "Dropout(", d.p)
81-
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
91+
d.dims != (:) && print(io, ", dims=", d.dims)
92+
d.active == nothing || print(io, ", active=", d.active)
8293
print(io, ")")
8394
end
8495

8596
"""
86-
AlphaDropout(p; rng = default_rng_value())
97+
AlphaDropout(p; [rng, active])
8798
8899
A dropout layer. Used in
89100
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -112,13 +123,13 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
112123
p::F
113124
active::Union{Bool, Nothing}
114125
rng::R
115-
function AlphaDropout(p, active, rng)
116-
@assert 0 p 1
117-
new{typeof(p), typeof(rng)}(p, active, rng)
118-
end
119126
end
127+
120128
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
121-
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
129+
function AlphaDropout(p; rng = default_rng_value(), active::Union{Bool,Nothing} = nothing)
130+
0 p 1 || throw(ArgumentError("AlphaDropout expects 0 ≤ p ≤ 1, got p = $p"))
131+
AlphaDropout(p, active, rng)
132+
end
122133

123134
@functor AlphaDropout
124135
trainable(a::AlphaDropout) = (;)
@@ -138,7 +149,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
138149
end
139150

140151
testmode!(m::AlphaDropout, mode=true) =
141-
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
152+
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)
142153

143154
"""
144155
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
@@ -257,7 +268,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
257268
"""
258269
BatchNorm(channels::Integer, λ=identity;
259270
initβ=zeros32, initγ=ones32,
260-
affine = true, track_stats = true,
271+
affine=true, track_stats=true, active=nothing,
261272
ϵ=1f-5, momentum= 0.1f0)
262273
263274
[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
@@ -310,7 +321,7 @@ end
310321

311322
function BatchNorm(chs::Int, λ=identity;
312323
initβ=zeros32, initγ=ones32,
313-
affine=true, track_stats=true,
324+
affine=true, track_stats=true, active::Union{Bool,Nothing}=nothing,
314325
ϵ=1f-5, momentum=0.1f0)
315326

316327
β = affine ? initβ(chs) : nothing
@@ -321,7 +332,7 @@ function BatchNorm(chs::Int, λ=identity;
321332
return BatchNorm(λ, β, γ,
322333
μ, σ², ϵ, momentum,
323334
affine, track_stats,
324-
nothing, chs)
335+
active, chs)
325336
end
326337

327338
@functor BatchNorm
@@ -335,12 +346,13 @@ function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
335346
end
336347

337348
testmode!(m::BatchNorm, mode=true) =
338-
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
349+
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)
339350

340351
function Base.show(io::IO, l::BatchNorm)
341352
print(io, "BatchNorm($(l.chs)")
342353
(l.λ == identity) || print(io, ", $(l.λ)")
343354
hasaffine(l) || print(io, ", affine=false")
355+
l.active == nothing || print(io, ", active=", l.active)
344356
print(io, ")")
345357
end
346358

@@ -399,7 +411,7 @@ end
399411

400412
function InstanceNorm(chs::Int, λ=identity;
401413
initβ=zeros32, initγ=ones32,
402-
affine=false, track_stats=false,
414+
affine=false, track_stats=false, active::Union{Bool,Nothing}=nothing,
403415
ϵ=1f-5, momentum=0.1f0)
404416

405417
β = affine ? initβ(chs) : nothing
@@ -410,7 +422,7 @@ function InstanceNorm(chs::Int, λ=identity;
410422
return InstanceNorm(λ, β, γ,
411423
μ, σ², ϵ, momentum,
412424
affine, track_stats,
413-
nothing, chs)
425+
active, chs)
414426
end
415427

416428
@functor InstanceNorm
@@ -424,12 +436,13 @@ function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
424436
end
425437

426438
testmode!(m::InstanceNorm, mode=true) =
427-
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
439+
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)
428440

429441
function Base.show(io::IO, l::InstanceNorm)
430442
print(io, "InstanceNorm($(l.chs)")
431443
l.λ == identity || print(io, ", $(l.λ)")
432444
hasaffine(l) || print(io, ", affine=false")
445+
l.active == nothing || print(io, ", active=", l.active)
433446
print(io, ")")
434447
end
435448

@@ -495,7 +508,7 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
495508

496509
function GroupNorm(chs::Int, G::Int, λ=identity;
497510
initβ=zeros32, initγ=ones32,
498-
affine=true, track_stats=false,
511+
affine=true, track_stats=false, active::Union{Bool,Nothing}=nothing,
499512
ϵ=1f-5, momentum=0.1f0)
500513

501514
if track_stats
@@ -514,7 +527,7 @@ end
514527
μ, σ²,
515528
ϵ, momentum,
516529
affine, track_stats,
517-
nothing, chs)
530+
active, chs)
518531
end
519532

520533
function (gn::GroupNorm)(x::AbstractArray)
@@ -529,13 +542,14 @@ function (gn::GroupNorm)(x::AbstractArray)
529542
end
530543

531544
testmode!(m::GroupNorm, mode = true) =
532-
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
545+
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)
533546

534547
function Base.show(io::IO, l::GroupNorm)
535548
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
536549
print(io, "GroupNorm($(l.chs), $(l.G)")
537550
l.λ == identity || print(io, ", ", l.λ)
538551
hasaffine(l) || print(io, ", affine=false")
552+
l.active == nothing || print(io, ", active=", l.active)
539553
print(io, ")")
540554
end
541555

test/layers/basic.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Test, Random
2-
import Flux: activations
2+
using Flux: activations
33

44
@testset "basic" begin
55
@testset "helpers" begin
@@ -16,11 +16,11 @@ import Flux: activations
1616
end
1717

1818
@testset "Chain" begin
19-
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
20-
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
19+
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn32(10))
20+
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn32(10))
2121
# numeric test should be put into testset of corresponding layer
2222

23-
@test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn(10))
23+
@test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn32(10))
2424
m = Chain(first = Dense(10, 5, σ), second = Dense(5, 2))
2525
@test m[:first] == m[1]
2626
@test m[1:2] == m
@@ -72,10 +72,10 @@ import Flux: activations
7272
@test_throws MethodError Dense(rand(5), rand(5), tanh)
7373
end
7474
@testset "dimensions" begin
75-
@test length(Dense(10, 5)(randn(10))) == 5
76-
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
77-
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
78-
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
75+
@test length(Dense(10 => 5)(randn32(10))) == 5
76+
@test_throws DimensionMismatch Dense(10 => 5)(randn32(1))
77+
@test_throws MethodError Dense(10 => 5)(1) # avoid broadcasting
78+
@test_throws MethodError Dense(10 => 5).(randn32(10)) # avoid broadcasting
7979
@test size(Dense(10, 5)(randn(10))) == (5,)
8080
@test size(Dense(10, 5)(randn(10,2))) == (5,2)
8181
@test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3)
@@ -333,7 +333,7 @@ import Flux: activations
333333
y = m(x)
334334
@test y isa Array{Float32, 3}
335335
@test size(y) == (embed_size, 3, 4)
336-
x3 = onehotbatch(x, 1:1:vocab_size)
336+
x3 = Flux.onehotbatch(x, 1:1:vocab_size)
337337
@test size(x3) == (vocab_size, 3, 4)
338338
y3 = m(x3)
339339
@test size(y3) == (embed_size, 3, 4)

0 commit comments

Comments
 (0)