-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_convergence.py
More file actions
69 lines (54 loc) · 2.38 KB
/
plot_convergence.py
File metadata and controls
69 lines (54 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
def plot_convergences():
# Setup paths
base_dir = Path(__file__).resolve().parent.parent
labels_dir = base_dir / "data" / "nsga_codex_labels"
plots_dir = base_dir / "data" / "convergence_plots"
plots_dir.mkdir(parents=True, exist_ok=True)
if not labels_dir.exists():
print(f"Error: {labels_dir} does not exist.")
return
# Find all .sol files
sol_files = list(labels_dir.rglob("*.sol"))
if not sol_files:
print("No .sol files found.")
return
print(f"Found {len(sol_files)} label files. Generating plots...")
for sol_file in tqdm(sol_files):
try:
with open(sol_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Check if convergence data exists
if "convergence_f0" not in data or "convergence_f1" not in data:
continue
f0_history = data["convergence_f0"]
f1_history = data["convergence_f1"]
generations = list(range(1, len(f0_history) + 1))
# Create plot
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
# Plot f0 (Total Distance)
ax1.plot(generations, f0_history, color='blue', linewidth=1.5)
ax1.set_ylabel('f0 (Total Distance)', color='blue')
ax1.tick_params(axis='y', labelcolor='blue')
ax1.grid(True, linestyle='--', alpha=0.7)
ax1.set_title(f"Convergence for {sol_file.stem}")
# Plot f1 (Random Cost)
ax2.plot(generations, f1_history, color='red', linewidth=1.5)
ax2.set_ylabel('f1 (Random Cost)', color='red')
ax2.tick_params(axis='y', labelcolor='red')
ax2.grid(True, linestyle='--', alpha=0.7)
ax2.set_xlabel('Generation')
plt.tight_layout()
# Save plot
plot_path = plots_dir / f"{sol_file.stem}.png"
plt.savefig(plot_path, dpi=150)
plt.close(fig)
except Exception as e:
print(f"Failed to process {sol_file.name}: {e}")
if __name__ == "__main__":
plot_convergences()
print("All plots generated successfully!")