-
Notifications
You must be signed in to change notification settings - Fork 3k
Labels
bugSomething isn't workingSomething isn't working
Description
System Info
8 gpus * 3 nodes, 2 nodes for actor, 1 node for rollout
torch 2.7.1
cuda 12.2
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Question as follows:
Using fully async mode , when I set
actor_rollout_ref.actor.strategy=fsdp
critic.strategy=fsdp
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
, executing custom script(modified from script dapo_7b_math_fsdp2_8_8.sh, results in error: AssertionError: Expects tensor to be on the compute device cuda:0, was on cpu
Reasons as follows:
FSDP's state_dict() requires parameters to be on the GPU,
but the current parameters are on the CPU (because offloading is enabled).
How to solve:
=== fully_async_policy/fsdp_workers.py:
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
assert self._is_actor
if hasattr(self, "_weights_info"):
return self._weights_info
# add:If offloaded, load the params to the GPU firstly
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if fsdp_version(self.actor_module_fsdp) == 1:
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
params = self._get_actor_params()
ret = []
for key, tensor in params.items():
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
# add:offload to the CPU after use
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
return ret
Expected behavior
when using FSDP and params offload in fully async mode, model parameters load and offload from the CPU normally
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working