Skip to content

Commit 2aef09d

Browse files
lingvo-botcopybara-github
authored andcommitted
Making get_zero_batch in gshard_decode a @classmethod
PiperOrigin-RevId: 488698222
1 parent f296b81 commit 2aef09d

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

lingvo/core/gshard_decode.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -55,43 +55,6 @@ def preload_zero(n=None, batch_size=None, max_len=None, key_size=2):
5555
return batch
5656

5757

58-
def get_zero_batch(batch_size=None,
59-
max_len=None,
60-
key_size=2,
61-
return_tgt_mask=False,
62-
return_scorer_alpha=False):
63-
"""Returns zero batch.
64-
65-
Args:
66-
batch_size: batch size.
67-
max_len: max length.
68-
key_size: key size.
69-
return_tgt_mask: if to return tgt_mask.
70-
return_scorer_alpha: if to return scorer_alpha used to set scaling factor
71-
for controlled decoding.
72-
Returns: a tuple of tensors
73-
key: int32 tensor [batch_size, key_size]
74-
tgt_id: int32 tensor [batch_size, max_len]
75-
tgt_segment_id: float32 tensor [batch_size, max_len]
76-
tgt_segment_pos: int32 tensor [batch_size, max_len]
77-
tgt_labels: int32 tensor [batch_size, max_len]
78-
tgt_sample_temperature: float32 tensor [batch_size]
79-
tgt_mask: optional float32 tensor [batch_size, max_len, max_len]
80-
tgt_scorer_alpha: float32 tensor [batch_size]
81-
"""
82-
batch = preload_zero(
83-
n=1, batch_size=batch_size, max_len=max_len, key_size=key_size)
84-
batch = py_utils.Transform(lambda x: np.squeeze(x, 0), batch)
85-
if return_tgt_mask:
86-
tgt_mask = np.zeros([batch_size, max_len, max_len], np.float32)
87-
batch = (*batch, tgt_mask)
88-
if return_scorer_alpha:
89-
assert not return_tgt_mask
90-
scorer_alpha = np.zeros([batch_size], np.float32)
91-
batch = (*batch, scorer_alpha)
92-
return batch
93-
94-
9558
# mimic training_loop.repeat(), but make it repeat forever.
9659
def infinite_repeat(body_fn, infeed_queue):
9760
"""Builds infinite loop.
@@ -353,6 +316,44 @@ def run_heartbeat_loop():
353316

354317
daemon(run_heartbeat_loop)
355318

319+
@classmethod
320+
def get_zero_batch(cls,
321+
batch_size=None,
322+
max_len=None,
323+
key_size=2,
324+
return_tgt_mask=False,
325+
return_scorer_alpha=False):
326+
"""Returns zero batch.
327+
328+
Args:
329+
batch_size: batch size.
330+
max_len: max length.
331+
key_size: key size.
332+
return_tgt_mask: if to return tgt_mask.
333+
return_scorer_alpha: if to return scorer_alpha used to set scaling factor
334+
for controlled decoding.
335+
Returns: a tuple of tensors
336+
key: int32 tensor [batch_size, key_size]
337+
tgt_id: int32 tensor [batch_size, max_len]
338+
tgt_segment_id: float32 tensor [batch_size, max_len]
339+
tgt_segment_pos: int32 tensor [batch_size, max_len]
340+
tgt_labels: int32 tensor [batch_size, max_len]
341+
tgt_sample_temperature: float32 tensor [batch_size]
342+
tgt_mask: optional float32 tensor [batch_size, max_len, max_len]
343+
tgt_scorer_alpha: float32 tensor [batch_size]
344+
"""
345+
batch = preload_zero(
346+
n=1, batch_size=batch_size, max_len=max_len, key_size=key_size)
347+
batch = py_utils.Transform(lambda x: np.squeeze(x, 0), batch)
348+
if return_tgt_mask:
349+
tgt_mask = np.zeros([batch_size, max_len, max_len], np.float32)
350+
batch = (*batch, tgt_mask)
351+
if return_scorer_alpha:
352+
assert not return_tgt_mask
353+
scorer_alpha = np.zeros([batch_size], np.float32)
354+
batch = (*batch, scorer_alpha)
355+
return batch
356+
356357
def _config_infeed(self,
357358
num_partitions,
358359
device_assignment,
@@ -362,7 +363,7 @@ def _config_infeed(self,
362363
return_scorer_alpha=False,
363364
use_partitioned_infeed_queue=False):
364365
"""Config the infeed ops and args."""
365-
zero_batch = get_zero_batch(
366+
zero_batch = self.get_zero_batch(
366367
batch_size=batch_size,
367368
max_len=self._prefix_max_len,
368369
key_size=key_size,

0 commit comments

Comments
 (0)