Skip to content

Commit 4453a4d

Browse files
committed
update mindspore opti
1 parent 686c964 commit 4453a4d

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

tensorlayerx/optimizers/mindspore_optimizers.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ def __init__(self, lr=0.001, initial_accumulator=0.1, eps=1e-07, weight_decay=0.
2727
self.eps = eps
2828
self.weight_decay = weight_decay
2929
self.adagrad = optimizer.Adagrad
30+
self.init_optim = False
3031

3132
def apply_gradients(self, grads_and_vars):
3233
grads, vars = list(zip(*grads_and_vars))
33-
optimizer = self.adagrad(
34+
if not self.init_optim:
35+
self.optimizer = self.adagrad(
3436
vars, learning_rate=self.lr, accum=self.initial_accumulator, weight_decay=self.weight_decay
3537
)
36-
optimizer(grads)
38+
self.init_optim = True
39+
self.optimizer(grads)
3740

3841

3942
class Adam(Cell):
@@ -46,14 +49,17 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, eps=1e-8, weight_decay=0.
4649
self.beta_2 = beta_2
4750
self.eps = eps
4851
self.weight_decay = weight_decay
52+
self.init_optim = False
4953

5054
def apply_gradients(self, grads_and_vars):
5155
grads, vars = list(zip(*grads_and_vars))
52-
optimizer_adam = self.adam(
56+
if not self.init_optim:
57+
self.optimizer_adam = self.adam(
5358
vars, learning_rate=self.lr, beta1=self.beta_1, beta2=self.beta_2, eps=self.eps,
5459
weight_decay=self.weight_decay
5560
)
56-
optimizer_adam(grads)
61+
self.init_optim = True
62+
self.optimizer_adam(grads)
5763

5864

5965
class Adamax(Cell):
@@ -80,14 +86,17 @@ def __init__(
8086
self.l1 = l1_regularization_strength
8187
self.l2 = l2_regularization_strength
8288
self.weight_decay = weight_decay
89+
self.init_optim = False
8390

8491
def apply_gradients(self, grads_and_vars):
8592
grads, vars = list(zip(*grads_and_vars))
86-
optimizer_adam = self.ftrl(
93+
if not self.init_optim:
94+
self.optimizer = self.ftrl(
8795
vars, learning_rate=self.lr, initial_accum=self.init_accum, lr_power=self.lr_power, l1=self.l1, l2=self.l2,
8896
weight_decay=self.weight_decay
8997
)
90-
optimizer_adam(grads)
98+
self.init_optim = True
99+
self.optimizer(grads)
91100

92101

93102
class Nadam(Cell):
@@ -113,14 +122,17 @@ def __init__(
113122
self.centered = centered
114123
self.weight_decay = weight_decay
115124
self.rmsprop = optimizer.RMSProp
125+
self.init_optim = False
116126

117127
def apply_gradients(self, grads_and_vars):
118128
grads, vars = list(zip(*grads_and_vars))
119-
optimizer_adam = self.rmsprop(
129+
if not self.init_optim:
130+
self.optimizer = self.rmsprop(
120131
vars, learning_rate=self.lr, decay=self.rho, momentum=self.momentum, epsilon=self.eps,
121132
centered=self.centered, weight_decay=self.weight_decay
122133
)
123-
optimizer_adam(grads)
134+
self.init_optim = True
135+
self.optimizer(grads)
124136

125137

126138
class SGD(Cell):
@@ -131,13 +143,16 @@ def __init__(self, lr=0.1, momentum=0.0, weight_decay=0.0, grad_clip=None):
131143
self.lr = lr
132144
self.momentum = momentum
133145
self.weight_decay = weight_decay
146+
self.init_optim = False
134147

135148
def apply_gradients(self, grads_and_vars):
136149
grads, vars = list(zip(*grads_and_vars))
137-
optimizer_sgd = self.sgd(
150+
if not self.init_optim:
151+
self.optimizer_sgd = self.sgd(
138152
vars, learning_rate=self.lr, momentum=self.momentum, weight_decay=self.weight_decay
139153
)
140-
optimizer_sgd(grads)
154+
self.init_optim = True
155+
self.optimizer_sgd(grads)
141156

142157

143158
class Momentum(Cell):
@@ -149,14 +164,17 @@ def __init__(self, lr, momentum, use_nesterov=False, weight_decay=0.0, grad_clip
149164
self.momentum = momentum
150165
self.use_nesterov = use_nesterov
151166
self.weight_decay = weight_decay
167+
self.init_optim = False
152168

153169
def apply_gradients(self, grads_and_vars):
154170
grads, vars = list(zip(*grads_and_vars))
155-
optimizer_mom = self.mom(
171+
if not self.init_optim:
172+
self.optimizer_mom = self.mom(
156173
vars, learning_rate=self.lr, momentum=self.momentum, use_nesterov=self.use_nesterov,
157174
weight_decay=self.weight_decay
158175
)
159-
optimizer_mom(grads)
176+
self.init_optim = True
177+
self.optimizer_mom(grads)
160178

161179

162180
class Lamb(Cell):
@@ -169,14 +187,17 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, eps=1.0e-6, weight_decay=
169187
self.beta2 = beta_2
170188
self.eps = eps
171189
self.weight_decay = weight_decay
190+
self.init_optim = False
172191

173192
def apply_gradients(self, grads_and_vars):
174193
grads, vars = list(zip(*grads_and_vars))
175-
optimizer_lamb = self.lamb(
194+
if not self.init_optim:
195+
self.optimizer_lamb = self.lamb(
176196
vars, learning_rate=self.lr, beta1=self.beta1, beta2=self.beta2, eps=self.eps,
177197
weight_decay=self.weight_decay
178198
)
179-
optimizer_lamb(grads)
199+
self.init_optim = True
200+
self.optimizer_lamb(grads)
180201

181202

182203
class LARS(Cell):

0 commit comments

Comments
 (0)