Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion resemble_enhance/enhancer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from pathlib import Path

import torch
import torch.distributed as dist
import torchaudio
from tqdm import tqdm

from .inference import denoise, enhance

from ..utils.distributed import local_rank, fix_unset_envs

@torch.inference_mode()
def main():
Expand Down Expand Up @@ -68,6 +69,11 @@ def main():
action="store_true",
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
)
parser.add_argument(
"--distributed_mode",
action="store_true",
help="Enable distributed training across multiple GPUs",
)

args = parser.parse_args()

Expand All @@ -86,6 +92,14 @@ def main():
if args.parallel_mode:
random.shuffle(paths)

if args.distributed_mode:
fix_unset_envs()
dist.init_process_group(backend='nccl' if device == "cuda" else "gloo")
torch.cuda.set_device(local_rank())
num_processed = dist.get_world_size()
rank = dist.get_rank()
paths = paths[rank::num_processed]

if len(paths) == 0:
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
return
Expand Down