diff --git a/vae/README.md b/vae/README.md index fcaae9a286..35cac8899e 100644 --- a/vae/README.md +++ b/vae/README.md @@ -22,13 +22,21 @@ To force execution on the CPU, use `--no-accel` command line argument: python main.py --no-accel ``` +To run on a TPU via XLA, install `torch_xla` and use the `--xla` flag: + +```bash +pip install torch_xla[tpu] +python main.py --xla +``` + The main.py script accepts the following optional arguments: ```bash --batch-size input batch size for training (default: 128) --epochs number of epochs to train (default: 10) --no-accel disables accelerator +--xla enables XLA device (e.g. TPU). Requires torch_xla. --seed random seed (default: 1) ---log-interval how many batches to wait before logging training status +--log-interval how many batches to wait before logging training status ``` diff --git a/vae/main.py b/vae/main.py index 6390965810..2902661636 100644 --- a/vae/main.py +++ b/vae/main.py @@ -8,39 +8,62 @@ from torchvision.utils import save_image +_XLA_AVAILABLE = False +try: + import torch_xla + _XLA_AVAILABLE = True +except ImportError: + pass + + parser = argparse.ArgumentParser(description='VAE MNIST Example') parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') -parser.add_argument('--no-accel', action='store_true', +parser.add_argument('--no-accel', action='store_true', help='disables accelerator') +parser.add_argument('--xla', action='store_true', default=False, + help='enables XLA device (e.g. TPU). Requires torch_xla.') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') args = parser.parse_args() -use_accel = not args.no_accel and torch.accelerator.is_available() +if args.xla: + if not _XLA_AVAILABLE: + raise RuntimeError( + "--xla flag requires torch_xla to be installed. " + "Install with: pip install torch_xla[tpu]" + ) + device = torch_xla.device() +else: + use_accel = not args.no_accel and torch.accelerator.is_available() + device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu") torch.manual_seed(args.seed) +print(f"Using device: {device}") -if use_accel: - device = torch.accelerator.current_accelerator() -else: - device = torch.device("cpu") +train_kwargs = {'batch_size': args.batch_size} +test_kwargs = {'batch_size': args.batch_size} -print(f"Using device: {device}") +if args.xla: + train_kwargs.update({'num_workers': 4, 'persistent_workers': True, + 'shuffle': True, 'drop_last': True}) + test_kwargs.update({'num_workers': 4, 'persistent_workers': True}) +elif not args.no_accel and torch.accelerator.is_available(): + train_kwargs.update({'num_workers': 1, 'pin_memory': True, 'shuffle': True}) + test_kwargs.update({'num_workers': 1, 'pin_memory': True}) -kwargs = {'num_workers': 1, 'pin_memory': True} if use_accel else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), - batch_size=args.batch_size, shuffle=True, **kwargs) + **train_kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), - batch_size=args.batch_size, shuffle=False, **kwargs) + **test_kwargs) class VAE(nn.Module): @@ -91,15 +114,17 @@ def loss_function(recon_x, x, mu, logvar): def train(epoch): model.train() - train_loss = 0 + train_loss = torch.tensor(0.0, device=device) for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward() - train_loss += loss.item() + train_loss += loss.detach() optimizer.step() + if args.xla: + torch_xla.sync() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), @@ -107,7 +132,7 @@ def train(epoch): loss.item() / len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format( - epoch, train_loss / len(train_loader.dataset))) + epoch, train_loss.item() / len(train_loader.dataset))) def test(epoch):