diff --git a/README.md b/README.md index bcbbfdb4..3932f47b 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,16 @@ # Pytorch-cifar100 - -practice on cifar100 using pytorch +This is a repository forked from weiaicunzai/pytorch-cifar100 ## Requirements - -This is my experiment eviroument -- python3.6 -- pytorch1.6.0+cu101 -- tensorboard 2.2.2(optional) - +- python >= 3.8 +- torch >= 2.0.0 +- tensorboard(optional) ## Usage ### 1. enter directory ```bash -$ cd pytorch-cifar100 +$ cd pytorch-cifar100-ddp ``` ### 2. dataset @@ -32,14 +28,26 @@ $ tensorboard --logdir='runs' --port=6006 --host='localhost' ``` ### 4. train the model -You need to specify the net you want to train using arg -net +You need to specify the net you want to train using arg --net ```bash -# use gpu to train vgg16 -$ python train.py -net vgg16 -gpu +# use cpu only (default) to train vgg16 +$ python train.py --net vgg16 +# use a single gpu to train vgg16 +$ python train.py --net vgg16 --gpu [gpu_id] +# for example, use GPU 0 to train vgg16 +$ python train.py --net vgg16 --gpu 0 + +# use multi gpus to train vgg16 +$ torchrun --master_addr [MASTER_ADDR] --master_port [MASTER_PORT] --nproc_per_node [NUM_GPUs_Per_Node] train.py --net vgg16 --gpu [gpu1_id,gpu2_id,...,gpun_id] +# for example, use GPU 0 and GPU 1 in one node to train vgg16 +$ torchrun --master_addr localhost --master_port 6000 --nproc_per_node 2 train.py --net vgg16 --gpu 0,1 + +# set training arguments if you want, for example, train 100 epochs and batch size is 256: +$ python train.py --net vgg16 --gpu 0 --epoch 100 --batch 256 ``` -sometimes, you might want to use warmup training by set ```-warm``` to 1 or 2, to prevent network +sometimes, you might want to use warmup training by set ```--warmup``` to 1 or 2, to prevent network diverge during early training phase. The supported net args are: @@ -92,12 +100,15 @@ Normally, the weights file with the best accuracy would be written to the disk w ### 5. test the model -Test the model using test.py +Test the model using test.py (Not implementing with DDP, which is not necessary...) ```bash -$ python test.py -net vgg16 -weights path_to_vgg16_weights_file +# use cpu (default) +$ python test.py --net vgg16 --weights path_to_vgg16_weights_file +# use gpu +$ python test.py --net vgg16 --weights path_to_vgg16_weights_file --gpu 0 ``` -## Implementated NetWork +## Implemenztated NetWork - vgg [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556v6) - googlenet [Going Deeper with Convolutions](https://arxiv.org/abs/1409.4842v1) diff --git a/test.py b/test.py index dab61a05..6c3bffc5 100644 --- a/test.py +++ b/test.py @@ -8,37 +8,56 @@ author baiyu """ +import os import argparse - +from collections import OrderedDict from matplotlib import pyplot as plt import torch -import torchvision.transforms as transforms -from torch.utils.data import DataLoader from conf import settings -from utils import get_network, get_test_dataloader +from utils import get_network, get_dataloader if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-net', type=str, required=True, help='net type') - parser.add_argument('-weights', type=str, required=True, help='the weights file you want to test') - parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not') - parser.add_argument('-b', type=int, default=16, help='batch size for dataloader') + parser.add_argument('--net', type=str, required=True, help='net type') + parser.add_argument('--weights', type=str, required=True, help='the weights file you want to test') + parser.add_argument('--gpu', type=str, default='-1', help='gpu device id, set `-1` to use cpu only') + parser.add_argument('--batch', '-b', type=int, default=16, help='batch size for dataloader') args = parser.parse_args() - net = get_network(args) - - cifar100_test_loader = get_test_dataloader( + if args.gpu == '-1': + device = 'cpu' + else: + if torch.cuda.is_available(): + device = f'cuda:{args.gpu}' + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + else: + raise ValueError('GPU is not available. please set `--gpu -1` to use cpu only. ') + + net = get_network(args.net).to(device) + + _, cifar100_test_loader = get_dataloader( settings.CIFAR100_TRAIN_MEAN, settings.CIFAR100_TRAIN_STD, - #settings.CIFAR100_PATH, + rank=0, num_workers=4, - batch_size=args.b, + batch_size=args.batch, + const_test_batch=False ) - net.load_state_dict(torch.load(args.weights)) + state_dict = torch.load(args.weights, weights_only=True) + new_state_dict = OrderedDict() + + # If training and saving model with DDP, it's necessary to remove the prefix `module.` + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + net.load_state_dict(new_state_dict) print(net) net.eval() @@ -50,12 +69,7 @@ for n_iter, (image, label) in enumerate(cifar100_test_loader): print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader))) - if args.gpu: - image = image.cuda() - label = label.cuda() - print('GPU INFO.....') - print(torch.cuda.memory_summary(), end='') - + image, label = image.to(device), label.to(device) output = net(image) _, pred = output.topk(5, 1, largest=True, sorted=True) @@ -69,7 +83,7 @@ #compute top1 correct_1 += correct[:, :1].sum() - if args.gpu: + if device != 'cpu': print('GPU INFO.....') print(torch.cuda.memory_summary(), end='') diff --git a/train.py b/train.py index c5034606..6df420ee 100644 --- a/train.py +++ b/train.py @@ -7,34 +7,29 @@ """ import os -import sys import argparse import time -from datetime import datetime -import numpy as np import torch import torch.nn as nn import torch.optim as optim -import torchvision -import torchvision.transforms as transforms -from torch.utils.data import DataLoader +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + from torch.utils.tensorboard import SummaryWriter from conf import settings -from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \ - most_recent_folder, most_recent_weights, last_epoch, best_acc_weights +from utils import get_network, get_dataloader, WarmUpLR, \ + most_recent_folder, most_recent_weights, last_epoch, \ + best_acc_weights, reduce_metric def train(epoch): - start = time.time() net.train() + step_times = [] for batch_index, (images, labels) in enumerate(cifar100_training_loader): - - if args.gpu: - labels = labels.cuda() - images = images.cuda() + images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = net(images) @@ -44,35 +39,45 @@ def train(epoch): n_iter = (epoch - 1) * len(cifar100_training_loader) + batch_index + 1 + if batch_index > 0 and batch_index < len(cifar100_training_loader) - 1: + step_time = time.time() - start + step_times.append(step_time) + last_layer = list(net.children())[-1] for name, para in last_layer.named_parameters(): if 'weight' in name: writer.add_scalar('LastLayerGradients/grad_norm2_weights', para.grad.norm(), n_iter) if 'bias' in name: writer.add_scalar('LastLayerGradients/grad_norm2_bias', para.grad.norm(), n_iter) - - print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format( - loss.item(), - optimizer.param_groups[0]['lr'], - epoch=epoch, - trained_samples=batch_index * args.b + len(images), - total_samples=len(cifar100_training_loader.dataset) - )) + + if master_process: + print('Training Epoch: {epoch} [{train_step}/{total_step}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format( + loss.item(), + optimizer.param_groups[0]['lr'], + epoch=epoch, + train_step=(batch_index + 1), + total_step=len(cifar100_training_loader) + )) #update training loss for each iteration writer.add_scalar('Train/loss', loss.item(), n_iter) - if epoch <= args.warm: + if epoch <= args.warmup: warmup_scheduler.step() - for name, param in net.named_parameters(): - layer, attr = os.path.splitext(name) - attr = attr[1:] - writer.add_histogram("{}/{}".format(layer, attr), param, epoch) - - finish = time.time() + if master_process: + for name, param in net.named_parameters(): + layer, attr = os.path.splitext(name) + attr = attr[1:] + writer.add_histogram("{}/{}".format(layer, attr), param, epoch) + + avg_step_time = sum(step_times) / len(step_times) if len(step_times) > 0 else 0 + avg_throughput = len(cifar100_training_loader.dataset) / avg_step_time if avg_step_time > 0 else 0 - print('epoch {} training time consumed: {:.2f}s'.format(epoch, finish - start)) + epoch_duration = time.time() - start + if master_process: + print('epoch {} training time consumed: {:.2f}s, avg_throughput: {:.2f} imgs/sec'.format( + epoch, epoch_duration, avg_throughput)) @torch.no_grad() def eval_training(epoch=0, tb=True): @@ -84,73 +89,104 @@ def eval_training(epoch=0, tb=True): correct = 0.0 for (images, labels) in cifar100_test_loader: - - if args.gpu: - images = images.cuda() - labels = labels.cuda() - + images, labels = images.to(device), labels.to(device) outputs = net(images) loss = loss_function(outputs, labels) - test_loss += loss.item() _, preds = outputs.max(1) correct += preds.eq(labels).sum() + if ddp: + correct = correct.clone().detach().to(device) + correct = reduce_metric(correct) + correct = 100. * correct.item() + finish = time.time() - if args.gpu: - print('GPU INFO.....') - print(torch.cuda.memory_summary(), end='') - print('Evaluating Network.....') - print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format( - epoch, - test_loss / len(cifar100_test_loader.dataset), - correct.float() / len(cifar100_test_loader.dataset), - finish - start - )) - print() + + if master_process: + if device != 'cpu': + print('GPU INFO.....') + print(torch.cuda.memory_summary(), end='') + print('Evaluating Network.....') + print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format( + epoch, + test_loss / len(cifar100_test_loader), + correct / len(cifar100_test_loader.dataset), + finish - start + )) + print() #add informations to tensorboard - if tb: - writer.add_scalar('Test/Average loss', test_loss / len(cifar100_test_loader.dataset), epoch) - writer.add_scalar('Test/Accuracy', correct.float() / len(cifar100_test_loader.dataset), epoch) + if tb and master_process: + writer.add_scalar('Test/Average loss', test_loss / len(cifar100_test_loader), epoch) + writer.add_scalar('Test/Accuracy', correct / len(cifar100_test_loader.dataset), epoch) - return correct.float() / len(cifar100_test_loader.dataset) + return correct / len(cifar100_test_loader.dataset) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-net', type=str, required=True, help='net type') - parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not') - parser.add_argument('-b', type=int, default=128, help='batch size for dataloader') - parser.add_argument('-warm', type=int, default=1, help='warm up training phase') - parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate') - parser.add_argument('-resume', action='store_true', default=False, help='resume training') + parser.add_argument('--net', type=str, required=True, help='net type') + parser.add_argument('--gpu', type=str, default='-1', help='gpu device id, set `-1` to use cpu only') + parser.add_argument('--batch', '-b', type=int, default=128, help='batch size for dataloader') + parser.add_argument('--use_test_batch', action='store_true', default=False, help='Whether to set test_dataloader batch size as args.batch') + parser.add_argument('--epoch', type=int, default=100, help='training epochs') + parser.add_argument('--warmup', type=int, default=1, help='warm up training phase') + parser.add_argument('--learning_rate', '-lr', type=float, default=0.1, help='initial learning rate') + parser.add_argument('--optimizer', type=str, default='SGD', help='optimizer name') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer') + parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay for optimizer') + parser.add_argument('--num_workers', type=int, default=4, help='num workers for dataloader') + parser.add_argument('--resume', action='store_true', default=False, help='resume training') args = parser.parse_args() - net = get_network(args) + if args.gpu == '-1': + device = 'cpu' + else: + if torch.cuda.is_available(): + device = 'cuda' + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + else: + raise ValueError('GPU is not available. please set `--gpu -1` to use cpu only. ') + + ddp = int(os.environ.get('RANK', -1)) != -1 + + if ddp: + dist.init_process_group(backend='nccl') + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + device = torch.device(f'cuda:{local_rank}') + master_process = rank == 0 + seed_offset = rank + else: + master_process = True + rank = 0 + seed_offset = 0 + world_size = 1 + + torch.manual_seed(2024 + seed_offset) + net = get_network(args.net).to(device) + if ddp: + net = DDP(net, device_ids=[local_rank], output_device=local_rank) + #data preprocessing: - cifar100_training_loader = get_training_dataloader( - settings.CIFAR100_TRAIN_MEAN, - settings.CIFAR100_TRAIN_STD, - num_workers=4, - batch_size=args.b, - shuffle=True - ) - - cifar100_test_loader = get_test_dataloader( + cifar100_training_loader, cifar100_test_loader = get_dataloader( settings.CIFAR100_TRAIN_MEAN, settings.CIFAR100_TRAIN_STD, - num_workers=4, - batch_size=args.b, + rank=rank, + batch_size=args.batch, + num_workers=args.num_workers, + const_test_batch=(not args.use_test_batch), shuffle=True ) loss_function = nn.CrossEntropyLoss() - optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) - train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2) #learning rate decay + optimizer = getattr(optim, args.optimizer, 'SGD')(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) + train_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch, eta_min=1e-6) iter_per_epoch = len(cifar100_training_loader) - warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm) + warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warmup) if args.resume: recent_folder = most_recent_folder(os.path.join(settings.CHECKPOINT_PATH, args.net), fmt=settings.DATE_FORMAT) @@ -168,12 +204,15 @@ def eval_training(epoch=0, tb=True): #since tensorboard can't overwrite old values #so the only way is to create a new tensorboard log + writer = SummaryWriter(log_dir=os.path.join( settings.LOG_DIR, args.net, settings.TIME_NOW)) - input_tensor = torch.Tensor(1, 3, 32, 32) - if args.gpu: - input_tensor = input_tensor.cuda() - writer.add_graph(net, input_tensor) + input_tensor = torch.Tensor(1, 3, 32, 32).to(device) + if ddp: + if dist.get_rank() == 0: + writer.add_graph(net.module, input_tensor) + else: + writer.add_graph(net, input_tensor) #create checkpoint folder to save model if not os.path.exists(checkpoint_path): @@ -197,12 +236,13 @@ def eval_training(epoch=0, tb=True): weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder, recent_weights_file) print('loading weights file {} to resume training.....'.format(weights_path)) net.load_state_dict(torch.load(weights_path)) - + if ddp: + net = DDP(net, device_ids=[local_rank], output_device=local_rank) resume_epoch = last_epoch(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder)) - for epoch in range(1, settings.EPOCH + 1): - if epoch > args.warm: + for epoch in range(1, args.epoch + 1): + if epoch > args.warmup: train_scheduler.step(epoch) if args.resume: @@ -212,17 +252,24 @@ def eval_training(epoch=0, tb=True): train(epoch) acc = eval_training(epoch) - #start to save best performance model after learning rate decay to 0.01 - if epoch > settings.MILESTONES[1] and best_acc < acc: - weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='best') - print('saving weights file to {}'.format(weights_path)) - torch.save(net.state_dict(), weights_path) - best_acc = acc - continue + if master_process: + #start to save best performance model after learning rate decay to 0.01 + if epoch > settings.MILESTONES[1] and best_acc < acc: + weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='best') + print('saving weights file to {}'.format(weights_path)) + torch.save(net.state_dict(), weights_path) + best_acc = acc + continue + + if not epoch % settings.SAVE_EPOCH: + weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular') + print('saving weights file to {}'.format(weights_path)) + torch.save(net.state_dict(), weights_path) - if not epoch % settings.SAVE_EPOCH: - weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular') - print('saving weights file to {}'.format(weights_path)) - torch.save(net.state_dict(), weights_path) + if ddp: + dist.barrier() writer.close() + if ddp: + dist.destroy_process_group() + diff --git a/utils.py b/utils.py index f2cfac38..5febee69 100644 --- a/utils.py +++ b/utils.py @@ -5,151 +5,152 @@ import os import sys import re +import time import datetime import numpy -import torch from torch.optim.lr_scheduler import _LRScheduler -import torchvision -import torchvision.transforms as transforms +from torchvision import datasets, transforms from torch.utils.data import DataLoader +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler -def get_network(args): +def get_network(net): """ return given network """ - if args.net == 'vgg16': + if net == 'vgg16': from models.vgg import vgg16_bn net = vgg16_bn() - elif args.net == 'vgg13': + elif net == 'vgg13': from models.vgg import vgg13_bn net = vgg13_bn() - elif args.net == 'vgg11': + elif net == 'vgg11': from models.vgg import vgg11_bn net = vgg11_bn() - elif args.net == 'vgg19': + elif net == 'vgg19': from models.vgg import vgg19_bn net = vgg19_bn() - elif args.net == 'densenet121': + elif net == 'densenet121': from models.densenet import densenet121 net = densenet121() - elif args.net == 'densenet161': + elif net == 'densenet161': from models.densenet import densenet161 net = densenet161() - elif args.net == 'densenet169': + elif net == 'densenet169': from models.densenet import densenet169 net = densenet169() - elif args.net == 'densenet201': + elif net == 'densenet201': from models.densenet import densenet201 net = densenet201() - elif args.net == 'googlenet': + elif net == 'googlenet': from models.googlenet import googlenet net = googlenet() - elif args.net == 'inceptionv3': + elif net == 'inceptionv3': from models.inceptionv3 import inceptionv3 net = inceptionv3() - elif args.net == 'inceptionv4': + elif net == 'inceptionv4': from models.inceptionv4 import inceptionv4 net = inceptionv4() - elif args.net == 'inceptionresnetv2': + elif net == 'inceptionresnetv2': from models.inceptionv4 import inception_resnet_v2 net = inception_resnet_v2() - elif args.net == 'xception': + elif net == 'xception': from models.xception import xception net = xception() - elif args.net == 'resnet18': + elif net == 'resnet18': from models.resnet import resnet18 net = resnet18() - elif args.net == 'resnet34': + elif net == 'resnet34': from models.resnet import resnet34 net = resnet34() - elif args.net == 'resnet50': + elif net == 'resnet50': from models.resnet import resnet50 net = resnet50() - elif args.net == 'resnet101': + elif net == 'resnet101': from models.resnet import resnet101 net = resnet101() - elif args.net == 'resnet152': + elif net == 'resnet152': from models.resnet import resnet152 net = resnet152() - elif args.net == 'preactresnet18': + elif net == 'preactresnet18': from models.preactresnet import preactresnet18 net = preactresnet18() - elif args.net == 'preactresnet34': + elif net == 'preactresnet34': from models.preactresnet import preactresnet34 net = preactresnet34() - elif args.net == 'preactresnet50': + elif net == 'preactresnet50': from models.preactresnet import preactresnet50 net = preactresnet50() - elif args.net == 'preactresnet101': + elif net == 'preactresnet101': from models.preactresnet import preactresnet101 net = preactresnet101() - elif args.net == 'preactresnet152': + elif net == 'preactresnet152': from models.preactresnet import preactresnet152 net = preactresnet152() - elif args.net == 'resnext50': + elif net == 'resnext50': from models.resnext import resnext50 net = resnext50() - elif args.net == 'resnext101': + elif net == 'resnext101': from models.resnext import resnext101 net = resnext101() - elif args.net == 'resnext152': + elif net == 'resnext152': from models.resnext import resnext152 net = resnext152() - elif args.net == 'shufflenet': + elif net == 'shufflenet': from models.shufflenet import shufflenet net = shufflenet() - elif args.net == 'shufflenetv2': + elif net == 'shufflenetv2': from models.shufflenetv2 import shufflenetv2 net = shufflenetv2() - elif args.net == 'squeezenet': + elif net == 'squeezenet': from models.squeezenet import squeezenet net = squeezenet() - elif args.net == 'mobilenet': + elif net == 'mobilenet': from models.mobilenet import mobilenet net = mobilenet() - elif args.net == 'mobilenetv2': + elif net == 'mobilenetv2': from models.mobilenetv2 import mobilenetv2 net = mobilenetv2() - elif args.net == 'nasnet': + elif net == 'nasnet': from models.nasnet import nasnet net = nasnet() - elif args.net == 'attention56': + elif net == 'attention56': from models.attention import attention56 net = attention56() - elif args.net == 'attention92': + elif net == 'attention92': from models.attention import attention92 net = attention92() - elif args.net == 'seresnet18': + elif net == 'seresnet18': from models.senet import seresnet18 net = seresnet18() - elif args.net == 'seresnet34': + elif net == 'seresnet34': from models.senet import seresnet34 net = seresnet34() - elif args.net == 'seresnet50': + elif net == 'seresnet50': from models.senet import seresnet50 net = seresnet50() - elif args.net == 'seresnet101': + elif net == 'seresnet101': from models.senet import seresnet101 net = seresnet101() - elif args.net == 'seresnet152': + elif net == 'seresnet152': from models.senet import seresnet152 net = seresnet152() - elif args.net == 'wideresnet': + elif net == 'wideresnet': from models.wideresidual import wideresnet net = wideresnet() - elif args.net == 'stochasticdepth18': + elif net == 'stochasticdepth18': from models.stochasticdepth import stochastic_depth_resnet18 net = stochastic_depth_resnet18() - elif args.net == 'stochasticdepth34': + elif net == 'stochasticdepth34': from models.stochasticdepth import stochastic_depth_resnet34 net = stochastic_depth_resnet34() - elif args.net == 'stochasticdepth50': + elif net == 'stochasticdepth50': from models.stochasticdepth import stochastic_depth_resnet50 net = stochastic_depth_resnet50() - elif args.net == 'stochasticdepth101': + elif net == 'stochasticdepth101': from models.stochasticdepth import stochastic_depth_resnet101 net = stochastic_depth_resnet101() @@ -157,20 +158,18 @@ def get_network(args): print('the network name you have entered is not supported yet') sys.exit() - if args.gpu: #use_gpu - net = net.cuda() - return net -def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True): +def get_dataloader(mean, std, rank=0, batch_size=16, num_workers=2, const_test_batch=True, shuffle=True): """ return training dataloader Args: mean: mean of cifar100 training dataset std: std of cifar100 training dataset - path: path to cifar100 training python dataset - batch_size: dataloader batchsize + rank: global rank of + batch_size: dataloader batch size num_workers: dataloader num_works + const_test_batch: if set `True`, then the batch size of test_dataloader set to be 64, regardless args.batch shuffle: whether to shuffle Returns: train_data_loader:torch dataloader object """ @@ -183,35 +182,45 @@ def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=Tru transforms.ToTensor(), transforms.Normalize(mean, std) ]) - #cifar100_training = CIFAR100Train(path, transform=transform_train) - cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) - cifar100_training_loader = DataLoader( - cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size) - - return cifar100_training_loader - -def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True): - """ return training dataloader - Args: - mean: mean of cifar100 test dataset - std: std of cifar100 test dataset - path: path to cifar100 test python dataset - batch_size: dataloader batchsize - num_workers: dataloader num_works - shuffle: whether to shuffle - Returns: cifar100_test_loader:torch dataloader object - """ transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) - #cifar100_test = CIFAR100Test(path, transform=transform_test) - cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) + + #cifar100_training = CIFAR100Train(path, transform=transform_train) + root_dir = './data/' + + if rank == 0: + cifar100_training = datasets.CIFAR100(root=root_dir, train=True, download=True, transform=transform_train) + cifar100_test = datasets.CIFAR100(root=root_dir, train=False, download=True, transform=transform_test) + with open(os.path.join(root_dir + 'download_complete.lock'), 'w') as f: + f.write('done') + + else: + while not os.path.exists(os.path.join(root_dir + 'download_complete.lock')): + time.sleep(1) + cifar100_training = datasets.CIFAR100(root=root_dir, train=True, download=False, transform=transform_train) + cifar100_test = datasets.CIFAR100(root=root_dir, train=False, download=False, transform=transform_test) + + if dist.is_initialized(): + shuffle = False + + cifar100_training_sampler = DistributedSampler(cifar100_training, num_replicas=dist.get_world_size(), rank=dist.get_rank()) if dist.is_initialized() else None + cifar100_test_sampler = DistributedSampler(cifar100_test, num_replicas=dist.get_world_size(), rank=dist.get_rank()) if dist.is_initialized() else None + + cifar100_training_loader = DataLoader( + cifar100_training, sampler=cifar100_training_sampler, + shuffle=shuffle, num_workers=num_workers, batch_size=batch_size) + + test_batch_size = 64 if const_test_batch else batch_size + cifar100_test_loader = DataLoader( - cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size) + cifar100_test, sampler=cifar100_test_sampler, + shuffle=shuffle, num_workers=num_workers, batch_size=test_batch_size) + + return cifar100_training_loader, cifar100_test_loader - return cifar100_test_loader def compute_mean_std(cifar100_dataset): """compute the mean and std of cifar100 dataset @@ -238,7 +247,6 @@ class WarmUpLR(_LRScheduler): total_iters: totoal_iters of warmup phase """ def __init__(self, optimizer, total_iters, last_epoch=-1): - self.total_iters = total_iters super().__init__(optimizer, last_epoch) @@ -305,4 +313,9 @@ def best_acc_weights(weights_folder): return '' best_files = sorted(best_files, key=lambda w: int(re.search(regex_str, w).groups()[1])) - return best_files[-1] \ No newline at end of file + return best_files[-1] + +def reduce_metric(metric): + dist.all_reduce(metric, op=dist.ReduceOp.SUM) + return metric +