-
Notifications
You must be signed in to change notification settings - Fork 566
Open
Description
❓ Questions and Help
Hi, I have noticed that when world_size == 1
, all_reduce
is a no-op and does not apply scale
:
In torch_xla.core.xla_model
in def all_reduce
:
# No-op if there is only one device
if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
bool, False):
if isinstance(inputs, torch.Tensor):
return inputs.clone()
else:
return inputs
Is this intended behavior? If it is indeed intended, it makes the use of all_reduce
inconsistent when using world_size == 1
vs world_size > 1
. The issue manifests, for example, when you are logging running average loss value:
epoch_loss = xm.all_reduce(xm.REDUCE_SUM, loss_accum, scale=1.0 / ((idx + 1) * world_size))
Metadata
Metadata
Assignees
Labels
No labels