Skip to content

Commit 8ce7aa8

Browse files
lingvo-botcopybara-github
authored andcommitted
Add batch util for calculating per-worker batch size.
PiperOrigin-RevId: 476979808
1 parent f6c05c0 commit 8ce7aa8

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

lingvo/core/batch_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,22 @@ def scale_split_to_infeed(split_batch_size, use_per_host_infeed):
7979
return global_batch_size // cluster.num_tpu_hosts
8080
else:
8181
return global_batch_size
82+
83+
84+
def scale_global_to_worker(global_batch_size):
85+
"""Obtains per-worker batch size given a global batch size.
86+
87+
Args:
88+
global_batch_size: int: Global batch size.
89+
90+
Returns:
91+
int: per-worker batch size.
92+
"""
93+
cluster = cluster_factory.Cluster.Top()
94+
if not cluster:
95+
raise ValueError('Called scale_global_to_worker without a current cluster.')
96+
q, r = divmod(global_batch_size, cluster.total_worker_devices)
97+
if r:
98+
raise ValueError(f'global_batch_size {global_batch_size} did not divide'
99+
f' evenly by {cluster.total_worker_devices} workers.')
100+
return q

lingvo/core/batch_utils_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def testScaleSplitToInfeedTPU(self, use_per_host_infeed, split_size,
9595
batch_utils.scale_split_to_infeed(1024, use_per_host_infeed),
9696
1024 * num_splits // num_infeeds)
9797

98+
@parameterized.product(tpus=[64, 128])
99+
def testScaleGlobalToWorkerTPU(self, tpus):
100+
with cluster_factory.ForTestingWorker(tpus=tpus) as cluster:
101+
self.assertEqual(cluster.total_worker_devices, tpus)
102+
self.assertEqual(batch_utils.scale_global_to_worker(1024), 1024 // tpus)
103+
98104

99105
if __name__ == '__main__':
100106
tf.test.main()

0 commit comments

Comments
 (0)