Skip to content

[RLlib] Clean up offline prelearner and its unit testing#60632

Open
ArturNiederfahrenhorst wants to merge 11 commits intoray-project:masterfrom
ArturNiederfahrenhorst:fixofflineprelearnertest
Open

[RLlib] Clean up offline prelearner and its unit testing#60632
ArturNiederfahrenhorst wants to merge 11 commits intoray-project:masterfrom
ArturNiederfahrenhorst:fixofflineprelearnertest

Conversation

@ArturNiederfahrenhorst
Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst commented Jan 31, 2026

Description

offline pre learner unit tests are timing out often.
This is because test_offline_prelearner_sample_from_episode_data takes 2 minutes on my macbook pro because it collects many samples. I'm not sure how long it takes on CI, but appears to be long enough to time out often. It also uses two env runners by default, which results in two Ray Data datasets executed at the same time for writing, which spawns too many tasks on my dev machine for unittesting and it freezes while the test is running (same for @pseudo-rnd-thoughts ).

Therefore, this PR reduces test runtime from >2 minutes to 8 seconds on my MBP and uses less resources with only one env runner. The PR also cleans up the OfflinePreLearnerClass to make it more maintainable for upcoming changes.

Removes >150loc. Added lines are mostly handling deprecation.

@ArturNiederfahrenhorst ArturNiederfahrenhorst requested a review from a team as a code owner January 31, 2026 11:09

# If multi-agent we need to extract the agent ID.
# TODO (simon): Check, what happens with the module ID.
if is_multi_agent:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi agent case is not implemented

to_numpy: bool = False,
input_compress_columns: Optional[List[str]] = None,
**kwargs: Dict[str, Any],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these parameters as arguments of this method because they can be expected to be constant over the lifetime of an OfflinePreLearner.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively cleans up the OfflinePreLearner class and its associated unit tests, leading to improved maintainability and faster test execution. The deprecation of ignore_final_observation and the refactoring of method signatures are well-executed. However, the refactoring has introduced breaking changes in some unit tests by converting static methods to instance methods without updating their call sites in the tests. These issues are critical and need to be addressed. I've also identified a minor type hint inaccuracy that should be corrected.

Comment on lines 421 to +422
def _map_sample_batch_to_episode(
is_multi_agent: bool,
self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Changing _map_sample_batch_to_episode from a static method to an instance method breaks the test_offline_prelearner_convert_from_old_sample_batch_to_episodes unit test in rllib/offline/tests/test_offline_prelearner.py. The test still calls this method statically (OfflinePreLearner._map_sample_batch_to_episode(...)), which will now fail. The test needs to be updated to instantiate OfflinePreLearner and then call this method on the instance.


logger = logging.getLogger(__name__)

def _validate_deprecated_map_args(kwargs: dict, config: "AlgorithmConfig") -> Set:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for this function is Set, but it returns a tuple of three elements (is_multi_agent, schema, input_compress_columns). This should be corrected to Tuple[bool, Dict, List] for better type safety and clarity.

Suggested change
def _validate_deprecated_map_args(kwargs: dict, config: "AlgorithmConfig") -> Set:
def _validate_deprecated_map_args(kwargs: dict, config: "AlgorithmConfig") -> Tuple[bool, Dict, List]:

unpacked_obs = (
unpack_if_needed(obs)
if Columns.OBS in input_compress_columns
else obs
)
# Set the next observation.
if ignore_final_observation:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore final observation is not tested and we don't use it ourselves anywhere in the codebase.

# If multi-agent we need to extract the agent ID.
# TODO (simon): Check, what happens with the module ID.
if is_multi_agent:
agent_id = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi agent case is not implemented

input_compress_columns: Optional[List[str]] = None,
ignore_final_observation: Optional[bool] = False,
observation_space: gym.Space = None,
action_space: gym.Space = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these parameters as arguments of this method because they can be expected to be constant over the lifetime of an OfflinePreLearner.


# Run the `Learner`'s connector pipeline.
batch = self._learner_connector(
rl_module=self._module,
batch={},
episodes=episodes,
shared_data={},
# TODO (sven): Add MetricsLogger to non-Learner components that have a
# LearnerConnector pipeline.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment moved to group with other TODOs

to_numpy=False,
input_compress_columns=self.config.input_compress_columns,
observation_space=self.observation_space,
action_space=self.action_space,
)["episodes"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The effect of removing these args can be seen here. 5 of the 7 arguments are constant so no need to parameterize.

self.config: AlgorithmConfig = config
self.input_read_episodes: bool = self.config.input_read_episodes
self.input_read_sample_batches: bool = self.config.input_read_sample_batches
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters are only used in 1 or max 2 places, so we can just access config there and keep the constructor clean.

@ray-gardener ray-gardener bot added the rllib RLlib related issues label Jan 31, 2026
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
simonsays1980 and others added 3 commits February 3, 2026 18:22
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
Signed-off-by: simonsays1980 <simon.zehnder@gmail.com>
return policies_to_train(module_id, multi_agent_batch)

@OverrideToImplementCustomLogic
@staticmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this a static method is convenient for testing (because you don't need a proper AlgorithmConfig to initialize the class).
But at the same time, this is true for many classes in RLlib.

I propose that we don't move towards making methods static methods if their arguments don't vary during runtime. Things that don't very should still go into the init args to not blow up our interfaces.

Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

import functools
import shutil
import unittest
import functools
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate import functools statement in test file

Low Severity

import functools appears on both line 1 and line 3, which is a redundant duplicate import.

Fix in Cursor Fix in Web


self.assertTrue(
all(all(eps.get_observations()[-1] == [0.0] * 4) for eps in episodes)
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test exercises removed ignore_final_observation feature

Medium Severity

test_offline_prelearner_ignore_final_observation was uncommented but tests the ignore_final_observation parameter which was removed from _map_to_episodes in this commit. The parameter goes into **kwargs and _validate_deprecated_map_args does not process it, so it is silently ignored. The assertion expecting zeroed-out final observations will always fail because the feature no longer exists.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rllib RLlib related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants