Skip to content

Commit 09a739d

Browse files
authored
tl.iterate.minibatch support list with shuffle (#474)
1 parent e8f6d34 commit 09a739d

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

tensorlayer/iterate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
5959
excerpt = indices[start_idx:start_idx + batch_size]
6060
else:
6161
excerpt = slice(start_idx, start_idx + batch_size)
62-
yield inputs[excerpt], targets[excerpt]
62+
if (isinstance(inputs, list) or isinstance(targets, list)) and (shuffle == True):
63+
yield [inputs[i] for i in excerpt], [targets[i] for i in excerpt] # zsdonghao: for list indexing when shuffle==True
64+
else:
65+
yield inputs[excerpt], targets[excerpt]
6366

6467

6568
def seq_minibatches(inputs, targets, batch_size, seq_length, stride=1):

tensorlayer/models/mobilenetv1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class MobileNetV1(Layer):
5656
>>> # restore pre-trained parameters
5757
>>> cnn.restore_params(sess)
5858
>>> # train your own classifier (only update the last layer)
59-
>>> train_params = tl.layers.get_variables_with_name('output')
59+
>>> train_params = tl.layers.get_variables_with_name('out')
6060
6161
Reuse model
6262

tensorlayer/models/vgg16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def restore_params(self, sess):
231231

232232

233233
class VGG16(VGG16Base):
234-
"""Pre-trained VGG-16 Model.
234+
"""Pre-trained VGG-16 model.
235235
236236
Parameters
237237
------------

0 commit comments

Comments
 (0)