-
Notifications
You must be signed in to change notification settings - Fork 840
Description
Problem Description
The current PPO with LSTM script ppo_atari_lstm.py uses sequential stepping through the LSTM, i.e. each step of the sequence in processed individually:
for h, d in zip(hidden, done):
h, lstm_state = self.lstm(
h.unsqueeze(0),
(
(1.0 - d).view(1, -1, 1) * lstm_state[0],
(1.0 - d).view(1, -1, 1) * lstm_state[1],
),
)
new_hidden += [h]
This method is very slow compared to sending the entire sequence of observations to the LSTM:
h, lstm_state = self.lstm(hidden, lstm_state)
This usually cannot be done in RL, as we have to reset the hidden states when an episode ends.
Other implementations of PPO use a trick, which is to split a sequence containing several trajectories into several sequences that contain only one trajectory. This is accomplished by splitting the input sequence everywhere where there's a done
and padding the rest of the sequence. This can be visualized as:
Original sequences: [ [a1, a2, a3, a4 | a5, a6],
[b1, b2 | b3, b4, b5 | b6] ]
Split sequences:[ [a1, a2, a3, a4],
[a5, a6, 0, 0],
[b1, b2, 0, 0],
[b3, b4, b5, 0],
[b6, 0, 0, 0] ]
With this trick it is possible to make a single call to the LSTM to process multiple sequences and batches.
Proposal
I implemented a version of the script that uses this trick to process sequences with one call. In my setup, it led to a 4x improvement in training speed. However, it comes with a higher memory usage (about 2x in my setup). The final performance of the policy is similar to the original script.
Would you be interested in adding this script to the repo? Should I make a PR to create a new file using this trick?