@@ -27,13 +27,16 @@ def __init__(self, lr=0.001, initial_accumulator=0.1, eps=1e-07, weight_decay=0.
2727 self .eps = eps
2828 self .weight_decay = weight_decay
2929 self .adagrad = optimizer .Adagrad
30+ self .init_optim = False
3031
3132 def apply_gradients (self , grads_and_vars ):
3233 grads , vars = list (zip (* grads_and_vars ))
33- optimizer = self .adagrad (
34+ if not self .init_optim :
35+ self .optimizer = self .adagrad (
3436 vars , learning_rate = self .lr , accum = self .initial_accumulator , weight_decay = self .weight_decay
3537 )
36- optimizer (grads )
38+ self .init_optim = True
39+ self .optimizer (grads )
3740
3841
3942class Adam (Cell ):
@@ -46,14 +49,17 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, eps=1e-8, weight_decay=0.
4649 self .beta_2 = beta_2
4750 self .eps = eps
4851 self .weight_decay = weight_decay
52+ self .init_optim = False
4953
5054 def apply_gradients (self , grads_and_vars ):
5155 grads , vars = list (zip (* grads_and_vars ))
52- optimizer_adam = self .adam (
56+ if not self .init_optim :
57+ self .optimizer_adam = self .adam (
5358 vars , learning_rate = self .lr , beta1 = self .beta_1 , beta2 = self .beta_2 , eps = self .eps ,
5459 weight_decay = self .weight_decay
5560 )
56- optimizer_adam (grads )
61+ self .init_optim = True
62+ self .optimizer_adam (grads )
5763
5864
5965class Adamax (Cell ):
@@ -80,14 +86,17 @@ def __init__(
8086 self .l1 = l1_regularization_strength
8187 self .l2 = l2_regularization_strength
8288 self .weight_decay = weight_decay
89+ self .init_optim = False
8390
8491 def apply_gradients (self , grads_and_vars ):
8592 grads , vars = list (zip (* grads_and_vars ))
86- optimizer_adam = self .ftrl (
93+ if not self .init_optim :
94+ self .optimizer = self .ftrl (
8795 vars , learning_rate = self .lr , initial_accum = self .init_accum , lr_power = self .lr_power , l1 = self .l1 , l2 = self .l2 ,
8896 weight_decay = self .weight_decay
8997 )
90- optimizer_adam (grads )
98+ self .init_optim = True
99+ self .optimizer (grads )
91100
92101
93102class Nadam (Cell ):
@@ -113,14 +122,17 @@ def __init__(
113122 self .centered = centered
114123 self .weight_decay = weight_decay
115124 self .rmsprop = optimizer .RMSProp
125+ self .init_optim = False
116126
117127 def apply_gradients (self , grads_and_vars ):
118128 grads , vars = list (zip (* grads_and_vars ))
119- optimizer_adam = self .rmsprop (
129+ if not self .init_optim :
130+ self .optimizer = self .rmsprop (
120131 vars , learning_rate = self .lr , decay = self .rho , momentum = self .momentum , epsilon = self .eps ,
121132 centered = self .centered , weight_decay = self .weight_decay
122133 )
123- optimizer_adam (grads )
134+ self .init_optim = True
135+ self .optimizer (grads )
124136
125137
126138class SGD (Cell ):
@@ -131,13 +143,16 @@ def __init__(self, lr=0.1, momentum=0.0, weight_decay=0.0, grad_clip=None):
131143 self .lr = lr
132144 self .momentum = momentum
133145 self .weight_decay = weight_decay
146+ self .init_optim = False
134147
135148 def apply_gradients (self , grads_and_vars ):
136149 grads , vars = list (zip (* grads_and_vars ))
137- optimizer_sgd = self .sgd (
150+ if not self .init_optim :
151+ self .optimizer_sgd = self .sgd (
138152 vars , learning_rate = self .lr , momentum = self .momentum , weight_decay = self .weight_decay
139153 )
140- optimizer_sgd (grads )
154+ self .init_optim = True
155+ self .optimizer_sgd (grads )
141156
142157
143158class Momentum (Cell ):
@@ -149,14 +164,17 @@ def __init__(self, lr, momentum, use_nesterov=False, weight_decay=0.0, grad_clip
149164 self .momentum = momentum
150165 self .use_nesterov = use_nesterov
151166 self .weight_decay = weight_decay
167+ self .init_optim = False
152168
153169 def apply_gradients (self , grads_and_vars ):
154170 grads , vars = list (zip (* grads_and_vars ))
155- optimizer_mom = self .mom (
171+ if not self .init_optim :
172+ self .optimizer_mom = self .mom (
156173 vars , learning_rate = self .lr , momentum = self .momentum , use_nesterov = self .use_nesterov ,
157174 weight_decay = self .weight_decay
158175 )
159- optimizer_mom (grads )
176+ self .init_optim = True
177+ self .optimizer_mom (grads )
160178
161179
162180class Lamb (Cell ):
@@ -169,14 +187,17 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, eps=1.0e-6, weight_decay=
169187 self .beta2 = beta_2
170188 self .eps = eps
171189 self .weight_decay = weight_decay
190+ self .init_optim = False
172191
173192 def apply_gradients (self , grads_and_vars ):
174193 grads , vars = list (zip (* grads_and_vars ))
175- optimizer_lamb = self .lamb (
194+ if not self .init_optim :
195+ self .optimizer_lamb = self .lamb (
176196 vars , learning_rate = self .lr , beta1 = self .beta1 , beta2 = self .beta2 , eps = self .eps ,
177197 weight_decay = self .weight_decay
178198 )
179- optimizer_lamb (grads )
199+ self .init_optim = True
200+ self .optimizer_lamb (grads )
180201
181202
182203class LARS (Cell ):
0 commit comments