Skip to content

Commit 3be540e

Browse files
committed
Updates from deployable-rl
1 parent 280b358 commit 3be540e

File tree

6 files changed

+9
-23
lines changed

6 files changed

+9
-23
lines changed

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dm-env = "^1.6"
2222
distrax = "^0.1.5"
2323
pillow = "^10.2.0"
2424
moviepy = "^1.0.3"
25-
safe-adaptation-gym = {git = "https://git@github.com/lasgroup/safe-adaptation-gym.git"}
25+
safe-adaptation-gym = {git = "ssh://git@github.com/lasgroup/safe-adaptation-gym"}
2626
jmp = {git = "https://github.com/deepmind/jmp"}
2727
tensorboard = "^2.16.2"
2828

safe_opax/configs/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ training:
4242
episodes_per_epoch: 5
4343
epochs: 200
4444
action_repeat: 1
45-
render_episodes: 1
45+
render_episodes: 0
4646
parallel_envs: 10
4747
scale_reward: 1.
4848
exploration_steps: 5000

safe_opax/configs/experiment/safety_gym_doggo.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ training:
66
epochs: 200
77
safe: true
88
action_repeat: 2
9-
episodes_per_epoch: 10
9+
episodes_per_epoch: 5
1010

1111
environment:
1212
safe_adaptation_gym:

safe_opax/la_mbda/actor_critic.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,6 @@
99
from safe_opax.rl.utils import rl_initialize_weights_trick
1010

1111

12-
class StableTanh(trx.Tanh):
13-
def inverse_and_log_det(self, y):
14-
dtype = y.dtype
15-
y = y.astype(jnp.float32)
16-
# Clip to avoid computing very large gradients outside of
17-
# the given range.
18-
y = jnp.clip(y, -0.99999997, 0.99999997)
19-
x = jnp.arctanh(y)
20-
x = x.astype(dtype)
21-
return x, -self.forward_log_det_jacobian(x)
22-
23-
2412
class ContinuousActor(eqx.Module):
2513
net: eqx.nn.MLP
2614
init_stddev: float = eqx.static_field()
@@ -55,9 +43,8 @@ def __call__(self, state: jax.Array) -> trx.Transformed:
5543
init_std = inv_softplus(self.init_stddev)
5644
stddev = jnn.softplus(stddev + init_std) + 1e-4
5745
mu = 5.0 * jnn.tanh(mu / 5.0)
58-
dist = trx.MultivariateNormalDiag(mu, stddev)
59-
bijector = trx.Block(StableTanh(), 1)
60-
dist = trx.Transformed(dist, bijector)
46+
dist = trx.Normal(mu, stddev)
47+
dist = trx.Transformed(dist, trx.Tanh())
6148
return dist
6249

6350
def act(

safe_opax/la_mbda/safe_actor_critic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ def evaluate_actor(
196196
objective_sentiment: Sentiment,
197197
constraint_sentiment: Sentiment,
198198
) -> ActorEvaluation:
199-
keys = jnp.asarray(jax.random.split(key, initial_states.shape[0]))
200-
trajectories, priors = rollout_fn(horizon, initial_states, keys, actor.act)
199+
trajectories, priors = rollout_fn(horizon, initial_states, key, actor.act)
201200
next_step = lambda x: x[:, 1:]
202201
current_step = lambda x: x[:, :-1]
203202
next_states = next_step(trajectories.next_state)

0 commit comments

Comments
 (0)