Skip to content

Commit f5a53c4

Browse files
pseudo-rnd-thoughtsMark Towers
andauthored
[RLlib] Fix MultiAgentEpisode.env_t_to_agent_t (#60319)
## Description In testing the off-policy algorithms (DQN, SAC, etc) with TicTacToe where agents act in turns, an error was raised related to `MultiAgentEpisode.env_t_to_agent_t`. This PR does a deep dive adding comprehensive testing of the `MultiAgentEpisode`, `MultiAgentEnvRunner` and a `MultiAgentEnv`'s interaction resolving several problems identified. --------- Signed-off-by: Mark Towers <mark@anyscale.com> Co-authored-by: Mark Towers <mark@anyscale.com>
1 parent db822f5 commit f5a53c4

File tree

5 files changed

+405
-208
lines changed

5 files changed

+405
-208
lines changed

rllib/env/multi_agent_episode.py

Lines changed: 99 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class MultiAgentEpisode:
3939
4040
Each AgentID in the `MultiAgentEpisode` has its own `SingleAgentEpisode` object
4141
in which this agent's data is stored. Together with the env_t_to_agent_t mapping,
42-
we can extract information either on any individual agent's time scale or from
43-
the (global) multi-agent environment time scale.
42+
we can extract information either on any individual agent's timescale or from
43+
the (global) multi-agent environment timescale.
4444
4545
Extraction of data from a MultiAgentEpisode happens via the getter APIs, e.g.
4646
`get_observations()`, which work analogous to the ones implemented in the
@@ -156,8 +156,8 @@ def __init__(
156156
of the episode. This is only larger zero, if an already ongoing episode
157157
chunk is being created, for example by slicing an ongoing episode or
158158
by calling the `cut()` method on an ongoing episode.
159-
agent_t_started: A dict mapping AgentIDs to the respective agent's (local)
160-
timestep at which its SingleAgentEpisode chunk started.
159+
agent_t_started: A dict mapping AgentIDs to the agent's timestep
160+
(not global env timestep) at which its SingleAgentEpisode chunk started.
161161
len_lookback_buffer: The size of the lookback buffers to keep in
162162
front of this Episode for each type of data (observations, actions,
163163
etc..). If larger 0, will interpret the first `len_lookback_buffer`
@@ -628,7 +628,7 @@ def add_env_step(
628628
)
629629
# Update the env- to agent-step mapping.
630630
self.env_t_to_agent_t[agent_id].append(
631-
len(sa_episode) + sa_episode.observations.lookback
631+
len(sa_episode) + self.agent_t_started[agent_id]
632632
)
633633

634634
# Agent is also done. -> Erase all hanging values for this agent
@@ -832,8 +832,16 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None:
832832
# wrt agent in `self`.
833833
if sa_episode is None:
834834
self.agent_episodes[agent_id] = other.agent_episodes[agent_id]
835-
self.env_t_to_agent_t[agent_id] = other.env_t_to_agent_t[agent_id]
836835
self.agent_t_started[agent_id] = other.agent_t_started[agent_id]
836+
837+
# If agent only has the first reset observation then no episode exists but `env_t_to_agent_t` does
838+
if agent_id not in self.env_t_to_agent_t:
839+
self.env_t_to_agent_t[agent_id] = other.env_t_to_agent_t[agent_id]
840+
else:
841+
# For a cut episode, the first timestep is a copy of the last timestep from the previous episode
842+
for val in other.env_t_to_agent_t[agent_id][1:]:
843+
self.env_t_to_agent_t[agent_id].append(val)
844+
837845
self._copy_hanging(agent_id, other)
838846

839847
# If the agent was done in `self`, ignore and continue. There should not be
@@ -858,12 +866,10 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None:
858866
)
859867

860868
# Concatenate the env- to agent-timestep mappings.
861-
j = self.env_t
862-
for i, val in enumerate(other.env_t_to_agent_t[agent_id][1:]):
863-
if val == self.SKIP_ENV_TS_TAG:
864-
self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG)
865-
else:
866-
self.env_t_to_agent_t[agent_id].append(i + 1 + j)
869+
# Skip the first element (overlapping boundary) and append the rest.
870+
# Values are agent timesteps, so append them directly.
871+
for val in other.env_t_to_agent_t[agent_id][1:]:
872+
self.env_t_to_agent_t[agent_id].append(val)
867873

868874
# Otherwise, the agent is only in `self` and not done. All data is stored
869875
# already -> skip
@@ -1581,7 +1587,7 @@ def slice(
15811587
if start < len(mapping):
15821588
for i in range(start, len(mapping)):
15831589
if mapping[i] != self.SKIP_ENV_TS_TAG:
1584-
agent_t_started[aid] = sa_episode.t_started + mapping[i]
1590+
agent_t_started[aid] = mapping[i]
15851591
break
15861592
terminateds["__all__"] = all(
15871593
terminateds.get(aid) for aid in self.agent_episodes
@@ -2158,6 +2164,36 @@ def _init_single_agent_episodes(
21582164
)
21592165
agent_module_ids = agent_module_ids or {}
21602166

2167+
# First pass: count observations per agent in lookback AND total.
2168+
# This allows us to recover the correct env_t_to_agent_t mapping.
2169+
lookback_obs_count_per_agent = defaultdict(int)
2170+
total_obs_count_per_agent = defaultdict(int)
2171+
for data_idx, obs in enumerate(observations):
2172+
for agent_id in obs:
2173+
total_obs_count_per_agent[agent_id] += 1
2174+
if data_idx < self._len_lookback_buffers:
2175+
lookback_obs_count_per_agent[agent_id] += 1
2176+
2177+
# Compute the starting agent_t for each agent.
2178+
# The formula depends on whether there are observations after the lookback:
2179+
# - If new_chunk_obs > 0: first_agent_t = agent_t_started - lookback_count
2180+
# - If new_chunk_obs == 0: first_agent_t = agent_t_started - lookback_count + 1
2181+
# This is because agent_t_started = len(completed_actions), which equals the
2182+
# observation_index of the NEXT observation if there is one, or the LAST
2183+
# observation if the action is still hanging.
2184+
current_agent_t = {}
2185+
for agent_id, lookback_count in lookback_obs_count_per_agent.items():
2186+
total_count = total_obs_count_per_agent[agent_id]
2187+
new_chunk_obs = total_count - lookback_count
2188+
if new_chunk_obs > 0:
2189+
current_agent_t[agent_id] = (
2190+
self.agent_t_started[agent_id] - lookback_count
2191+
)
2192+
else:
2193+
current_agent_t[agent_id] = (
2194+
self.agent_t_started[agent_id] - lookback_count + 1
2195+
)
2196+
21612197
# Step through all observations and interpret these as the (global) env steps.
21622198
for data_idx, (obs, inf) in enumerate(zip(observations, infos)):
21632199
# If we do have actions/extra outs/rewards for this timestep, use the data.
@@ -2216,10 +2252,15 @@ def _init_single_agent_episodes(
22162252
elif data_idx < len(observations) - 1:
22172253
done_per_agent[agent_id] = terminateds[agent_id] = True
22182254

2219-
# Update env_t_to_agent_t mapping.
2220-
self.env_t_to_agent_t[agent_id].append(
2221-
len(observations_per_agent[agent_id]) - 1
2222-
)
2255+
# Update env_t_to_agent_t mapping using the recovered agent_t.
2256+
# For agents in the lookback, current_agent_t was computed earlier as:
2257+
# agent_t_started - lookback_obs_count
2258+
# For agents not in lookback but with prior history, use agent_t_started.
2259+
# For truly new agents (no prior history), start at 0.
2260+
if agent_id not in current_agent_t:
2261+
current_agent_t[agent_id] = self.agent_t_started.get(agent_id, 0)
2262+
self.env_t_to_agent_t[agent_id].append(current_agent_t[agent_id])
2263+
current_agent_t[agent_id] += 1
22232264

22242265
# Those agents that did NOT step:
22252266
# - Get self.SKIP_ENV_TS_TAG added to their env_t_to_agent_t mapping.
@@ -2297,7 +2338,7 @@ def _init_single_agent_episodes(
22972338
t_started=self.agent_t_started[agent_id],
22982339
len_lookback_buffer=max(len_lookback_buffer_per_agent[agent_id], 0),
22992340
)
2300-
# .. and store it.
2341+
# and store it.
23012342
self.agent_episodes[agent_id] = sa_episode
23022343

23032344
def _get(
@@ -2377,7 +2418,9 @@ def _get_data_by_agent_steps(
23772418
_add_last_ts_value=hanging_val,
23782419
**one_hot_discrete,
23792420
)
2380-
if agent_value is None or agent_value == []:
2421+
if agent_value is None or (
2422+
isinstance(agent_value, list) and agent_value == []
2423+
):
23812424
continue
23822425
ret[agent_id] = agent_value
23832426
return ret
@@ -2399,7 +2442,7 @@ def _get_data_by_env_steps_as_list(
23992442
for agent_id in self.agent_episodes.keys():
24002443
if agent_id not in agent_ids:
24012444
continue
2402-
agent_indices[agent_id] = self.env_t_to_agent_t[agent_id].get(
2445+
agent_t_indices = self.env_t_to_agent_t[agent_id].get(
24032446
indices,
24042447
neg_index_as_lookback=neg_index_as_lookback,
24052448
fill=self.SKIP_ENV_TS_TAG,
@@ -2408,6 +2451,24 @@ def _get_data_by_env_steps_as_list(
24082451
# the env_t_to_agent_t mappings.
24092452
_ignore_last_ts=what not in ["observations", "infos"],
24102453
)
2454+
# Convert absolute agent_t to buffer position (including lookback offset).
2455+
# Formula: buffer_pos = agent_t - agent_t_started + lookback
2456+
sa_episode = self.agent_episodes[agent_id]
2457+
lookback = sa_episode.observations.lookback
2458+
if isinstance(agent_t_indices, int):
2459+
if agent_t_indices != self.SKIP_ENV_TS_TAG:
2460+
agent_t_indices = (
2461+
agent_t_indices - self.agent_t_started[agent_id] + lookback
2462+
)
2463+
else:
2464+
assert isinstance(agent_t_indices, list)
2465+
agent_t_indices = [
2466+
index - self.agent_t_started[agent_id] + lookback
2467+
if index != self.SKIP_ENV_TS_TAG
2468+
else index
2469+
for index in agent_t_indices
2470+
]
2471+
agent_indices[agent_id] = agent_t_indices
24112472
if not agent_indices:
24122473
return []
24132474
ret = []
@@ -2479,7 +2540,17 @@ def _get_data_by_env_steps(
24792540
hanging_val,
24802541
filter_for_skip_indices=agent_indices,
24812542
)
2543+
# Convert absolute agent_t to buffer position (including lookback offset).
2544+
# Formula: buffer_pos = agent_t - agent_t_started + lookback
2545+
lookback = sa_episode.observations.lookback
24822546
if isinstance(agent_indices, list):
2547+
agent_indices = [
2548+
index - self.agent_t_started[agent_id] + lookback
2549+
if index != self.SKIP_ENV_TS_TAG
2550+
else index
2551+
for index in agent_indices
2552+
]
2553+
24832554
agent_values = self._get_single_agent_data_by_env_step_indices(
24842555
what=what,
24852556
agent_id=agent_id,
@@ -2492,6 +2563,11 @@ def _get_data_by_env_steps(
24922563
if len(agent_values) > 0:
24932564
ret[agent_id] = agent_values
24942565
else:
2566+
if agent_indices != self.SKIP_ENV_TS_TAG:
2567+
agent_indices = (
2568+
agent_indices - self.agent_t_started[agent_id] + lookback
2569+
)
2570+
24952571
agent_values = self._get_single_agent_data_by_index(
24962572
what=what,
24972573
inf_lookback_buffer=inf_lookback_buffer,
@@ -2523,7 +2599,7 @@ def _get_single_agent_data_by_index(
25232599
if index_incl_lookback == self.SKIP_ENV_TS_TAG:
25242600
# We don't want to fill -> Skip this agent.
25252601
if fill is None:
2526-
return
2602+
return None
25272603
# Provide filled value for this agent.
25282604
return getattr(sa_episode, f"get_{what}")(
25292605
indices=1000000000000,
@@ -2605,7 +2681,7 @@ def _get_single_agent_data_by_env_step_indices(
26052681
lookback buffer should be returned, not the first value after the
26062682
lookback buffer (which would be normal behavior for pulling items from
26072683
an `InfiniteLookbackBuffer` object).
2608-
agent_id: The individual agent ID to pull data for. Used to lookup the
2684+
agent_id: The individual agent ID to pull data for. Used to look up the
26092685
`SingleAgentEpisode` object for this agent in `self`.
26102686
fill: An optional float value to use for filling up the returned results at
26112687
the boundaries. This filling only happens if the requested index range's
@@ -2627,7 +2703,7 @@ def _get_single_agent_data_by_env_step_indices(
26272703
hanging_val: In case we are pulling actions, rewards, or extra_model_outputs
26282704
data, there might be information "hanging" (cached). For example,
26292705
if an agent receives an observation o0 and then immediately sends an
2630-
action a0 back, but then does NOT immediately reveive a next
2706+
action a0 back, but then does NOT immediately retrieve the next
26312707
observation, a0 is now cached (not fully logged yet with this
26322708
episode). The currently cached value must be provided here to be able
26332709
to return it in case the index is -1 (most recent timestep).

0 commit comments

Comments
 (0)