Skip to content

Commit ed8e7e9

Browse files
authored
Merge pull request #7 from shahineb/rebuttal
rebuttal changes
2 parents 22305b6 + c339d53 commit ed8e7e9

File tree

42 files changed

+4400
-165
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4400
-165
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ cython_debug/
183183
*.eqx
184184
*.jpg
185185
*.png
186+
*.eps
186187

187188

188189
# VS Code

paper/access/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ class DataConfig:
3434
3535
Specifies dataset paths, climate model, experiments, and pattern scaling parameters.
3636
"""
37-
root_dir: str = "/orcd/data/raffaele/001/shahineb/cmip6/processed" # CMIP6 data directory
37+
root_dir: str = "/orcd/data/raffaele/001/shahineb/products/cmip6/processed" # CMIP6 data directory
3838
model_name: str = "ACCESS-ESM1-5" # Climate model to use
3939
train_experiments: List[str] = ("piControl", "historical", "ssp126", "ssp585") # Training experiments
40-
val_experiments: List[str] = ("ssp370",) # Validation experiments
40+
val_experiments: List[str] = ("1pctCO2",) # Validation experiments
4141
variables: List[str] = ("tas", "pr", "hurs", "sfcWind") # Climate variables
42-
val_time_slice: Tuple[str, str] = ("2080-01", "2100-12") # Time range for validation
42+
val_time_slice: Tuple[str, str] = (None, None) # Time range for validation
4343
pattern_scaling_path: str = os.path.join(CACHE_DIR, "β.npy") # Path to save/load pattern scaling coefficients
4444
norm_stats_path: str = os.path.join(CACHE_DIR, "μ_σ.npz") # Path to save/load normalization statistics
4545
in_memory: bool = True # Whether to load full dataset into memory
@@ -55,9 +55,9 @@ class TrainingConfig:
5555
Defines hyperparameters, logging intervals, and output paths.
5656
"""
5757
batch_size: int = 32 # Number of samples per batch
58-
learning_rate: float = 3e-4 # Adam optimizer learning rate
58+
learning_rate: float = 1e-4 # Adam optimizer learning rate
5959
ema_decay: float = 0.999 # Exponential moving average decay
60-
epochs: int = 10 # Number of training epochs
60+
epochs: int = 15 # Number of training epochs
6161
log_interval: int = 20 # Steps between metric logging
6262
queue_length: int = 30 # Length of sliding window for metrics
6363
sample_interval: int = 10000 # Steps between sample generation
@@ -90,7 +90,7 @@ class SamplingConfig:
9090
n_samples: int = 50 # Number of samples to generate per test point
9191
batch_size: int = 2 # Batch size for evaluation
9292
random_seed: int = 2100 # Seed for reproducibility
93-
output_dir: str = f"/orcd/data/raffaele/001/shahineb/jax-esm-emulation/paper/{EXPERIMENT_NAME}/outputs" # Output directory for inference
93+
output_dir: str = f"/orcd/data/raffaele/001/shahineb/emulated/climemu/paper/{EXPERIMENT_NAME}/outputs" # Output directory for inference
9494

9595

9696
@dataclass

paper/access/data.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,34 @@ def estimate_sigma_max(
196196
print(f"Loading σmax = {σmax} from {sigma_max_path}")
197197
return σmax
198198

199+
# Estimate the leading principal component
200+
dataset_size = len(dataset)
201+
subset_size = min(10000, dataset_size)
202+
key = jr.PRNGKey(42)
203+
indices = jr.permutation(key, dataset_size)[:subset_size].tolist()
204+
dataset_subset = Subset(dataset, indices)
205+
dummy_loader = DataLoader(dataset_subset, batch_size=10, collate_fn=numpy_collate)
206+
X = []
207+
for batch in tqdm(dummy_loader, desc=f"Loading {subset_size} samples"):
208+
X.append(utils.process_batch(batch, μ, σ)[:, :-ctx_size])
209+
X = jnp.concatenate(X)
210+
μX = X.mean(axis=0)
211+
Xc = X - μX
212+
wlat = jnp.cos(jnp.deg2rad(dataset.cmip6data.lat))
213+
G = jnp.einsum("nchw,h,mchw->nm", Xc, wlat, Xc)
214+
Σ2, U = jnp.linalg.eigh(G)
215+
u1 = U[:, -1]
216+
σ1 = jnp.sqrt(Σ2[-1])
217+
v1 = jnp.einsum("nchw,n->chw", Xc, u1) / σ1
218+
v1 = v1 * wlat[:, None]
219+
v1 = v1.ravel()
220+
199221
# Define search parameters
200222
σmax_low, σmax_high = search_interval
201223
max_split = 20
202224
n_montecarlo = 100
203225
max_montecarlo = 10000
204-
npool = 50000
226+
popsize = 8
205227
tgt_pow = 0.1
206228
tol = 0.001 + 1.96 * np.sqrt(tgt_pow * (1 - tgt_pow) / max_montecarlo)
207229
key = jr.PRNGKey(seed)
@@ -218,7 +240,9 @@ def estimate_sigma_max(
218240
σmax=σmax,
219241
α=alpha,
220242
n_montecarlo=n_montecarlo,
221-
npool=npool,
243+
popsize=popsize,
244+
v1=v1,
245+
μX=μX,
222246
μ=μ,
223247
σ=σ,
224248
ctx_size=ctx_size,

paper/access/plots/piControl/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from dask.diagnostics import ProgressBar
99

1010
# Module-level path configuration
11-
CLIMATOLOGY_ROOT = "/home/shahineb/data/cmip6/processed"
11+
CLIMATOLOGY_ROOT = "/home/shahineb/data/products/cmip6/processed"
1212
CLIMATOLOGY_MODEL = 'ACCESS-ESM1-5'
1313
CLIMATOLOGY_MEMBER = 'r1i1p1f1'
14-
RAW_CMIP6_ROOT = "/orcd/home/002/shahineb/data/cmip6/raw"
14+
RAW_CMIP6_ROOT = "/orcd/home/002/shahineb/data/products/cmip6/raw"
1515

1616

1717
def groupby_month_and_year(ds):

paper/access/trainer.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,31 @@ class TrainingState:
3636
epoch: int = 0
3737

3838

39+
def log_training_metrics(config, state, loss, grad):
40+
wandb.log({"Train Loss": loss, "Gradient norm": grad}, step=state.step)
41+
42+
43+
def log_validation_metrics(config, state, val_loader, μ, σ, schedule, χval):
44+
# Validation phase
45+
val_loss = 0
46+
n_val_steps = len(val_loader)
47+
with tqdm(total=n_val_steps, desc="Evaluation") as pbar:
48+
for batch_idx, batch in enumerate(val_loader):
49+
# Process batch and compute validation loss
50+
x = utils.process_batch(batch, μ, σ)
51+
_, χval = jr.split(χval)
52+
val_value = denoising_batch_loss(
53+
state.ema_model, config.model.context_channels, schedule, x, χval
54+
)
55+
val_loss += val_value.item()
56+
# Update progress bar
57+
pbar.set_description(f"Epoch {state.epoch + 1} | Val {round(val_loss / (batch_idx + 1), 2)}")
58+
pbar.update(1)
59+
# Log validation loss
60+
wandb.log({"Validation Loss": val_loss / n_val_steps}, step=state.step)
61+
62+
63+
3964
def train_epoch(
4065
state: TrainingState,
4166
train_loader: DataLoader,
@@ -104,38 +129,21 @@ def train_epoch(
104129
pbar.set_description(f"Epoch {state.epoch + 1} | Loss {round(running_loss, 2)}")
105130
_ = pbar.update(1)
106131

107-
# Log metrics at specified intervals
108-
if (state.step + 1) % config.training.log_interval == 0:
109-
wandb.log({
110-
"Train Loss": running_loss,
111-
"Gradient norm": running_grad
112-
}, step=state.step)
113-
114-
# Generate and log samples at specified intervals
115-
if (state.step + 1) % config.training.sample_interval == 0:
132+
# Log training metrics at specified intervals
133+
if (state.step + 1) % config.training.log_interval == 0 or (state.step + 1) & state.step == 0:
134+
log_training_metrics(config, state, running_loss, running_grad)
135+
136+
# log validation metrics + samples at specified intervals
137+
if (state.step + 1) % config.training.sample_interval == 0 or (state.step + 1) & state.step == 0:
138+
log_validation_metrics(config, state, val_loader, μ, σ, schedule, χval)
139+
_, χval = jr.split(χval)
140+
116141
# Generate samples from current model
117142
pred_samples = log_sampler(model=ema_model, key=χtrain)
118143

119144
# Log samples and metrics to wandb
120145
utils.log_samples(pred_samples, log_target_data, config.data.variables, state.step)
121146

122-
# Validation phase
123-
val_loss = 0
124-
with tqdm(total=n_val_steps, desc="Evaluation") as pbar:
125-
for batch_idx, batch in enumerate(val_loader):
126-
# Process batch and compute validation loss
127-
x = utils.process_batch(batch, μ, σ)
128-
val_value = denoising_batch_loss(
129-
state.ema_model, config.model.context_channels, schedule, x, χval
130-
)
131-
val_loss += val_value.item()
132-
# Update progress bar
133-
pbar.set_description(f"Epoch {state.epoch + 1} | Val {round(val_loss / (batch_idx + 1), 2)}")
134-
pbar.update(1)
135-
136-
# Log validation loss
137-
wandb.log({"Validation Loss": val_loss / n_val_steps}, step=state.step)
138-
139147
# Checkpoint weights
140148
if (state.epoch + 1) % config.training.checkpoint_interval == 0:
141149
eqx.tree_serialise_leaves(config.training.checkpoint_filename, state.ema_model)

paper/access/utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,12 @@ def process_batch(batch: Tuple, μ: jnp.ndarray, σ: jnp.ndarray) -> jnp.ndarray
8282
################################################################################
8383

8484

85-
def estimate_power(dataset, σmax, α, n_montecarlo, npool, μ, σ, ctx_size, key):
85+
def estimate_power(dataset, σmax, α, n_montecarlo, popsize, v1, μX, μ, σ, ctx_size, key):
8686
# Initialize dataloader on subset of size n_iter
8787
dataset_size = len(dataset)
88-
indices = jr.permutation(key, dataset_size)[:n_montecarlo].tolist()
89-
rejections = 0
88+
indices = jr.permutation(key, dataset_size)[:n_montecarlo * popsize].tolist()
9089
dataset_subset = Subset(dataset, indices)
91-
dummy_loader = DataLoader(dataset_subset, batch_size=1, shuffle=True, collate_fn=numpy_collate)
90+
dummy_loader = DataLoader(dataset_subset, batch_size=popsize, shuffle=True, collate_fn=numpy_collate)
9291

9392
# Estimate power on this subset
9493
rejections = 0
@@ -97,12 +96,14 @@ def estimate_power(dataset, σmax, α, n_montecarlo, npool, μ, σ, ctx_size, ke
9796
for batch in dummy_loader:
9897
# Draw sample and flatten
9998
x = process_batch(batch, μ, σ)[:, :-ctx_size]
100-
x0 = np.array(x.ravel())
101-
x0 = np.random.choice(x0, size=npool, replace=False)
99+
x0 = np.array(x - μX).reshape(popsize, -1)
100+
101+
# Add noise and project against lead PC
102+
xn = x0 + σmax * np.random.randn(*x0.shape)
103+
xnTv1 = xn @ v1
102104

103-
# Add noise and perform test
104-
xn = x0 + σmax * np.random.randn(len(x0))
105-
_, pvalue = kstest(xn, "norm", args=(0, σmax))
105+
# Perform test
106+
_, pvalue = kstest(xnTv1, "norm", args=(0, σmax))
106107
rejections += (pvalue < α)
107108
_ = pbar.update(1)
108109
return rejections / n_montecarlo

paper/intermodel/plot_losses.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# %%
2+
import os
3+
import numpy as np
4+
import pandas as pd
5+
import matplotlib.pyplot as plt
6+
7+
8+
9+
# %%
10+
root = "/Users/shahine/Documents/Research/MIT/code/repos/climemu-private/paper/intermodel/wandb/"
11+
train_losses = {"MIROC6": pd.read_csv(os.path.join(root, 'miroc_train.csv')),
12+
"MPI-ESM1-2-LR": pd.read_csv(os.path.join(root, 'mpi_train.csv')),
13+
"ACCESS-ESM1-5": pd.read_csv(os.path.join(root, 'access_train.csv'))}
14+
15+
val_losses = {"MIROC6": pd.read_csv(os.path.join(root, 'miroc_val.csv')),
16+
"MPI-ESM1-2-LR": pd.read_csv(os.path.join(root, 'mpi_val.csv')),
17+
"ACCESS-ESM1-5": pd.read_csv(os.path.join(root, 'access_val.csv'))}
18+
19+
grad_df = {"MIROC6": pd.read_csv(os.path.join(root, 'miroc_grad.csv')),
20+
"MPI-ESM1-2-LR": pd.read_csv(os.path.join(root, 'mpi_grad.csv')),
21+
"ACCESS-ESM1-5": pd.read_csv(os.path.join(root, 'access_grad.csv'))}
22+
23+
24+
25+
# %%
26+
fig, ax = plt.subplots(1, 2, figsize=(15, 5), gridspec_kw={'width_ratios': [1.5, 1]})
27+
28+
train_df = train_losses['MIROC6']
29+
val_df = val_losses['MIROC6']
30+
ax[0].plot(train_df['Step'], train_df.iloc[:, 1], label='MIROC6 Training Loss', alpha=0.5, color='#0072B2', zorder=0)
31+
ax[0].plot(val_df['Step'], val_df.iloc[:, 1], ls='--', label='MIROC6 Validation Loss', color='#0072B2')
32+
33+
train_df = train_losses['MPI-ESM1-2-LR']
34+
val_df = val_losses['MPI-ESM1-2-LR']
35+
ax[0].plot(train_df['Step'], train_df.iloc[:, 1], label='MPI-ESM1-2-LR Training Loss', alpha=0.5, color='#E69F00', zorder=0)
36+
ax[0].plot(val_df['Step'], val_df.iloc[:, 1], ls='--', label='MPI-ESM1-2-LR Validation Loss', color='#E69F00')
37+
38+
train_df = train_losses['ACCESS-ESM1-5']
39+
val_df = val_losses['ACCESS-ESM1-5']
40+
ax[0].plot(train_df['Step'], train_df.iloc[:, 1], label='ACCESS-ESM1-5 Training Loss', alpha=0.5, color='#CC79A7', zorder=0)
41+
ax[0].plot(val_df['Step'], val_df.iloc[:, 1], ls='--', label='ACCESS-ESM1-5 Validation Loss', color='#CC79A7')
42+
43+
ax[0].legend(frameon=False, fontsize=12, loc='lower left')
44+
ax[0].set_yscale('log')
45+
ax[0].set_xscale('log')
46+
ax[0].set_xlabel("Training Steps", fontsize=14)
47+
ax[0].set_ylabel("Loss", fontsize=14)
48+
ax[0].margins(0.01)
49+
50+
grad_df_miroc = grad_df['MIROC6']
51+
ax[1].plot(grad_df_miroc['Step'], grad_df_miroc.iloc[:, 1], label='MIROC6 Gradient Norm', alpha=0.5, color='#0072B2')
52+
53+
grad_df_mpi = grad_df['MPI-ESM1-2-LR']
54+
ax[1].plot(grad_df_mpi['Step'], grad_df_mpi.iloc[:, 1], label='MPI-ESM1-2-LR Gradient Norm', alpha=0.5, color='#E69F00')
55+
56+
grad_df_access = grad_df['ACCESS-ESM1-5']
57+
ax[1].plot(grad_df_access['Step'], grad_df_access.iloc[:, 1], label='ACCESS-ESM1-5 Gradient Norm', alpha=0.5, color='#CC79A7')
58+
59+
ax[1].legend(frameon=False, fontsize=12)
60+
ax[1].set_yscale('log')
61+
ax[1].set_xscale('log')
62+
ax[1].set_xlabel("Training Steps", fontsize=14)
63+
ax[1].set_ylabel("Gradient Norm", fontsize=14)
64+
ax[1].margins(0.01)
65+
66+
plt.savefig("losses.jpg", dpi=300, bbox_inches="tight")
67+
68+
# %%

paper/miroc/config.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44

55

6-
76
EXPERIMENT_DIR = os.path.dirname(__file__)
87
CACHE_DIR = os.path.join(EXPERIMENT_DIR, "cache")
98
EXPERIMENT_NAME = os.path.basename(EXPERIMENT_DIR)
@@ -34,18 +33,18 @@ class DataConfig:
3433
3534
Specifies dataset paths, climate model, experiments, and pattern scaling parameters.
3635
"""
37-
root_dir: str = "/orcd/data/raffaele/001/shahineb/cmip6/processed" # CMIP6 data directory
36+
root_dir: str = "/orcd/data/raffaele/001/shahineb/products/cmip6/processed" # CMIP6 data directory
3837
model_name: str = "MIROC6" # Climate model to use
3938
train_experiments: List[str] = ("piControl", "historical", "ssp126", "ssp585") # Training experiments
40-
val_experiments: List[str] = ("ssp370",) # Validation experiments
39+
val_experiments: List[str] = ("1pctCO2",) # Validation experiments
4140
variables: List[str] = ("tas", "pr", "hurs", "sfcWind") # Climate variables
42-
val_time_slice: Tuple[str, str] = ("2080-01", "2100-12") # Time range for validation
41+
val_time_slice: Tuple[str, str] = (None, None) # Time range for validation
4342
pattern_scaling_path: str = os.path.join(CACHE_DIR, "β.npy") # Path to save/load pattern scaling coefficients
4443
norm_stats_path: str = os.path.join(CACHE_DIR, "μ_σ.npz") # Path to save/load normalization statistics
4544
in_memory: bool = True # Whether to load full dataset into memory
4645
norm_max_samples: int = 10000 # Maximum number of samples to use for normalization
4746
sigma_max_path: str = os.path.join(CACHE_DIR, "σmax.npy") # Path to save/load σmax
48-
sigma_max_search_interval: List[int] = (0, 200) # Interval in which we search for sigma max
47+
sigma_max_search_interval: List[int] = (0, 400) # Interval in which we search for sigma max
4948

5049

5150
@dataclass
@@ -55,9 +54,9 @@ class TrainingConfig:
5554
Defines hyperparameters, logging intervals, and output paths.
5655
"""
5756
batch_size: int = 32 # Number of samples per batch
58-
learning_rate: float = 3e-4 # Adam optimizer learning rate
57+
learning_rate: float = 1e-4 # Adam optimizer learning rate
5958
ema_decay: float = 0.999 # Exponential moving average decay
60-
epochs: int = 10 # Number of training epochs
59+
epochs: int = 15 # Number of training epochs
6160
log_interval: int = 20 # Steps between metric logging
6261
queue_length: int = 30 # Length of sliding window for metrics
6362
sample_interval: int = 10000 # Steps between sample generation
@@ -90,7 +89,7 @@ class SamplingConfig:
9089
n_samples: int = 50 # Number of samples to generate per test point
9190
batch_size: int = 2 # Batch size for evaluation
9291
random_seed: int = 2100 # Seed for reproducibility
93-
output_dir: str = f"/orcd/data/raffaele/001/shahineb/jax-esm-emulation/paper/{EXPERIMENT_NAME}/outputs" # Output directory for inference
92+
output_dir: str = f"/orcd/data/raffaele/001/shahineb/emulated/climemu/paper/{EXPERIMENT_NAME}/outputs" # Output directory for inference
9493

9594

9695
@dataclass

paper/miroc/data.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,34 @@ def estimate_sigma_max(
196196
print(f"Loading σmax = {σmax} from {sigma_max_path}")
197197
return σmax
198198

199+
# Estimate the leading principal component
200+
dataset_size = len(dataset)
201+
subset_size = min(10000, dataset_size)
202+
key = jr.PRNGKey(42)
203+
indices = jr.permutation(key, dataset_size)[:subset_size].tolist()
204+
dataset_subset = Subset(dataset, indices)
205+
dummy_loader = DataLoader(dataset_subset, batch_size=10, collate_fn=numpy_collate)
206+
X = []
207+
for batch in tqdm(dummy_loader, desc=f"Loading {subset_size} samples"):
208+
X.append(utils.process_batch(batch, μ, σ)[:, :-ctx_size])
209+
X = jnp.concatenate(X)
210+
μX = X.mean(axis=0)
211+
Xc = X - μX
212+
wlat = jnp.cos(jnp.deg2rad(dataset.cmip6data.lat))
213+
G = jnp.einsum("nchw,h,mchw->nm", Xc, wlat, Xc)
214+
Σ2, U = jnp.linalg.eigh(G)
215+
u1 = U[:, -1]
216+
σ1 = jnp.sqrt(Σ2[-1])
217+
v1 = jnp.einsum("nchw,n->chw", Xc, u1) / σ1
218+
v1 = v1 * wlat[:, None]
219+
v1 = v1.ravel()
220+
199221
# Define search parameters
200222
σmax_low, σmax_high = search_interval
201223
max_split = 20
202224
n_montecarlo = 100
203225
max_montecarlo = 10000
204-
npool = 50000
226+
popsize = 8
205227
tgt_pow = 0.1
206228
tol = 0.001 + 1.96 * np.sqrt(tgt_pow * (1 - tgt_pow) / max_montecarlo)
207229
key = jr.PRNGKey(seed)
@@ -218,7 +240,9 @@ def estimate_sigma_max(
218240
σmax=σmax,
219241
α=alpha,
220242
n_montecarlo=n_montecarlo,
221-
npool=npool,
243+
popsize=popsize,
244+
v1=v1,
245+
μX=μX,
222246
μ=μ,
223247
σ=σ,
224248
ctx_size=ctx_size,
@@ -246,7 +270,6 @@ def estimate_sigma_max(
246270
if np.allclose(σmax_low, σmax_high, atol=1):
247271
break
248272

249-
250273
# Save and return
251274
if sigma_max_path:
252275
print(f"Saving σmax = {σmax} to {sigma_max_path}")

0 commit comments

Comments
 (0)