-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheggroll_mnist_evolutionary.py
More file actions
334 lines (273 loc) · 12.9 KB
/
eggroll_mnist_evolutionary.py
File metadata and controls
334 lines (273 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.func import functional_call, vmap
import time
import matplotlib.pyplot as plt
# --- Configuration ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
POPULATION_SIZE = 10000 # Total generation size
PARALLEL_BATCH_SIZE = 1000 # How many models to run in parallel VRAM
SIGMA = 0.02 # Noise standard deviation
LEARNING_RATE = 0.001 # Base model update rate
RANK = 16 # Increased Rank for deeper model
WEIGHT_DECAY = 0.005 # L2 Regularization
EVAL_BATCH_SIZE = 1000 # Images to evaluate per generation (Training)
VAL_BATCH_SIZE = 1000 # Images to evaluate for validation
GENERATIONS = 5000
PLOT_INTERVAL = 10 # Update plot every N generations
HIDDEN_SIZES = [256, 128, 64] # Deeper architecture: 784 -> 256 -> 128 -> 64 -> 10
# Adaptive Sigma Configuration
SIGMA_DECAY = 0.95
PATIENCE = 10
MIN_SIGMA = 0.001
print(f"Configuration: Device={DEVICE}, Pop={POPULATION_SIZE}, Parallel={PARALLEL_BATCH_SIZE}, Rank={RANK}, Hidden={HIDDEN_SIZES}")
# --- 1. Define Model ---
class SimpleMLP(nn.Module):
def __init__(self, input_size=28*28, hidden_sizes=None, output_size=10):
super().__init__()
if hidden_sizes is None:
hidden_sizes = [128]
self.flatten = nn.Flatten()
layers = []
prev_size = input_size
for size in hidden_sizes:
layers.append(nn.Linear(prev_size, size))
layers.append(nn.LayerNorm(size)) # Add LayerNorm for stability
layers.append(nn.ReLU())
prev_size = size
layers.append(nn.Linear(prev_size, output_size))
self.net = nn.Sequential(*layers)
def forward(self, x):
x = self.flatten(x)
return self.net(x)
# --- 2. Low-Rank Noise Generation (Batched) ---
def generate_low_rank_perturbations(param_shapes, rank, num_models, generator=None):
"""
Generates low-rank noise components for a batch of models.
Returns a dictionary: {param_name: (U, V)} where:
- U: (num_models, rows, rank)
- V: (num_models, cols, rank)
For 1D biases, returns just the noise vector (num_models, dim).
"""
perturbations = {}
for name, shape in param_shapes.items():
if len(shape) == 2: # Weight matrix (Out, In)
rows, cols = shape
# Generate U and V matrices
# Scaling by 1/sqrt(rank) to maintain variance when multiplied
u = torch.randn(num_models, rows, rank, device=DEVICE, generator=generator)
v = torch.randn(num_models, cols, rank, device=DEVICE, generator=generator)
perturbations[name] = (u, v)
else: # Bias vector
perturbations[name] = torch.randn(num_models, *shape, device=DEVICE, generator=generator)
return perturbations
def reconstruct_params(base_params, perturbations, sigma, rank, mirror=False):
"""
Combines base parameters with low-rank perturbations.
If mirror=True, returns params with -noise instead of +noise.
"""
batched_params = {}
scale_factor = sigma / (rank ** 0.5)
sign = -1.0 if mirror else 1.0
for name, p in base_params.items():
if name in perturbations:
pert = perturbations[name]
if isinstance(pert, tuple):
u, v = pert
noise = torch.bmm(u, v.transpose(1, 2))
batched_params[name] = p.unsqueeze(0) + (noise * scale_factor * sign)
else: # Bias
batched_params[name] = p.unsqueeze(0) + (pert * sigma * sign)
else:
batched_params[name] = p.unsqueeze(0).expand(perturbations[next(iter(perturbations))].shape[0], *p.shape)
return batched_params
# --- 3. Functional Evaluation ---
def evaluate_batch(model, params, data, target):
"""
Evaluates a single model (defined by params) on a batch of data.
Returns negative cross entropy loss (higher is better for ES).
"""
# functional_call allows us to use the nn.Module structure with external parameters
outputs = functional_call(model, params, (data,))
# Calculate Cross Entropy Loss
# We use F.cross_entropy. Since we need a scalar 'fitness' to maximize,
# we return negative loss.
loss = F.cross_entropy(outputs, target)
return -loss
def evaluate_accuracy(model, params, data, target):
outputs = functional_call(model, params, (data,))
pred = outputs.argmax(dim=1)
return (pred == target).float().mean()
# --- Main Training Loop ---
def train():
# Setup Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
full_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
# Split into Train and Validation
train_size = 50000
val_size = 10000
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Loaders
train_loader = DataLoader(train_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)
# Initialize Base Model
model = SimpleMLP(hidden_sizes=HIDDEN_SIZES).to(DEVICE)
model.eval() # We don't need training mode (dropout/batchnorm tracking)
# Extract initial parameters
base_params = {k: v.detach() for k, v in model.named_parameters()}
param_shapes = {k: v.shape for k, v in base_params.items()}
# Pre-compile the vmap function for speed
batch_eval_fn = vmap(lambda p, d, t: evaluate_batch(model, p, d, t), in_dims=(0, None, None))
# Setup Plotting
plt.ion()
fig, ax = plt.subplots()
train_losses = []
val_losses = []
val_accuracies = []
epochs = []
line_train, = ax.plot([], [], label='Train Loss')
line_val, = ax.plot([], [], label='Val Loss')
ax.legend()
ax.set_xlabel('Generation')
ax.set_ylabel('Loss')
ax.set_title('Evolutionary Training Progress')
print("Starting Evolution...")
data_iterator = iter(train_loader)
# State for scheduler
current_sigma = SIGMA
current_lr = LEARNING_RATE # Track LR dynamically
best_val_loss = float('inf')
patience_counter = 0
for epoch in range(1, GENERATIONS + 1):
start_time = time.time()
# Get a fresh batch of training data
try:
inputs, targets = next(data_iterator)
except StopIteration:
data_iterator = iter(train_loader)
inputs, targets = next(data_iterator)
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
# Storage for results
all_fitness = []
chunk_seeds = []
# --- A. Rollout (Mirrored Sampling) ---
# We process (POPULATION_SIZE / 2) positive pairs, effectively doing POPULATION_SIZE evaluations
# Adjusted chunking for mirrored pairs
num_pairs = POPULATION_SIZE // 2
chunk_size_pairs = PARALLEL_BATCH_SIZE // 2
num_chunks = num_pairs // chunk_size_pairs
all_fitness_pos = []
all_fitness_neg = []
for i in range(num_chunks):
# 1. Generate Seeds
chunk_seed = torch.randint(0, 2**32, (1,)).item()
chunk_seeds.append(chunk_seed)
rng = torch.Generator(device=DEVICE)
rng.manual_seed(chunk_seed)
# 2. Generate Noise for HALF the batch size
perturbations = generate_low_rank_perturbations(
param_shapes, RANK, chunk_size_pairs, generator=rng
)
# 3. Evaluate Positive (+Noise)
params_pos = reconstruct_params(base_params, perturbations, SIGMA, RANK, mirror=False)
fitness_pos = batch_eval_fn(params_pos, inputs, targets)
all_fitness_pos.append(fitness_pos)
# 4. Evaluate Negative (-Noise)
params_neg = reconstruct_params(base_params, perturbations, current_sigma, RANK, mirror=True)
fitness_neg = batch_eval_fn(params_neg, inputs, targets)
all_fitness_neg.append(fitness_neg)
all_fitness_pos = torch.cat(all_fitness_pos)
all_fitness_neg = torch.cat(all_fitness_neg)
# Combine for stats
all_fitness = torch.cat([all_fitness_pos, all_fitness_neg])
# --- B. Update Step (Mirrored) ---
# Rank-based shaping often works better than raw z-score for ES
# But sticking to z-score for consistency with original code
# We normalize across the entire population (pos + neg)
rewards = all_fitness.float()
mean_reward = rewards.mean()
std_reward = rewards.std()
# Compute normalized scores for the PAIRS
# We want Update ~ (Reward_pos - Reward_neg) * Noise
norm_pos = (all_fitness_pos - mean_reward) / (std_reward + 1e-8)
norm_neg = (all_fitness_neg - mean_reward) / (std_reward + 1e-8)
# The effective "step" for the noise direction is (Pos - Neg)
# because Update = (Pos * Noise) + (Neg * -Noise) = (Pos - Neg) * Noise
step_scores = (norm_pos - norm_neg) / 2.0
total_updates = {k: torch.zeros_like(v) for k, v in base_params.items()}
for i in range(num_chunks):
rng = torch.Generator(device=DEVICE)
rng.manual_seed(chunk_seeds[i])
perturbations = generate_low_rank_perturbations(
param_shapes, RANK, chunk_size_pairs, generator=rng
)
chunk_steps = step_scores[i*chunk_size_pairs : (i+1)*chunk_size_pairs]
chunk_steps_view = chunk_steps.view(-1, 1, 1)
chunk_steps_bias_view = chunk_steps.view(-1, 1)
for name, pert in perturbations.items():
if isinstance(pert, tuple):
u, v = pert
# sum(step * U @ V.T)
weighted_noise = torch.bmm(u * chunk_steps_view, v.transpose(1, 2))
total_updates[name].add_(weighted_noise.sum(dim=0))
else:
weighted_noise = pert * chunk_steps_bias_view
total_updates[name].add_(weighted_noise.sum(dim=0))
# Apply Update with Weight Decay
# Update rule: theta = theta * (1 - wd) + (alpha / (pop * sigma)) * sum(step * noise)
# We track current_lr to prevent step size explosion when sigma drops
scale = current_lr / (num_pairs * current_sigma)
for name, p in base_params.items():
# Apply L2 Weight Decay
if 'weight' in name:
p.mul_(1.0 - (current_lr * WEIGHT_DECAY))
p.add_(total_updates[name], alpha=scale)
# --- Logging & Validation ---
# 1. Current Train Loss (Negative of mean fitness)
current_train_loss = -mean_reward.item() if std_reward == 0 else -all_fitness.mean().item()
# 2. Validation (Run on a batch of validation data)
# We check the BASE model on validation data
val_inputs, val_targets = next(iter(val_loader))
val_inputs, val_targets = val_inputs.to(DEVICE), val_targets.to(DEVICE)
with torch.no_grad():
# Using functional call with base_params
val_out = functional_call(model, base_params, (val_inputs,))
val_loss = F.cross_entropy(val_out, val_targets).item()
val_acc = (val_out.argmax(dim=1) == val_targets).float().mean().item()
# Adaptive Sigma Scheduler
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= PATIENCE and current_sigma > MIN_SIGMA:
current_sigma *= SIGMA_DECAY
current_lr *= SIGMA_DECAY # Decay LR together with Sigma to keep step size stable
patience_counter = 0
print(f" -> Decaying Sigma to {current_sigma:.5f} | LR to {current_lr:.5f}")
print(f"Gen {epoch}: Train Loss: {current_train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}% | "
f"Sigma: {current_sigma:.4f} | Time: {time.time()-start_time:.2f}s")
# Update Plot
if epoch % PLOT_INTERVAL == 0:
train_losses.append(current_train_loss)
val_losses.append(val_loss)
epochs.append(epoch)
line_train.set_data(epochs, train_losses)
line_val.set_data(epochs, val_losses)
ax.relim()
ax.autoscale_view()
fig.canvas.draw()
fig.canvas.flush_events()
# Save plot to file just in case
plt.savefig('evolution_progress.png')
plt.ioff()
plt.show()
if __name__ == "__main__":
train()