Skip to content

Commit 845420a

Browse files
authored
[RLHF] Fix torch.dtype not serializable in example (#22158)
Signed-off-by: 22quinn <[email protected]>
1 parent e27d25a commit 845420a

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

examples/offline_inference/rlhf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def __init__(self, *args, **kwargs):
126126

127127
# Synchronize the updated weights to the inference engine.
128128
for name, p in train_model.named_parameters():
129-
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
129+
dtype_name = str(p.dtype).split(".")[-1]
130+
handle = llm.collective_rpc.remote(
131+
"update_weight", args=(name, dtype_name, p.shape)
132+
)
130133
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
131134
ray.get(handle)
132135

examples/offline_inference/rlhf_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def init_weight_update_group(
4545
self.device,
4646
)
4747

48-
def update_weight(self, name, dtype, shape):
48+
def update_weight(self, name, dtype_name, shape):
49+
dtype = getattr(torch, dtype_name)
4950
weight = torch.empty(shape, dtype=dtype, device="cuda")
5051
self.model_update_group.broadcast(
5152
weight, src=0, stream=torch.cuda.current_stream()

0 commit comments

Comments
 (0)