diff --git a/tools/train.py b/tools/train.py index 6ae1b9f..dad70df 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,6 +2,7 @@ # Copyright (C) Alibaba Group Holding Limited. All rights reserved. import argparse import copy +import os import torch from loguru import logger @@ -45,7 +46,8 @@ def main(): args = make_parser().parse_args() torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group(backend='nccl', init_method='env://') + backend = 'nccl' if os.name == 'posix' else 'gloo' + torch.distributed.init_process_group(backend=backend, init_method='env://') synchronize() if args.tea_config is not None: tea_config = parse_config(args.tea_config)