Skip to content

Commit a486ff8

Browse files
LCAIZJCaveNightingalenwpu-zxrwangxiaoteng888Han-Xinlong
authored
KVCache Transfer via Layer-wise Strategy in Disaggregation (#2602)
### What this PR does / why we need it? See RFC: #2470 This PR add a new kv connector for layer-wised kv transfer ### Does this PR introduce _any_ user-facing change? yes, a new kv connector is added. User can use layer wised feature now. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@releases/v0.11.0 --------- Signed-off-by: leichao.lc <[email protected]> Signed-off-by: CaveNightingale <[email protected]> Signed-off-by: nwpu-zxr <[email protected]> Signed-off-by: wangxiaoteng <[email protected]> Signed-off-by: hanxinlong <[email protected]> Signed-off-by: liziyu <[email protected]> Co-authored-by: CaveNightingale <[email protected]> Co-authored-by: nwpu-zxr <[email protected]> Co-authored-by: wangxiaoteng <[email protected]> Co-authored-by: hanxinlong <[email protected]>
1 parent f8c93d8 commit a486ff8

File tree

10 files changed

+3012
-4
lines changed

10 files changed

+3012
-4
lines changed

examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py

Lines changed: 576 additions & 0 deletions
Large diffs are not rendered by default.

examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,4 +544,4 @@ async def healthcheck():
544544
global global_args
545545
global_args = parse_args()
546546
import uvicorn
547-
uvicorn.run(app, host=global_args.host, port=global_args.port)
547+
uvicorn.run(app, host=global_args.host, port=global_args.port)

tests/ut/distributed/test_parallel_state.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from vllm.config import ParallelConfig
55

66
from vllm_ascend.distributed.parallel_state import (
7-
_LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group,
8-
get_mc2_group, get_otp_group, init_ascend_model_parallel)
7+
_LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel,
8+
get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group,
9+
init_ascend_model_parallel)
910

1011

1112
@pytest.fixture
@@ -30,6 +31,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
3031
mock_ascend_config = MagicMock()
3132
mock_ascend_config.lmhead_tensor_parallel_size = 2
3233
mock_ascend_config.oproj_tensor_parallel_size = 2
34+
mock_ascend_config.pd_tp_ratio = 2
3335
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
3436
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
3537
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
@@ -38,11 +40,14 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
3840
mc2_group = get_mc2_group()
3941
lmheadtp_group = get_lmhead_tp_group()
4042
otp_group = get_otp_group()
43+
p_tp_group = get_p_tp_group()
4144
assert mc2_group is not None
4245
assert otp_group is not None
4346
assert lmheadtp_group is not None
47+
assert p_tp_group is not None
4448

4549
destroy_ascend_model_parallel()
4650
assert _MC2 is None
4751
assert _LMTP is None
4852
assert _OTP is None
53+
assert _P_TP is None

0 commit comments

Comments
 (0)