diff --git a/mnist/README.md b/mnist/README.md index dc97bd5c53..fc12f4c302 100644 --- a/mnist/README.md +++ b/mnist/README.md @@ -3,5 +3,6 @@ ```bash pip install -r requirements.txt python main.py +python main.py --data-dir /tmp/mnist-data # CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 ``` diff --git a/mnist/main.py b/mnist/main.py index dee5a384cb..6962a2d880 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -90,6 +90,8 @@ def main(): help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') + parser.add_argument('--data-dir', type=str, default='../data', + help='directory for storing input data (default: ../data)') parser.add_argument('--save-model', action='store_true', help='For Saving the current Model') args = parser.parse_args() @@ -117,9 +119,9 @@ def main(): transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, + dataset1 = datasets.MNIST(args.data_dir, train=True, download=True, transform=transform) - dataset2 = datasets.MNIST('../data', train=False, + dataset2 = datasets.MNIST(args.data_dir, train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)