Skip to content

Commit f9f0273

Browse files
committed
tf implementation
1 parent 1ed4d3e commit f9f0273

4 files changed

Lines changed: 267 additions & 109 deletions

File tree

tensorboard.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22
# Run TensorBoard in Docker with the runs directory mounted
33

4-
LOGDIR="${1:-./runs_vae_tf}"
4+
LOGDIR="${1:-./runs_vae}"
55
PORT="${2:-6006}"
66

77
docker run --rm -it \

vae.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,15 @@ def __init__(
199199
# First norm + conv
200200
self.norm1 = get_norm_layer(in_channels, norm_type, norm_num_groups)
201201
self.act1 = get_activation(activation)
202-
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
202+
self.conv1 = nn.Conv2d(in_channels, out_channels,
203+
kernel_size=3, padding=1)
203204

204205
# Second norm + conv
205206
self.norm2 = get_norm_layer(out_channels, norm_type, norm_num_groups)
206207
self.act2 = get_activation(activation)
207208
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
208-
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
209+
self.conv2 = nn.Conv2d(out_channels, out_channels,
210+
kernel_size=3, padding=1)
209211

210212
# Skip connection (1x1 conv if channels differ)
211213
if in_channels != out_channels:
@@ -301,7 +303,8 @@ class Downsample(nn.Module):
301303

302304
def __init__(self, channels: int):
303305
super().__init__()
304-
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
306+
self.conv = nn.Conv2d(channels, channels,
307+
kernel_size=3, stride=2, padding=1)
305308

306309
def forward(self, x: torch.Tensor) -> torch.Tensor:
307310
return self.conv(x)
@@ -369,7 +372,8 @@ def __init__(
369372
self.use_attention_at = set(use_attention_at)
370373

371374
# Initial convolution
372-
self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
375+
self.conv_in = nn.Conv2d(
376+
in_channels, base_channels, kernel_size=3, padding=1)
373377

374378
# Downsampling stages
375379
self.stages = nn.ModuleList()
@@ -540,7 +544,8 @@ def __init__(
540544

541545
# Initial conv from latent
542546
first_ch = base_channels * channel_multipliers_rev[0]
543-
self.conv_in = nn.Conv2d(latent_channels, first_ch, kernel_size=3, padding=1)
547+
self.conv_in = nn.Conv2d(
548+
latent_channels, first_ch, kernel_size=3, padding=1)
544549

545550
# Store for later
546551
self._channel_multipliers_rev = channel_multipliers_rev
@@ -594,7 +599,8 @@ def __init__(
594599
final_ch = base_channels * channel_multipliers_rev[-1]
595600
self.norm_out = get_norm_layer(final_ch, norm_type, norm_num_groups)
596601
self.act_out = get_activation(activation)
597-
self.conv_out = nn.Conv2d(final_ch, out_channels, kernel_size=3, padding=1)
602+
self.conv_out = nn.Conv2d(
603+
final_ch, out_channels, kernel_size=3, padding=1)
598604

599605
# Create attention modules
600606
self._create_attention_modules()
@@ -825,7 +831,8 @@ def forward(
825831
x_recon = torch.clamp(x_recon, min=-1.0, max=1.0)
826832

827833
# Compute losses
828-
recon_loss = reconstruction_loss(x, x_recon, loss_type=self.recon_loss_type)
834+
recon_loss = reconstruction_loss(
835+
x, x_recon, loss_type=self.recon_loss_type)
829836
kl_loss = kl_divergence(mu, logvar)
830837

831838
# Clamp individual losses to prevent extreme values
@@ -1142,7 +1149,8 @@ def train_epoch(
11421149
if max_grad_norm is not None:
11431150
# Unscareparameterizele once before clipping (official pattern)
11441151
scaler.unscale_(optimizer)
1145-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
1152+
torch.nn.utils.clip_grad_norm_(
1153+
model.parameters(), max_grad_norm)
11461154

11471155
# internally checks for NaN/Inf grads and skips update if needed
11481156
scaler.step(optimizer)
@@ -1159,13 +1167,15 @@ def train_epoch(
11591167
loss.backward()
11601168

11611169
if has_nonfinite_gradients(model):
1162-
print(f"Skipping batch {batch_idx} due to non-finite gradients")
1170+
print(
1171+
f"Skipping batch {batch_idx} due to non-finite gradients")
11631172
optimizer.zero_grad(set_to_none=True)
11641173
continue
11651174

11661175
# Gradient clipping
11671176
if max_grad_norm is not None:
1168-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
1177+
torch.nn.utils.clip_grad_norm_(
1178+
model.parameters(), max_grad_norm)
11691179

11701180
optimizer.step()
11711181

@@ -1181,18 +1191,21 @@ def train_epoch(
11811191
writer.add_scalar(
11821192
"train/recon_loss", outputs["recon_loss"].item(), global_step
11831193
)
1184-
writer.add_scalar("train/kl_loss", outputs["kl_loss"].item(), global_step)
1194+
writer.add_scalar(
1195+
"train/kl_loss", outputs["kl_loss"].item(), global_step)
11851196
writer.add_scalar("train/kl_weight", kl_weight, global_step)
11861197
# Log mu and logvar histograms for diagnosing posterior collapse
11871198
if "mu" in outputs and "logvar" in outputs:
1188-
writer.add_histogram("train/mu", outputs["mu"].detach(), global_step)
1199+
writer.add_histogram(
1200+
"train/mu", outputs["mu"].detach(), global_step)
11891201
writer.add_histogram(
11901202
"train/logvar", outputs["logvar"].detach(), global_step
11911203
)
11921204

11931205
# Image logging
11941206
if writer is not None and global_step % image_log_interval == 0:
1195-
log_images(writer, x, outputs["x_recon"], global_step, prefix="train")
1207+
log_images(writer, x, outputs["x_recon"],
1208+
global_step, prefix="train")
11961209

11971210
global_step += 1
11981211

@@ -1527,7 +1540,8 @@ def _extract_random_tile(
15271540
# Some TIF backgrounds render as black (0,0,0) rather than transparent
15281541
arr = np.array(img)
15291542
near_black_mask = (
1530-
(arr[:, :, 0] < 4) & (arr[:, :, 1] < 4) & (arr[:, :, 2] < 4)
1543+
(arr[:, :, 0] < 4) & (
1544+
arr[:, :, 1] < 4) & (arr[:, :, 2] < 4)
15311545
)
15321546
arr[near_black_mask] = [255, 255, 255]
15331547
img = Image.fromarray(arr)
@@ -1701,7 +1715,8 @@ def parse_args() -> argparse.Namespace:
17011715
"--epochs", type=int, default=100, help="Number of training epochs"
17021716
)
17031717
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
1704-
parser.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
1718+
parser.add_argument("--weight-decay", type=float,
1719+
default=0.01, help="Weight decay")
17051720
parser.add_argument(
17061721
"--beta", type=float, default=0.3, help="Maximum KL weight (beta-VAE)"
17071722
)
@@ -1754,7 +1769,8 @@ def parse_args() -> argparse.Namespace:
17541769
)
17551770

17561771
# Device
1757-
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
1772+
parser.add_argument("--device", type=str,
1773+
default="cuda", help="Device to use")
17581774

17591775
# Reproducibility
17601776
parser.add_argument("--seed", type=int, default=42, help="Random seed")
@@ -1784,7 +1800,8 @@ def main():
17841800
)
17851801

17861802
# Parse channel multipliers and attention resolutions
1787-
channel_multipliers = tuple(int(x) for x in args.channel_multipliers.split(","))
1803+
channel_multipliers = tuple(int(x)
1804+
for x in args.channel_multipliers.split(","))
17881805
use_attention_at = tuple(int(x) for x in args.use_attention_at.split(","))
17891806

17901807
# Create config
@@ -1807,7 +1824,8 @@ def main():
18071824
config.validate()
18081825

18091826
print(f"\nVAE Configuration:")
1810-
print(f" Image size: {config.img_size}x{config.img_size}x{config.img_channels}")
1827+
print(
1828+
f" Image size: {config.img_size}x{config.img_size}x{config.img_channels}")
18111829
print(
18121830
f" Latent size: {config.latent_size}x{config.latent_size}x{config.latent_channels}"
18131831
)
@@ -1825,7 +1843,8 @@ def main():
18251843

18261844
# Count parameters
18271845
num_params = sum(p.numel() for p in model.parameters())
1828-
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
1846+
num_trainable = sum(p.numel()
1847+
for p in model.parameters() if p.requires_grad)
18291848
print(f"Model parameters: {num_params:,} ({num_trainable:,} trainable)")
18301849

18311850
# Create optimizer
@@ -1941,10 +1960,13 @@ def main():
19411960
writer.add_scalar(
19421961
"epoch/train_recon_loss", train_metrics["recon_loss"], epoch
19431962
)
1944-
writer.add_scalar("epoch/train_kl_loss", train_metrics["kl_loss"], epoch)
1963+
writer.add_scalar("epoch/train_kl_loss",
1964+
train_metrics["kl_loss"], epoch)
19451965
writer.add_scalar("epoch/val_loss", val_metrics["loss"], epoch)
1946-
writer.add_scalar("epoch/val_recon_loss", val_metrics["recon_loss"], epoch)
1947-
writer.add_scalar("epoch/val_kl_loss", val_metrics["kl_loss"], epoch)
1966+
writer.add_scalar("epoch/val_recon_loss",
1967+
val_metrics["recon_loss"], epoch)
1968+
writer.add_scalar("epoch/val_kl_loss",
1969+
val_metrics["kl_loss"], epoch)
19481970

19491971
# Print progress
19501972
print(
@@ -1989,7 +2011,8 @@ def main():
19892011
# Save best checkpoint
19902012
if val_metrics["loss"] < best_val_loss:
19912013
best_val_loss = val_metrics["loss"]
1992-
save_path = os.path.join(args.checkpoint_dir, "checkpoint_best.pt")
2014+
save_path = os.path.join(
2015+
args.checkpoint_dir, "checkpoint_best.pt")
19932016
torch.save(checkpoint, save_path)
19942017
print(f"Saved best checkpoint: {save_path}")
19952018

0 commit comments

Comments
 (0)