182
182
# Later, we will see how the target parameters should be updated in TorchRL.
183
183
#
184
184
185
- from tensordict .nn import TensorDictModule
185
+ from tensordict .nn import TensorDictModule , TensorDictSequential
186
186
187
187
188
188
def _init (
@@ -290,12 +290,11 @@ def _loss_actor(
290
290
) -> torch .Tensor :
291
291
td_copy = tensordict .select (* self .actor_in_keys )
292
292
# Get an action from the actor network: since we made it functional, we need to pass the params
293
- td_copy = self .actor_network (td_copy , params = self .actor_network_params )
293
+ with self .actor_network_params .to_module (self .actor_network ):
294
+ td_copy = self .actor_network (td_copy )
294
295
# get the value associated with that action
295
- td_copy = self .value_network (
296
- td_copy ,
297
- params = self .value_network_params .detach (),
298
- )
296
+ with self .value_network_params .detach ().to_module (self .value_network ):
297
+ td_copy = self .value_network (td_copy )
299
298
return - td_copy .get ("state_action_value" )
300
299
301
300
@@ -317,7 +316,8 @@ def _loss_value(
317
316
td_copy = tensordict .clone ()
318
317
319
318
# V(s, a)
320
- self .value_network (td_copy , params = self .value_network_params )
319
+ with self .value_network_params .to_module (self .value_network ):
320
+ self .value_network (td_copy )
321
321
pred_val = td_copy .get ("state_action_value" ).squeeze (- 1 )
322
322
323
323
# we manually reconstruct the parameters of the actor-critic, where the first
@@ -332,9 +332,8 @@ def _loss_value(
332
332
batch_size = self .target_actor_network_params .batch_size ,
333
333
device = self .target_actor_network_params .device ,
334
334
)
335
- target_value = self .value_estimator .value_estimate (
336
- tensordict , target_params = target_params
337
- ).squeeze (- 1 )
335
+ with target_params .to_module (self .actor_critic ):
336
+ target_value = self .value_estimator .value_estimate (tensordict ).squeeze (- 1 )
338
337
339
338
# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`
340
339
loss_value = distance_loss (pred_val , target_value , loss_function = self .loss_function )
@@ -717,7 +716,7 @@ def get_env_stats():
717
716
ActorCriticWrapper ,
718
717
DdpgMlpActor ,
719
718
DdpgMlpQNet ,
720
- OrnsteinUhlenbeckProcessWrapper ,
719
+ OrnsteinUhlenbeckProcessModule ,
721
720
ProbabilisticActor ,
722
721
TanhDelta ,
723
722
ValueOperator ,
@@ -776,15 +775,18 @@ def make_ddpg_actor(
776
775
# Exploration
777
776
# ~~~~~~~~~~~
778
777
#
779
- # The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper `
778
+ # The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule `
780
779
# exploration module, as suggested in the original paper.
781
780
# Let's define the number of frames before OU noise reaches its minimum value
782
781
annealing_frames = 1_000_000
783
782
784
- actor_model_explore = OrnsteinUhlenbeckProcessWrapper (
783
+ actor_model_explore = TensorDictSequential (
785
784
actor ,
786
- annealing_num_steps = annealing_frames ,
787
- ).to (device )
785
+ OrnsteinUhlenbeckProcessModule (
786
+ spec = actor .spec .clone (),
787
+ annealing_num_steps = annealing_frames ,
788
+ ).to (device ),
789
+ )
788
790
if device == torch .device ("cpu" ):
789
791
actor_model_explore .share_memory ()
790
792
@@ -1168,7 +1170,7 @@ def ceil_div(x, y):
1168
1170
)
1169
1171
1170
1172
# update the exploration strategy
1171
- actor_model_explore .step (current_frames )
1173
+ actor_model_explore [ 1 ] .step (current_frames )
1172
1174
1173
1175
collector .shutdown ()
1174
1176
del collector
0 commit comments