-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
124 lines (98 loc) · 3.53 KB
/
main.py
File metadata and controls
124 lines (98 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
""" Usage: python main.py -f input_file.yaml """
import numpy as np
import torch
import torch.nn as nn
import yaml
import argparse
from trainer import AdiabaticMathieu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("# Device:", device)
class FCN(nn.Module):
"""Defines a connected network"""
def __init__(self, N_INPUT, N_OUTPUT, N_HIDDEN, N_LAYERS):
super().__init__()
activation = nn.Tanh
self.fcs = nn.Sequential(
nn.Linear(N_INPUT, N_HIDDEN),
activation()
)
self.fch = nn.Sequential(*[
nn.Sequential(
nn.Linear(N_HIDDEN, N_HIDDEN),
activation()
) for _ in range(N_LAYERS)
])
self.fce = nn.Linear(N_HIDDEN, N_OUTPUT)
#---
def forward_aperiodic(self, x):
hidden = x
hidden = self.fcs(hidden)
hidden = self.fch(hidden)
output = self.fce(hidden)
return output
#---
def forward(self, x):
return self.forward_aperiodic(x)
#---
#---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Read values from a YAML config file.")
parser.add_argument("-f", "--file", help="Path to the YAML configuration file")
args = parser.parse_args()
input_file_path = args.file
with open(input_file_path, "r") as F:
run_info = yaml.safe_load(F)
N_INPUT = 1
N_OUTPUT = 1
N_LAYERS = run_info["N_LAYERS"] # 2
N_HIDDEN = run_info["N_HIDDEN"] # 256 # 128
torch.manual_seed(seed = 82354)
model = FCN(N_INPUT=N_INPUT, N_OUTPUT=N_OUTPUT, N_HIDDEN=N_HIDDEN,N_LAYERS=N_LAYERS).to(device)
print(f"# Creating model on {device}")
n = run_info["n"]
periodicity = run_info["periodicity"]
parity = run_info["parity"]
if periodicity=="periodic" and (n % 2)!=0:
raise("Invalid parameters")
if periodicity=="antiperiodic" and (n%2)==0:
raise("Invalid parameters")
if parity=="odd" and n==0:
raise("Invalid parameters")
#--
print(f"n={n}, parity={parity}, periodicity={periodicity}")
N_phys = run_info["N_phys"] # 500
x_left, x_right = -np.pi, np.pi
# np.random.seed(12345)
x_np = np.linspace(x_left, x_right, N_phys)
x_np_period = np.array([np.linspace(-np.pi,0, N_phys), np.linspace(0,np.pi, N_phys)])
# Convert to PyTorch tensors
x_phys = torch.tensor(x_np, dtype=torch.float32).unsqueeze(dim=-1).requires_grad_(True).to(device)
x_period = torch.tensor(x_np_period, dtype=torch.float32).unsqueeze(dim=-1).requires_grad_(True).to(device)
q_step = run_info["q_step"] # 0.05
q_start = run_info["q_start"]
q_end = run_info["q_end"]
N_epochs = run_info["N_epochs"]
patience = run_info["patience"]
lr_start = run_info["lr_start"]
q1 = q_start if q_start==0 else q_start-q_step
q_pairs=[(x, x+q_step) for x in np.arange(q1, q_end, q_step)]
if q_start==0.0:
q_pairs = [(0.0,0.0)] + q_pairs
# q_pairs=[(x, x+q_step) for x in np.arange(0.25, 10.0, q_step)]
# q_pairs=[(0.0, 0.0), (0.0, 0.25)]
AM = AdiabaticMathieu(
model=model,
n=n, periodicity=periodicity, parity=parity,
q_pairs=q_pairs,
x_phys=x_phys,
x_left=x_left,
x_right=x_right,
x_period=x_period,
N_epochs=N_epochs,
patience=patience,
lr_start=lr_start,
aux_dir=f"./models-n{n}_{parity}_{periodicity}/"
)
AM.training_loop()
#---
print("Done!")