Skip to content

Commit dde4c59

Browse files
authored
Unsupervised (#25)
* Version update initial implementation unsupervised * Mask out rewards * Reload new agent when transitioning from unsupervised * Take the mean * Default flags * Unsupervised in agent * Change only replay buffer * Learn reward and policy in unsupervised * Update configs * Scale up rewards * Update exploration steps * Initialize model weights better
1 parent 2a176dd commit dde4c59

File tree

10 files changed

+90
-95
lines changed

10 files changed

+90
-95
lines changed

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

safe_opax/configs/agent/la_mbda.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,6 @@ safety_slack: 0.
5252
evaluate_model: false
5353
exploration_strategy: uniform
5454
exploration_steps: 5000
55-
exploration_reward_scale: 10.0
55+
exploration_reward_scale: 10.0
56+
unsupervised: false
57+
reward_scale: 1.

safe_opax/configs/experiment/active_exploration.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

safe_opax/configs/experiment/cartpole_sparse_hard.yaml

Lines changed: 0 additions & 12 deletions
This file was deleted.

safe_opax/configs/experiment/safety_gym_doggo_explore.yaml renamed to safe_opax/configs/experiment/safe_sparse_goal.yaml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
defaults:
33
- override /environment: safe_adaptation_gym
44

5+
environment:
6+
safe_adaptation_gym:
7+
task: go_to_goal_scarce
58

69
training:
7-
epochs: 200
10+
epochs: 100
811
safe: true
912
action_repeat: 2
10-
episodes_per_epoch: 5
11-
12-
environment:
13-
safe_adaptation_gym:
14-
robot_name: doggo
15-
task: collect
1613

1714
agent:
1815
exploration_strategy: opax
19-
exploration_steps: 1000000
16+
exploration_steps: 850000
17+
actor:
18+
init_stddev: 0.025
19+
sentiment:
20+
model_initialization_scale: 0.05

safe_opax/configs/experiment/safety_gym_doggo.yaml

Lines changed: 0 additions & 18 deletions
This file was deleted.

safe_opax/configs/experiment/unsupervised.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@ defaults:
44

55
training:
66
trainer: unsupervised
7-
epochs: 100
7+
epochs: 200
88
safe: true
99
action_repeat: 2
10+
episodes_per_epoch: 5
1011
exploration_steps: 1000000
12+
test_task_name: go_to_goal
13+
14+
environment:
15+
safe_adaptation_gym:
16+
robot_name: doggo
1117

1218
agent:
1319
exploration_strategy: opax
1420
exploration_steps: 1000000
21+
unsupervised: true

safe_opax/la_mbda/la_mbda.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState"
5252
return self
5353

5454

55-
56-
5755
class LaMBDA:
5856
def __init__(
5957
self,
@@ -148,21 +146,40 @@ def observe_transition(self, transition: Transition) -> None:
148146
def update(self):
149147
total_steps = self.config.agent.update_steps
150148
for batch in self.replay_buffer.sample(total_steps):
149+
batch = TrajectoryData(
150+
batch.observation,
151+
batch.next_observation,
152+
batch.action,
153+
batch.reward * self.config.agent.reward_scale,
154+
batch.cost,
155+
)
151156
inferred_rssm_states = self.update_model(batch)
152157
initial_states = inferred_rssm_states.reshape(
153158
-1, inferred_rssm_states.shape[-1]
154159
)
155-
outs = self.actor_critic.update(self.model, initial_states, next(self.prng))
156160
if self.should_explore():
161+
if not self.config.agent.unsupervised:
162+
outs = self.actor_critic.update(
163+
self.model, initial_states, next(self.prng)
164+
)
165+
else:
166+
outs = {}
157167
exploration_outs = self.exploration.update(
158168
self.model, initial_states, next(self.prng)
159169
)
160170
outs.update(exploration_outs)
171+
else:
172+
outs = self.actor_critic.update(
173+
self.model, initial_states, next(self.prng)
174+
)
161175
for k, v in outs.items():
162176
self.metrics_monitor[k] = v
163177

164178
def update_model(self, batch: TrajectoryData) -> jax.Array:
165179
features, actions = _prepare_features(batch)
180+
learn_reward = not self.should_explore() or (
181+
self.should_explore() and not self.config.agent.unsupervised
182+
)
166183
(self.model, self.model_learner.state), (loss, rest) = variational_step(
167184
features,
168185
actions,
@@ -173,6 +190,7 @@ def update_model(self, batch: TrajectoryData) -> jax.Array:
173190
self.config.agent.beta,
174191
self.config.agent.free_nats,
175192
self.config.agent.kl_mix,
193+
learn_reward,
176194
)
177195
self.metrics_monitor["agent/model/loss"] = float(loss.mean())
178196
self.metrics_monitor["agent/model/reconstruction"] = float(

safe_opax/la_mbda/world_model.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,24 +250,32 @@ def variational_step(
250250
beta: float = 1.0,
251251
free_nats: float = 0.0,
252252
kl_mix: float = 0.8,
253+
with_reward: bool = True,
253254
) -> tuple[tuple[WorldModel, OptState], tuple[jax.Array, TrainingResults]]:
254255
def loss_fn(model):
255256
infer_fn = lambda features, actions: model(features, actions, key)
256257
inference_result: InferenceResult = eqx.filter_vmap(infer_fn)(features, actions)
257-
y = features.observation, jnp.concatenate([features.reward, features.cost], -1)
258-
y_hat = inference_result.image, inference_result.reward_cost
259258
batch_ndim = 2
260-
reconstruction_loss = -sum(
261-
map(
262-
lambda predictions, targets: dtx.Independent(
263-
dtx.Normal(targets, 1.0), targets.ndim - batch_ndim
264-
)
265-
.log_prob(predictions)
266-
.mean(),
267-
y_hat,
268-
y,
259+
logprobs = (
260+
lambda predictions, targets: dtx.Independent(
261+
dtx.Normal(targets, 1.0), targets.ndim - batch_ndim
269262
)
263+
.log_prob(predictions)
264+
.mean()
265+
)
266+
if not with_reward:
267+
reward = jnp.zeros_like(features.reward)
268+
_, pred_cost = jnp.split(inference_result.reward_cost, 2, -1)
269+
reward_cost = jnp.concatenate([reward, pred_cost], -1)
270+
else:
271+
reward = features.reward
272+
reward_cost = inference_result.reward_cost
273+
reward_cost_logprobs = logprobs(
274+
reward_cost,
275+
jnp.concatenate([reward, features.cost], -1),
270276
)
277+
image_logprobs = logprobs(inference_result.image, features.observation)
278+
reconstruction_loss = -reward_cost_logprobs - image_logprobs
271279
kl_loss = kl_divergence(
272280
inference_result.posteriors, inference_result.priors, free_nats, kl_mix
273281
)

safe_opax/rl/trainer.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from safe_opax.rl import acting, episodic_async_env
1313
from safe_opax.rl.epoch_summary import EpochSummary
1414
from safe_opax.rl.logging import StateWriter, TrainingLogger
15-
from safe_opax.rl.types import Agent, EnvironmentFactory
15+
from safe_opax.rl.types import EnvironmentFactory
1616
from safe_opax.rl.utils import PRNGSequence
1717

1818
from safe_adaptation_gym.benchmark import TASKS
1919
from safe_adaptation_gym.tasks import Task
20-
from safe_opax.benchmark_suites.safe_adaptation_gym import sample_task
2120

2221
_LOG = logging.getLogger(__name__)
2322

@@ -58,7 +57,7 @@ def __init__(
5857
self,
5958
config: DictConfig,
6059
make_env: EnvironmentFactory,
61-
agent: Agent | None = None,
60+
agent: LaMBDA | LaMBDADalal | None = None,
6261
start_epoch: int = 0,
6362
step: int = 0,
6463
seeds: PRNGSequence | None = None,
@@ -86,24 +85,27 @@ def __enter__(self):
8685
if self.seeds is None:
8786
self.seeds = PRNGSequence(self.config.training.seed)
8887
if self.agent is None:
89-
if self.config.agent.name == "lambda":
90-
self.agent = LaMBDA(
91-
self.env.observation_space,
92-
self.env.action_space,
93-
self.config,
94-
)
95-
elif self.config.agent.name == "lambda_dalal":
96-
self.agent = LaMBDADalal(
97-
self.env.observation_space,
98-
self.env.action_space,
99-
self.config,
100-
)
101-
else:
102-
raise NotImplementedError(
103-
f"Unknown agent type: {self.config.agent.name}"
104-
)
88+
self.agent = self.make_agent()
10589
return self
10690

91+
def make_agent(self) -> LaMBDA | LaMBDADalal:
92+
assert self.env is not None
93+
if self.config.agent.name == "lambda":
94+
agent = LaMBDA(
95+
self.env.observation_space,
96+
self.env.action_space,
97+
self.config,
98+
)
99+
elif self.config.agent.name == "lambda_dalal":
100+
agent = LaMBDADalal(
101+
self.env.observation_space,
102+
self.env.action_space,
103+
self.config,
104+
)
105+
else:
106+
raise NotImplementedError(f"Unknown agent type: {self.config.agent.name}")
107+
return agent
108+
107109
def __exit__(self, exc_type, exc_val, exc_tb):
108110
assert self.logger is not None and self.state_writer is not None
109111
self.state_writer.close()
@@ -197,13 +199,13 @@ def __init__(
197199
self,
198200
config: DictConfig,
199201
make_env: EnvironmentFactory,
200-
agent: Agent | None = None,
202+
agent: LaMBDA | LaMBDADalal | None = None,
201203
start_epoch: int = 0,
202204
step: int = 0,
203205
seeds: PRNGSequence | None = None,
204206
):
205207
super().__init__(config, make_env, agent, start_epoch, step, seeds)
206-
self.test_task_name = sample_task(self.config.training.seed)
208+
self.test_task_name = self.config.training.test_task_name
207209
self.test_tasks: list[Task] | None = None
208210

209211
def __enter__(self):
@@ -233,4 +235,7 @@ def _run_training_epoch(
233235
]
234236
assert self.env is not None
235237
self.env.reset(options={"task": self.test_tasks})
238+
assert self.agent is not None
239+
new_agent = self.make_agent()
240+
self.agent.replay_buffer = new_agent.replay_buffer
236241
return outs

0 commit comments

Comments
 (0)