diff --git a/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py b/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py index d4aa12490ca..4127d463796 100644 --- a/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py +++ b/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py @@ -239,7 +239,7 @@ def receive_weights(self, on_bucket_received: callable): # receive bucket and update weights while True: metadata = self.socket.recv_pyobj() - weights, tensor = [], None + weights = [] for name, meta in metadata["bucket_meta"].items(): shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] size = dtype.itemsize * shape.numel() @@ -255,7 +255,10 @@ def receive_weights(self, on_bucket_received: callable): get_torch_device().synchronize() self.socket.send(b"") on_bucket_received(weights) - del weights, tensor + for _, tensor in weights: + del tensor + weights.clear() + del weights if metadata["is_last"]: break finally: