@@ -55,43 +55,6 @@ def preload_zero(n=None, batch_size=None, max_len=None, key_size=2):
5555 return batch
5656
5757
58- def get_zero_batch (batch_size = None ,
59- max_len = None ,
60- key_size = 2 ,
61- return_tgt_mask = False ,
62- return_scorer_alpha = False ):
63- """Returns zero batch.
64-
65- Args:
66- batch_size: batch size.
67- max_len: max length.
68- key_size: key size.
69- return_tgt_mask: if to return tgt_mask.
70- return_scorer_alpha: if to return scorer_alpha used to set scaling factor
71- for controlled decoding.
72- Returns: a tuple of tensors
73- key: int32 tensor [batch_size, key_size]
74- tgt_id: int32 tensor [batch_size, max_len]
75- tgt_segment_id: float32 tensor [batch_size, max_len]
76- tgt_segment_pos: int32 tensor [batch_size, max_len]
77- tgt_labels: int32 tensor [batch_size, max_len]
78- tgt_sample_temperature: float32 tensor [batch_size]
79- tgt_mask: optional float32 tensor [batch_size, max_len, max_len]
80- tgt_scorer_alpha: float32 tensor [batch_size]
81- """
82- batch = preload_zero (
83- n = 1 , batch_size = batch_size , max_len = max_len , key_size = key_size )
84- batch = py_utils .Transform (lambda x : np .squeeze (x , 0 ), batch )
85- if return_tgt_mask :
86- tgt_mask = np .zeros ([batch_size , max_len , max_len ], np .float32 )
87- batch = (* batch , tgt_mask )
88- if return_scorer_alpha :
89- assert not return_tgt_mask
90- scorer_alpha = np .zeros ([batch_size ], np .float32 )
91- batch = (* batch , scorer_alpha )
92- return batch
93-
94-
9558# mimic training_loop.repeat(), but make it repeat forever.
9659def infinite_repeat (body_fn , infeed_queue ):
9760 """Builds infinite loop.
@@ -353,6 +316,44 @@ def run_heartbeat_loop():
353316
354317 daemon (run_heartbeat_loop )
355318
319+ @classmethod
320+ def get_zero_batch (cls ,
321+ batch_size = None ,
322+ max_len = None ,
323+ key_size = 2 ,
324+ return_tgt_mask = False ,
325+ return_scorer_alpha = False ):
326+ """Returns zero batch.
327+
328+ Args:
329+ batch_size: batch size.
330+ max_len: max length.
331+ key_size: key size.
332+ return_tgt_mask: if to return tgt_mask.
333+ return_scorer_alpha: if to return scorer_alpha used to set scaling factor
334+ for controlled decoding.
335+ Returns: a tuple of tensors
336+ key: int32 tensor [batch_size, key_size]
337+ tgt_id: int32 tensor [batch_size, max_len]
338+ tgt_segment_id: float32 tensor [batch_size, max_len]
339+ tgt_segment_pos: int32 tensor [batch_size, max_len]
340+ tgt_labels: int32 tensor [batch_size, max_len]
341+ tgt_sample_temperature: float32 tensor [batch_size]
342+ tgt_mask: optional float32 tensor [batch_size, max_len, max_len]
343+ tgt_scorer_alpha: float32 tensor [batch_size]
344+ """
345+ batch = preload_zero (
346+ n = 1 , batch_size = batch_size , max_len = max_len , key_size = key_size )
347+ batch = py_utils .Transform (lambda x : np .squeeze (x , 0 ), batch )
348+ if return_tgt_mask :
349+ tgt_mask = np .zeros ([batch_size , max_len , max_len ], np .float32 )
350+ batch = (* batch , tgt_mask )
351+ if return_scorer_alpha :
352+ assert not return_tgt_mask
353+ scorer_alpha = np .zeros ([batch_size ], np .float32 )
354+ batch = (* batch , scorer_alpha )
355+ return batch
356+
356357 def _config_infeed (self ,
357358 num_partitions ,
358359 device_assignment ,
@@ -362,7 +363,7 @@ def _config_infeed(self,
362363 return_scorer_alpha = False ,
363364 use_partitioned_infeed_queue = False ):
364365 """Config the infeed ops and args."""
365- zero_batch = get_zero_batch (
366+ zero_batch = self . get_zero_batch (
366367 batch_size = batch_size ,
367368 max_len = self ._prefix_max_len ,
368369 key_size = key_size ,
0 commit comments