@@ -199,13 +199,15 @@ def __init__(
199199 # First norm + conv
200200 self .norm1 = get_norm_layer (in_channels , norm_type , norm_num_groups )
201201 self .act1 = get_activation (activation )
202- self .conv1 = nn .Conv2d (in_channels , out_channels , kernel_size = 3 , padding = 1 )
202+ self .conv1 = nn .Conv2d (in_channels , out_channels ,
203+ kernel_size = 3 , padding = 1 )
203204
204205 # Second norm + conv
205206 self .norm2 = get_norm_layer (out_channels , norm_type , norm_num_groups )
206207 self .act2 = get_activation (activation )
207208 self .dropout = nn .Dropout (dropout ) if dropout > 0 else nn .Identity ()
208- self .conv2 = nn .Conv2d (out_channels , out_channels , kernel_size = 3 , padding = 1 )
209+ self .conv2 = nn .Conv2d (out_channels , out_channels ,
210+ kernel_size = 3 , padding = 1 )
209211
210212 # Skip connection (1x1 conv if channels differ)
211213 if in_channels != out_channels :
@@ -301,7 +303,8 @@ class Downsample(nn.Module):
301303
302304 def __init__ (self , channels : int ):
303305 super ().__init__ ()
304- self .conv = nn .Conv2d (channels , channels , kernel_size = 3 , stride = 2 , padding = 1 )
306+ self .conv = nn .Conv2d (channels , channels ,
307+ kernel_size = 3 , stride = 2 , padding = 1 )
305308
306309 def forward (self , x : torch .Tensor ) -> torch .Tensor :
307310 return self .conv (x )
@@ -369,7 +372,8 @@ def __init__(
369372 self .use_attention_at = set (use_attention_at )
370373
371374 # Initial convolution
372- self .conv_in = nn .Conv2d (in_channels , base_channels , kernel_size = 3 , padding = 1 )
375+ self .conv_in = nn .Conv2d (
376+ in_channels , base_channels , kernel_size = 3 , padding = 1 )
373377
374378 # Downsampling stages
375379 self .stages = nn .ModuleList ()
@@ -540,7 +544,8 @@ def __init__(
540544
541545 # Initial conv from latent
542546 first_ch = base_channels * channel_multipliers_rev [0 ]
543- self .conv_in = nn .Conv2d (latent_channels , first_ch , kernel_size = 3 , padding = 1 )
547+ self .conv_in = nn .Conv2d (
548+ latent_channels , first_ch , kernel_size = 3 , padding = 1 )
544549
545550 # Store for later
546551 self ._channel_multipliers_rev = channel_multipliers_rev
@@ -594,7 +599,8 @@ def __init__(
594599 final_ch = base_channels * channel_multipliers_rev [- 1 ]
595600 self .norm_out = get_norm_layer (final_ch , norm_type , norm_num_groups )
596601 self .act_out = get_activation (activation )
597- self .conv_out = nn .Conv2d (final_ch , out_channels , kernel_size = 3 , padding = 1 )
602+ self .conv_out = nn .Conv2d (
603+ final_ch , out_channels , kernel_size = 3 , padding = 1 )
598604
599605 # Create attention modules
600606 self ._create_attention_modules ()
@@ -825,7 +831,8 @@ def forward(
825831 x_recon = torch .clamp (x_recon , min = - 1.0 , max = 1.0 )
826832
827833 # Compute losses
828- recon_loss = reconstruction_loss (x , x_recon , loss_type = self .recon_loss_type )
834+ recon_loss = reconstruction_loss (
835+ x , x_recon , loss_type = self .recon_loss_type )
829836 kl_loss = kl_divergence (mu , logvar )
830837
831838 # Clamp individual losses to prevent extreme values
@@ -1142,7 +1149,8 @@ def train_epoch(
11421149 if max_grad_norm is not None :
11431150 # Unscareparameterizele once before clipping (official pattern)
11441151 scaler .unscale_ (optimizer )
1145- torch .nn .utils .clip_grad_norm_ (model .parameters (), max_grad_norm )
1152+ torch .nn .utils .clip_grad_norm_ (
1153+ model .parameters (), max_grad_norm )
11461154
11471155 # internally checks for NaN/Inf grads and skips update if needed
11481156 scaler .step (optimizer )
@@ -1159,13 +1167,15 @@ def train_epoch(
11591167 loss .backward ()
11601168
11611169 if has_nonfinite_gradients (model ):
1162- print (f"Skipping batch { batch_idx } due to non-finite gradients" )
1170+ print (
1171+ f"Skipping batch { batch_idx } due to non-finite gradients" )
11631172 optimizer .zero_grad (set_to_none = True )
11641173 continue
11651174
11661175 # Gradient clipping
11671176 if max_grad_norm is not None :
1168- torch .nn .utils .clip_grad_norm_ (model .parameters (), max_grad_norm )
1177+ torch .nn .utils .clip_grad_norm_ (
1178+ model .parameters (), max_grad_norm )
11691179
11701180 optimizer .step ()
11711181
@@ -1181,18 +1191,21 @@ def train_epoch(
11811191 writer .add_scalar (
11821192 "train/recon_loss" , outputs ["recon_loss" ].item (), global_step
11831193 )
1184- writer .add_scalar ("train/kl_loss" , outputs ["kl_loss" ].item (), global_step )
1194+ writer .add_scalar (
1195+ "train/kl_loss" , outputs ["kl_loss" ].item (), global_step )
11851196 writer .add_scalar ("train/kl_weight" , kl_weight , global_step )
11861197 # Log mu and logvar histograms for diagnosing posterior collapse
11871198 if "mu" in outputs and "logvar" in outputs :
1188- writer .add_histogram ("train/mu" , outputs ["mu" ].detach (), global_step )
1199+ writer .add_histogram (
1200+ "train/mu" , outputs ["mu" ].detach (), global_step )
11891201 writer .add_histogram (
11901202 "train/logvar" , outputs ["logvar" ].detach (), global_step
11911203 )
11921204
11931205 # Image logging
11941206 if writer is not None and global_step % image_log_interval == 0 :
1195- log_images (writer , x , outputs ["x_recon" ], global_step , prefix = "train" )
1207+ log_images (writer , x , outputs ["x_recon" ],
1208+ global_step , prefix = "train" )
11961209
11971210 global_step += 1
11981211
@@ -1527,7 +1540,8 @@ def _extract_random_tile(
15271540 # Some TIF backgrounds render as black (0,0,0) rather than transparent
15281541 arr = np .array (img )
15291542 near_black_mask = (
1530- (arr [:, :, 0 ] < 4 ) & (arr [:, :, 1 ] < 4 ) & (arr [:, :, 2 ] < 4 )
1543+ (arr [:, :, 0 ] < 4 ) & (
1544+ arr [:, :, 1 ] < 4 ) & (arr [:, :, 2 ] < 4 )
15311545 )
15321546 arr [near_black_mask ] = [255 , 255 , 255 ]
15331547 img = Image .fromarray (arr )
@@ -1701,7 +1715,8 @@ def parse_args() -> argparse.Namespace:
17011715 "--epochs" , type = int , default = 100 , help = "Number of training epochs"
17021716 )
17031717 parser .add_argument ("--lr" , type = float , default = 1e-4 , help = "Learning rate" )
1704- parser .add_argument ("--weight-decay" , type = float , default = 0.01 , help = "Weight decay" )
1718+ parser .add_argument ("--weight-decay" , type = float ,
1719+ default = 0.01 , help = "Weight decay" )
17051720 parser .add_argument (
17061721 "--beta" , type = float , default = 0.3 , help = "Maximum KL weight (beta-VAE)"
17071722 )
@@ -1754,7 +1769,8 @@ def parse_args() -> argparse.Namespace:
17541769 )
17551770
17561771 # Device
1757- parser .add_argument ("--device" , type = str , default = "cuda" , help = "Device to use" )
1772+ parser .add_argument ("--device" , type = str ,
1773+ default = "cuda" , help = "Device to use" )
17581774
17591775 # Reproducibility
17601776 parser .add_argument ("--seed" , type = int , default = 42 , help = "Random seed" )
@@ -1784,7 +1800,8 @@ def main():
17841800 )
17851801
17861802 # Parse channel multipliers and attention resolutions
1787- channel_multipliers = tuple (int (x ) for x in args .channel_multipliers .split ("," ))
1803+ channel_multipliers = tuple (int (x )
1804+ for x in args .channel_multipliers .split ("," ))
17881805 use_attention_at = tuple (int (x ) for x in args .use_attention_at .split ("," ))
17891806
17901807 # Create config
@@ -1807,7 +1824,8 @@ def main():
18071824 config .validate ()
18081825
18091826 print (f"\n VAE Configuration:" )
1810- print (f" Image size: { config .img_size } x{ config .img_size } x{ config .img_channels } " )
1827+ print (
1828+ f" Image size: { config .img_size } x{ config .img_size } x{ config .img_channels } " )
18111829 print (
18121830 f" Latent size: { config .latent_size } x{ config .latent_size } x{ config .latent_channels } "
18131831 )
@@ -1825,7 +1843,8 @@ def main():
18251843
18261844 # Count parameters
18271845 num_params = sum (p .numel () for p in model .parameters ())
1828- num_trainable = sum (p .numel () for p in model .parameters () if p .requires_grad )
1846+ num_trainable = sum (p .numel ()
1847+ for p in model .parameters () if p .requires_grad )
18291848 print (f"Model parameters: { num_params :,} ({ num_trainable :,} trainable)" )
18301849
18311850 # Create optimizer
@@ -1941,10 +1960,13 @@ def main():
19411960 writer .add_scalar (
19421961 "epoch/train_recon_loss" , train_metrics ["recon_loss" ], epoch
19431962 )
1944- writer .add_scalar ("epoch/train_kl_loss" , train_metrics ["kl_loss" ], epoch )
1963+ writer .add_scalar ("epoch/train_kl_loss" ,
1964+ train_metrics ["kl_loss" ], epoch )
19451965 writer .add_scalar ("epoch/val_loss" , val_metrics ["loss" ], epoch )
1946- writer .add_scalar ("epoch/val_recon_loss" , val_metrics ["recon_loss" ], epoch )
1947- writer .add_scalar ("epoch/val_kl_loss" , val_metrics ["kl_loss" ], epoch )
1966+ writer .add_scalar ("epoch/val_recon_loss" ,
1967+ val_metrics ["recon_loss" ], epoch )
1968+ writer .add_scalar ("epoch/val_kl_loss" ,
1969+ val_metrics ["kl_loss" ], epoch )
19481970
19491971 # Print progress
19501972 print (
@@ -1989,7 +2011,8 @@ def main():
19892011 # Save best checkpoint
19902012 if val_metrics ["loss" ] < best_val_loss :
19912013 best_val_loss = val_metrics ["loss" ]
1992- save_path = os .path .join (args .checkpoint_dir , "checkpoint_best.pt" )
2014+ save_path = os .path .join (
2015+ args .checkpoint_dir , "checkpoint_best.pt" )
19932016 torch .save (checkpoint , save_path )
19942017 print (f"Saved best checkpoint: { save_path } " )
19952018
0 commit comments