Skip to content

Commit 9f0136d

Browse files
committed
update paddle backend
1 parent f704d91 commit 9f0136d

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

tensorlayer/backend/ops/paddle_nn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,11 @@ def __init__(self, keep, seed=1):
338338
self.seed = seed
339339

340340
def __call__(self, inputs):
341-
raise NotImplementedError
341+
output = F.dropout(
342+
inputs,
343+
p=self.keep,
344+
mode='upscale_in_train')
345+
return output
342346

343347

344348
class BiasAdd(object):

tensorlayer/models/core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ class Model:
3030
The training or testing network.
3131
loss_fn : function
3232
Objective function
33-
optimizer : function
33+
optimizer : class
3434
Optimizer for updating the weights
35-
metrics : function
35+
metrics : class
3636
Dict or set of metrics to be evaluated by the model during
3737
3838
Methods
@@ -65,7 +65,7 @@ class Model:
6565
>>> return out
6666
>>>
6767
>>> net = Net()
68-
>>> loss = tl.cost.cross_entropy
68+
>>> loss = tl.cost.softmax_cross_entropy_with_logits
6969
>>> optim = tl.optimizers.Momentum(params=net.trainable_weights, learning_rate=0.1, momentum=0.9)
7070
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
7171
>>> dataset = get_dataset()
@@ -150,7 +150,7 @@ def save_weights(self, file_path, format=None):
150150
>>> net = vgg16()
151151
>>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
152152
>>> metric = tl.metric.Accuracy()
153-
>>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
153+
>>> model = tl.models.Model(network=net, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer, metrics=metric)
154154
>>> model.save_weights('./model.h5')
155155
...
156156
>>> model.load_weights('./model.h5')
@@ -195,7 +195,7 @@ def load_weights(self, file_path, format=None, in_order=True, skip=False):
195195
>>> net = vgg16()
196196
>>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
197197
>>> metric = tl.metric.Accuracy()
198-
>>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
198+
>>> model = tl.models.Model(network=net, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer, metrics=metric)
199199
>>> model.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
200200
>>> model.load_weights('./model_eager.h5') # load sequentially
201201

0 commit comments

Comments
 (0)