Skip to content

Commit 96ec92b

Browse files
committed
rebase fix
Signed-off-by: Yongji Wu <[email protected]>
1 parent b39f0f3 commit 96ec92b

File tree

5 files changed

+18
-87
lines changed

5 files changed

+18
-87
lines changed

experimental/bench.sh

Lines changed: 0 additions & 16 deletions
This file was deleted.

experimental/scale.sh

Lines changed: 0 additions & 5 deletions
This file was deleted.

experimental/serve.sh

Lines changed: 0 additions & 49 deletions
This file was deleted.

vllm/distributed/elastic_ep/elastic_execute.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,25 @@ def batch_transfer_weights(
5858
state_dict = model.state_dict()
5959
all_params = []
6060

61-
for _, param in state_dict.items():
61+
for name, param in state_dict.items():
62+
if name.endswith("expert_map"):
63+
continue
6264
if param.data_ptr() not in expert_weights_set:
6365
all_params.append(param.data)
6466

65-
if all_params:
66-
p2p_ops = []
67-
for param in all_params:
68-
op = object.__new__(P2POp)
69-
if is_sender:
70-
op.op = torch.distributed.isend
71-
op.tensor = param
72-
else:
73-
op.op = torch.distributed.irecv
74-
op.tensor = param
75-
op.group_peer = peer_rank
76-
p2p_ops.append(op)
77-
78-
device_comm.batch_isend_irecv(p2p_ops)
67+
assert len(all_params) > 0
68+
p2p_ops = []
69+
for param in all_params:
70+
op = object.__new__(P2POp)
71+
if is_sender:
72+
op.op = torch.distributed.isend
73+
op.tensor = param
74+
else:
75+
op.op = torch.distributed.irecv
76+
op.tensor = param
77+
op.group_peer = peer_rank
78+
p2p_ops.append(op)
79+
device_comm.batch_isend_irecv(p2p_ops)
7980

8081

8182
def broadcast_expert_mapping(

vllm/v1/engine/core_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def _ensure_output_queue_task(self):
892892
assert output_socket is not None
893893

894894
notification_callback_handler: Optional[
895-
Callable[[AsyncMPClient, tuple[str, int]], Any]
895+
Callable[[AsyncMPClient, Sequence[Any]], Any]
896896
] = getattr(self.__class__, "process_worker_notification", None)
897897

898898
async def process_outputs_socket():
@@ -913,7 +913,7 @@ async def process_outputs_socket():
913913
if outputs.utility_output.result is None:
914914
continue
915915
notification_data = outputs.utility_output.result.result
916-
assert isinstance(notification_data, tuple)
916+
assert isinstance(notification_data, Sequence)
917917
assert len(notification_data) == 2
918918
asyncio.create_task(
919919
notification_callback_handler(_self, notification_data)

0 commit comments

Comments
 (0)