Lowering
SARSA_PARAMS["pos_states"] = 51
To
SARSA_PARAMS["pos_states"] = 31
causes assertions at:
https://github.com/tud-cor-sr/ics-pa-sv/blob/ca94a78cdb3c2cfdd9f3475e7d6869a6d83205ae/assignment/problem_3/task_3_rl.ipynb?short_path=28a759c#L952-L968
to fail since with that level of discretization the discrete positions are the same, and thus rewards will be the same unless a hack is made to circumvent this assertion. Testing observe_reward() independently of discretize_state() would fix this.
Rendered code:
x0 = jnp.array([jnp.pi / 2, 0])
s = discretize_state(x0, SARSA_PARAMS)
r_t = observe_reward(aa, s, SARSA_PARAMS)
x1 = jnp.array([0.9 * jnp.pi / 2, 0])
s1 = discretize_state(x1, SARSA_PARAMS)
r_1 = observe_reward(aa, s1, SARSA_PARAMS)
assert r_t > r_1
x1 = jnp.array([1.1 * jnp.pi / 2, 0])
s1 = discretize_state(x1, SARSA_PARAMS)
r_1 = observe_reward(aa, s1, SARSA_PARAMS)
assert r_t > r_1