208208"""
209209Generate next state and reward for AST MDP (handles episodic reward problems). Overridden from `POMDPs.gen` interface.
210210"""
211- function POMDPs. gen (:: DDNOut , mdp:: ASTMDP , s:: ASTState , a:: ASTAction , rng:: AbstractRNG )
211+ function POMDPs. gen (mdp:: ASTMDP , s:: ASTState , a:: ASTAction , rng:: AbstractRNG )
212212 @assert mdp. sim_hash == s. hash
213213 mdp. t_index += 1
214214 set_global_seed (a. rsg)
@@ -279,7 +279,7 @@ function go_to_state(mdp::ASTMDP, target_state::ASTState)
279279 actions = action_trace (target_state)
280280 R = 0.0
281281 for a in actions
282- s, r = gen (DDNOut ( :sp , :r ), mdp, s, a, Random. GLOBAL_RNG)
282+ s, r = @ gen (:sp , :r )( mdp, s, a, Random. GLOBAL_RNG)
283283 R += r
284284 end
285285 @assert s == target_state
@@ -318,7 +318,7 @@ function rollout(mdp::ASTMDP, s::ASTState, d::Int64)
318318 else
319319 a:: ASTAction = random_action (mdp)
320320
321- (sp, r) = gen (DDNOut ( :sp , :r ), mdp, s, a, Random. GLOBAL_RNG)
321+ (sp, r) = @ gen (:sp , :r )( mdp, s, a, Random. GLOBAL_RNG)
322322 q_value = r + discount (mdp)* rollout (mdp, sp, d- 1 )
323323
324324 return q_value
@@ -367,7 +367,7 @@ function rollout_end(mdp::ASTMDP, s::ASTState, d::Int64; max_depth=-1, feed_gen=
367367 if feed_available && (start_of_rollout_feed || mid_rollout_feed)
368368 (sp, r) = feed_gen (mdp, s, a, Random. GLOBAL_RNG)
369369 else
370- (sp, r) = gen (DDNOut ( :sp , :r ), mdp, s, a, Random. GLOBAL_RNG)
370+ (sp, r) = @ gen (:sp , :r )( mdp, s, a, Random. GLOBAL_RNG)
371371 end
372372
373373 # Note, pass all keywords.
@@ -444,7 +444,7 @@ function playback(mdp::ASTMDP, actions::Vector{ASTAction}, func=nothing; verbose
444444 @show func (mdp. sim)
445445 end
446446 for a in actions
447- (sp, r) = gen (DDNOut ( :sp , :r ), mdp, s, a, rng)
447+ (sp, r) = @ gen (:sp , :r )( mdp, s, a, rng)
448448 s = sp
449449 if display_trace
450450 @show func (mdp. sim)
@@ -469,12 +469,12 @@ function online_path(mdp::MDP, planner::Policy; verbose::Bool=false)
469469 a = action (planner, s)
470470 actions = ASTAction[a]
471471 printstep (mdp, a)
472- (s, r) = gen (DDNOut ( :sp , :r ), mdp, s, a, Random. GLOBAL_RNG)
472+ (s, r) = @ gen (:sp , :r )( mdp, s, a, Random. GLOBAL_RNG)
473473
474474 while ! BlackBox. isterminal! (mdp. sim)
475475 a = action (planner, s)
476476 push! (actions, a)
477- (s, r) = gen (DDNOut ( :sp , :r ), mdp, s, a, Random. GLOBAL_RNG)
477+ (s, r) = @ gen (:sp , :r )( mdp, s, a, Random. GLOBAL_RNG)
478478 printstep (mdp, a)
479479 end
480480
0 commit comments