@@ -47,9 +47,15 @@ def train(args, model, device, train_loader, optimizer, epoch):
4747 loss .backward ()
4848 optimizer .step ()
4949 if batch_idx % args .log_interval == 0 :
50- print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
51- epoch , batch_idx * len (data ), len (train_loader .dataset ),
52- 100. * batch_idx / len (train_loader ), loss .item ()))
50+ print (
51+ "Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}" .format (
52+ epoch ,
53+ batch_idx * len (data ),
54+ len (train_loader .dataset ),
55+ 100.0 * batch_idx / len (train_loader ),
56+ loss .item (),
57+ )
58+ )
5359 if args .dry_run :
5460 break
5561
@@ -62,66 +68,110 @@ def test(model, device, test_loader):
6268 for data , target in test_loader :
6369 data , target = data .to (device ), target .to (device )
6470 output = model (data )
65- test_loss += F .nll_loss (output , target , reduction = 'sum' ).item () # sum up batch loss
66- pred = output .argmax (dim = 1 , keepdim = True ) # get the index of the max log-probability
71+ test_loss += F .nll_loss (
72+ output , target , reduction = "sum"
73+ ).item () # sum up batch loss
74+ pred = output .argmax (
75+ dim = 1 , keepdim = True
76+ ) # get the index of the max log-probability
6777 correct += pred .eq (target .view_as (pred )).sum ().item ()
6878
6979 test_loss /= len (test_loader .dataset )
7080
71- print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
72- test_loss , correct , len (test_loader .dataset ),
73- 100. * correct / len (test_loader .dataset )))
81+ print (
82+ "\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n " .format (
83+ test_loss ,
84+ correct ,
85+ len (test_loader .dataset ),
86+ 100.0 * correct / len (test_loader .dataset ),
87+ )
88+ )
7489
7590
7691def main ():
7792 # Training settings
78- parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
79- parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
80- help = 'input batch size for training (default: 64)' )
81- parser .add_argument ('--test-batch-size' , type = int , default = 1000 , metavar = 'N' ,
82- help = 'input batch size for testing (default: 1000)' )
83- parser .add_argument ('--epochs' , type = int , default = 14 , metavar = 'N' ,
84- help = 'number of epochs to train (default: 14)' )
85- parser .add_argument ('--lr' , type = float , default = 1.0 , metavar = 'LR' ,
86- help = 'learning rate (default: 1.0)' )
87- parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
88- help = 'Learning rate step gamma (default: 0.7)' )
89- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
90- help = 'disables CUDA training' )
91- parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
92- help = 'quickly check a single pass' )
93- parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
94- help = 'random seed (default: 1)' )
95- parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
96- help = 'how many batches to wait before logging training status' )
97- parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
98- help = 'For Saving the current Model' )
99- parser .add_argument ('-D' , '--define' , nargs = '*' , default = [], action = "extend" )
93+ parser = argparse .ArgumentParser (description = "PyTorch MNIST Example" )
94+ parser .add_argument (
95+ "--batch-size" ,
96+ type = int ,
97+ default = 64 ,
98+ metavar = "N" ,
99+ help = "input batch size for training (default: 64)" ,
100+ )
101+ parser .add_argument (
102+ "--test-batch-size" ,
103+ type = int ,
104+ default = 1000 ,
105+ metavar = "N" ,
106+ help = "input batch size for testing (default: 1000)" ,
107+ )
108+ parser .add_argument (
109+ "--epochs" ,
110+ type = int ,
111+ default = 14 ,
112+ metavar = "N" ,
113+ help = "number of epochs to train (default: 14)" ,
114+ )
115+ parser .add_argument (
116+ "--lr" ,
117+ type = float ,
118+ default = 1.0 ,
119+ metavar = "LR" ,
120+ help = "learning rate (default: 1.0)" ,
121+ )
122+ parser .add_argument (
123+ "--gamma" ,
124+ type = float ,
125+ default = 0.7 ,
126+ metavar = "M" ,
127+ help = "Learning rate step gamma (default: 0.7)" ,
128+ )
129+ parser .add_argument (
130+ "--no-cuda" , action = "store_true" , default = False , help = "disables CUDA training"
131+ )
132+ parser .add_argument (
133+ "--dry-run" ,
134+ action = "store_true" ,
135+ default = False ,
136+ help = "quickly check a single pass" ,
137+ )
138+ parser .add_argument (
139+ "--seed" , type = int , default = 1 , metavar = "S" , help = "random seed (default: 1)"
140+ )
141+ parser .add_argument (
142+ "--log-interval" ,
143+ type = int ,
144+ default = 10 ,
145+ metavar = "N" ,
146+ help = "how many batches to wait before logging training status" ,
147+ )
148+ parser .add_argument (
149+ "--save-model" ,
150+ action = "store_true" ,
151+ default = False ,
152+ help = "For Saving the current Model" ,
153+ )
154+ parser .add_argument ("-D" , "--define" , nargs = "*" , default = [], action = "extend" )
100155 args = parser .parse_args ()
101156 use_cuda = not args .no_cuda and torch .cuda .is_available ()
102157
103158 torch .manual_seed (args .seed )
104159
105160 device = torch .device ("cuda" if use_cuda else "cpu" )
106161
107- train_kwargs = {' batch_size' : args .batch_size }
108- test_kwargs = {' batch_size' : args .test_batch_size }
162+ train_kwargs = {" batch_size" : args .batch_size }
163+ test_kwargs = {" batch_size" : args .test_batch_size }
109164 if use_cuda :
110- cuda_kwargs = {'num_workers' : 1 ,
111- 'pin_memory' : True ,
112- 'shuffle' : True }
165+ cuda_kwargs = {"num_workers" : 1 , "pin_memory" : True , "shuffle" : True }
113166 train_kwargs .update (cuda_kwargs )
114167 test_kwargs .update (cuda_kwargs )
115168
116- transform = transforms .Compose ([
117- transforms .ToTensor (),
118- transforms .Normalize ((0.1307 ,), (0.3081 ,))
119- ])
120- dataset1 = datasets .MNIST ('../data' , train = True , download = True ,
121- transform = transform )
122- dataset2 = datasets .MNIST ('../data' , train = False ,
123- transform = transform )
124- train_loader = torch .utils .data .DataLoader (dataset1 ,** train_kwargs )
169+ transform = transforms .Compose (
170+ [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
171+ )
172+ dataset1 = datasets .MNIST ("../data" , train = True , download = True , transform = transform )
173+ dataset2 = datasets .MNIST ("../data" , train = False , transform = transform )
174+ train_loader = torch .utils .data .DataLoader (dataset1 , ** train_kwargs )
125175 test_loader = torch .utils .data .DataLoader (dataset2 , ** test_kwargs )
126176
127177 with param_scope (* args .define ):
@@ -138,5 +188,5 @@ def main():
138188 torch .save (model .state_dict (), "mnist_cnn.pt" )
139189
140190
141- if __name__ == ' __main__' :
191+ if __name__ == " __main__" :
142192 main ()
0 commit comments