1+ def train_pytorch ():
2+ import os
3+ import logging
4+
5+ import torch
6+ from torch import nn
7+ import torch .nn .functional as F
8+
9+ from torchvision import datasets , transforms
10+ import torch .distributed as dist
11+ from torch .utils .data import DataLoader , DistributedSampler
12+
13+ # Configure logger (similar to KFTO mnist.py)
14+ log_formatter = logging .Formatter (
15+ "%(asctime)s %(levelname)-8s %(message)s" , "%Y-%m-%dT%H:%M:%SZ"
16+ )
17+ logger = logging .getLogger (__file__ )
18+ console_handler = logging .StreamHandler ()
19+ console_handler .setFormatter (log_formatter )
20+ logger .addHandler (console_handler )
21+ logger .setLevel (logging .INFO )
22+
23+ # [1] Configure CPU/GPU device and distributed backend.
24+ # Kubeflow Trainer will automatically configure the distributed environment.
25+ device , backend = ("cuda" , "nccl" ) if torch .cuda .is_available () else ("cpu" , "gloo" )
26+ dist .init_process_group (backend = backend )
27+
28+ local_rank = int (os .getenv ("LOCAL_RANK" , 0 ))
29+ logger .info (
30+ "Distributed Training with WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}." .format (
31+ dist .get_world_size (),
32+ dist .get_rank (),
33+ local_rank ,
34+ )
35+ )
36+
37+ # [2] Define PyTorch CNN Model to be trained.
38+ class Net (nn .Module ):
39+ def __init__ (self ):
40+ super (Net , self ).__init__ ()
41+ self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
42+ self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
43+ self .fc1 = nn .Linear (9216 , 128 )
44+ self .fc2 = nn .Linear (128 , 10 )
45+
46+ def forward (self , x ):
47+ x = F .relu (self .conv1 (x ))
48+ x = F .relu (self .conv2 (x ))
49+ x = F .max_pool2d (x , 2 )
50+ x = x .view (- 1 , 9216 )
51+ x = F .relu (self .fc1 (x ))
52+ x = self .fc2 (x )
53+ return F .log_softmax (x , dim = 1 )
54+
55+ # [3] Attach model to the correct device.
56+ device = torch .device (f"{ device } :{ local_rank } " )
57+ model = nn .parallel .DistributedDataParallel (Net ().to (device ))
58+ model .train ()
59+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.1 , momentum = 0.9 )
60+
61+ # [4] Get the Fashion-MNIST dataset.
62+ # Dataset should be pre-downloaded to avoid network dependencies.
63+ dataset_path = os .getenv ("DATASET_PATH" , "./data" )
64+
65+ # Load dataset (download=False assumes dataset is already present)
66+ dataset = datasets .FashionMNIST (
67+ dataset_path ,
68+ train = True ,
69+ download = False ,
70+ transform = transforms .Compose ([transforms .ToTensor ()]),
71+ )
72+ # Batch size configurable via env var (smaller = more iterations = longer training)
73+ batch_size = int (os .getenv ("BATCH_SIZE" , "64" ))
74+ train_loader = DataLoader (
75+ dataset ,
76+ batch_size = batch_size ,
77+ sampler = DistributedSampler (dataset ),
78+ )
79+
80+ # [5] Define the training loop.
81+ num_epochs = int (os .getenv ("NUM_EPOCHS" , "1" ))
82+ global_rank = dist .get_rank ()
83+ world_size = dist .get_world_size ()
84+
85+ for epoch in range (num_epochs ):
86+ # Log epoch start from ALL ranks
87+ num_batches = len (train_loader )
88+ device_type = "GPU" if torch .cuda .is_available () else "CPU"
89+ logger .info (f"[{ device_type } { global_rank } ] Epoch { epoch } | Batchsize: { batch_size } | Steps: { num_batches } | World Size: { world_size } " )
90+
91+ # Set epoch for DistributedSampler to ensure proper shuffling
92+ if isinstance (train_loader .sampler , DistributedSampler ):
93+ train_loader .sampler .set_epoch (epoch )
94+
95+ epoch_loss = 0.0
96+ num_batches_processed = 0
97+
98+ for batch_idx , (inputs , labels ) in enumerate (train_loader ):
99+ # Attach tensors to the device.
100+ inputs , labels = inputs .to (device ), labels .to (device )
101+
102+ # Forward pass
103+ outputs = model (inputs )
104+ loss = F .nll_loss (outputs , labels )
105+
106+ # Backward pass
107+ optimizer .zero_grad ()
108+ loss .backward ()
109+ optimizer .step ()
110+
111+ # Track loss for epoch summary
112+ epoch_loss += loss .item ()
113+ num_batches_processed += 1
114+
115+ # Log detailed training progress from rank 0 only (to avoid log spam)
116+ if batch_idx % 10 == 0 and global_rank == 0 :
117+ logger .info (
118+ "Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}" .format (
119+ epoch ,
120+ batch_idx * len (inputs ) * world_size , # Adjust for distributed training
121+ len (train_loader .dataset ),
122+ 100.0 * batch_idx / num_batches ,
123+ loss .item (),
124+ )
125+ )
126+
127+ # End-of-epoch summary from ALL ranks
128+ avg_loss = epoch_loss / num_batches_processed
129+ logger .info (f"[{ device_type } { global_rank } ] Epoch { epoch } completed | Avg Loss: { avg_loss :.6f} | Batches: { num_batches_processed } " )
130+
131+ # Wait for the training to complete and destroy to PyTorch distributed process group.
132+ dist .barrier ()
133+ # All ranks report completion
134+ logger .info (f"[{ device_type } { global_rank } ] Training is finished" )
135+ dist .destroy_process_group ()
136+
137+
138+ if __name__ == "__main__" :
139+ train_pytorch ()
0 commit comments