@@ -30,9 +30,9 @@ class Model:
3030 The training or testing network.
3131 loss_fn : function
3232 Objective function
33- optimizer : function
33+ optimizer : class
3434 Optimizer for updating the weights
35- metrics : function
35+ metrics : class
3636 Dict or set of metrics to be evaluated by the model during
3737
3838 Methods
@@ -65,7 +65,7 @@ class Model:
6565 >>> return out
6666 >>>
6767 >>> net = Net()
68- >>> loss = tl.cost.cross_entropy
68+ >>> loss = tl.cost.softmax_cross_entropy_with_logits
6969 >>> optim = tl.optimizers.Momentum(params=net.trainable_weights, learning_rate=0.1, momentum=0.9)
7070 >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
7171 >>> dataset = get_dataset()
@@ -150,7 +150,7 @@ def save_weights(self, file_path, format=None):
150150 >>> net = vgg16()
151151 >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
152152 >>> metric = tl.metric.Accuracy()
153- >>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy , optimizer=optimizer, metrics=metric)
153+ >>> model = tl.models.Model(network=net, loss_fn=tl.cost.softmax_cross_entropy_with_logits , optimizer=optimizer, metrics=metric)
154154 >>> model.save_weights('./model.h5')
155155 ...
156156 >>> model.load_weights('./model.h5')
@@ -195,7 +195,7 @@ def load_weights(self, file_path, format=None, in_order=True, skip=False):
195195 >>> net = vgg16()
196196 >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
197197 >>> metric = tl.metric.Accuracy()
198- >>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy , optimizer=optimizer, metrics=metric)
198+ >>> model = tl.models.Model(network=net, loss_fn=tl.cost.softmax_cross_entropy_with_logits , optimizer=optimizer, metrics=metric)
199199 >>> model.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
200200 >>> model.load_weights('./model_eager.h5') # load sequentially
201201
0 commit comments