1
1
# Internal function, used only for layers defined in this file.
2
2
_isactive (m, x) = isnothing (m. active) ? NNlib. within_gradient (x) : m. active
3
3
4
+ # Internal function, used only in this file.
5
+ _tidy_active (mode:: Bool ) = mode
6
+ _tidy_active (:: Nothing ) = nothing
7
+ _tidy_active (mode) = mode === :auto ? nothing : throw (ArgumentError (" active = $(repr (mode)) is not accepted, must be true/false/nothing or :auto" ))
8
+
4
9
"""
5
- Dropout(p; [dims, rng])
10
+ Dropout(p; [dims, rng, active ])
6
11
7
12
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
8
13
This is used as a regularisation, i.e. to reduce overfitting.
@@ -12,7 +17,8 @@ or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
12
17
While testing, it has no effect.
13
18
14
19
By default the mode will switch automatically, but it can also
15
- be controlled manually via [`Flux.testmode!`](@ref).
20
+ be controlled manually via [`Flux.testmode!`](@ref),
21
+ or by passing keyword `active=true` for training mode.
16
22
17
23
By default every input is treated independently. With the `dims` keyword,
18
24
instead it takes a random choice only along that dimension.
@@ -36,7 +42,11 @@ julia> m(ones(2, 7)) # test mode, no effect
36
42
2.0 2.0 2.0 2.0 2.0 2.0 2.0
37
43
2.0 2.0 2.0 2.0 2.0 2.0 2.0
38
44
39
- julia> Flux.trainmode!(m); # equivalent to use within gradient
45
+ julia> Flux.trainmode!(m) # equivalent to use within gradient
46
+ Chain(
47
+ Dense(2 => 3), # 9 parameters
48
+ Dropout(0.4, active=true),
49
+ )
40
50
41
51
julia> m(ones(2, 7))
42
52
3×7 Matrix{Float64}:
@@ -63,9 +73,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
63
73
end
64
74
Dropout (p:: Real , dims, active) = Dropout (p, dims, active, default_rng_value ())
65
75
66
- function Dropout (p:: Real ; dims= :, rng = default_rng_value ())
76
+ function Dropout (p:: Real ; dims= :, active :: Union{Bool,Nothing} = nothing , rng = default_rng_value ())
67
77
0 ≤ p ≤ 1 || throw (ArgumentError (" Dropout expects 0 ≤ p ≤ 1, got p = $p " ))
68
- Dropout (p, dims, nothing , rng)
78
+ Dropout (p, dims, active , rng)
69
79
end
70
80
71
81
@functor Dropout
@@ -74,16 +84,17 @@ trainable(a::Dropout) = (;)
74
84
(a:: Dropout )(x) = dropout (a. rng, x, a. p * _isactive (a, x); dims= a. dims)
75
85
76
86
testmode! (m:: Dropout , mode= true ) =
77
- (m. active = ( isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
87
+ (m. active = isnothing (_tidy_active ( mode)) ? nothing : ! mode; m)
78
88
79
89
function Base. show (io:: IO , d:: Dropout )
80
90
print (io, " Dropout(" , d. p)
81
- d. dims != (:) && print (io, " , dims = $(repr (d. dims)) " )
91
+ d. dims != (:) && print (io, " , dims=" , d. dims)
92
+ d. active == nothing || print (io, " , active=" , d. active)
82
93
print (io, " )" )
83
94
end
84
95
85
96
"""
86
- AlphaDropout(p; rng = default_rng_value() )
97
+ AlphaDropout(p; [ rng, active] )
87
98
88
99
A dropout layer. Used in
89
100
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -112,13 +123,13 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
112
123
p:: F
113
124
active:: Union{Bool, Nothing}
114
125
rng:: R
115
- function AlphaDropout (p, active, rng)
116
- @assert 0 ≤ p ≤ 1
117
- new {typeof(p), typeof(rng)} (p, active, rng)
118
- end
119
126
end
127
+
120
128
AlphaDropout (p, active) = AlphaDropout (p, active, default_rng_value ())
121
- AlphaDropout (p; rng = default_rng_value ()) = AlphaDropout (p, nothing , rng)
129
+ function AlphaDropout (p; rng = default_rng_value (), active:: Union{Bool,Nothing} = nothing )
130
+ 0 ≤ p ≤ 1 || throw (ArgumentError (" AlphaDropout expects 0 ≤ p ≤ 1, got p = $p " ))
131
+ AlphaDropout (p, active, rng)
132
+ end
122
133
123
134
@functor AlphaDropout
124
135
trainable (a:: AlphaDropout ) = (;)
@@ -138,7 +149,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
138
149
end
139
150
140
151
testmode! (m:: AlphaDropout , mode= true ) =
141
- (m. active = ( isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
152
+ (m. active = isnothing (_tidy_active ( mode)) ? nothing : ! mode; m)
142
153
143
154
"""
144
155
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
@@ -257,7 +268,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
257
268
"""
258
269
BatchNorm(channels::Integer, λ=identity;
259
270
initβ=zeros32, initγ=ones32,
260
- affine = true, track_stats = true,
271
+ affine= true, track_stats= true, active=nothing ,
261
272
ϵ=1f-5, momentum= 0.1f0)
262
273
263
274
[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
310
321
311
322
function BatchNorm (chs:: Int , λ= identity;
312
323
initβ= zeros32, initγ= ones32,
313
- affine= true , track_stats= true ,
324
+ affine= true , track_stats= true , active :: Union{Bool,Nothing} = nothing ,
314
325
ϵ= 1f-5 , momentum= 0.1f0 )
315
326
316
327
β = affine ? initβ (chs) : nothing
@@ -321,7 +332,7 @@ function BatchNorm(chs::Int, λ=identity;
321
332
return BatchNorm (λ, β, γ,
322
333
μ, σ², ϵ, momentum,
323
334
affine, track_stats,
324
- nothing , chs)
335
+ active , chs)
325
336
end
326
337
327
338
@functor BatchNorm
@@ -335,12 +346,13 @@ function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
335
346
end
336
347
337
348
testmode! (m:: BatchNorm , mode= true ) =
338
- (m. active = ( isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
349
+ (m. active = isnothing (_tidy_active ( mode)) ? nothing : ! mode; m)
339
350
340
351
function Base. show (io:: IO , l:: BatchNorm )
341
352
print (io, " BatchNorm($(l. chs) " )
342
353
(l. λ == identity) || print (io, " , $(l. λ) " )
343
354
hasaffine (l) || print (io, " , affine=false" )
355
+ l. active == nothing || print (io, " , active=" , l. active)
344
356
print (io, " )" )
345
357
end
346
358
399
411
400
412
function InstanceNorm (chs:: Int , λ= identity;
401
413
initβ= zeros32, initγ= ones32,
402
- affine= false , track_stats= false ,
414
+ affine= false , track_stats= false , active :: Union{Bool,Nothing} = nothing ,
403
415
ϵ= 1f-5 , momentum= 0.1f0 )
404
416
405
417
β = affine ? initβ (chs) : nothing
@@ -410,7 +422,7 @@ function InstanceNorm(chs::Int, λ=identity;
410
422
return InstanceNorm (λ, β, γ,
411
423
μ, σ², ϵ, momentum,
412
424
affine, track_stats,
413
- nothing , chs)
425
+ active , chs)
414
426
end
415
427
416
428
@functor InstanceNorm
@@ -424,12 +436,13 @@ function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
424
436
end
425
437
426
438
testmode! (m:: InstanceNorm , mode= true ) =
427
- (m. active = ( isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
439
+ (m. active = isnothing (_tidy_active ( mode)) ? nothing : ! mode; m)
428
440
429
441
function Base. show (io:: IO , l:: InstanceNorm )
430
442
print (io, " InstanceNorm($(l. chs) " )
431
443
l. λ == identity || print (io, " , $(l. λ) " )
432
444
hasaffine (l) || print (io, " , affine=false" )
445
+ l. active == nothing || print (io, " , active=" , l. active)
433
446
print (io, " )" )
434
447
end
435
448
@@ -495,7 +508,7 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)
495
508
496
509
function GroupNorm (chs:: Int , G:: Int , λ= identity;
497
510
initβ= zeros32, initγ= ones32,
498
- affine= true , track_stats= false ,
511
+ affine= true , track_stats= false , active :: Union{Bool,Nothing} = nothing ,
499
512
ϵ= 1f-5 , momentum= 0.1f0 )
500
513
501
514
if track_stats
514
527
μ, σ²,
515
528
ϵ, momentum,
516
529
affine, track_stats,
517
- nothing , chs)
530
+ active , chs)
518
531
end
519
532
520
533
function (gn:: GroupNorm )(x:: AbstractArray )
@@ -529,13 +542,14 @@ function (gn::GroupNorm)(x::AbstractArray)
529
542
end
530
543
531
544
testmode! (m:: GroupNorm , mode = true ) =
532
- (m. active = ( isnothing (mode) || mode == :auto ) ? nothing : ! mode; m)
545
+ (m. active = isnothing (_tidy_active ( mode)) ? nothing : ! mode; m)
533
546
534
547
function Base. show (io:: IO , l:: GroupNorm )
535
548
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
536
549
print (io, " GroupNorm($(l. chs) , $(l. G) " )
537
550
l. λ == identity || print (io, " , " , l. λ)
538
551
hasaffine (l) || print (io, " , affine=false" )
552
+ l. active == nothing || print (io, " , active=" , l. active)
539
553
print (io, " )" )
540
554
end
541
555
0 commit comments