Skip to content

[RLlib] Nested spaces (tuple or dict) break FlattenObservations in partially active multi-agent setups #59849

@cyianor

Description

@cyianor

What happened + What you expected to happen

When using a nested tuple or dict observation space in a multi-agent setup, it is sometimes necessary to apply FlattenObservations to only some agents. The observation space for the other agents should just be passed through. However, recompute_observation_space fails due to the following problem:

Assume a nested observation space is used, for example (this is an extension of the test code in flatten_observations.py)

import gymnasium as gym
import numpy as np

# Some arbitrarily nested, complex observation space.
obs_space = gym.spaces.Dict(
    {
        "a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
        "b": gym.spaces.Tuple(  # nested tuple space
            [
                gym.spaces.Discrete(2),
                gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
            ]
        ),
        "c": gym.spaces.MultiDiscrete([2, 3]),
        "d": gym.spaces.Dict(  # nested dict space
            {
                "e": gym.spaces.Box(-10.0, 10.0, (), np.float32),
            }
        ),
    }
)

When the connector is used in the multi_agent is True setting but not for a specific agent with id agent_id, then the structure of the input observation space is copied

spaces[agent_id] = self._input_obs_base_struct[agent_id]

For all other agents, the input observation space is flattened to a gym.spaces.Box.
At the end, a new dictionary space is created
return gym.spaces.Dict(spaces)

This will however result in an error when used with nested observation spaces such as the one above.

The reason is that the observation space structure was earlier obtained as such

self._input_obs_base_struct = get_base_struct_from_space(

where nested tuple and dict spaces get converted to Python tuples and dictionaries. This is what tree.map_structure expects but not what should be copied over when the connector is not applied to a observation space. Instead,

if self._agent_ids and agent_id not in self._agent_ids:
    spaces[agent_id] = deepcopy(self.input_observation_space[agent_id])

should be used instead. This ensures that the observation space for agents unaffected by this connector stays intact.

Versions / Dependencies

ray version 2.53.0 but same result in earlier version

Reproduction script

import gymnasium as gym
import numpy as np

from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.test_utils import check

# Some arbitrarily nested, complex observation space.
obs_space = gym.spaces.Dict(
    {
        "a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
        "b": gym.spaces.Tuple(  # nested tuple space
            [
                gym.spaces.Discrete(2),
                gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
            ]
        ),
        "c": gym.spaces.MultiDiscrete([2, 3]),
        "d": gym.spaces.Dict(  # nested dict space
            {
                "e": gym.spaces.Box(-10.0, 10.0, (), np.float32),
            }
        ),
    }
)
act_space = gym.spaces.Discrete(2)

# Two example episodes, both with initial (reset) observations coming from the
# above defined observation space.
episode_1 = SingleAgentEpisode(
    observations=[
        {
            "a": np.array(-10.0, np.float32),
            "b": (1, np.array([[-1.0], [-1.0]], np.float32)),
            "c": np.array([0, 2]),
            "d": {"e": np.array(-10.0, np.float32)},
        },
    ],
    agent_id="a1",
)
episode_2 = SingleAgentEpisode(
    observations=[
        {
            "a": np.array(10.0, np.float32),
            "b": (0, np.array([[1.0], [1.0]], np.float32)),
            "c": np.array([1, 1]),
            "d": {"e": np.array(10.0, np.float32)},
        },
    ],
    agent_id="a2",
)

# Construct our connector piece.
connector = FlattenObservations(
    obs_space, act_space, multi_agent=True, agent_ids=["a1"]
)

Results in

AssertionError: Dict space element is not an instance of Space: key='b', space=(Discrete(2), Box(-1.0, 1.0, (2, 1), float32))

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'tcommunity-backlogrllibRLlib related issuesstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions