Skip to content

Commit 4279e54

Browse files
authored
tidy switchout (#65)
1 parent af48c96 commit 4279e54

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

common/training/switchout.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
def switchout_target(self, source, targetb_blank_idx: int,
55
target_num_labels: int, time_factor: int = 6,
66
switchout_prob: float = 0.05,
7-
switchout_blank_prob: float = 0.5, *, **kwargs):
7+
switchout_blank_prob: float = 0.5, **kwargs):
88
"""Switchout. It takes as input a batch of outputs and returns a switchout version of it.
99
Usage:
1010
{
@@ -33,8 +33,6 @@ def switchout_target(self, source, targetb_blank_idx: int,
3333
def get_switched():
3434
x_ = x
3535
shape = tf.shape(x)
36-
n_batch = tf.shape(x)[data.batch_dim_axis]
37-
n_time = tf.shape(x)[data.time_dim_axis]
3836
take_rnd_mask = tf.less(tf.random_uniform(shape=shape, minval=0., maxval=1.), switchout_prob)
3937
take_blank_mask = tf.less(tf.random_uniform(shape=shape, minval=0., maxval=1.), switchout_blank_prob)
4038
rnd_label = tf.random_uniform(shape=shape, minval=0, maxval=target_num_labels, dtype=tf.int32)

0 commit comments

Comments
 (0)