@@ -39,8 +39,8 @@ class MultiAgentEpisode:
3939
4040 Each AgentID in the `MultiAgentEpisode` has its own `SingleAgentEpisode` object
4141 in which this agent's data is stored. Together with the env_t_to_agent_t mapping,
42- we can extract information either on any individual agent's time scale or from
43- the (global) multi-agent environment time scale .
42+ we can extract information either on any individual agent's timescale or from
43+ the (global) multi-agent environment timescale .
4444
4545 Extraction of data from a MultiAgentEpisode happens via the getter APIs, e.g.
4646 `get_observations()`, which work analogous to the ones implemented in the
@@ -156,8 +156,8 @@ def __init__(
156156 of the episode. This is only larger zero, if an already ongoing episode
157157 chunk is being created, for example by slicing an ongoing episode or
158158 by calling the `cut()` method on an ongoing episode.
159- agent_t_started: A dict mapping AgentIDs to the respective agent's (local)
160- timestep at which its SingleAgentEpisode chunk started.
159+ agent_t_started: A dict mapping AgentIDs to the agent's timestep
160+ (not global env timestep) at which its SingleAgentEpisode chunk started.
161161 len_lookback_buffer: The size of the lookback buffers to keep in
162162 front of this Episode for each type of data (observations, actions,
163163 etc..). If larger 0, will interpret the first `len_lookback_buffer`
@@ -628,7 +628,7 @@ def add_env_step(
628628 )
629629 # Update the env- to agent-step mapping.
630630 self .env_t_to_agent_t [agent_id ].append (
631- len (sa_episode ) + sa_episode . observations . lookback
631+ len (sa_episode ) + self . agent_t_started [ agent_id ]
632632 )
633633
634634 # Agent is also done. -> Erase all hanging values for this agent
@@ -832,8 +832,16 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None:
832832 # wrt agent in `self`.
833833 if sa_episode is None :
834834 self .agent_episodes [agent_id ] = other .agent_episodes [agent_id ]
835- self .env_t_to_agent_t [agent_id ] = other .env_t_to_agent_t [agent_id ]
836835 self .agent_t_started [agent_id ] = other .agent_t_started [agent_id ]
836+
837+ # If agent only has the first reset observation then no episode exists but `env_t_to_agent_t` does
838+ if agent_id not in self .env_t_to_agent_t :
839+ self .env_t_to_agent_t [agent_id ] = other .env_t_to_agent_t [agent_id ]
840+ else :
841+ # For a cut episode, the first timestep is a copy of the last timestep from the previous episode
842+ for val in other .env_t_to_agent_t [agent_id ][1 :]:
843+ self .env_t_to_agent_t [agent_id ].append (val )
844+
837845 self ._copy_hanging (agent_id , other )
838846
839847 # If the agent was done in `self`, ignore and continue. There should not be
@@ -858,12 +866,10 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None:
858866 )
859867
860868 # Concatenate the env- to agent-timestep mappings.
861- j = self .env_t
862- for i , val in enumerate (other .env_t_to_agent_t [agent_id ][1 :]):
863- if val == self .SKIP_ENV_TS_TAG :
864- self .env_t_to_agent_t [agent_id ].append (self .SKIP_ENV_TS_TAG )
865- else :
866- self .env_t_to_agent_t [agent_id ].append (i + 1 + j )
869+ # Skip the first element (overlapping boundary) and append the rest.
870+ # Values are agent timesteps, so append them directly.
871+ for val in other .env_t_to_agent_t [agent_id ][1 :]:
872+ self .env_t_to_agent_t [agent_id ].append (val )
867873
868874 # Otherwise, the agent is only in `self` and not done. All data is stored
869875 # already -> skip
@@ -1581,7 +1587,7 @@ def slice(
15811587 if start < len (mapping ):
15821588 for i in range (start , len (mapping )):
15831589 if mapping [i ] != self .SKIP_ENV_TS_TAG :
1584- agent_t_started [aid ] = sa_episode . t_started + mapping [i ]
1590+ agent_t_started [aid ] = mapping [i ]
15851591 break
15861592 terminateds ["__all__" ] = all (
15871593 terminateds .get (aid ) for aid in self .agent_episodes
@@ -2158,6 +2164,36 @@ def _init_single_agent_episodes(
21582164 )
21592165 agent_module_ids = agent_module_ids or {}
21602166
2167+ # First pass: count observations per agent in lookback AND total.
2168+ # This allows us to recover the correct env_t_to_agent_t mapping.
2169+ lookback_obs_count_per_agent = defaultdict (int )
2170+ total_obs_count_per_agent = defaultdict (int )
2171+ for data_idx , obs in enumerate (observations ):
2172+ for agent_id in obs :
2173+ total_obs_count_per_agent [agent_id ] += 1
2174+ if data_idx < self ._len_lookback_buffers :
2175+ lookback_obs_count_per_agent [agent_id ] += 1
2176+
2177+ # Compute the starting agent_t for each agent.
2178+ # The formula depends on whether there are observations after the lookback:
2179+ # - If new_chunk_obs > 0: first_agent_t = agent_t_started - lookback_count
2180+ # - If new_chunk_obs == 0: first_agent_t = agent_t_started - lookback_count + 1
2181+ # This is because agent_t_started = len(completed_actions), which equals the
2182+ # observation_index of the NEXT observation if there is one, or the LAST
2183+ # observation if the action is still hanging.
2184+ current_agent_t = {}
2185+ for agent_id , lookback_count in lookback_obs_count_per_agent .items ():
2186+ total_count = total_obs_count_per_agent [agent_id ]
2187+ new_chunk_obs = total_count - lookback_count
2188+ if new_chunk_obs > 0 :
2189+ current_agent_t [agent_id ] = (
2190+ self .agent_t_started [agent_id ] - lookback_count
2191+ )
2192+ else :
2193+ current_agent_t [agent_id ] = (
2194+ self .agent_t_started [agent_id ] - lookback_count + 1
2195+ )
2196+
21612197 # Step through all observations and interpret these as the (global) env steps.
21622198 for data_idx , (obs , inf ) in enumerate (zip (observations , infos )):
21632199 # If we do have actions/extra outs/rewards for this timestep, use the data.
@@ -2216,10 +2252,15 @@ def _init_single_agent_episodes(
22162252 elif data_idx < len (observations ) - 1 :
22172253 done_per_agent [agent_id ] = terminateds [agent_id ] = True
22182254
2219- # Update env_t_to_agent_t mapping.
2220- self .env_t_to_agent_t [agent_id ].append (
2221- len (observations_per_agent [agent_id ]) - 1
2222- )
2255+ # Update env_t_to_agent_t mapping using the recovered agent_t.
2256+ # For agents in the lookback, current_agent_t was computed earlier as:
2257+ # agent_t_started - lookback_obs_count
2258+ # For agents not in lookback but with prior history, use agent_t_started.
2259+ # For truly new agents (no prior history), start at 0.
2260+ if agent_id not in current_agent_t :
2261+ current_agent_t [agent_id ] = self .agent_t_started .get (agent_id , 0 )
2262+ self .env_t_to_agent_t [agent_id ].append (current_agent_t [agent_id ])
2263+ current_agent_t [agent_id ] += 1
22232264
22242265 # Those agents that did NOT step:
22252266 # - Get self.SKIP_ENV_TS_TAG added to their env_t_to_agent_t mapping.
@@ -2297,7 +2338,7 @@ def _init_single_agent_episodes(
22972338 t_started = self .agent_t_started [agent_id ],
22982339 len_lookback_buffer = max (len_lookback_buffer_per_agent [agent_id ], 0 ),
22992340 )
2300- # .. and store it.
2341+ # and store it.
23012342 self .agent_episodes [agent_id ] = sa_episode
23022343
23032344 def _get (
@@ -2377,7 +2418,9 @@ def _get_data_by_agent_steps(
23772418 _add_last_ts_value = hanging_val ,
23782419 ** one_hot_discrete ,
23792420 )
2380- if agent_value is None or agent_value == []:
2421+ if agent_value is None or (
2422+ isinstance (agent_value , list ) and agent_value == []
2423+ ):
23812424 continue
23822425 ret [agent_id ] = agent_value
23832426 return ret
@@ -2399,7 +2442,7 @@ def _get_data_by_env_steps_as_list(
23992442 for agent_id in self .agent_episodes .keys ():
24002443 if agent_id not in agent_ids :
24012444 continue
2402- agent_indices [ agent_id ] = self .env_t_to_agent_t [agent_id ].get (
2445+ agent_t_indices = self .env_t_to_agent_t [agent_id ].get (
24032446 indices ,
24042447 neg_index_as_lookback = neg_index_as_lookback ,
24052448 fill = self .SKIP_ENV_TS_TAG ,
@@ -2408,6 +2451,24 @@ def _get_data_by_env_steps_as_list(
24082451 # the env_t_to_agent_t mappings.
24092452 _ignore_last_ts = what not in ["observations" , "infos" ],
24102453 )
2454+ # Convert absolute agent_t to buffer position (including lookback offset).
2455+ # Formula: buffer_pos = agent_t - agent_t_started + lookback
2456+ sa_episode = self .agent_episodes [agent_id ]
2457+ lookback = sa_episode .observations .lookback
2458+ if isinstance (agent_t_indices , int ):
2459+ if agent_t_indices != self .SKIP_ENV_TS_TAG :
2460+ agent_t_indices = (
2461+ agent_t_indices - self .agent_t_started [agent_id ] + lookback
2462+ )
2463+ else :
2464+ assert isinstance (agent_t_indices , list )
2465+ agent_t_indices = [
2466+ index - self .agent_t_started [agent_id ] + lookback
2467+ if index != self .SKIP_ENV_TS_TAG
2468+ else index
2469+ for index in agent_t_indices
2470+ ]
2471+ agent_indices [agent_id ] = agent_t_indices
24112472 if not agent_indices :
24122473 return []
24132474 ret = []
@@ -2479,7 +2540,17 @@ def _get_data_by_env_steps(
24792540 hanging_val ,
24802541 filter_for_skip_indices = agent_indices ,
24812542 )
2543+ # Convert absolute agent_t to buffer position (including lookback offset).
2544+ # Formula: buffer_pos = agent_t - agent_t_started + lookback
2545+ lookback = sa_episode .observations .lookback
24822546 if isinstance (agent_indices , list ):
2547+ agent_indices = [
2548+ index - self .agent_t_started [agent_id ] + lookback
2549+ if index != self .SKIP_ENV_TS_TAG
2550+ else index
2551+ for index in agent_indices
2552+ ]
2553+
24832554 agent_values = self ._get_single_agent_data_by_env_step_indices (
24842555 what = what ,
24852556 agent_id = agent_id ,
@@ -2492,6 +2563,11 @@ def _get_data_by_env_steps(
24922563 if len (agent_values ) > 0 :
24932564 ret [agent_id ] = agent_values
24942565 else :
2566+ if agent_indices != self .SKIP_ENV_TS_TAG :
2567+ agent_indices = (
2568+ agent_indices - self .agent_t_started [agent_id ] + lookback
2569+ )
2570+
24952571 agent_values = self ._get_single_agent_data_by_index (
24962572 what = what ,
24972573 inf_lookback_buffer = inf_lookback_buffer ,
@@ -2523,7 +2599,7 @@ def _get_single_agent_data_by_index(
25232599 if index_incl_lookback == self .SKIP_ENV_TS_TAG :
25242600 # We don't want to fill -> Skip this agent.
25252601 if fill is None :
2526- return
2602+ return None
25272603 # Provide filled value for this agent.
25282604 return getattr (sa_episode , f"get_{ what } " )(
25292605 indices = 1000000000000 ,
@@ -2605,7 +2681,7 @@ def _get_single_agent_data_by_env_step_indices(
26052681 lookback buffer should be returned, not the first value after the
26062682 lookback buffer (which would be normal behavior for pulling items from
26072683 an `InfiniteLookbackBuffer` object).
2608- agent_id: The individual agent ID to pull data for. Used to lookup the
2684+ agent_id: The individual agent ID to pull data for. Used to look up the
26092685 `SingleAgentEpisode` object for this agent in `self`.
26102686 fill: An optional float value to use for filling up the returned results at
26112687 the boundaries. This filling only happens if the requested index range's
@@ -2627,7 +2703,7 @@ def _get_single_agent_data_by_env_step_indices(
26272703 hanging_val: In case we are pulling actions, rewards, or extra_model_outputs
26282704 data, there might be information "hanging" (cached). For example,
26292705 if an agent receives an observation o0 and then immediately sends an
2630- action a0 back, but then does NOT immediately reveive a next
2706+ action a0 back, but then does NOT immediately retrieve the next
26312707 observation, a0 is now cached (not fully logged yet with this
26322708 episode). The currently cached value must be provided here to be able
26332709 to return it in case the index is -1 (most recent timestep).
0 commit comments