Skip to content

Commit 92e1abc

Browse files
committed
Fix DQN w RNN tutorial
1 parent 9a44439 commit 92e1abc

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

intermediate_source/dqn_with_rnn_tutorial.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,10 @@
342342
# will return a new instance of the LSTM (with shared weights) that will
343343
# assume that the input data is sequential in nature.
344344
#
345-
policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)
345+
from torchrl.modules import set_recurrent_mode
346+
347+
with set_recurrent_mode(True):
348+
policy = Seq(feature, lstm, mlp, qval)
346349

347350
######################################################################
348351
# Because we still have a couple of uninitialized parameters we should
@@ -389,7 +392,9 @@
389392
# For the sake of efficiency, we're only running a few thousands iterations
390393
# here. In a real setting, the total number of frames should be set to 1M.
391394
#
392-
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
395+
collector = SyncDataCollector(
396+
env, stoch_policy, frames_per_batch=50, total_frames=200, device=device
397+
)
393398
rb = TensorDictReplayBuffer(
394399
storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
395400
)
@@ -464,5 +469,5 @@
464469
#
465470
# Further Reading
466471
# ---------------
467-
#
472+
#
468473
# - The TorchRL documentation can be found `here <https://pytorch.org/rl/>`_.

0 commit comments

Comments
 (0)