33from itertools import count
44
55import numpy as np
6-
76import torch
8- import torch .nn as nn
9- import torch .nn .functional as F
10- import torch .optim as optim
117
12- from lagom import set_global_seeds
8+ from lagom import Logger
9+ from lagom .utils import pickle_dump
10+ from lagom .utils import set_global_seeds
11+
1312from lagom import BaseAlgorithm
14- from lagom import pickle_dump
1513
1614from lagom .envs import make_gym_env
1715from lagom .envs import make_vec_env
1816from lagom .envs import EnvSpec
1917from lagom .envs .vec_env import SerialVecEnv
20- from lagom .envs .vec_env import ParallelVecEnv
2118from lagom .envs .vec_env import VecStandardize
2219
23- from lagom .core .policies import CategoricalPolicy
24- from lagom .core .policies import GaussianPolicy
25-
2620from lagom .runner import TrajectoryRunner
2721from lagom .runner import SegmentRunner
2822
29- from lagom .agents import A2CAgent
30-
23+ from model import Agent
3124from engine import Engine
32- from policy import Network
33- from policy import LSTM
3425
3526
3627class Algorithm (BaseAlgorithm ):
37- def __call__ (self , config , seed , device_str ):
28+ def __call__ (self , config , seed , device ):
3829 set_global_seeds (seed )
39- device = torch .device (device_str )
4030 logdir = Path (config ['log.dir' ]) / str (config ['ID' ]) / str (seed )
41-
42- # Environment related
31+
4332 env = make_vec_env (vec_env_class = SerialVecEnv ,
4433 make_env = make_gym_env ,
4534 env_id = config ['env.id' ],
4635 num_env = config ['train.N' ], # batched environment
47- init_seed = seed ,
48- rolling = True )
36+ init_seed = seed )
4937 eval_env = make_vec_env (vec_env_class = SerialVecEnv ,
5038 make_env = make_gym_env ,
5139 env_id = config ['env.id' ],
5240 num_env = config ['eval.N' ],
53- init_seed = seed ,
54- rolling = False )
41+ init_seed = seed )
5542 if config ['env.standardize' ]: # running averages of observation and reward
5643 env = VecStandardize (venv = env ,
5744 use_obs = True ,
@@ -60,102 +47,37 @@ def __call__(self, config, seed, device_str):
6047 clip_reward = 10. ,
6148 gamma = 0.99 ,
6249 eps = 1e-8 )
63- eval_env = VecStandardize (venv = eval_env , # remember to synchronize running averages during evaluation !!!
50+ eval_env = VecStandardize (venv = eval_env ,
6451 use_obs = True ,
6552 use_reward = False , # do not process rewards, no training
6653 clip_obs = env .clip_obs ,
6754 clip_reward = env .clip_reward ,
6855 gamma = env .gamma ,
6956 eps = env .eps ,
70- constant_obs_mean = env .obs_runningavg .mu , # use current running average as constant
57+ constant_obs_mean = env .obs_runningavg .mu ,
7158 constant_obs_std = env .obs_runningavg .sigma )
7259 env_spec = EnvSpec (env )
60+
61+ agent = Agent (config , env_spec , device )
7362
74- # Network and policy
75- if config ['network.recurrent' ]:
76- network = LSTM (config = config , device = device , env_spec = env_spec )
77- else :
78- network = Network (config = config , device = device , env_spec = env_spec )
79- if env_spec .control_type == 'Discrete' :
80- policy = CategoricalPolicy (config = config ,
81- network = network ,
82- env_spec = env_spec ,
83- device = device ,
84- learn_V = True )
85- elif env_spec .control_type == 'Continuous' :
86- policy = GaussianPolicy (config = config ,
87- network = network ,
88- env_spec = env_spec ,
89- device = device ,
90- learn_V = True ,
91- min_std = config ['agent.min_std' ],
92- std_style = config ['agent.std_style' ],
93- constant_std = config ['agent.constant_std' ],
94- std_state_dependent = config ['agent.std_state_dependent' ],
95- init_std = config ['agent.init_std' ])
96-
97- # Optimizer and learning rate scheduler
98- optimizer = optim .Adam (policy .network .parameters (), lr = config ['algo.lr' ])
99- if config ['algo.use_lr_scheduler' ]:
100- if 'train.iter' in config : # iteration-based
101- max_epoch = config ['train.iter' ]
102- elif 'train.timestep' in config : # timestep-based
103- max_epoch = config ['train.timestep' ] + 1 # avoid zero lr in final iteration
104- lambda_f = lambda epoch : 1 - epoch / max_epoch
105- lr_scheduler = optim .lr_scheduler .LambdaLR (optimizer , lr_lambda = lambda_f )
106-
107- # Agent
108- kwargs = {'device' : device }
109- if config ['algo.use_lr_scheduler' ]:
110- kwargs ['lr_scheduler' ] = lr_scheduler
111- agent = A2CAgent (config = config ,
112- policy = policy ,
113- optimizer = optimizer ,
114- ** kwargs )
115-
116- # Runner
117- runner = SegmentRunner (agent = agent ,
118- env = env ,
119- gamma = config ['algo.gamma' ])
120- eval_runner = TrajectoryRunner (agent = agent ,
121- env = eval_env ,
122- gamma = 1.0 )
63+ runner = SegmentRunner (config , agent , env )
64+ eval_runner = TrajectoryRunner (config , agent , eval_env )
12365
124- # Engine
125- engine = Engine (agent = agent ,
126- runner = runner ,
127- config = config ,
128- eval_runner = eval_runner )
66+ engine = Engine (agent , runner , config , eval_runner = eval_runner )
12967
130- # Training and evaluation
13168 train_logs = []
13269 eval_logs = []
133-
134- if config ['network.recurrent' ]:
135- rnn_states_buffer = agent .policy .rnn_states # for SegmentRunner
136-
13770 for i in count ():
13871 if 'train.iter' in config and i >= config ['train.iter' ]: # enough iterations
13972 break
14073 elif 'train.timestep' in config and agent .total_T >= config ['train.timestep' ]: # enough timesteps
14174 break
14275
143- if config ['network.recurrent' ]:
144- if isinstance (rnn_states_buffer , list ): # LSTM: [h, c]
145- rnn_states_buffer = [buf .detach () for buf in rnn_states_buffer ]
146- else :
147- rnn_states_buffer = rnn_states_buffer .detach ()
148- agent .policy .rnn_states = rnn_states_buffer
149-
150- train_output = engine .train (n = i )
76+ train_output = engine .train (i )
15177
152- # Logging
15378 if i == 0 or (i + 1 ) % config ['log.record_interval' ] == 0 or (i + 1 ) % config ['log.print_interval' ] == 0 :
15479 train_log = engine .log_train (train_output )
15580
156- if config ['network.recurrent' ]:
157- rnn_states_buffer = agent .policy .rnn_states # for SegmentRunner
158-
15981 with torch .no_grad (): # disable grad, save memory
16082 eval_output = engine .eval (n = i )
16183 eval_log = engine .log_eval (eval_output )
@@ -164,7 +86,6 @@ def __call__(self, config, seed, device_str):
16486 train_logs .append (train_log )
16587 eval_logs .append (eval_log )
16688
167- # Save all loggings
16889 pickle_dump (obj = train_logs , f = logdir / 'train_logs' , ext = '.pkl' )
16990 pickle_dump (obj = eval_logs , f = logdir / 'eval_logs' , ext = '.pkl' )
17091
0 commit comments