-
Notifications
You must be signed in to change notification settings - Fork 468
support cp&dcp #3260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LookAround0301
wants to merge
14
commits into
vllm-project:main
Choose a base branch
from
LookAround0301:long_seq_dev
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
support cp&dcp #3260
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
f5862ac
[mla backend] support dcp&cp prefill
LookAround0301 d1ad588
model runner support cp: input ids, position ids and slot mapping
HiC4Sh1e c0e0f51
Merge pull request #28 from HiC4Sh1e/long_seq_dev
LookAround0301 b301659
model runner support cp: metadata, logits indices
HiC4Sh1e 2f36197
[mla backend] add num_computed_tokens_of_dcp_sp
LookAround0301 30e8076
Merge pull request #29 from HiC4Sh1e/long_seq_dev
LookAround0301 f887deb
[bug] fix config & block_table bug
LookAround0301 1bc86bc
[optim] support not enable cp and add env
LookAround0301 b69f45a
[bug] fix prefill bug
LookAround0301 8b333b9
[bug] fix decode bug (single batch)
LookAround0301 2470894
[bug] fix dcp bug
LookAround0301 8442fb8
[bug] fix block size bug
LookAround0301 9022138
[optim] clean code
LookAround0301 8dda1ba
GQA support pcp and dcp
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
import time | ||
import argparse | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
os.environ["VLLM_USE_MODELSCOPE"] = "True" | ||
os.environ["VLLM_ASCEND_ENABLE_CP"] = "1" | ||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--input_len', type=int, default=1024) | ||
parser.add_argument('--output_len', type=int, default=128) | ||
parser.add_argument('--bs', type=int, default=1) | ||
parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite") | ||
parser.add_argument('--tp', type=int, default=2) | ||
parser.add_argument('--cp', type=int, default=2) | ||
parser.add_argument('--dcp', type=int, default=1) | ||
parser.add_argument('--iter_times', type=int, default=1) | ||
|
||
args = parser.parse_args() | ||
|
||
prompts = [ | ||
"The capital of France is", | ||
"Hello, my name is Tom, I am", | ||
"The president of United States is", | ||
"AI future is" | ||
] | ||
|
||
sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len) | ||
llm = LLM( | ||
model=args.model_path, | ||
trust_remote_code=True, | ||
enforce_eager=True, | ||
tensor_parallel_size=args.tp, | ||
context_parallel_size=args.cp, | ||
decode_context_parallel_size=args.dcp, | ||
enable_prefix_caching=False, | ||
enable_expert_parallel=True, | ||
enable_chunked_prefill=False, | ||
max_num_batched_tokens=2048, | ||
max_model_len=1024, | ||
additional_config={"ascend_scheduler_config": {"enabled": False}}, | ||
max_num_seqs=1, | ||
block_size=128, | ||
gpu_memory_utilization=0.9 | ||
) | ||
|
||
t0 = time.time() | ||
for _ in range(args.iter_times): | ||
outputs = llm.generate(prompts, sampling_params) | ||
t1 = time.time() | ||
print(f"TTFT: {(t1 - t0) * 1000 / (args.iter_times * args.bs)} ms") | ||
|
||
for i, output in enumerate(outputs): | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"req_num: {i}\nGenerated text: {generated_text!r}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.