Skip to content

Commit 24cd620

Browse files
committed
Update
Former-commit-id: b73f9f094298c17cf8dbf4d921d666ed3df6b451 [formerly d0e5b23b985d3ede488cc0866ab16796bea317ba] Former-commit-id: 1e1166290a20adfe31d678533f7022141ecc7af0
1 parent 54c9f0d commit 24cd620

File tree

24 files changed

+1116
-758
lines changed

24 files changed

+1116
-758
lines changed

docs/source/history.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,20 @@ lagom.history: History
1515

1616
.. autoclass:: Segment
1717
:members:
18+
19+
Metrics
20+
----------------
21+
22+
.. currentmodule:: lagom.history.metrics
23+
24+
.. autofunction:: terminal_state_from_trajectory
25+
26+
.. autofunction:: terminal_state_from_segment
27+
28+
.. autofunction:: final_state_from_trajectory
29+
30+
.. autofunction:: final_state_from_segment
31+
32+
.. autofunction:: bootstrapped_returns_from_trajectory
33+
34+
.. autofunction:: bootstrapped_returns_from_segment

examples/policy_gradient/a2c/algo.py

Lines changed: 17 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,42 @@
33
from itertools import count
44

55
import numpy as np
6-
76
import torch
8-
import torch.nn as nn
9-
import torch.nn.functional as F
10-
import torch.optim as optim
117

12-
from lagom import set_global_seeds
8+
from lagom import Logger
9+
from lagom.utils import pickle_dump
10+
from lagom.utils import set_global_seeds
11+
1312
from lagom import BaseAlgorithm
14-
from lagom import pickle_dump
1513

1614
from lagom.envs import make_gym_env
1715
from lagom.envs import make_vec_env
1816
from lagom.envs import EnvSpec
1917
from lagom.envs.vec_env import SerialVecEnv
20-
from lagom.envs.vec_env import ParallelVecEnv
2118
from lagom.envs.vec_env import VecStandardize
2219

23-
from lagom.core.policies import CategoricalPolicy
24-
from lagom.core.policies import GaussianPolicy
25-
2620
from lagom.runner import TrajectoryRunner
2721
from lagom.runner import SegmentRunner
2822

29-
from lagom.agents import A2CAgent
30-
23+
from model import Agent
3124
from engine import Engine
32-
from policy import Network
33-
from policy import LSTM
3425

3526

3627
class Algorithm(BaseAlgorithm):
37-
def __call__(self, config, seed, device_str):
28+
def __call__(self, config, seed, device):
3829
set_global_seeds(seed)
39-
device = torch.device(device_str)
4030
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
41-
42-
# Environment related
31+
4332
env = make_vec_env(vec_env_class=SerialVecEnv,
4433
make_env=make_gym_env,
4534
env_id=config['env.id'],
4635
num_env=config['train.N'], # batched environment
47-
init_seed=seed,
48-
rolling=True)
36+
init_seed=seed)
4937
eval_env = make_vec_env(vec_env_class=SerialVecEnv,
5038
make_env=make_gym_env,
5139
env_id=config['env.id'],
5240
num_env=config['eval.N'],
53-
init_seed=seed,
54-
rolling=False)
41+
init_seed=seed)
5542
if config['env.standardize']: # running averages of observation and reward
5643
env = VecStandardize(venv=env,
5744
use_obs=True,
@@ -60,102 +47,37 @@ def __call__(self, config, seed, device_str):
6047
clip_reward=10.,
6148
gamma=0.99,
6249
eps=1e-8)
63-
eval_env = VecStandardize(venv=eval_env, # remember to synchronize running averages during evaluation !!!
50+
eval_env = VecStandardize(venv=eval_env,
6451
use_obs=True,
6552
use_reward=False, # do not process rewards, no training
6653
clip_obs=env.clip_obs,
6754
clip_reward=env.clip_reward,
6855
gamma=env.gamma,
6956
eps=env.eps,
70-
constant_obs_mean=env.obs_runningavg.mu, # use current running average as constant
57+
constant_obs_mean=env.obs_runningavg.mu,
7158
constant_obs_std=env.obs_runningavg.sigma)
7259
env_spec = EnvSpec(env)
60+
61+
agent = Agent(config, env_spec, device)
7362

74-
# Network and policy
75-
if config['network.recurrent']:
76-
network = LSTM(config=config, device=device, env_spec=env_spec)
77-
else:
78-
network = Network(config=config, device=device, env_spec=env_spec)
79-
if env_spec.control_type == 'Discrete':
80-
policy = CategoricalPolicy(config=config,
81-
network=network,
82-
env_spec=env_spec,
83-
device=device,
84-
learn_V=True)
85-
elif env_spec.control_type == 'Continuous':
86-
policy = GaussianPolicy(config=config,
87-
network=network,
88-
env_spec=env_spec,
89-
device=device,
90-
learn_V=True,
91-
min_std=config['agent.min_std'],
92-
std_style=config['agent.std_style'],
93-
constant_std=config['agent.constant_std'],
94-
std_state_dependent=config['agent.std_state_dependent'],
95-
init_std=config['agent.init_std'])
96-
97-
# Optimizer and learning rate scheduler
98-
optimizer = optim.Adam(policy.network.parameters(), lr=config['algo.lr'])
99-
if config['algo.use_lr_scheduler']:
100-
if 'train.iter' in config: # iteration-based
101-
max_epoch = config['train.iter']
102-
elif 'train.timestep' in config: # timestep-based
103-
max_epoch = config['train.timestep'] + 1 # avoid zero lr in final iteration
104-
lambda_f = lambda epoch: 1 - epoch/max_epoch
105-
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_f)
106-
107-
# Agent
108-
kwargs = {'device': device}
109-
if config['algo.use_lr_scheduler']:
110-
kwargs['lr_scheduler'] = lr_scheduler
111-
agent = A2CAgent(config=config,
112-
policy=policy,
113-
optimizer=optimizer,
114-
**kwargs)
115-
116-
# Runner
117-
runner = SegmentRunner(agent=agent,
118-
env=env,
119-
gamma=config['algo.gamma'])
120-
eval_runner = TrajectoryRunner(agent=agent,
121-
env=eval_env,
122-
gamma=1.0)
63+
runner = SegmentRunner(config, agent, env)
64+
eval_runner = TrajectoryRunner(config, agent, eval_env)
12365

124-
# Engine
125-
engine = Engine(agent=agent,
126-
runner=runner,
127-
config=config,
128-
eval_runner=eval_runner)
66+
engine = Engine(agent, runner, config, eval_runner=eval_runner)
12967

130-
# Training and evaluation
13168
train_logs = []
13269
eval_logs = []
133-
134-
if config['network.recurrent']:
135-
rnn_states_buffer = agent.policy.rnn_states # for SegmentRunner
136-
13770
for i in count():
13871
if 'train.iter' in config and i >= config['train.iter']: # enough iterations
13972
break
14073
elif 'train.timestep' in config and agent.total_T >= config['train.timestep']: # enough timesteps
14174
break
14275

143-
if config['network.recurrent']:
144-
if isinstance(rnn_states_buffer, list): # LSTM: [h, c]
145-
rnn_states_buffer = [buf.detach() for buf in rnn_states_buffer]
146-
else:
147-
rnn_states_buffer = rnn_states_buffer.detach()
148-
agent.policy.rnn_states = rnn_states_buffer
149-
150-
train_output = engine.train(n=i)
76+
train_output = engine.train(i)
15177

152-
# Logging
15378
if i == 0 or (i+1) % config['log.record_interval'] == 0 or (i+1) % config['log.print_interval'] == 0:
15479
train_log = engine.log_train(train_output)
15580

156-
if config['network.recurrent']:
157-
rnn_states_buffer = agent.policy.rnn_states # for SegmentRunner
158-
15981
with torch.no_grad(): # disable grad, save memory
16082
eval_output = engine.eval(n=i)
16183
eval_log = engine.log_eval(eval_output)
@@ -164,7 +86,6 @@ def __call__(self, config, seed, device_str):
16486
train_logs.append(train_log)
16587
eval_logs.append(eval_log)
16688

167-
# Save all loggings
16889
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
16990
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
17091

examples/policy_gradient/a2c/engine.py

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from lagom import Logger
6-
from lagom import color_str
6+
from lagom.utils import color_str
77

88
from lagom.engine import BaseEngine
99

@@ -15,12 +15,10 @@
1515

1616
class Engine(BaseEngine):
1717
def train(self, n):
18-
self.agent.policy.network.train() # train mode
18+
self.agent.train()
1919

20-
# Collect a list of Segment
2120
D = self.runner(T=self.config['train.T'])
2221

23-
# Train agent with collected data
2422
out_agent = self.agent.learn(D)
2523

2624
train_output = {}
@@ -31,35 +29,31 @@ def train(self, n):
3129
return train_output
3230

3331
def log_train(self, train_output, **kwargs):
34-
# Unpack
3532
D = train_output['D']
3633
out_agent = train_output['out_agent']
3734
n = train_output['n']
3835

39-
# Loggings
40-
logger = Logger(name='train_logger')
41-
logger.log('train_iteration', n+1) # starts from 1
42-
if self.config['algo.use_lr_scheduler']:
43-
logger.log('current_lr', out_agent['current_lr'])
44-
45-
logger.log('loss', out_agent['loss'])
46-
logger.log('policy_loss', out_agent['policy_loss'])
47-
logger.log('policy_entropy', -out_agent['entropy_loss']) # entropy: negative entropy loss
48-
logger.log('value_loss', out_agent['value_loss'])
36+
logger = Logger()
37+
logger('train_iteration', n+1) # starts from 1
38+
if 'current_lr' in out_agent:
39+
logger('current_lr', out_agent['current_lr'])
40+
logger('loss', out_agent['loss'])
41+
logger('policy_loss', out_agent['policy_loss'])
42+
logger('policy_entropy', -out_agent['entropy_loss'])
43+
logger('value_loss', out_agent['value_loss'])
4944

5045
all_immediate_reward = [segment.all_r for segment in D]
5146
num_timesteps = sum([segment.T for segment in D])
5247

53-
logger.log('num_segments', len(D))
54-
logger.log('num_subsegments', sum([len(segment.trajectories) for segment in D]))
55-
logger.log('num_timesteps', num_timesteps)
56-
logger.log('accumulated_trained_timesteps', self.agent.total_T)
57-
logger.log('average_immediate_reward', np.mean(all_immediate_reward))
58-
logger.log('std_immediate_reward', np.std(all_immediate_reward))
59-
logger.log('min_immediate_reward', np.min(all_immediate_reward))
60-
logger.log('max_immediate_reward', np.max(all_immediate_reward))
48+
logger('num_segments', len(D))
49+
logger('num_subsegments', sum([len(segment.trajectories) for segment in D]))
50+
logger('num_timesteps', num_timesteps)
51+
logger('accumulated_trained_timesteps', self.agent.total_T)
52+
logger('average_immediate_reward', np.mean(all_immediate_reward))
53+
logger('std_immediate_reward', np.std(all_immediate_reward))
54+
logger('min_immediate_reward', np.min(all_immediate_reward))
55+
logger('max_immediate_reward', np.max(all_immediate_reward))
6156

62-
# Dump loggings
6357
if n == 0 or (n+1) % self.config['log.print_interval'] == 0:
6458
print('-'*50)
6559
logger.dump(keys=None, index=None, indent=0)
@@ -68,16 +62,15 @@ def log_train(self, train_output, **kwargs):
6862
return logger.logs
6963

7064
def eval(self, n):
71-
self.agent.policy.network.eval() # evaluation mode
65+
self.agent.eval()
7266

7367
# Synchronize running average of observations for evaluation
7468
if self.config['env.standardize']:
7569
self.eval_runner.env.constant_obs_mean = self.runner.env.obs_runningavg.mu
7670
self.eval_runner.env.constant_obs_std = self.runner.env.obs_runningavg.sigma
7771

78-
# Collect a list of Trajectory
7972
T = self.eval_runner.env.T
80-
D = self.eval_runner(T=T)
73+
D = self.eval_runner(T)
8174

8275
eval_output = {}
8376
eval_output['D'] = D
@@ -87,29 +80,26 @@ def eval(self, n):
8780
return eval_output
8881

8982
def log_eval(self, eval_output, **kwargs):
90-
# Unpack
9183
D = eval_output['D']
9284
n = eval_output['n']
9385
T = eval_output['T']
9486

95-
# Loggings
96-
logger = Logger(name='eval_logger')
87+
logger = Logger()
9788

9889
batch_returns = [sum(trajectory.all_r) for trajectory in D]
9990
batch_T = [trajectory.T for trajectory in D]
10091

101-
logger.log('evaluation_iteration', n+1)
102-
logger.log('num_trajectories', len(D))
103-
logger.log('max_allowed_horizon', T)
104-
logger.log('average_horizon', np.mean(batch_T))
105-
logger.log('num_timesteps', np.sum(batch_T))
106-
logger.log('accumulated_trained_timesteps', self.agent.total_T)
107-
logger.log('average_return', np.mean(batch_returns))
108-
logger.log('std_return', np.std(batch_returns))
109-
logger.log('min_return', np.min(batch_returns))
110-
logger.log('max_return', np.max(batch_returns))
111-
112-
# Dump loggings
92+
logger('evaluation_iteration', n+1)
93+
logger('num_trajectories', len(D))
94+
logger('max_allowed_horizon', T)
95+
logger('average_horizon', np.mean(batch_T))
96+
logger('num_timesteps', np.sum(batch_T))
97+
logger('accumulated_trained_timesteps', self.agent.total_T)
98+
logger('average_return', np.mean(batch_returns))
99+
logger('std_return', np.std(batch_returns))
100+
logger('min_return', np.min(batch_returns))
101+
logger('max_return', np.max(batch_returns))
102+
113103
if n == 0 or (n+1) % self.config['log.print_interval'] == 0:
114104
print(color_str('+'*50, 'yellow', 'bold'))
115105
logger.dump(keys=None, index=None, indent=0)

0 commit comments

Comments
 (0)