From 48ada3075f9c78d485721b7b333a6bdffb2f2ca3 Mon Sep 17 00:00:00 2001 From: ophirhan <67260757+ophirhan@users.noreply.github.com> Date: Thu, 8 Jun 2023 10:39:44 +0300 Subject: [PATCH 1/3] fix: switch backend to 'gloo' when running on windows --- tools/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/train.py b/tools/train.py index 6ae1b9f..bd2975a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -2,6 +2,7 @@ # Copyright (C) Alibaba Group Holding Limited. All rights reserved. import argparse import copy +import os import torch from loguru import logger @@ -45,6 +46,7 @@ def main(): args = make_parser().parse_args() torch.cuda.set_device(args.local_rank) + backend = 'nccl' if os.name == 'posix' else 'gloo' torch.distributed.init_process_group(backend='nccl', init_method='env://') synchronize() if args.tea_config is not None: From 560c2da15e2cef3a5f6d554f40a7f4eaf87c7bbb Mon Sep 17 00:00:00 2001 From: ophirhan <67260757+ophirhan@users.noreply.github.com> Date: Thu, 8 Jun 2023 17:03:47 +0300 Subject: [PATCH 2/3] use the new backend --- tools/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/train.py b/tools/train.py index bd2975a..1fad6b1 100644 --- a/tools/train.py +++ b/tools/train.py @@ -47,7 +47,7 @@ def main(): torch.cuda.set_device(args.local_rank) backend = 'nccl' if os.name == 'posix' else 'gloo' - torch.distributed.init_process_group(backend='nccl', init_method='env://') + torch.distributed.init_process_group(backend=backend', init_method='env://') synchronize() if args.tea_config is not None: tea_config = parse_config(args.tea_config) From fc6b9f18ad43804879f61ee3f2b2436839d785ea Mon Sep 17 00:00:00 2001 From: ophirhan <67260757+ophirhan@users.noreply.github.com> Date: Thu, 8 Jun 2023 17:04:32 +0300 Subject: [PATCH 3/3] use it --- tools/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/train.py b/tools/train.py index 1fad6b1..dad70df 100644 --- a/tools/train.py +++ b/tools/train.py @@ -47,7 +47,7 @@ def main(): torch.cuda.set_device(args.local_rank) backend = 'nccl' if os.name == 'posix' else 'gloo' - torch.distributed.init_process_group(backend=backend', init_method='env://') + torch.distributed.init_process_group(backend=backend, init_method='env://') synchronize() if args.tea_config is not None: tea_config = parse_config(args.tea_config)