Skip to content

Commit 9fa0dc6

Browse files
committed
POMDPs.jl v0.9 support
1 parent ec20805 commit 9fa0dc6

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "POMDPStressTesting"
22
uuid = "6fc570d8-62cd-4d35-b113-bbf3c1b8276a"
33
authors = ["Robert Moss <[email protected]>"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
D3Trees = "e3df1716-f71e-5df9-9e2d-98e193103c45"
@@ -17,3 +17,7 @@ PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
1717
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1818
Seaborn = "d2ef9438-c967-53ab-8060-373fdd9e13eb"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
20+
21+
[compat]
22+
POMDPs = "0.9"
23+
julia = "1"

src/AST.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ end
208208
"""
209209
Generate next state and reward for AST MDP (handles episodic reward problems). Overridden from `POMDPs.gen` interface.
210210
"""
211-
function POMDPs.gen(::DDNOut, mdp::ASTMDP, s::ASTState, a::ASTAction, rng::AbstractRNG)
211+
function POMDPs.gen(mdp::ASTMDP, s::ASTState, a::ASTAction, rng::AbstractRNG)
212212
@assert mdp.sim_hash == s.hash
213213
mdp.t_index += 1
214214
set_global_seed(a.rsg)
@@ -279,7 +279,7 @@ function go_to_state(mdp::ASTMDP, target_state::ASTState)
279279
actions = action_trace(target_state)
280280
R = 0.0
281281
for a in actions
282-
s, r = gen(DDNOut(:sp, :r), mdp, s, a, Random.GLOBAL_RNG)
282+
s, r = @gen(:sp, :r)(mdp, s, a, Random.GLOBAL_RNG)
283283
R += r
284284
end
285285
@assert s == target_state
@@ -318,7 +318,7 @@ function rollout(mdp::ASTMDP, s::ASTState, d::Int64)
318318
else
319319
a::ASTAction = random_action(mdp)
320320

321-
(sp, r) = gen(DDNOut(:sp, :r), mdp, s, a, Random.GLOBAL_RNG)
321+
(sp, r) = @gen(:sp, :r)(mdp, s, a, Random.GLOBAL_RNG)
322322
q_value = r + discount(mdp)*rollout(mdp, sp, d-1)
323323

324324
return q_value
@@ -367,7 +367,7 @@ function rollout_end(mdp::ASTMDP, s::ASTState, d::Int64; max_depth=-1, feed_gen=
367367
if feed_available && (start_of_rollout_feed || mid_rollout_feed)
368368
(sp, r) = feed_gen(mdp, s, a, Random.GLOBAL_RNG)
369369
else
370-
(sp, r) = gen(DDNOut(:sp, :r), mdp, s, a, Random.GLOBAL_RNG)
370+
(sp, r) = @gen(:sp, :r)(mdp, s, a, Random.GLOBAL_RNG)
371371
end
372372

373373
# Note, pass all keywords.
@@ -444,7 +444,7 @@ function playback(mdp::ASTMDP, actions::Vector{ASTAction}, func=nothing; verbose
444444
@show func(mdp.sim)
445445
end
446446
for a in actions
447-
(sp, r) = gen(DDNOut(:sp, :r), mdp, s, a, rng)
447+
(sp, r) = @gen(:sp, :r)(mdp, s, a, rng)
448448
s = sp
449449
if display_trace
450450
@show func(mdp.sim)
@@ -469,12 +469,12 @@ function online_path(mdp::MDP, planner::Policy; verbose::Bool=false)
469469
a = action(planner, s)
470470
actions = ASTAction[a]
471471
printstep(mdp, a)
472-
(s, r) = gen(DDNOut(:sp, :r), mdp, s, a, Random.GLOBAL_RNG)
472+
(s, r) = @gen(:sp, :r)(mdp, s, a, Random.GLOBAL_RNG)
473473

474474
while !BlackBox.isterminal!(mdp.sim)
475475
a = action(planner, s)
476476
push!(actions, a)
477-
(s, r) = gen(DDNOut(:sp, :r), mdp, s, a, Random.GLOBAL_RNG)
477+
(s, r) = @gen(:sp, :r)(mdp, s, a, Random.GLOBAL_RNG)
478478
printstep(mdp, a)
479479
end
480480

0 commit comments

Comments
 (0)