diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml index d70ecb2a7e7b..d392a5f64062 100644 --- a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2 model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml index 4397effa82cc..4b7776b20da2 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml @@ -1,3 +1,4 @@ +# For hf script, without -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml index fa6ea236ef04..05b66175199e 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml @@ -1,3 +1,4 @@ +# For hf script, without -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5 model_name: "meta-llama/Meta-Llama-3-70B-Instruct" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml index c513159c6fa0..12a87e529014 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml index 5e57fcbcf7d9..7c7a1ca6edbf 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml index 374171f1f915..1d45c3770458 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml index dc36b705634f..29a145252ef6 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml index 0ecfc01ef049..3a5f120b3e71 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml index bc2900298596..5ff57bae4921 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml index 3964f3be5e87..07fb130464ab 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml index fb4b4915ab95..c27886525bbb 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml @@ -1,4 +1,5 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1 +# For hf script, without -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 model_name: "meta-llama/Meta-Llama-3-8B-Instruct" tasks: - name: "gsm8k" diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml index 042458659839..56ec933c9cc0 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 model_name: "HandH1998/QQQ-Llama-3-8b-g128" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml index 78347f63fa79..83e11f2be77e 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml index 4ef8b5c3709b..15a836dddbd8 100644 --- a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1 model_name: "mgoin/Minitron-4B-Base-FP8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml index 75a24e408e7a..5633a2d9b821 100644 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8 model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml index 436ec21924ca..b8024c80e8eb 100644 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4 model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml index dec9164d1b84..188a112ca3a4 100644 --- a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml @@ -1,4 +1,5 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4 +# For hf script, without -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" tasks: - name: "gsm8k" diff --git a/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml index 166af81a3f0e..099e0f465bac 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml @@ -1,11 +1,12 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1 model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.31 + value: 0.30 - name: "exact_match,flexible-extract" - value: 0.47 + value: 0.465 limit: 1319 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml index 42936fbfbe7d..426e8ff69873 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1 model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml index 43ff2bc5ce35..8d57e9dabd56 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml index 259799ba8bfa..1bce7e7fdf14 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1 model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise" tasks: diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml index 45d5efc8860f..fc9707d0d6f1 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 model_name: "Qwen/Qwen2-57B-A14B-Instruct" tasks: diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml index 2928d75ce446..9a9c749748ec 100644 --- a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -1,3 +1,4 @@ +# For vllm script, with -t option (tensor parallel size). # bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" tasks: diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 4ae23eff62f3..6015a83e8295 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -16,7 +16,7 @@ import pytest import yaml -RTOL = 0.05 +RTOL = 0.08 TEST_DATA_FILE = os.environ.get( "LM_EVAL_TEST_DATA_FILE", ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 3354ea37002b..a21a657c4b05 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -86,3 +86,18 @@ steps: - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" + + - block: "Build Neuron release image" + key: block-neuron-release-image-build + depends_on: ~ + + - label: "Build and publish Neuron release image" + depends_on: block-neuron-release-image-build + agents: + queue: neuron-postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" + env: + DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 469422ddec20..368f30434aa1 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -98,6 +98,13 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_machete_mm.py \ --ignore=kernels/test_mha_attn.py \ --ignore=kernels/test_block_fp8.py \ + --ignore=kernels/test_cutlass_moe.py \ + --ignore=kernels/test_mamba_ssm_ssd.py \ + --ignore=kernels/test_attention.py \ + --ignore=kernels/test_block_int8.py \ + --ignore=kernels/test_fused_quant_layernorm.py \ + --ignore=kernels/test_int8_kernel.py \ + --ignore=kernels/test_triton_moe_ptpc_fp8.py \ --ignore=kernels/test_permute_cols.py" fi diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 9c5cf7cad948..5d863dd82e9b 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -5,10 +5,41 @@ set -ex # Setup cleanup -remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } +remove_docker_container() { + if [[ -n "$container_id" ]]; then + podman rm -f "$container_id" || true + fi + podman system prune -f +} trap remove_docker_container EXIT remove_docker_container # Try building the docker image -docker build -t cpu-test -f docker/Dockerfile.ppc64le . +podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le . + +# Run the image +container_id=$(podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN cpu-test-ubi9-ppc) + +function cpu_tests() { + + # offline inference + podman exec -it "$container_id" bash -c " + set -e + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + + # Run basic model test + podman exec -it "$container_id" bash -c " + set -e + pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib + pip install sentence-transformers datamodel_code_generator + pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] + pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] + pytest -v -s tests/models/encoder_decoder/language -m cpu_model" +} + +# All of CPU tests are expected to be finished less than 40 mins. + +export container_id +export -f cpu_tests +timeout 40m bash -c cpu_tests diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 87f74277cf90..21982b01b9cc 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -17,10 +17,13 @@ source /etc/environment docker run --privileged --net host --shm-size=16G -it \ -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ - && python3 -m pip install pytest \ + && python3 -m pip install pytest pytest-asyncio tpu-info \ && python3 -m pip install lm_eval[api]==0.4.4 \ + && export VLLM_XLA_CACHE_PATH= \ && export VLLM_USE_V1=1 \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \ + && echo HARDWARE \ + && tpu-info \ && echo TEST_0 \ && pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ && echo TEST_1 \ @@ -40,7 +43,11 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_8 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ + && echo TEST_10 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ + && echo TEST_11 \ + && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 16acc2fd1127..20d858cb15a1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -8,6 +8,7 @@ # Documentation # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. +# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only # optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. # command(str): the single command to run for tests. incompatible with commands. @@ -70,6 +71,7 @@ steps: - label: Basic Correctness Test # 30min #mirror_hardwares: [amd] fast_check: true + torch_nightly: true source_file_dependencies: - vllm/ - tests/basic_correctness/test_basic_correctness @@ -104,6 +106,7 @@ steps: - label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" fast_check: true + torch_nightly: true #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -118,7 +121,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py - pytest -v -s entrypoints/test_chat_utils.py - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -205,6 +208,8 @@ steps: - pytest -v -s v1/sample - pytest -v -s v1/worker - pytest -v -s v1/structured_output + - pytest -v -s v1/spec_decode + - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py @@ -294,6 +299,7 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min source_file_dependencies: @@ -312,15 +318,46 @@ steps: commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Test %N # 1h each - # mirror_hardwares: [amd] +- label: Kernels Core Operation Test source_file_dependencies: - csrc/ + - tests/kernels/core + commands: + - pytest -v -s kernels/core + +- label: Kernels Attention Test %N + source_file_dependencies: + - csrc/attention/ - vllm/attention - - tests/kernels + - vllm/v1/attention + - tests/kernels/attention commands: - - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT - parallelism: 4 + - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Quantization Test %N + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/layers/quantization + - tests/kernels/quantization + commands: + - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels MoE Test + source_file_dependencies: + - csrc/moe/ + - tests/kernels/moe + - vllm/model_executor/layers/fused_moe/ + commands: + - pytest -v -s kernels/moe + +- label: Kernels Mamba Test + source_file_dependencies: + - csrc/mamba/ + - tests/kernels/mamba + commands: + - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min # mirror_hardwares: [amd] @@ -341,6 +378,13 @@ steps: commands: - bash scripts/run-benchmarks.sh +- label: Benchmarks CLI Test # 10min + source_file_dependencies: + - vllm/ + - tests/benchmarks/ + commands: + - pytest -v -s benchmarks/ + - label: Quantization Test # 33min source_file_dependencies: - csrc/ @@ -393,8 +437,9 @@ steps: - pytest -v -s models/test_transformers.py - pytest -v -s models/test_registry.py # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' - label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] @@ -404,6 +449,8 @@ steps: - tests/models/embedding/language - tests/models/encoder_decoder/language commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install causal-conv1d - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model @@ -415,6 +462,8 @@ steps: - tests/models/embedding/language - tests/models/encoder_decoder/language commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install causal-conv1d - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' @@ -535,11 +584,14 @@ steps: - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' + # test sequence parallel + - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. # TODO: investigate and fix # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min working_dir: "/vllm-workspace/tests" diff --git a/.github/workflows/snyk-security-scan.yml b/.github/workflows/snyk-security-scan.yml new file mode 100644 index 000000000000..eb0674a2e311 --- /dev/null +++ b/.github/workflows/snyk-security-scan.yml @@ -0,0 +1,24 @@ +name: snyk security scan +run-name: SNYK security scan for '${{ github.ref }}' + +on: [push] + +env: + SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} + +jobs: + + SNYK: + runs-on: ubuntu-latest + timeout-minutes: 60 + + steps: + - uses: actions/checkout@master + + - uses: snyk/actions/setup@master + + - name: Generate code vulnerability report + id: run_snyk + run: | + snyk code test --project-name="${{ github.repository }}" --report ${{ github.workspace }} + continue-on-error: true diff --git a/.gitignore b/.gitignore index 6f5cbd0733da..728213ceb74f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,6 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* -!vllm/vllm_flash_attn/fa_utils.py # Byte-compiled / optimized / DLL files __pycache__/ @@ -203,3 +202,6 @@ benchmarks/**/*.json # Linting actionlint shellcheck*/ + +# Ingore moe/marlin_moe gen code +csrc/moe/marlin_moe_wna16/kernel_* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e921f69925b6..f76b24c025ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,6 @@ repos: hooks: - id: yapf args: [--in-place, --verbose] - additional_dependencies: [toml] # TODO: Remove when yapf is upgraded - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.3 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index a0c25df6bd54..3314f05fd2a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -251,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG v3.8.0 + GIT_TAG v3.9.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -290,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/cutlass_extensions/common.cpp") + "csrc/cutlass_extensions/common.cpp" + "csrc/attention/mla/cutlass_mla_entry.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -463,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() - # + # CUTLASS MLA Archs and flags + cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS) + set(SRCS + "csrc/attention/mla/cutlass_mla_kernels.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${MLA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1") + # Add MLA-specific include directories only to MLA source files + set_source_files_properties(${SRCS} + PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_DIR}/examples/77_blackwell_fmha;${CUTLASS_DIR}/examples/common") + message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}") + else() + message(STATUS "Not building CUTLASS MLA as no compatible archs were found.") + # clear MLA_ARCHS + set(MLA_ARCHS) + endif() + # CUTLASS MoE kernels # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works @@ -609,21 +629,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) - set(MARLIN_MOE_SRC - "csrc/moe/marlin_kernels/marlin_moe_kernel.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_moe_ops.cu") + # + # For the Marlin MOE kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MOE_MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py) + file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH) + + message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} + OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} + RESULT_VARIABLE moe_marlin_generation_result + OUTPUT_VARIABLE moe_marlin_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log + ) + + if (NOT moe_marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin MOE generation failed." + " Result: \"${moe_marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log") + else() + set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH} + CACHE STRING "Last run Marlin MOE generate script hash" FORCE) + message(STATUS "Marlin MOE generation completed successfully.") + endif() + else() + message(STATUS "Marlin MOE generation script has not changed, skipping generation.") + endif() + + file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu") set_gencode_flags_for_srcs( - SRCS "${MARLIN_MOE_SRC}" + SRCS "${MOE_WNAA16_MARLIN_SRC}" CUDA_ARCHS "${MARLIN_MOE_ARCHS}") - list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") + list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC}) + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() message(STATUS "Not building Marlin MOE kernels as no compatible archs found" @@ -648,6 +698,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") # set(VLLM_ROCM_EXT_SRC "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/skinny_gemms.cu" "csrc/rocm/attention.cu") define_gpu_extension_target( diff --git a/Dockerfile.rocm.ubi b/Dockerfile.rocm.ubi index 88924a2a7768..5673e9831d0c 100644 --- a/Dockerfile.rocm.ubi +++ b/Dockerfile.rocm.ubi @@ -1,9 +1,11 @@ ## Global Args ################################################################## ARG BASE_UBI_IMAGE_TAG=9.5-1742914212 ARG PYTHON_VERSION=3.12 + # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" ARG MAX_JOBS=12 +ARG VLLM_TGIS_ADAPTER_VERSION=0.7.0 FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base @@ -44,13 +46,12 @@ gpgcheck=1\n\ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key" > /etc/yum.repos.d/amdgpu.repo -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=cache,target=/root/.cache/uv \ +RUN --mount=type=cache,target=/root/.cache/uv \ export version="$(awk -F. '{print $1"."$2}' <<< $ROCM_VERSION)" && \ uv pip install --pre \ - --index-url "https://download.pytorch.org/whl/nightly/rocm${version}" \ - torch==2.7.0.dev20250308+rocm${version}\ - torchvision==0.22.0.dev20250308+rocm${version} && \ + --index-url "https://download.pytorch.org/whl/rocm${version}" \ + torch==2.7.0+rocm${version}\ + torchvision==0.22.0+rocm${version} && \ # Install libdrm-amdgpu to avoid errors when retrieving device information (amdgpu.ids: No such file or directory) microdnf install -y --nodocs libdrm-amdgpu && \ microdnf clean all @@ -185,6 +186,8 @@ RUN CFLAGS="-O3 -Wall -Werror=format-security -Wno-unused-function -Wp,-D_GLIBCX FROM rocm_base AS vllm-openai ARG MAX_JOBS +ARG FLASH_ATTENTION_WHEEL_STRATEGY +ARG VLLM_WHEEL_STRATEGY WORKDIR /workspace diff --git a/Dockerfile.ubi b/Dockerfile.ubi index b03d42bd3162..b5d0f8ecd0c5 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -47,8 +47,9 @@ RUN curl -Lo /etc/yum.repos.d/cuda-rhel9.repo \ ENV CUDA_HOME="/usr/local/cuda" \ PATH="${CUDA_HOME}/bin:${PATH}" ENV LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/lib64/stubs/:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH}" -RUN microdnf install -y --nodocs \ - cuda-nvcc-12-4 cuda-nvtx-12-4 cuda-libraries-devel-12-4 && \ + +RUN microdnf install -y \ + cuda-nvcc-12-8 cuda-nvtx-12-8 cuda-libraries-devel-12-8 && \ microdnf clean all && \ ln -s ${CUDA_HOME}/lib64/stubs/libcuda.so /usr/lib64/ diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 287d500a81de..efd51c79c37c 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import io import json import os import sys @@ -32,6 +33,7 @@ class RequestFuncInput: extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False + language: Optional[str] = None @dataclass @@ -436,6 +438,110 @@ async def async_request_openai_chat_completions( return output +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + api_url = request_func_input.api_url + assert api_url.endswith( + ("transcriptions", "translations" + )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + form = aiohttp.FormData() + form.add_field('file', f, content_type='audio/wav') + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': from modelscope import snapshot_download @@ -493,6 +599,7 @@ def get_tokenizer( "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_openai_completions, "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, "sglang": async_request_openai_completions, diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 63f174275d47..ccbc6c022f1f 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -64,6 +64,7 @@ class SampleRequest: class BenchmarkDataset(ABC): DEFAULT_SEED = 0 + IS_MULTIMODAL = False def __init__( self, @@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset): SUPPORTED_DATASET_PATHS = { 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' } + IS_MULTIMODAL = True def sample(self, tokenizer: PreTrainedTokenizerBase, @@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset): "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"] } + IS_MULTIMODAL = True def sample( self, @@ -815,3 +818,80 @@ def sample(self, )) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", + "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ + "<|notimestamps|>" + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + import librosa + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning("%d samples discarded from dataset due to" \ + " their length being greater than" \ + " what Whisper supports.", skipped) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 4fff7a8fc8ed..f44da95d3216 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -63,14 +63,16 @@ class Request: output_len: int -def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: +def sample_tokens(tokenizer: PreTrainedTokenizerBase, + length: int) -> list[int]: vocab = tokenizer.get_vocab() + all_special_ids = set(tokenizer.all_special_ids) + # Remove the special tokens. - vocab = { - k: v - for k, v in vocab.items() if k not in tokenizer.all_special_ids - } - return random.choices(list(vocab.values()), k=length) + return random.choices( + [v for k, v in vocab.items() if k not in all_special_ids], + k=length, + ) def sample_requests_from_dataset( diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index b5bd840d8410..da124e1a81b4 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -50,7 +50,7 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (AIMODataset, BurstGPTDataset, +from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, ConversationDataset, HuggingFaceDataset, InstructCoderDataset, RandomDataset, SampleRequest, ShareGPTDataset, SonnetDataset, @@ -274,10 +274,6 @@ async def benchmark( input_requests[0].expected_output_len, \ input_requests[0].multi_modal_data - if backend != "openai-chat" and test_mm_content is not None: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' backend.") assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( model=model_id, @@ -604,6 +600,9 @@ def main(args: argparse.Namespace): elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -615,6 +614,13 @@ def main(args: argparse.Namespace): f" from one of following: {supported_datasets}. " "Please consider contributing if you would " "like to add support for additional dataset formats.") + + if (dataset_class.IS_MULTIMODAL and backend not in \ + ["openai-chat", "openai-audio"]): + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " \ + "'openai-audio' backend.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -707,7 +713,7 @@ def main(args: argparse.Namespace): )) # Save config and results to json - if args.save_result: + if args.save_result or args.append_result: result_json: dict[str, Any] = {} # Setup @@ -728,6 +734,14 @@ def main(args: argparse.Namespace): raise ValueError( "Invalid metadata format. Please use KEY=VALUE format." ) + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} if not args.save_detailed: # Remove fields with too many data points @@ -738,15 +752,6 @@ def main(args: argparse.Namespace): if field in result_json: del result_json[field] - # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency - - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} - # Save to file base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" @@ -756,7 +761,12 @@ def main(args: argparse.Namespace): file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, "w", encoding='utf-8') as outfile: + with open(file_name, + mode="a+" if args.append_result else "w", + encoding='utf-8') as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") json.dump(result_json, outfile) save_to_pytorch_benchmark_format(args, result_json, file_name) @@ -888,6 +898,11 @@ def main(args: argparse.Namespace): help="When saving the results, whether to include per request " "information such as response, error, ttfs, tpots, etc.", ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) parser.add_argument( "--metadata", metavar="KEY=VALUE", diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index e52f16a8b129..74ee00ec8930 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -51,7 +51,7 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from vllm.v1.structured_output.utils import ( +from vllm.v1.structured_output.backend_xgrammar import ( has_xgrammar_unsupported_json_features) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -150,17 +150,17 @@ def get_schema(index: int): elif args.dataset == "grammar": schema = """ - ?start: select_statement + root ::= select_statement - ?select_statement: "SELECT " column_list " FROM " table_name + select_statement ::= "SELECT " column " from " table " where " condition - ?column_list: column_name ("," column_name)* + column ::= "col_1 " | "col_2 " - ?table_name: identifier + table ::= "table_1 " | "table_2 " - ?column_name: identifier + condition ::= column "= " number - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + number ::= "1 " | "2 " """ prompt = "Generate an SQL query to show the 'username' \ and 'email' from the 'users' table." diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 67e509c1f550..1f65277e1bfe 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -523,6 +523,13 @@ def validate_args(args): raise ValueError( "Tokenizer must be the same as the model for MII backend.") + # --data-parallel is not supported currently. + # https://github.com/vllm-project/vllm/issues/16222 + if args.data_parallel_size > 1: + raise ValueError( + "Data parallel is not supported in offline benchmark, \ + please use benchmark serving instead") + if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py new file mode 100644 index 000000000000..b23b4f3ea685 --- /dev/null +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION) + +try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError("bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") +except ImportError as e: + bitblas_import_exception = e + raise ValueError("Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + +from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target + +from vllm.utils import FlexibleArgumentParser + +parser = FlexibleArgumentParser( + description="Benchmark BitBLAS int4 on a specific target.") + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking.", +) +parser.add_argument("--group_size", + type=int, + default=None, + help="Group size for grouped quantization.") +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "float32", "float64", "int32", "int8"], + help="Data type of activation A.", +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=[ + "float16", + "float32", + "float64", + "int32", + "int8", + "int4", + "int2", + "int1", + "nf4", + "fp4_e2m1", + ], + help="Data type of weight W.", +) +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], + help="Data type for accumulation.", +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], + help="Data type for output.", +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], + help="Matrix layout, 'nt' for non-transpose A and transpose W.", +) +parser.add_argument("--with_bias", + action="store_true", + help="Include bias in the benchmark.") +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization.", +) +parser.add_argument("--with_zeros", + action="store_true", + help="Include zeros in the quantization.") +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], + help="Specify the mode for calculating zeros.", +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +# Define a list of shared arguments that repeat in every config +shared_args = [ + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, +] + +# Define just the (M, K, N) shapes in a more compact list +shapes = [ + # square test + (1, 16384, 16384), + # BLOOM-176B + (1, 43008, 14336), + (1, 14336, 14336), + (1, 57344, 14336), + (1, 14336, 57344), + # OPT-65B + (1, 9216, 9216), + (1, 36864, 9216), + (1, 9216, 36864), + (1, 22016, 8192), + # LLAMA-70B/65B + (1, 8192, 22016), + (1, 8192, 8192), + (1, 28672, 8192), + (1, 8192, 28672), + # square test + (16384, 16384, 16384), + # BLOOM-176B + (8192, 43008, 14336), + (8192, 14336, 14336), + (8192, 57344, 14336), + (8192, 14336, 57344), + # OPT-65B + (8192, 9216, 9216), + (8192, 36864, 9216), + (8192, 9216, 36864), + (8192, 22016, 8192), + # LLAMA-70B/65B + (8192, 8192, 22016), + (8192, 8192, 8192), + (8192, 28672, 8192), + (8192, 8192, 28672), +] + +# Build test shapes with all the shared arguments +test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) + for shape in shapes] + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +benchmark_results = {} +for config_class, operator, input_args in benchmark_sets: + config = config_class(*input_args) + matmul = operator(config, target=target, enable_tuning=True) + kernel_latency = matmul.profile_latency() + + print("Time cost is: {:.3f} ms".format(kernel_latency)) + + profile_config = { + f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_latency": kernel_latency, + } + } + + benchmark_results.update(profile_config) + +# Define headers for the table +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Latency", +] + +# Calculate column widths for pretty printing +col_widths = [0, 0, 0] +for config_key, values in benchmark_results.items(): + args_split = config_key.split("-") + func_name = args_split[0] + input_args_str = "-".join(args_split[1:]) + col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) + col_widths[1] = max(col_widths[1], + len(input_args_str) + 2, + len(headers[1]) + 2) + col_widths[2] = max(col_widths[2], + len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, + len(headers[2]) + 2) + # break only if you want to measure widths from a single example; + # otherwise, let it loop over all items. + +# Print header +for i, header in enumerate(headers): + headers[i] = header.ljust(col_widths[i]) +print("".join(headers)) +print("-" * sum(col_widths)) + +# Print rows +for config_key, values in benchmark_results.items(): + args_split = config_key.split("-") + func_name = args_split[0] + input_args_str = "-".join(args_split[1:]) + row = [ + func_name, + input_args_str, + f"{values['BitBLAS_top20_latency']:.3f} ms", + ] + row_str = "".join( + [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) + print(row_str) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index b4b91eda2844..d382ede10b41 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -17,8 +17,14 @@ from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, + lora_shrink) + from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, + _LORA_B_PTR_DICT) + from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index afe0b53077a7..a274537a6751 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -527,7 +527,7 @@ def get_weight_block_size_safety(config, default_value=None): def main(args: argparse.Namespace): print(args) - block_quant_shape = None + config = AutoConfig.from_pretrained( args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": @@ -546,16 +546,16 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - block_quant_shape = get_weight_block_size_safety(config) - elif config.architectures[0] == "Qwen2MoeForCausalLM": + elif config.architectures[0] in [ + "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" + ]: E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size else: - if not hasattr(config, "hidden_size"): - # Support for llama4 - config = config.text_config + # Support for llama4 + config = config.get_text_config() # Default: Mixtral. E = config.num_local_experts topk = config.num_experts_per_tok @@ -566,6 +566,7 @@ def main(args: argparse.Namespace): dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" + block_quant_shape = get_weight_block_size_safety(config) if args.batch_size is None: batch_sizes = [ diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index afd7c47e8ac0..b04e4c2d06ed 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 + GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 7af0caceda2f..14e5edd7e283 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel( #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ { \ - vllm::merge_attn_states_kernel<<>>( \ - reinterpret_cast(output.data_ptr()), output_lse_ptr, \ - reinterpret_cast(prefix_output.data_ptr()), \ - reinterpret_cast(prefix_lse.data_ptr()), \ - reinterpret_cast(suffix_output.data_ptr()), \ - reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + vllm::merge_attn_states_kernel \ + <<>>( \ + reinterpret_cast(output.data_ptr()), output_lse_ptr, \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ + num_heads, head_size); \ } /*@brief Merges the attention states from prefix and suffix @@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel( * @param output [n,h,d] The output tensor to store the merged attention states. * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. - * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention + * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention * states. * @param suffix_output [n,h,d] The suffix attention states. - * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention + * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention * states. */ template @@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output, if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); } - // process one pack elements per thread. float -> 4, half/bf16 -> 8 + // Process one pack elements per thread. for float, the + // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; const uint total_threads = num_tokens * num_heads * threads_per_head; dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); } diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu new file mode 100644 index 000000000000..0319d1daf302 --- /dev/null +++ b/csrc/attention/mla/cutlass_mla_entry.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA +void cutlass_mla_decode_sm100a(torch::Tensor const& out, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale); +#endif + +void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale) { +#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA + return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); +} diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu new file mode 100644 index 000000000000..6743af0cf2db --- /dev/null +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass_extensions/common.hpp" + +#include "device/sm100_mla.hpp" +#include "kernel/sm100_mla_tile_scheduler.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; + +template +struct MlaSm100 { + using Element = T; + using ElementAcc = float; + using ElementOut = T; + + using TileShape = Shape<_128, _128, Shape<_512, _64>>; + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = + std::conditional_t; + + using FmhaKernel = + cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, + /*kIsCpAsync=*/true>; + using Fmha = cutlass::fmha::device::MLA; +}; + +template +typename T::Fmha::Arguments args_from_options( + at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe, + at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, + at::Tensor const& page_table, double scale) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = q_nope.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + + int batches = q_nope.sizes()[0]; + int page_count_per_seq = page_table.sizes()[1]; + int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; + int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int max_seq_len = page_size * page_count_per_seq; + using TileShapeH = typename T::TileShapeH; + using TileShapeD = typename T::TileShapeD; + auto problem_shape = + cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using StrideQ = typename T::StrideQ; + using StrideK = typename T::StrideK; + using StrideO = typename T::StrideO; + using StrideLSE = typename T::StrideLSE; + + StrideQ stride_Q_latent = cute::make_tuple( + static_cast(D_latent), _1{}, static_cast(H * D_latent)); + StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, + static_cast(H * D_rope)); + StrideK stride_C = + cute::make_tuple(static_cast(D_latent + D_rope), _1{}, + static_cast(page_size * (D_latent + D_rope))); + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); + StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast(H)); + StrideO stride_O = cute::make_tuple(static_cast(D_latent), _1{}, + static_cast(H * D_latent)); + + using Element = typename T::Element; + using ElementOut = typename T::ElementOut; + using ElementAcc = typename T::ElementAcc; + auto Q_latent_ptr = static_cast(q_nope.data_ptr()); + auto Q_rope_ptr = static_cast(q_pe.data_ptr()); + auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + auto scale_f = static_cast(scale); + typename T::Fmha::Arguments arguments{ + problem_shape, + {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr, + stride_C, C_ptr + D_latent, stride_C, + static_cast(seq_lens.data_ptr()), + static_cast(page_table.data_ptr()), stride_PT, page_count_total, + page_size}, + {static_cast(out.data_ptr()), stride_O, + static_cast(nullptr), stride_LSE}, + hw_info, + -1, // split_kv + nullptr, // is_var_split_kv + }; + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute + // split_kv automatically based on batch size and sequence length to balance + // workload across available SMs. Consider using var_split_kv for manual + // control if needed. + T::Fmha::set_split_kv(arguments); + return arguments; +} + +template +void runMla(at::Tensor const& out, at::Tensor const& q_nope, + at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, at::Tensor const& page_table, + float scale, cudaStream_t stream) { + using MlaSm100Type = MlaSm100; + typename MlaSm100Type::Fmha fmha; + auto arguments = args_from_options( + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); + size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(fmha.can_implement(arguments)); + + CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); +} + +void cutlass_mla_decode_sm100a(torch::Tensor const& out, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale) { + TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA"); + TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor"); + TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, + "kv_c_and_k_pe_cache must be a 3D tensor"); + TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); + TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); + TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); + + auto B_q_nope = q_nope.size(0); + auto H_q_nope = q_nope.size(1); + auto D_q_nope = q_nope.size(2); + auto B_q_pe = q_pe.size(0); + auto H_q_pe = q_pe.size(1); + auto D_q_pe = q_pe.size(2); + auto B_pt = page_table.size(0); + auto PAGE_NUM = page_table.size(1); + auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); + auto D_ckv = kv_c_and_k_pe_cache.size(2); + auto B_o = out.size(0); + auto H_o = out.size(1); + auto D_o = out.size(2); + + TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512"); + TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64"); + TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576"); + TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128, + "H_q_nope, H_q_pe, and H_o must be equal to 128"); + TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, + "PAGE_SIZE must be a power of 2"); + TORCH_CHECK( + B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o, + "Batch dims must be same for page_table, q_nope and q_pe, and out"); + TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, + "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); + TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); + + TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half || + q_nope.dtype() == at::ScalarType::BFloat16 || + q_nope.dtype() == at::ScalarType::Float8_e4m3fn, + "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); + TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && + q_nope.dtype() == q_pe.dtype(), + "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, + "seq_lens must be a 32-bit integer tensor"); + TORCH_CHECK(page_table.dtype() == torch::kInt32, + "page_table must be a 32-bit integer tensor"); + + auto in_dtype = q_nope.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(q_nope.get_device()); + if (in_dtype == at::ScalarType::Half) { + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, + page_table, scale, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } +} diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0b3f6fc8c19a..88559c8fe718 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel( cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, // head_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, const int key_stride, const int value_stride, - const int num_heads, const int head_size, const int block_size, - const float* k_scale, const float* v_scale) { + const int64_t block_stride, const int64_t page_stride, + const int64_t head_stride, const int64_t key_stride, + const int64_t value_stride, const int num_heads, const int head_size, + const int block_size, const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel( const int head_idx = i / head_size; const int head_offset = i % head_size; const int64_t tgt_key_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; + block_offset * page_stride + + head_idx * head_stride + head_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { @@ -396,16 +397,16 @@ void reshape_and_cache( // KV_T is the data type of key and value tensors. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_flash_kernel \ - <<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, \ - reinterpret_cast(k_scale.data_ptr()), \ +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, page_stride, \ + head_stride, key_stride, value_stride, num_heads, head_size, \ + block_size, reinterpret_cast(k_scale.data_ptr()), \ reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( @@ -432,9 +433,11 @@ void reshape_and_cache_flash( int head_size = key.size(2); int block_size = key_cache.size(1); - int key_stride = key.stride(0); - int value_stride = value.stride(0); - int block_stride = key_cache.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int64_t block_stride = key_cache.stride(0); + int64_t page_stride = key_cache.stride(1); + int64_t head_stride = key_cache.stride(2); TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); dim3 grid(num_tokens); diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py new file mode 100644 index 000000000000..d1c0d92f6814 --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ("template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{'true' if has_act_order else 'false'}}, " + "{{'true' if has_zp else 'false'}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );") + +# int8 with zero point case (vllm::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, -1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + has_zp = "B" not in scalar_type + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + has_act_order = group_blocks == 0 + if has_zp and has_act_order: + continue + if thread_configs[2] == 256: + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + has_act_order=has_act_order, + has_zp=has_zp, + group_blocks=group_blocks, + is_zp_float=False, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h new file mode 100644 index 000000000000..3d92660e8028 --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -0,0 +1,44 @@ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "quantization/gptq_marlin/marlin.cuh" +#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ + bool use_fp32_reduce + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h new file mode 100644 index 000000000000..205b308fe511 --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -0,0 +1,1917 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "quantization/gptq_marlin/marlin.cuh" +#include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace MARLIN_NAMESPACE_NAME + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, + typename ScalarType::FragB& frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline typename ScalarType::FragB dequant( + int q, typename ScalarType::FragB& frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant(int q, + typename ScalarType::FragB& frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids + const int32_t* __restrict__ expert_ids_ptr, // moe expert ids + const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens + const float* __restrict__ topk_weights_ptr, // moe top weights + int top_k, // num of experts per token + bool mul_topk_weights, // mul topk weights or not + bool is_ep, // expert parallelism + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce // whether to use fp32 global reduce +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + extern __shared__ int4 sh[]; + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + const int group_size = + (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; + const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int zp_expert_stride = + is_zp_float ? prob_n * prob_k / group_size / 8 + : prob_n * prob_k / group_size / (pack_factor * 4); + + // parallel: num valid moe blocks + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int parallel = num_tokens_past_padded / moe_block_size; + int num_valid_blocks = parallel; + if (is_ep) { + for (int i = 0; i < parallel; i++) { + if (expert_ids_ptr[i] == -1) num_valid_blocks--; + } + } + int num_invalid_blocks = parallel - num_valid_blocks; + parallel = num_valid_blocks; + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int block_id = -1; + int64_t expert_id = 0; // use int64 to avoid computation result overflow + int old_expert_id = 0; + int64_t B_expert_off = 0; + + int4* sh_block_sorted_ids_int4 = sh; + int32_t* sh_block_sorted_ids = + reinterpret_cast(sh_block_sorted_ids_int4); + int4* sh_block_topk_weights_int4 = + sh_block_sorted_ids_int4 + moe_block_size / 4; + scalar_t2* sh_block_topk_weights = + reinterpret_cast(sh_block_topk_weights_int4); + int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4; + + int32_t block_num_valid_tokens = 0; + int32_t locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // read moe block data given block_id + // block_sorted_ids / block_num_valid_tokens / block_topk_weights + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + #pragma unroll + for (int i = 0; i < moe_block_size / 4; i++) { + int4 sorted_token_ids_int4 = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); + #pragma unroll + for (int j = 0; j < 4; j++) { + if (sorted_token_ids[j] >= prob_m * top_k) { + block_num_valid_tokens = i * 4 + j; + break; + } + } + if (block_num_valid_tokens != moe_block_size) break; + } + + __syncthreads(); + int tid4 = threadIdx.x / 4; + if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { + sh_block_sorted_ids_int4[tid4] = reinterpret_cast( + sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + + if (mul_topk_weights) { + #pragma unroll + for (int i = 0; i < 4; i++) { + sh_block_topk_weights[tid4 * 4 + i] = + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + } + } + } + __syncthreads(); + }; + + // when move to next moe block, find the next block_id and expert_id + // and then read moe block data + auto update_next_moe_block_data = [&]() { + if (par_id >= parallel) return; + + old_expert_id = expert_id; + if (num_invalid_blocks > 0) { + int skip_count = block_id == -1 ? par_id : 0; + block_id++; + for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) { + expert_id = expert_ids_ptr[i]; + if (expert_id != -1) { + if (skip_count == 0) { + block_id = i; + break; + }; + skip_count--; + }; + } + } else { + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; + } + + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); + scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; + if constexpr (has_zp) { + zp_ptr += (expert_id - old_expert_id) * zp_expert_stride; + } + if constexpr (has_act_order) { + g_idx += (expert_id - old_expert_id) * prob_k; + } + + read_moe_block_data(block_id); + }; + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = + div_ceil(block_num_valid_tokens, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int col = slice_col * 16 * thread_n_blocks / 8 + + threadIdx.x % threads_per_m; + C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + slice_col = 0; + par_id++; + update_next_moe_block_data(); + } + }; + + update_next_moe_block_data(); + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh_new; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_b; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + int a_remaining_load_count_in_slice = stages; + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || + a_remaining_load_count_in_slice > 0) { + a_remaining_load_count_in_slice--; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int64_t sorted_row = 0; + if (!m_block_size_8 || row < 8) + sorted_row = sh_block_sorted_ids[row] / top_k; + int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], + row < block_num_valid_tokens); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + B_ptr[i] + j + B_expert_off); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = + ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant(zp_quant_0, frag_zp_0); + dequant(zp_quant_1, frag_zp_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + } + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant(b_quant_0, frag_b0); + dequant(b_quant_1, frag_b1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); + + } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); + } else if constexpr (has_zp && is_zp_float && group_blocks != -1) { + if (is_new_zp) + frag_zpf[k2][j] = __hmul2( + frag_zpf[k2][j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x); + scale_and_sub(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + if (!is_th_active) { + return; + } + + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + if (!first) { + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx]; + } + } + } + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + + int c_idx; + if constexpr (m_block_size_8) + c_idx = c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i; + else + c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + if (c_idx / c_gl_stride < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride]; + int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + C[true_idx] = c; + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + if constexpr (m_block_size_8) { + if (k % 2) continue; + } else { + if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) + continue; + } + + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4 && !has_zp) { + res = __hmul2(res, s[0]); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + int row = c_gl_wr / c_gl_stride; + if (row < block_num_valid_tokens) { + int64_t sorted_row = sh_block_sorted_ids[row]; + int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride; + scalar_t2 topk_weight_score; + if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row]; + if (use_atomic_add && slice_count > 1 || mul_topk_weights) { + scalar_t2* C_half2 = reinterpret_cast(&C[true_idx]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + scalar_t2 res = sh_red_half2[a]; + if (mul_topk_weights) { + res = __hmul2(res, topk_weight_score); + } + + if (use_atomic_add && slice_count > 1) { + atomicAdd(&C_half2[a], res); + } else { + C_half2[a] = res; + }; + } + } else { + C[true_idx] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + fetch_col_scale_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + a_remaining_load_count_in_slice = 0; + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + #pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8 && !has_zp) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + if (slice_row) a_remaining_load_count_in_slice = stages; + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu new file mode 100644 index 000000000000..a16e955a325e --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -0,0 +1,927 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 +#endif + +#include "kernel.h" +#include "core/registration.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) {}; + +} // namespace marlin + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +template +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + const int32_t* __restrict__ sorted_token_ids_ptr, + const int32_t* __restrict__ expert_ids_ptr, + const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, + int size_k, int top_k) { + int num_tokens_past_padded = num_tokens_past_padded_ptr[0]; + int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size); + int32_t block_sorted_ids[moe_block_size]; + int block_num_valid_tokens = 0; + int64_t old_expert_id = 0; + int64_t expert_id = 0; + int row_stride = size_k * sizeof(half) / 16; + + auto read_moe_block_data = [&](int block_id) { + block_num_valid_tokens = moe_block_size; + int4* tmp_block_sorted_ids = reinterpret_cast(block_sorted_ids); + for (int i = 0; i < moe_block_size / 4; i++) { + tmp_block_sorted_ids[i] = + ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; + } + for (int i = 0; i < moe_block_size; i++) { + if (block_sorted_ids[i] >= size_m * top_k) { + block_num_valid_tokens = i; + break; + }; + } + }; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int in_offset = (row / top_k) * row_stride; + int out_offset = row * row_stride; + + half const* a_row_half = + reinterpret_cast(a_int4_ptr + in_offset); + half* out_half = reinterpret_cast(out_int4_ptr + out_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) { + old_expert_id = expert_id; + int tmp_expert_id = expert_ids_ptr[index]; + if (tmp_expert_id == -1) continue; + expert_id = tmp_expert_id; + perm_int_ptr += (expert_id - old_expert_id) * size_k; + read_moe_block_data(index); + + for (int i = 0; i < block_num_valid_tokens; i++) + permute_row(block_sorted_ids[i]); + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + + // shm size for block_sorted_ids/block_topk_weights + // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) + int sh_block_meta_size = tb_m * 4 * 2; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + + sh_g_idx_size + sh_block_meta_size; + + return total_size; +} + +bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float); + return cache_size <= max_shared_mem; +} + + #define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ + NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ + } + + #define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) + + #define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ + NUM_THREADS, false) + + #define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \ + false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) + + #define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) \ + \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, false) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ + NUM_THREADS, false) + + // We currently have 4-bit models only with group_blocks == 4 + #define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ + true) \ + __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) \ + __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ + NUM_THREADS, true) + +template +MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, + int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool m_block_size_8, + bool has_act_order, bool has_zp, + int group_blocks, int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256) + GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256) + GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128) + + GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) + GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128) + + GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) + GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128) + + AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) + AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128) + + AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256) + AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128) + + return kernel; +} + +template +exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, + int group_size, bool has_act_order, + bool is_k_full, bool has_zp, + bool is_zp_float, int max_shared_mem) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = + thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + int count = 0; + constexpr int device_max_reg_size = 255 * 1024; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, + is_zp_float, max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, has_zp, is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, th_config.thread_n / 16, + th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, + group_blocks, th_config.num_threads, is_zp_float); + + if (kernel == MarlinDefault) continue; + + if (thread_m_blocks > 1) { + exec_cfg = {1, th_config}; + break; + } else { + cudaFuncAttributes attr; + cudaFuncGetAttributes(&attr, kernel); + int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4; + int allow_count = min(device_max_reg_size / reg_size, + max_shared_mem / (cache_size + 1024)); + allow_count = max(min(allow_count, 4), 1); + if (allow_count > count) { + count = allow_count; + exec_cfg = {count, th_config}; + }; + } + } + + return exec_cfg; +} + +template +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, + void* sorted_token_ids, void* expert_ids, + void* num_tokens_past_padded, void* topk_weights, + int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + int thread_m_blocks = div_ceil(moe_block_size, 16); + bool m_block_size_8 = moe_block_size == 8; + + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids; + const int32_t* expert_ids_ptr = (const int32_t*)expert_ids; + const int32_t* num_tokens_past_padded_ptr = + (const int32_t*)num_tokens_past_padded; + const float* topk_weights_ptr = (const float*)topk_weights; + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + auto kernel = permute_cols_kernel<8>; + if (moe_block_size == 8) { + } else if (moe_block_size == 16) + kernel = permute_cols_kernel<16>; + else if (moe_block_size == 32) + kernel = permute_cols_kernel<32>; + else if (moe_block_size == 48) + kernel = permute_cols_kernel<48>; + else if (moe_block_size == 64) + kernel = permute_cols_kernel<64>; + else + TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size); + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr, + num_tokens_past_padded_ptr, prob_m, prob_k, top_k); + // clang-format on + A_ptr = a_tmp_ptr; + prob_m = prob_m * top_k; + top_k = 1; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem); + thread_tfg = exec_cfg.tb_cfg; + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, + prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem = ", max_shared_mem); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, + has_act_order, has_zp, group_blocks, num_threads, is_zp_float); + + if (kernel == MarlinDefault) { + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, + topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); + // clang-format on +} + +} // namespace MARLIN_NAMESPACE_NAME + +torch::Tensor moe_wna16_marlin_gemm( + torch::Tensor& a, std::optional const& c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, + torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + if (moe_block_size != 8) { + TORCH_CHECK(moe_block_size % 16 == 0, + "unsupported moe_block_size=", moe_block_size); + TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64, + "unsupported moe_block_size=", moe_block_size); + } + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1), + "Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1), + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(2) = ", b_q_weight.size(2), + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = + (b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c; + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK(c.size(0) == size_m * top_k, + "Shape mismatch: c.size(0) = ", c.size(0), + ", size_m * topk = ", size_m * top_k); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), + ", size_n = ", size_n); + } else { + c = torch::empty({size_m * top_k, size_n}, options); + } + + // Alloc C tmp buffer that is going to be used for the global reduce + torch::Tensor c_tmp; + auto options_fp32 = + torch::TensorOptions().dtype(at::kFloat).device(a.device()); + if (use_fp32_reduce && !use_atomic_add) { + // max num of threadblocks is sms * 4 + long max_c_tmp_size = min( + (long)size_n * sorted_token_ids.size(0), + (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + if (moe_block_size == 8) max_c_tmp_size *= 2; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); + } else { + c_tmp = torch::empty({0}, options_fp32); + } + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int rank = b_scales.sizes().size(); + TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + torch::Tensor g_idx, perm, a_tmp; + ; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) || + (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", g_idx.size(-1), + " and perm.size(-1) = ", perm.size(-1), + ", where size_k = ", size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + + if (has_act_order) { + a_tmp = torch::empty({size_m * top_k, size_k}, options); + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + a_tmp = torch::empty({0}, options); + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(1) = ", b_scales.size(1)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + + if (has_zp) { + TORCH_CHECK( + b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK( + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + if (is_zp_float) { + TORCH_CHECK(b_zeros.size(2) == size_n, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n = ", size_n); + TORCH_CHECK(num_groups == b_zeros.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); + } else { + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + } + + // Verify workspace size + TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n; + int min_workspace_size = min( + max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4); + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + c_tmp.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), + topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, + size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, + is_k_full, has_zp, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + MARLIN_NAMESPACE_NAME::marlin_mm( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), + b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + perm.data_ptr(), a_tmp.data_ptr(), + sorted_token_ids.data_ptr(), expert_ids.data_ptr(), + num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), + moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); + } else { + TORCH_CHECK(false, + "moe_wna16_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 718418e6cd49..d0de42251f97 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); m.def( - "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " - "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " - "int b_q_type, SymInt size_m, " - "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " - "topk, " - "int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," + "Tensor sorted_token_ids," + "Tensor! expert_ids, Tensor! num_tokens_past_padded," + "Tensor! topk_weights, int moe_block_size, int top_k, " + "bool mul_topk_weights, bool is_ep, int b_q_type_id," + "int size_m, int size_n, int size_k," + "bool is_full_k, bool use_atomic_add," + "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + // conditionally compiled so impl registration is in source file #endif diff --git a/csrc/ops.h b/csrc/ops.h index 86039a26041b..fe120af5d568 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,6 +128,12 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); +void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, double scale); + torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 2fb0417ce6c4..894727383a63 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -46,14 +46,26 @@ __global__ void compute_expert_offsets( } __global__ void compute_arg_sorts(const int* __restrict__ topk_ids, + const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, int32_t* atomic_buffer, const int topk_length, const int topk) { - int expert_id = blockIdx.x; + int const blk_expert_id = blockIdx.x; + int const num_experts = gridDim.x; + int32_t const num_tokens = expert_offsets[num_experts]; for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - if (topk_ids[i] == expert_id) { + int const expert_id = topk_ids[i]; + if (expert_id == -1 && blockIdx.x == 0) { + // output_permutation is used to re-order the moe outputs. It is + // used as c2 = c2[c_map], where c2 is a torch.tensor that is the + // output of the cutlass kernels and c_map is the output_permutation. + // c2 is initialized to zeros, therefore by setting the output_permutation + // to num_tokens, we are guaranteed to fill the moe outputs to zero + // for "invalid" topk_ids. + output_permutation[i] = num_tokens; + } else if (expert_id == blk_expert_id) { int start = atomicAdd(&atomic_buffer[expert_id], 1); input_permutation[start] = i / topk; output_permutation[i] = start; @@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), num_experts); compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), + static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh index a96b9e594a36..bc1f17ec097d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh @@ -882,7 +882,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, } uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh index 95723b31ca3c..87be125b2eb3 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh @@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 6e14de0c7805..97c0e0da7b1f 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options( using StrideB = typename T::StrideB; using StrideD = typename T::StrideD; using Sm100BlkScaledConfig = - typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; int m = static_cast(M); int n = static_cast(N); diff --git a/csrc/quantization/gptq_marlin/marlin.cuh b/csrc/quantization/gptq_marlin/marlin.cuh index 74ccbac57bd3..f3b44641e77e 100644 --- a/csrc/quantization/gptq_marlin/marlin.cuh +++ b/csrc/quantization/gptq_marlin/marlin.cuh @@ -9,7 +9,11 @@ #include #include -namespace marlin { +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { // Marlin params @@ -23,6 +27,7 @@ static constexpr int pipe_stages = static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; static constexpr int tile_size = 16; static constexpr int max_par = 16; @@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace marlin +} // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh index be06c09bee33..cc1605481434 100644 --- a/csrc/quantization/gptq_marlin/marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/marlin_dtypes.cuh @@ -5,7 +5,11 @@ #include #include -namespace marlin { +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { template class ScalarType {}; @@ -54,7 +58,7 @@ class ScalarType { using FragS = Vec; using FragZP = Vec; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } @@ -74,6 +78,6 @@ class ScalarType { #endif }; -} // namespace marlin +} // namespace MARLIN_NAMESPACE_NAME #endif diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index afb735450e0c..b90cfdc617af 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -2,6 +2,15 @@ #include +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block); + +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount); + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); + void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu new file mode 100644 index 000000000000..72d2820f2aab --- /dev/null +++ b/csrc/rocm/skinny_gemms.cu @@ -0,0 +1,1600 @@ +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "quantization/fp8/common.cuh" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(__HIPCC__) && defined(__gfx942__) + #define __HIP__MI300__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +struct scalar {}; + +template +struct scalar2 {}; + +template +__device__ __forceinline__ float2 __s22float2(T v); + +template +__device__ __forceinline__ T __float2s(float v); + +template +__device__ __forceinline__ T __float22s2_rn(float2 v); + +// Definitions and cvt functions for fp16 +template <> +struct scalar { + using type = half; +}; + +template <> +struct scalar2 { + using type = __half2; +}; + +template <> +__device__ __forceinline__ half __float2s(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ float2 __s22float2(__half2 v) { + return __half22float2(v); +} + +template <> +__device__ __forceinline__ __half2 __float22s2_rn(float2 v) { + return __float22half2_rn(v); +} + +// Definitions and cvt functions for bf16 +template <> +struct scalar { + using type = __hip_bfloat16; +}; + +template <> +struct scalar2 { + using type = __hip_bfloat162; +}; + +template <> +__device__ __forceinline__ __hip_bfloat16 __float2s(float v) { + return __float2bfloat16(v); +} + +template <> +__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { + return __bfloat1622float2(v); +} + +template <> +__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) { + return __float22bfloat162_rn(v); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, + scalar_t* out_c, const int K) { + using scalar2_t = typename scalar2::type; + auto af4 = reinterpret_cast(in_a); + auto bf4 = reinterpret_cast(in_b); + auto c = reinterpret_cast(out_c); + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / num_warps; + const int qthreadid = threadid % num_warps; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float acc[NUM_A_ROWS_PER_BLOCK]; + scalar2_t acch2; + scalar2_t oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + } + + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + scalar2_t Af2; + [[maybe_unused]] scalar2_t Bf2; + float2 S; + + auto Ah2ptr = reinterpret_cast(&rowA_elem4); + scalar2_t* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 scalar_t. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __s22float2(acch2); + + // See comment above concerning the if guard. + acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f); + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; + for (int mask = num_warps / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], num_warps); + + if (lane % (num_warps * 2) == 0) { + oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + auto N = in_b.size(0); + + TORCH_CHECK(N == 1, "Row number of activation tensor must be 1."); + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(in_b.dtype() == torch::kFloat16 || + in_b.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + + int NUM_BLOCKS = M / rows_per_block; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // call the kernel function... + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + auto c_ptr = out_c.data_ptr(); + if (rows_per_block == 2) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } + }); + + return out_c; +} + +#define DOT2C(V0, V2, V3) \ + if constexpr (std::is_same_v) { \ + asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ + } else if constexpr (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ + } + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ scalar_t s[1024 * 32]; + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * N, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + float sum[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (m < M) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = 0; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]) + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); + } + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); + } + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); + } + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); + } + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); + } + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); + } + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); + } + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ scalar_t s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * N, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (m < M) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = 0; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); + } + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); + } + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); + } + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); + } + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); + } + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); + } + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); + } + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + } +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ scalar_t s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * N, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / N; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW); + while (m < Mrndp) { + #else + while (m < M) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = 0; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t n = 0; n < N; n++) { + uint32_t k_in = kBase + n * K + kOff; + uint32_t k_ot = n * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (m >= M) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + #ifdef PCML + bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n]))); + #else + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); + } + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); + } + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); + } + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); + } + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); + } + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); + } + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); + } + } + } + } + } + + #ifdef PCML + if (m >= M) { + m += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount) { + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); + TORCH_CHECK(in_a.dtype() == torch::kFloat16 || + in_a.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N_in, M_in}, + torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + dim3 grid(CuCount); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitK_hf_sml_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitK_hf_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ + wvSplitK_hf_big_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { + using fptype = typename scalar::type; + fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + fptype* c = reinterpret_cast(out_c.data_ptr()); + switch (N_in) { + case 1: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + return out_c; +} + +#if defined(__HIP__MI300__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[1024 * 64]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0.f}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + #pragma unroll + for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f}; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f}; + } + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) { + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + if (k >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +#if defined(__HIP__MI300__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[1024 * 64]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + for (int y = 0; y < YTILE; ++y) { + if (y + m >= M) break; // To avoid mem access fault. + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + if (k_ + K * n < 64 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + if (y + m >= M) break; // To avoid mem access fault. + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, + const int64_t CuCount) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + auto Kp_in = in_a.stride(0); + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); + TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); + TORCH_CHECK(out_c.dtype() == torch::kFloat16 || + out_c.dtype() == torch::kBFloat16); + + dim3 grid(CuCount); + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitKQ_hf_sml_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitKQ_hf_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] { + using fptype = typename scalar::type; + auto c_ptr = reinterpret_cast(out_c.data_ptr()); + auto s_a = scale_a.data_ptr(); + auto s_b = scale_b.data_ptr(); + VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + switch (N_in) { + case 1: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + }); +} diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 537e9357d52b..4ac6fd1e9940 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -14,6 +14,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // vLLM custom ops for rocm + // Custom gemm op for matrix-vector multiplication + rocm_ops.def( + "LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> " + "Tensor"); + rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1); + + // Custom gemm op for skinny matrix-matrix multiplication + rocm_ops.def( + "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "Tensor"); + rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); + + // wvSplitK for fp8 + rocm_ops.def( + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + " Tensor scale_b, int CuCount) -> ()"); + rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); + // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b6ff6a006c02..c9a120976b1c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); + // Compute MLA decode using cutlass. + ops.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table, float scale) -> ()"); + ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/docker-bake.hcl b/docker-bake.hcl index 1d60cd769c3c..ad9e00c49d5f 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -2,12 +2,25 @@ variable "REPOSITORY" { default = "quay.io/vllm/vllm" } -# GITHUB_* variables are only available in github actions +# GITHUB_* variables are set as env vars in github actions variable "GITHUB_SHA" {} variable "GITHUB_REPOSITORY" {} variable "GITHUB_RUN_ID" {} -variable "VLLM_VERSION" {} # set by github actions or manually? +variable "VLLM_VERSION" {} + +variable "PYTHON_VERSION" { + default = "3.12" +} + +variable "ROCM_VERSION" { + default = "6.3.4" +} + +variable "VLLM_TGIS_ADAPTER_VERSION" { + default = "0.7.0" +} + target "docker-metadata-action" {} // populated by gha docker/metadata-action @@ -15,7 +28,7 @@ target "_common" { context = "." args = { - BASE_UBI_IMAGE_TAG = "9.5-1736404155" + BASE_UBI_IMAGE_TAG = "9.5-1742914212" PYTHON_VERSION = "3.12" } @@ -34,6 +47,7 @@ target "_common" { group "default" { targets = [ "cuda", + "rocm", ] } @@ -42,11 +56,10 @@ target "cuda" { dockerfile = "Dockerfile.ubi" args = { - BASE_UBI_IMAGE_TAG = "9.5-1739420147" - PYTHON_VERSION = "3.12" + PYTHON_VERSION = "${PYTHON_VERSION}" # CUDA_VERSION = "12.4" # TODO: the dockerfile cannot consume the cuda version LIBSODIUM_VERSION = "1.0.20" - VLLM_TGIS_ADAPTER_VERSION = "0.7.0" + VLLM_TGIS_ADAPTER_VERSION = "${VLLM_TGIS_ADAPTER_VERSION}" FLASHINFER_VERSION = "https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl" } @@ -57,3 +70,21 @@ target "cuda" { "${REPOSITORY}:${formatdate("YYYY-MM-DD-hh-mm", timestamp())}" ] } + +target "rocm" { + inherits = ["_common"] + dockerfile = "Dockerfile.rocm.ubi" + + args = { + PYTHON_VERSION = "${PYTHON_VERSION}" + ROCM_VERSION = "${ROCM_VERSION}" + LIBSODIUM_VERSION = "1.0.20" + VLLM_TGIS_ADAPTER_VERSION = "${VLLM_TGIS_ADAPTER_VERSION}" + } + + tags = [ + "${REPOSITORY}:${replace(VLLM_VERSION, "+", "_")}", # vllm_version might contain local version specifiers (+) which are not valid tags + "${REPOSITORY}:${GITHUB_SHA}", + "${REPOSITORY}:${formatdate("YYYY-MM-DD-hh-mm", timestamp())}" + ] +} diff --git a/docker/Dockerfile b/docker/Dockerfile index d1ecef586d50..1b28845d0ac0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -162,6 +162,9 @@ ENV UV_HTTP_TIMEOUT=500 COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt +# Workaround for #17068 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system mamba-ssm==2.2.4 --no-build-isolation RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/dev.txt #################### DEV IMAGE #################### @@ -240,6 +243,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ fi COPY examples examples +COPY benchmarks benchmarks +COPY ./vllm/collect_env.py . # Although we build Flashinfer with AOT mode, there's still # some issues w.r.t. JIT compilation. Therefore we need to @@ -263,6 +268,9 @@ ADD . /vllm-workspace/ ENV UV_HTTP_TIMEOUT=500 # install development dependencies (for testing) +# Workaround for #17068 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system mamba-ssm==2.2.4 --no-build-isolation RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/dev.txt @@ -289,6 +297,7 @@ RUN mv vllm test_docs/ #################### OPENAI API SERVER #################### # base openai image with additional requirements, for any subsequent openai-style images FROM vllm-base AS vllm-openai-base +ARG TARGETPLATFORM # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 54d1ce86d011..c647d9036f40 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ADD ./tests/ ./tests/ ADD ./examples/ ./examples/ ADD ./benchmarks/ ./benchmarks/ +ADD ./vllm/collect_env.py . # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch new file mode 100644 index 000000000000..0063712e4781 --- /dev/null +++ b/docker/Dockerfile.nightly_torch @@ -0,0 +1,307 @@ +# The vLLM Dockerfile is used to construct vLLM image against torch nightly that can be directly used for testing + +# for torch nightly, cuda >=12.6 is required, +# use 12.8 due to FlashAttention issue with cuda 12.6 (https://github.com/vllm-project/vllm/issues/15435#issuecomment-2775924628) +ARG CUDA_VERSION=12.8.0 +# +#################### BASE BUILD IMAGE #################### +# prepare basic build environment +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base +ARG CUDA_VERSION=12.8.0 +ARG PYTHON_VERSION=3.12 +ARG TARGETPLATFORM +ENV DEBIAN_FRONTEND=noninteractive +# Install Python and other dependencies +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update -y \ + && apt-get install -y ccache software-properties-common git curl sudo \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version \ + && python3 -m pip --version +# Install uv for faster pip installs +RUN --mount=type=cache,target=/root/.cache/uv \ + python3 -m pip install uv + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 +# as it was causing spam when compiling the CUTLASS kernels +RUN apt-get install -y gcc-10 g++-10 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 +RUN < torch_build_versions.txt +RUN cat torch_build_versions.txt + +# cuda arch list used by torch +# can be useful for `test` +# explicitly set the list to avoid issues with torch 2.2 +# see https://github.com/pytorch/pytorch/pull/123243 + +# Override the arch list for flash-attn to reduce the binary size +ARG vllm_fa_cmake_gpu_arches='80-real;90-real' +ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} +#################### BASE BUILD IMAGE #################### + +#################### WHEEL BUILD IMAGE #################### +FROM base AS build +ARG TARGETPLATFORM + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +COPY . . + +RUN python3 use_existing_torch.py + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/build.txt + +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi + +# Max jobs used by Ninja to build extensions +ARG max_jobs=16 +ENV MAX_JOBS=${max_jobs} +ARG nvcc_threads=2 +ENV NVCC_THREADS=$nvcc_threads + +ARG USE_SCCACHE +ARG SCCACHE_BUCKET_NAME=vllm-build-sccache +ARG SCCACHE_REGION_NAME=us-west-2 +ARG SCCACHE_S3_NO_CREDENTIALS=0 + +# if USE_SCCACHE is set, use sccache to speed up compilation +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=.git,target=.git \ + if [ "$USE_SCCACHE" = "1" ]; then \ + echo "Installing sccache..." \ + && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ + && tar -xzf sccache.tar.gz \ + && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ + && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ + && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ + && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ + && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ + && export SCCACHE_IDLE_TIMEOUT=0 \ + && export CMAKE_BUILD_TYPE=Release \ + && sccache --show-stats \ + && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ + && sccache --show-stats; \ + fi + +ENV CCACHE_DIR=/root/.cache/ccache +RUN --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=.git,target=.git \ + if [ "$USE_SCCACHE" != "1" ]; then \ + # Clean any existing CMake artifacts + rm -rf .deps && \ + mkdir -p .deps && \ + python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ + fi + +#################### WHEEL BUILD IMAGE #################### + +################### VLLM INSTALLED IMAGE #################### +# Setup clean environment for vLLM and its dependencies for test and api server using ubuntu22.04 with AOT flashinfer +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base +# prepare for environment starts +ARG CUDA_VERSION=12.8.0 +ARG PYTHON_VERSION=3.12 +WORKDIR /vllm-workspace +ENV DEBIAN_FRONTEND=noninteractive +ARG TARGETPLATFORM + +RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ + echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment + +# Install Python and other dependencies +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update -y \ + && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +RUN --mount=type=cache,target=/root/.cache/uv \ + python3 -m pip install uv + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ + +# get the nightly torch version used in the build to make sure the version is the same +COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128 + +# install the vllm wheel +RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm-dist \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system vllm-dist/*.whl --verbose + +# install xformers again for the new environment +RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose + +ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0' + +# install package for build flashinfer +# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738 +RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.post1 + + +# build flashinfer for torch nightly from source around 10 mins +# release version: v0.2.2.post1 +# todo(elainewy): cache flashinfer build result for faster build +ENV CCACHE_DIR=/root/.cache/ccache +RUN --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ + echo "git clone flashinfer..." \ + && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ + && cd flashinfer \ + && git checkout v0.2.2.post1 \ + && git submodule update --init --recursive \ + && echo "finish git clone flashinfer..." \ + && rm -rf build \ + && export TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} \ + && FLASHINFER_ENABLE_AOT=1 python3 setup.py bdist_wheel --dist-dir=../flashinfer-dist --verbose \ + && cd .. \ + && rm -rf flashinfer + +# install flashinfer +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system flashinfer-dist/*.whl --verbose + +# install common packages +COPY requirements/common.txt requirements/common.txt +COPY use_existing_torch.py use_existing_torch.py +COPY pyproject.toml pyproject.toml + +COPY examples examples +COPY benchmarks benchmarks +COPY ./vllm/collect_env.py . + +RUN python3 use_existing_torch.py +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/common.txt + +################### VLLM INSTALLED IMAGE #################### + + +#################### UNITTEST IMAGE ############################# +FROM vllm-base as test +COPY tests/ tests/ + +# install build and runtime dependencies without stable torch version +COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 + +# install development dependencies (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -e tests/vllm_test_utils + +# enable fast downloads from hf (for testing) +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system hf_transfer +ENV HF_HUB_ENABLE_HF_TRANSFER 1 + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements/nightly_torch_test.txt + +#################### UNITTEST IMAGE ############################# + diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index b8523fbc2a01..1776b26d445c 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="8970b25b" +ARG AITER_BRANCH="7e1ed08" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docs/source/assets/deployment/anything-llm-chat-with-doc.png b/docs/source/assets/deployment/anything-llm-chat-with-doc.png new file mode 100644 index 000000000000..f9b57f5c3cec Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-chat-with-doc.png differ diff --git a/docs/source/assets/deployment/anything-llm-chat-without-doc.png b/docs/source/assets/deployment/anything-llm-chat-without-doc.png new file mode 100644 index 000000000000..952a43bcd677 Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-chat-without-doc.png differ diff --git a/docs/source/assets/deployment/anything-llm-provider.png b/docs/source/assets/deployment/anything-llm-provider.png new file mode 100644 index 000000000000..bb699f7571f4 Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-provider.png differ diff --git a/docs/source/assets/deployment/anything-llm-upload-doc.png b/docs/source/assets/deployment/anything-llm-upload-doc.png new file mode 100644 index 000000000000..00c70e9c01f6 Binary files /dev/null and b/docs/source/assets/deployment/anything-llm-upload-doc.png differ diff --git a/docs/source/assets/deployment/open_webui.png b/docs/source/assets/deployment/open_webui.png new file mode 100644 index 000000000000..fe9a7e15ea71 Binary files /dev/null and b/docs/source/assets/deployment/open_webui.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index a83ad764125c..c2ad6f9fa3a5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -177,6 +177,11 @@ def linkcode_resolve(domain, info): for part in info['fullname'].split('.'): obj = getattr(obj, part) + # Skip decorator wrappers by checking if the object is a function + # and has a __wrapped__ attribute (which decorators typically set) + while hasattr(obj, '__wrapped__'): + obj = obj.__wrapped__ + if not (inspect.isclass(obj) or inspect.isfunction(obj) or inspect.ismethod(obj)): obj = obj.__class__ # Get the class of the instance diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 03d830fe90f1..b42536f054d7 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -128,11 +128,9 @@ HF processing as well as memory profiling. ### For memory profiling -Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` -to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of -the model so that vLLM can reserve the correct amount of memory for it. +Override the abstract methods {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text` and {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_mm_data` to construct dummy inputs for memory profiling. These dummy inputs should result in the worst-case memory usage of the model so that vLLM can reserve the correct amount of memory for it. -Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. +Assuming that the memory usage increases with the number of tokens, the dummy inputs can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. ::::{tab-set} :::{tab-item} Basic example: LLaVA @@ -244,38 +242,45 @@ def get_num_image_tokens( ``` Notice that the number of image tokens doesn't depend on the image width and height. -We can simply use a dummy `image_size`: +We can simply use a dummy `image_size` to calculate the multimodal profiling data: ```python +# NOTE: In actuality, this is usually implemented as part of the +# model's subclass of `BaseProcessingInfo`, but we show it as is +# here for simplicity. def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() width = height = hf_config.image_size return ImageSize(width=width, height=height) -def get_dummy_processor_inputs( +def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], -) -> ProcessorInputs: +) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - processor = self.info.get_hf_processor() - image_token = processor.image_token - - hf_config = self.get_hf_config() - target_width, target_height = self.info.get_image_size_with_most_features() + target_width, target_height = \ + self.info.get_image_size_with_most_features() - mm_data = { + return { "image": self._get_dummy_images(width=target_width, height=target_height, num_images=num_images) } +``` - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) +For the text, we simply expand the multimodal image token from the model config to match the desired number of images. + +```python +def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images ``` ::: @@ -412,29 +417,30 @@ def get_image_size_with_most_features(self) -> ImageSize: Fuyu does not expect image placeholders in the inputs to HF processor, so the dummy prompt text is empty regardless of the number of images. -Otherwise, the logic of this method is very similar to LLaVA: ```python -def get_dummy_processor_inputs( +def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" +``` + +For the multimodal image profiling data, the logic is very similar to LLaVA: + +```python +def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], -) -> ProcessorInputs: +) -> MultiModalDataDict: target_width, target_height = \ self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - mm_data = { + return { "image": self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + height=target_height, + num_images=num_images) } - - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) ``` ::: diff --git a/docs/source/deployment/docker.md b/docs/source/deployment/docker.md index 6b794db656c0..ca56710bc2ef 100644 --- a/docs/source/deployment/docker.md +++ b/docs/source/deployment/docker.md @@ -19,6 +19,18 @@ $ docker run --runtime nvidia --gpus all \ --model mistralai/Mistral-7B-v0.1 ``` +This image can also be used with other container engines such as [Podman](https://podman.io/). + +```console +$ podman run --gpus all \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ + -p 8000:8000 \ + --ipc=host \ + vllm/vllm-openai:latest \ + --model mistralai/Mistral-7B-v0.1 +``` + You can add any other you need after the image tag (`vllm/vllm-openai:latest`). :::{note} diff --git a/docs/source/deployment/frameworks/anything-llm.md b/docs/source/deployment/frameworks/anything-llm.md new file mode 100644 index 000000000000..d430c170ef54 --- /dev/null +++ b/docs/source/deployment/frameworks/anything-llm.md @@ -0,0 +1,47 @@ +(deployment-anything-llm)= + +# Anything LLM + +[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. + +It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. + +## Prerequisites + +- Setup vLLM environment + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 +``` + +- Download and install [Anything LLM desktop](https://anythingllm.com/desktop). + +- On the bottom left of open settings, AI Prooviders --> LLM: + - LLM Provider: Generic OpenAI + - Base URL: http://{vllm server host}:{vllm server port}/v1 + - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` + +:::{image} /assets/deployment/anything-llm-provider.png +::: + +- Back to home page, New Workspace --> create `vllm` workspace, and start to chat: + +:::{image} /assets/deployment/anything-llm-chat-without-doc.png +::: + +- Click the upload button: + - upload the doc + - select the doc and move to the workspace + - save and embed + +:::{image} /assets/deployment/anything-llm-upload-doc.png +::: + +- Chat again: + +:::{image} /assets/deployment/anything-llm-chat-with-doc.png +::: diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index cb758d3e6d2e..a1b405386b77 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -3,12 +3,14 @@ :::{toctree} :maxdepth: 1 +anything-llm bentoml cerebrium dstack helm lws modal +open-webui skypilot triton ::: diff --git a/docs/source/deployment/frameworks/open-webui.md b/docs/source/deployment/frameworks/open-webui.md new file mode 100644 index 000000000000..83e5303a00ef --- /dev/null +++ b/docs/source/deployment/frameworks/open-webui.md @@ -0,0 +1,29 @@ +(deployment-open-webui)= + +# Open WebUI + +1. Install the [Docker](https://docs.docker.com/engine/install/) + +2. Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port): + +```console +docker run -d -p 3000:8080 \ +--name open-webui \ +-v open-webui:/app/backend/data \ +-e OPENAI_API_BASE_URL=http://:/v1 \ +--restart always \ +ghcr.io/open-webui/open-webui:main +``` + +1. Open it in the browser: + +On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`. + +:::{image} /assets/deployment/open_webui.png +::: diff --git a/docs/source/deployment/integrations/production-stack.md b/docs/source/deployment/integrations/production-stack.md index e66e8e6a16b2..05f1568306cc 100644 --- a/docs/source/deployment/integrations/production-stack.md +++ b/docs/source/deployment/integrations/production-stack.md @@ -16,7 +16,7 @@ Ensure that you have a running Kubernetes environment with GPU (you can follow [ ## Deployment using vLLM production stack -The standard vLLM production stack install uses a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/tutorials/install-helm.sh) to install Helm on your GPU server. +The standard vLLM production stack is installed using a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/utils/install-helm.sh) to install Helm on your GPU server. To install the vLLM production stack, run the following commands on your desktop: diff --git a/docs/source/deployment/security.md b/docs/source/deployment/security.md new file mode 100644 index 000000000000..e2ef8196c167 --- /dev/null +++ b/docs/source/deployment/security.md @@ -0,0 +1,58 @@ +# Security Guide + +## Inter-Node Communication + +All communications between nodes in a multi-node vLLM deployment are **insecure by default** and must be protected by placing the nodes on an isolated network. This includes: + +1. PyTorch Distributed communications +2. KV cache transfer communications +3. Tensor, Pipeline, and Data parallel communications + +### Configuration Options for Inter-Node Communications + +The following options control inter-node communications in vLLM: + +1. **Environment Variables:** + - `VLLM_HOST_IP`: Sets the IP address for vLLM processes to communicate on + +2. **KV Cache Transfer Configuration:** + - `--kv-ip`: The IP address for KV cache transfer communications (default: 127.0.0.1) + - `--kv-port`: The port for KV cache transfer communications (default: 14579) + +3. **Data Parallel Configuration:** + - `data_parallel_master_ip`: IP of the data parallel master (default: 127.0.0.1) + - `data_parallel_master_port`: Port of the data parallel master (default: 29500) + +### Notes on PyTorch Distributed + +vLLM uses PyTorch's distributed features for some inter-node communication. For +detailed information about PyTorch Distributed security considerations, please +refer to the [PyTorch Security +Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features). + +Key points from the PyTorch security guide: +- PyTorch Distributed features are intended for internal communication only +- They are not built for use in untrusted environments or networks +- No authorization protocol is included for performance reasons +- Messages are sent unencrypted +- Connections are accepted from anywhere without checks + +### Security Recommendations + +1. **Network Isolation:** + - Deploy vLLM nodes on a dedicated, isolated network + - Use network segmentation to prevent unauthorized access + - Implement appropriate firewall rules + +2. **Configuration Best Practices:** + - Always set `VLLM_HOST_IP` to a specific IP address rather than using defaults + - Configure firewalls to only allow necessary ports between nodes + +3. **Access Control:** + - Restrict physical and network access to the deployment environment + - Implement proper authentication and authorization for management interfaces + - Follow the principle of least privilege for all system components + +## Reporting Security Vulnerabilities + +If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md). diff --git a/docs/source/design/mm_processing.md b/docs/source/design/mm_processing.md index 0947c1da1e54..dc92a3c2c511 100644 --- a/docs/source/design/mm_processing.md +++ b/docs/source/design/mm_processing.md @@ -47,7 +47,7 @@ Moreover, since the tokenized text has not passed through the HF processor, we h ### Dummy text -We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. +We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. (mm-automatic-prompt-updating)= diff --git a/docs/source/design/v1/metrics.md b/docs/source/design/v1/metrics.md index b3981b2dc24a..3f96290798a3 100644 --- a/docs/source/design/v1/metrics.md +++ b/docs/source/design/v1/metrics.md @@ -66,8 +66,8 @@ vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/getting_ The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: - `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds -- `vllm:prompt_tokens_total` - Prompt Tokens/Sec -- `vllm:generation_tokens_total` - Generation Tokens/Sec +- `vllm:prompt_tokens_total` - Prompt Tokens +- `vllm:generation_tokens_total` - Generation Tokens - `vllm:time_per_output_token_seconds` - Inter token latency (Time Per Output Token, TPOT) in second. - `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds. - `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in RUNNING, WAITING, and SWAPPED state @@ -86,6 +86,17 @@ See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful b Prometheus support was initially added [using the aioprometheus library](gh-pr:1890), but a switch was made quickly to [prometheus_client](gh-pr:2730). The rationale is discussed in both linked PRs. +With the switch to `aioprometheus`, we lost a `MetricsMiddleware` to track HTTP metrics, but this was reinstated [using prometheus_fastapi_instrumentator](gh-pr:15657): + +```bash +$ curl http://0.0.0.0:8000/metrics 2>/dev/null | grep -P '^http_(?!.*(_bucket|_created|_sum)).*' +http_requests_total{handler="/v1/completions",method="POST",status="2xx"} 201.0 +http_request_size_bytes_count{handler="/v1/completions"} 201.0 +http_response_size_bytes_count{handler="/v1/completions"} 201.0 +http_request_duration_highr_seconds_count 201.0 +http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201.0 +``` + ### Multi-process Mode In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See . diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 57dba680b97c..7920131643c2 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -99,7 +99,7 @@ This time, Inductor compilation is completely bypassed, and we will load from di The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example: -`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` +`vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'compile_sizes': [1, 2, 4, 8]}"` Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel. @@ -134,6 +134,6 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`: -`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` +`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. diff --git a/docs/source/features/disagg_prefill.md b/docs/source/features/disagg_prefill.md index 52d253b9c2b1..2fa20140c086 100644 --- a/docs/source/features/disagg_prefill.md +++ b/docs/source/features/disagg_prefill.md @@ -21,11 +21,11 @@ Disaggregated prefill DOES NOT improve throughput. ## Usage example -Please refer to `examples/online_serving/disaggregated_prefill.sh` for the example usage of disaggregated prefilling. +Please refer to for the example usage of disaggregated prefilling. ## Benchmarks -Please refer to `benchmarks/disagg_benchmarks/` for disaggregated prefilling benchmarks. +Please refer to for disaggregated prefilling benchmarks. ## Development diff --git a/docs/source/features/lora.md b/docs/source/features/lora.md index a71da72e4360..b5b51095b3a7 100644 --- a/docs/source/features/lora.md +++ b/docs/source/features/lora.md @@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \ ## Dynamically serving LoRA Adapters -In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading -LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility -to change models on-the-fly is needed. +In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed. Note: Enabling this feature in production environments is risky as users may participate in model adapter management. -To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` -is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active. +To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` +is set to `True`. ```bash export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True ``` +### Using API Endpoints Loading a LoRA Adapter: To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary @@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \ }' ``` +### Using Plugins +Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter. + +You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds. + +You can either install existing plugins or implement your own. + +Steps to implement your own LoRAResolver plugin: +1. Implement the LoRAResolver interface. + + Example of a simple S3 LoRAResolver implementation: + + ```python + import os + import s3fs + from vllm.lora.request import LoRARequest + from vllm.lora.resolver import LoRAResolver + + class S3LoRAResolver(LoRAResolver): + def __init__(self): + self.s3 = s3fs.S3FileSystem() + self.s3_path_format = os.getenv("S3_PATH_TEMPLATE") + self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE") + + async def resolve_lora(self, base_model_name, lora_name): + s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name) + local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name) + + # Download the LoRA from S3 to the local path + await self.s3._get( + s3_path, local_path, recursive=True, maxdepth=1 + ) + + lora_request = LoRARequest( + lora_name=lora_name, + lora_path=local_path, + lora_int_id=abs(hash(lora_name)) + ) + return lora_request + ``` + +2. Register LoRAResolver plugin. + + ```python + from vllm.lora.resolver import LoRAResolverRegistry + + s3_resolver = S3LoRAResolver() + LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver) + ``` + + For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md). + ## New format for `--lora-modules` In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example: diff --git a/docs/source/features/quantization/auto_awq.md b/docs/source/features/quantization/auto_awq.md index b703d0195319..b4ac597f5a79 100644 --- a/docs/source/features/quantization/auto_awq.md +++ b/docs/source/features/quantization/auto_awq.md @@ -6,13 +6,13 @@ To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint. The main benefits are lower latency and memory usage. -You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq). +You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?search=awq). ```console pip install autoawq ``` -After installing AutoAWQ, you are ready to quantize a model. Please refer to the `AutoAWQ documentation `_ for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: +After installing AutoAWQ, you are ready to quantize a model. Please refer to the [AutoAWQ documentation](https://casper-hansen.github.io/AutoAWQ/examples/#basic-quantization) for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: ```python from awq import AutoAWQForCausalLM diff --git a/docs/source/features/quantization/bitblas.md b/docs/source/features/quantization/bitblas.md new file mode 100644 index 000000000000..d0b2bf858c9b --- /dev/null +++ b/docs/source/features/quantization/bitblas.md @@ -0,0 +1,48 @@ +(bitblas)= + +# BitBLAS + +vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations. + +:::{note} +Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). +Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. +For details see [supported hardware](https://docs.vllm.ai/en/latest/features/quantization/supported_hardware.html). +::: + +Below are the steps to utilize BitBLAS with vLLM. + +```console +pip install bitblas>=0.1.0 +``` + +vLLM reads the model's config file and supports pre-quantized checkpoints. + +You can find pre-quantized models on: + +- [Hugging Face (BitBLAS)](https://huggingface.co/models?search=bitblas) +- [Hugging Face (GPTQ)](https://huggingface.co/models?search=gptq) + +Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section. + +## Read bitblas format checkpoint + +```python +from vllm import LLM +import torch + +# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint. +model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas" +llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas") +``` + +## Read gptq format checkpoint + +```python +from vllm import LLM +import torch + +# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint. +model_id = "hxbgsyxh/llama-13b-4bit-g-1" +llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024) +``` diff --git a/docs/source/features/quantization/bnb.md b/docs/source/features/quantization/bnb.md index e356b99d85cd..1843a33a3dfd 100644 --- a/docs/source/features/quantization/bnb.md +++ b/docs/source/features/quantization/bnb.md @@ -14,7 +14,7 @@ pip install bitsandbytes>=0.45.3 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. -You can find bitsandbytes quantized models on . +You can find bitsandbytes quantized models on . And usually, these repositories have a config.json file that includes a quantization_config section. ## Read quantized checkpoint diff --git a/docs/source/features/quantization/gptqmodel.md b/docs/source/features/quantization/gptqmodel.md index 34adf6512b7e..9771d5a4fe9e 100644 --- a/docs/source/features/quantization/gptqmodel.md +++ b/docs/source/features/quantization/gptqmodel.md @@ -16,12 +16,16 @@ GPTQModel is one of the few quantization toolkits in the world that allows `Dyna is fully integrated into vLLM and backed up by support from the ModelCloud.AI team. Please refer to [GPTQModel readme](https://github.com/ModelCloud/GPTQModel?tab=readme-ov-file#dynamic-quantization-per-module-quantizeconfig-override) for more details on this and other advanced features. -You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?sort=trending&search=gptq). +## Installation + +You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?search=gptq). ```console pip install -U gptqmodel --no-build-isolation -v ``` +## Quantizing a model + After installing GPTQModel, you are ready to quantize a model. Please refer to the [GPTQModel readme](https://github.com/ModelCloud/GPTQModel/?tab=readme-ov-file#quantization) for further details. Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`: @@ -49,12 +53,16 @@ model.quantize(calibration_dataset, batch_size=2) model.save(quant_path) ``` +## Running a quantized model with vLLM + To run an GPTQModel quantized model with vLLM, you can use [DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2](https://huggingface.co/ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2) with the following command: ```console -python examples/offline_inference/llm_engine_example.py --model DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 +python examples/offline_inference/llm_engine_example.py --model ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 ``` +## Using GPTQModel with vLLM's Python API + GPTQModel quantized models are also supported directly through the LLM entrypoint: ```python @@ -67,17 +75,22 @@ prompts = [ "The capital of France is", "The future of AI is", ] + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.6, top_p=0.9) # Create an LLM. -llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2") +llm = LLM(model="ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2") + # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) + # Print the outputs. +print("-"*50) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-"*50) ``` diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index 6f539f6e3f48..c7c8aeb662a5 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -11,6 +11,7 @@ Quantization trades off model precision for smaller memory footprint, allowing l supported_hardware auto_awq bnb +bitblas gguf gptqmodel int4 diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 2cbe8779dd8a..984e6626e241 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -74,6 +74,17 @@ The table below shows the compatibility of various quantization implementations * ❌ * ❌ * ❌ +- * BitBLAS (GPTQ) + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ❌ + * ❌ + * ❌ + * ❌ - * AQLM * ✅︎ * ✅︎ diff --git a/docs/source/features/quantization/torchao.md b/docs/source/features/quantization/torchao.md index 9a85f0bab9ec..82100c6ddcac 100644 --- a/docs/source/features/quantization/torchao.md +++ b/docs/source/features/quantization/torchao.md @@ -30,5 +30,4 @@ tokenizer.push_to_hub(hub_repo) quantized_model.push_to_hub(hub_repo, safe_serialization=False) ``` -Alternatively, you can use the TorchAO Quantization space for quantizing models with a simple UI. -See: https://huggingface.co/spaces/medmekk/TorchAO_Quantization +Alternatively, you can use the [TorchAO Quantization space](https://huggingface.co/spaces/medmekk/TorchAO_Quantization) for quantizing models with a simple UI. diff --git a/docs/source/features/structured_outputs.md b/docs/source/features/structured_outputs.md index de3c5bf5e7ab..03119ec7441c 100644 --- a/docs/source/features/structured_outputs.md +++ b/docs/source/features/structured_outputs.md @@ -2,8 +2,11 @@ # Structured Outputs -vLLM supports the generation of structured outputs using [outlines](https://github.com/dottxt-ai/outlines), [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), or [xgrammar](https://github.com/mlc-ai/xgrammar) as backends for the guided decoding. -This document shows you some examples of the different options that are available to generate structured outputs. +vLLM supports the generation of structured outputs using +[xgrammar](https://github.com/mlc-ai/xgrammar) or +[guidance](https://github.com/guidance-ai/llguidance) as backends. +This document shows you some examples of the different options that are +available to generate structured outputs. ## Online Serving (OpenAI API) @@ -15,10 +18,17 @@ The following parameters are supported, which must be added as extra parameters: - `guided_regex`: the output will follow the regex pattern. - `guided_json`: the output will follow the JSON schema. - `guided_grammar`: the output will follow the context free grammar. -- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding. -- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error. +- `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text. -You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page. +You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server) page. + +Structured outputs are supported by default in the OpenAI-Compatible Server. You +may choose to specify the backend to use by setting the +`--guided-decoding-backend` flag to `vllm serve`. The default backend is `auto`, +which will try to choose an appropriate backend based on the details of the +request. You may also choose a specific backend, along with +some options. A full set of options is available in the `vllm serve --help` +text. Now let´s see an example for each of the cases, starting with the `guided_choice`, as it´s the easiest one: @@ -50,7 +60,7 @@ completion = client.chat.completions.create( "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", } ], - extra_body={"guided_regex": "\w+@\w+\.com\n", "stop": ["\n"]}, + extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]}, ) print(completion.choices[0].message.content) ``` @@ -96,26 +106,29 @@ print(completion.choices[0].message.content) ``` :::{tip} -While not strictly necessary, normally it´s better to indicate in the prompt that a JSON needs to be generated and which fields and how should the LLM fill them. -This can improve the results notably in most cases. +While not strictly necessary, normally it´s better to indicate in the prompt the +JSON schema and how the fields should be populated. This can improve the +results notably in most cases. ::: -Finally we have the `guided_grammar`, which probably is the most difficult one to use but it´s really powerful, as it allows us to define complete languages like SQL queries. -It works by using a context free EBNF grammar, which for example we can use to define a specific format of simplified SQL queries, like in the example below: +Finally we have the `guided_grammar` option, which is probably the most +difficult to use, but it´s really powerful. It allows us to define complete +languages like SQL queries. It works by using a context free EBNF grammar. +As an example, we can use to define a specific format of simplified SQL queries: ```python simplified_sql_grammar = """ - ?start: select_statement + root ::= select_statement - ?select_statement: "SELECT " column_list " FROM " table_name + select_statement ::= "SELECT " column " from " table " where " condition - ?column_list: column_name ("," column_name)* + column ::= "col_1 " | "col_2 " - ?table_name: identifier + table ::= "table_1 " | "table_2 " - ?column_name: identifier + condition ::= column "= " number - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + number ::= "1 " | "2 " """ completion = client.chat.completions.create( @@ -226,6 +239,8 @@ Step #2: explanation="Next, let's isolate 'x' by dividing both sides of the equa Answer: x = -29/8 ``` +An example of using `structural_tag` can be found here: + ## Offline Inference Offline inference allows for the same types of guided decoding. @@ -236,11 +251,11 @@ The main available options inside `GuidedDecodingParams` are: - `regex` - `choice` - `grammar` -- `backend` -- `whitespace_pattern` +- `structural_tag` -These parameters can be used in the same way as the parameters from the Online Serving examples above. -One example for the usage of the `choices` parameter is shown below: +These parameters can be used in the same way as the parameters from the Online +Serving examples above. One example for the usage of the `choice` parameter is +shown below: ```python from vllm import LLM, SamplingParams diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index 8b8bbd28d348..f98ec6108cea 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -152,12 +152,14 @@ Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_cha Supported models: -* `meta-llama/Meta-Llama-3.1-8B-Instruct` -* `meta-llama/Meta-Llama-3.1-70B-Instruct` -* `meta-llama/Meta-Llama-3.1-405B-Instruct` -* `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8` +All Llama 3.1, 3.2 and 4 models should be supported. + +* `meta-llama/Llama-3.1-*` +* `meta-llama/Llama-3.2-*` +* `meta-llama/Llama-4-*` + +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) in Llama-3.2 models, see the `pythonic` tool parser below. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: @@ -166,10 +168,20 @@ Known issues: 2. The model can generate parameters with a wrong format, such as generating an array serialized as string instead of an array. -The `tool_chat_template_llama3_json.jinja` file contains the "official" Llama chat template, but tweaked so that -it works better with vLLM. +VLLM provides two JSON based chat templates for Llama 3.1 and 3.2: + +* `examples/tool_chat_template_llama3.1_json.jinja` - this is the "official" chat template for the Llama 3.1 +models, but tweaked so that it works better with vLLM. +* `examples/tool_chat_template_llama3.2_json.jinja` - this extends upon the Llama 3.1 chat template by adding support for +images. + +Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` + +VLLM also provides a JSON based chat template for Llama 4: +* `examples/tool_chat_template_llama4_json.jinja` - this is based on the "official" chat template for the Llama 4 +models, but tweaked so that it works better with vLLM. -Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` +For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. #### IBM Granite diff --git a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md b/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md index e3046f35ee15..78938de317c4 100644 --- a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md @@ -13,11 +13,11 @@ There are no pre-built wheels or images for this device, so you must build vLLM - Intel Gaudi accelerator - Intel Gaudi software version 1.18.0 -Please follow the instructions provided in the [Gaudi Installation -Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) +Please follow the instructions provided in the +[Gaudi Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html) to set up the execution environment. To achieve the best performance, -please follow the methods outlined in the [Optimizing Training Platform -Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). +please follow the methods outlined in the +[Optimizing Training Platform Guide](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html). ## Configure a new environment @@ -32,15 +32,13 @@ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloade pip list | grep neural # verify that neural_compressor is installed ``` -Refer to [Intel Gaudi Software Stack -Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) +Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) for more details. ### Run Docker Image It is highly recommended to use the latest Docker image from Intel Gaudi -vault. Refer to the [Intel Gaudi -documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) +vault. Refer to the [Intel Gaudi documentation](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers) for more details. Use the following commands to run a Docker image: @@ -278,8 +276,9 @@ Lower value corresponds to less usable graph memory reserved for prefill stage, ::: User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: -\- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode -\- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt + +- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode +- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. @@ -326,8 +325,7 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi - We recommend running inference on Gaudi 2 with `block_size` of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine - under-utilization (see [Gaudi - Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). + under-utilization (see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html)). - For max throughput on Llama 7B, we recommend running with batch size of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section. @@ -336,11 +334,11 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi **Diagnostic and profiling knobs:** -- `VLLM_PROFILER_ENABLED`: if `true`, high level profiler will be enabled. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). Disabled by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: if `true`, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside `PT_HPU_METRICS_GC_DETAILS=1`. Disabled by default. -- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: if `true`, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: if `true`, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default. -- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default. +- `VLLM_PROFILER_ENABLED`: If `true`, enable the high level profiler. Resulting JSON traces can be viewed in [perfetto.habana.ai](https://perfetto.habana.ai/#!/viewer). `false` by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION`: If `true`, log graph compilations for each vLLM engine step when any occurs. Highly recommended to use with `PT_HPU_METRICS_GC_DETAILS=1`. `false` by default. +- `VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL`: If `true`, always log graph compilations for each vLLM engine step even if none occurred. `false` by default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS`: If `true`, log CPU fallbacks for each vLLM engine step when any occurs. `false` by default. +- `VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL`: if `true`, always log CPU fallbacks for each vLLM engine step even if none occurred. `false` by default. **Performance tuning knobs:** @@ -381,7 +379,7 @@ INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of devi Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution: -- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used, `1` is default +- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used; if `1`, PyTorch Lazy backend for Gaudi will be used. `1` is default. - `PT_HPU_ENABLE_LAZY_COLLECTIVES`: required to be `true` for tensor parallel inference with HPU Graphs ## Troubleshooting: tweaking HPU graphs diff --git a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md b/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md index beb803cf0597..8beb92ef7da0 100644 --- a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md @@ -44,7 +44,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu You can provision Cloud TPUs using the [Cloud TPU API](https://cloud.google.com/tpu/docs/reference/rest) or the [queued resources](https://cloud.google.com/tpu/docs/queued-resources) -API. This section shows how to create TPUs using the queued resource API. For +API (preferred). This section shows how to create TPUs using the queued resource API. For more information about using the Cloud TPU API, see [Create a Cloud TPU using the Create Node API](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#create-node-api). Queued resources enable you to request Cloud TPU resources in a queued manner. When you request queued resources, the request is added to a queue maintained by @@ -97,10 +97,10 @@ gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ `TPU regions and zones `_ - * ACCELERATOR_TYPE * The TPU version you want to use. Specify the TPU version, for example - `v5litepod-4` specifies a v5e TPU with 4 cores. For more information, - see `TPU versions `_. + `v5litepod-4` specifies a v5e TPU with 4 cores, `v6e-1` specifies a v6e TPU with 1 core. For more information, + see [TPU versions](https://cloud.devsite.corp.google.com/tpu/docs/system-architecture-tpu-vm#versions). - * RUNTIME_VERSION - * The TPU VM runtime version to use. For more information see `TPU VM images `_. + * The TPU VM runtime version to use. For example, use `v2-alpha-tpuv6e` for a VM loaded with one or more v6e TPU(s). For more information see [TPU VM images](https://cloud.google.com/tpu/docs/runtimes). - * SERVICE_ACCOUNT * The email address for your service account. You can find it in the IAM Cloud Console under *Service Accounts*. For example: diff --git a/docs/source/getting_started/installation/cpu.md b/docs/source/getting_started/installation/cpu.md index db22ef79c926..2c0ec60d7100 100644 --- a/docs/source/getting_started/installation/cpu.md +++ b/docs/source/getting_started/installation/cpu.md @@ -272,7 +272,7 @@ $ python examples/offline_inference/basic/basic.py - Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. -- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance. +- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance. - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: diff --git a/docs/source/getting_started/installation/cpu/build.inc.md b/docs/source/getting_started/installation/cpu/build.inc.md index 39d9dfbd2b2e..f385f3d5b198 100644 --- a/docs/source/getting_started/installation/cpu/build.inc.md +++ b/docs/source/getting_started/installation/cpu/build.inc.md @@ -2,7 +2,7 @@ First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as ```console sudo apt-get update -y -sudo apt-get install -y gcc-12 g++-12 libnuma-dev +sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 ``` @@ -26,3 +26,9 @@ Finally, build and install vLLM CPU backend: ```console VLLM_TARGET_DEVICE=cpu python setup.py install ``` + +If you want to develop vllm, install it in editable mode instead. + +```console +VLLM_TARGET_DEVICE=cpu python setup.py develop +``` diff --git a/docs/source/getting_started/installation/gpu/cuda.inc.md b/docs/source/getting_started/installation/gpu/cuda.inc.md index d3e375aec10c..46bdb08ebb77 100644 --- a/docs/source/getting_started/installation/gpu/cuda.inc.md +++ b/docs/source/getting_started/installation/gpu/cuda.inc.md @@ -46,7 +46,7 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe ##### Install the latest code using `pip` ```console -pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +pip install -U vllm --pre --extra-index-url https://wheels.vllm.ai/nightly ``` `--pre` is required for `pip` to consider pre-released versions. @@ -65,9 +65,11 @@ Note that the wheels are built with Python 3.8 ABI (see [PEP 425](https://peps.p Another way to install the latest code is to use `uv`: ```console -uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly +uv pip install -U vllm --extra-index-url https://wheels.vllm.ai/nightly ``` +##### Install specific revisions using `uv` + If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL: ```console @@ -151,7 +153,7 @@ git clone https://github.com/vllm-project/vllm.git cd vllm python use_existing_torch.py pip install -r requirements/build.txt -pip install -e . --no-build-isolation +pip install --no-build-isolation -e . ``` ##### Use the local cutlass for compilation diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index c41905f250f8..fbf5421eeec5 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -23,6 +23,8 @@ Currently, there are no pre-built XPU wheels. - Second, install Python packages for vLLM XPU backend building: ```console +git clone https://github.com/vllm-project/vllm.git +cd vllm pip install --upgrade pip pip install -v -r requirements/xpu.txt ``` diff --git a/docs/source/getting_started/troubleshooting.md b/docs/source/getting_started/troubleshooting.md index 87fa442e9a48..a4744827f226 100644 --- a/docs/source/getting_started/troubleshooting.md +++ b/docs/source/getting_started/troubleshooting.md @@ -24,7 +24,7 @@ To isolate the model downloading and loading issue, you can use the `--load-form ## Out of memory -If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. +If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](#reducing-memory-usage) to reduce the memory consumption. ## Generation quality changed diff --git a/docs/source/getting_started/v1_user_guide.md b/docs/source/getting_started/v1_user_guide.md index a87484c3bb04..de90b8a7851e 100644 --- a/docs/source/getting_started/v1_user_guide.md +++ b/docs/source/getting_started/v1_user_guide.md @@ -44,8 +44,8 @@ This living user guide outlines a few known **important changes and limitations* |-----------------|-----------------------------------------------------------------------------------| | **Prefix Caching** | 🚀 Optimized | | **Chunked Prefill** | 🚀 Optimized | +| **LoRA** | 🚀 Optimized | | **Logprobs Calculation** | 🟢 Functional | -| **LoRA** | 🟢 Functional ([PR #13096](https://github.com/vllm-project/vllm/pull/13096))| | **Multimodal Models** | 🟢 Functional | | **FP8 KV Cache** | 🟢 Functional on Hopper devices ([PR #15191](https://github.com/vllm-project/vllm/pull/15191))| | **Spec Decode** | 🚧 WIP ([PR #13933](https://github.com/vllm-project/vllm/pull/13933))| @@ -121,11 +121,6 @@ Although we have re-implemented and partially optimized many features and models These features are already supported in vLLM V1, but their optimization is still in progress. -- **LoRA**: LoRA is functionally working on vLLM V1 but its performance is - inferior to that of V0. The team is actively working on improving its - performance -(e.g., see [PR #13096](https://github.com/vllm-project/vllm/pull/13096)). - - **Spec Decode**: Currently, only ngram-based spec decode is supported in V1. There will be follow-up work to support other types of spec decode (e.g., see [PR #13933](https://github.com/vllm-project/vllm/pull/13933)). We will prioritize the support for Eagle, MTP compared to draft model based spec decode. diff --git a/docs/source/index.md b/docs/source/index.md index 28dc0f67d774..43b330e4b432 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -132,6 +132,7 @@ serving/integrations/index :caption: Deployment :maxdepth: 1 +deployment/security deployment/docker deployment/k8s deployment/nginx diff --git a/docs/source/models/extensions/fastsafetensor.md b/docs/source/models/extensions/fastsafetensor.md index 66cd710c97e9..531d58690014 100644 --- a/docs/source/models/extensions/fastsafetensor.md +++ b/docs/source/models/extensions/fastsafetensor.md @@ -1,5 +1,5 @@ Loading Model weights with fastsafetensors =================================================================== -Using fastsafetensor library enables loading model weights to GPU memory by leveraging GPU direct storage. See https://github.com/foundation-model-stack/fastsafetensors for more details. +Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details. For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true`` diff --git a/docs/source/models/extensions/runai_model_streamer.md b/docs/source/models/extensions/runai_model_streamer.md index 99c37876a01b..e0daa6f86dde 100644 --- a/docs/source/models/extensions/runai_model_streamer.md +++ b/docs/source/models/extensions/runai_model_streamer.md @@ -51,3 +51,29 @@ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer - :::{note} For further instructions about tunable parameters and additional parameters configurable through environment variables, read the [Environment Variables Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md). ::: + +## Sharded Model Loading + +vLLM also supports loading sharded models using Run:ai Model Streamer. This is particularly useful for large models that are split across multiple files. To use this feature, use the `--load-format runai_streamer_sharded` flag: + +```console +vllm serve /path/to/sharded/model --load-format runai_streamer_sharded +``` + +The sharded loader expects model files to follow the same naming pattern as the regular sharded state loader: `model-rank-{rank}-part-{part}.safetensors`. You can customize this pattern using the `pattern` parameter in `--model-loader-extra-config`: + +```console +vllm serve /path/to/sharded/model --load-format runai_streamer_sharded --model-loader-extra-config '{"pattern":"custom-model-rank-{rank}-part-{part}.safetensors"}' +``` + +To create sharded model files, you can use the script provided in . This script demonstrates how to save a model in the sharded format that is compatible with the Run:ai Model Streamer sharded loader. + +The sharded loader supports all the same tunable parameters as the regular Run:ai Model Streamer, including `concurrency` and `memory_limit`. These can be configured in the same way: + +```console +vllm serve /path/to/sharded/model --load-format runai_streamer_sharded --model-loader-extra-config '{"concurrency":16, "memory_limit":5368709120}' +``` + +:::{note} +The sharded loader is particularly efficient for tensor or pipeline parallel models where each worker only needs to read its own shard rather than the entire checkpoint. +::: diff --git a/docs/source/models/generative_models.md b/docs/source/models/generative_models.md index 63fc53b0e7c5..3291006ed668 100644 --- a/docs/source/models/generative_models.md +++ b/docs/source/models/generative_models.md @@ -59,7 +59,7 @@ A code example can be found here: ]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": []}'`(online). + +Here is an example to serve a model with Matryoshka Embeddings enabled. + +```text +vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}' +``` + +### Offline Inference + +You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in {class}`~vllm.PoolingParams`. + +```python +from vllm import LLM, PoolingParams + +model = LLM(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) +outputs = model.embed(["Follow the white rabbit."], + pooling_params=PoolingParams(dimensions=32)) +print(outputs[0].outputs) +``` + +A code example can be found here: + +### Online Inference + +Use the following command to start vllm server. + +```text +vllm serve jinaai/jina-embeddings-v3 --trust-remote-code +``` + +You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter. + +```text +curl http://127.0.0.1:8000/v1/embeddings \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "input": "Follow the white rabbit.", + "model": "jinaai/jina-embeddings-v3", + "encoding_format": "float", + "dimensions": 32 + }' +``` + +Expected output: + +```json +{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} +``` + +A openai client example can be found here: diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 6cfd68a30e6e..bc68e34832cc 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -40,29 +40,37 @@ You can force the use of `TransformersForCausalLM` by setting `model_impl="trans vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. ::: -#### Supported features +#### Custom models -The Transformers modeling backend explicitly supports the following features: +If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! -- (except GGUF) -- -- +For a model to be compatible with the Transformers backend for vLLM it must: -#### Remote Code +- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): + * The model directory must have the correct structure (e.g. `config.json` is present). + * `config.json` must contain `auto_map.AutoModel`. +- be a Transformers backend for vLLM compatible model (see ): + * Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). -If your model is neither supported natively by vLLM or Transformers, you can still run it in vLLM! +If the compatible model is: -Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers. -Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM! +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for or `--trust-remode-code` for the . +- in a local directory, simply pass directory path to `model=` for or `vllm serve ` for the . -```python -from vllm import LLM -llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model -llm.apply_model(lambda model: print(model.__class__)) -``` +This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! + +(writing-custom-models)= + +#### Writing custom models + +This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). To make your model compatible with the Transformers backend, it needs: +1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. +2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. +3. `MyModel` must contain `_supports_attention_backend = True`. + ```{code-block} python :caption: modeling_my_model.py @@ -71,7 +79,7 @@ from torch import nn class MyAttention(nn.Module): - def forward(self, hidden_states, **kwargs): # <- kwargs are required + def forward(self, hidden_states, **kwargs): ... attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -87,11 +95,11 @@ class MyModel(PreTrainedModel): _supports_attention_backend = True ``` -Here is what happens in the background: +Here is what happens in the background when this model is loaded: -1. The config is loaded -2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. -3. The `TransformersForCausalLM` backend is used. See , which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. +1. The config is loaded. +2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. +3. `MyModel` is loaded into `TransformersForCausalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! @@ -129,7 +137,7 @@ class MyConfig(PretrainedConfig): ### Hugging Face Hub -By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). +By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). To change the download path for models, you can set the `HF_HOME` environment variable; for more details, refer to [their official documentation](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome). To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository. If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. @@ -213,6 +221,16 @@ output = llm.encode("Hello, my name is") print(output) ``` +(feature-status-legend)= + +## Feature Status Legend + +- ✅︎ indicates that the feature is supported for the model. + +- 🚧 indicates that the feature is planned but not yet supported for the model. + +- ⚠️ indicates that the feature is available but may have known issues or limitations. + (supported-text-models)= ## List of Text-only Language Models @@ -314,7 +332,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `GemmaForCausalLM` * Gemma - * `google/gemma-2b`, `google/gemma-7b`, etc. + * `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. * ✅︎ * ✅︎ - * `Gemma2ForCausalLM` @@ -334,7 +352,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Glm4ForCausalLM` * GLM-4-0414 - * `THUDM/GLM-4-32B-Chat-0414`, etc. + * `THUDM/GLM-4-32B-0414`, etc. * ✅︎ * ✅︎ - * `GPT2LMHeadModel` @@ -497,6 +515,11 @@ See [this page](#generative-models) for more information on how to use generativ * `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. * * ✅︎ +- * `Plamo2ForCausalLM` + * PLaMo2 + * `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. + * + * - * `QWenLMHeadModel` * Qwen * `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. @@ -735,6 +758,11 @@ If your model is not in the above list, we will try to automatically convert the * `BAAI/bge-reranker-v2-m3`, etc. * * +- * `ModernBertForSequenceClassification` + * ModernBert-based + * `Alibaba-NLP/gte-reranker-modernbert-base`, etc. + * + * ::: (supported-mm-models)= @@ -765,6 +793,8 @@ or `--limit-mm-per-prompt` (online serving). For example, to enable passing up t Offline inference: ```python +from vllm import LLM + llm = LLM( model="Qwen/Qwen2-VL-7B-Instruct", limit_mm_per_prompt={"image": 4}, @@ -774,7 +804,7 @@ llm = LLM( Online serving: ```bash -vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4 +vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' ``` **This is no longer required if you are using vLLM V1.** @@ -865,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `GraniteSpeechForConditionalGeneration` + * Granite Speech + * T + A + * `ibm-granite/granite-speech-3.3-8b` + * ✅︎ + * ✅︎ + * ✅︎ - * `H2OVLChatModel` * H2OVL * T + IE+ @@ -886,6 +923,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `KimiVLForConditionalGeneration` + * Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking + * T + I+ + * `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` + * + * + * ✅︎ - * `Llama4ForConditionalGeneration` * Llama 4 * T + I+ @@ -990,7 +1034,7 @@ See [this page](#generative-models) for more information on how to use generativ * `microsoft/Phi-4-multimodal-instruct`, etc. * ✅︎ * - * + * ✅︎ - * `PixtralForConditionalGeneration` * Pixtral * T + I+ @@ -1026,6 +1070,13 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `Qwen2_5OmniThinkerForConditionalGeneration` + * Qwen2.5-Omni + * T + IE+ + VE+ + A+ + * `Qwen/Qwen2.5-Omni-7B` + * + * ✅︎ + * ✅︎\* - * `SkyworkR1VChatModel` * Skywork-R1V-38B * T + I @@ -1057,7 +1108,7 @@ See [this page](#generative-models) for more information on how to use generativ :::{important} Pan-and-scan image pre-processing is currently supported on V0 (but not V1). -You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`. +You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": true}'`. ::: :::{warning} @@ -1072,7 +1123,7 @@ V0 correctly implements the model's attention pattern: V1 currently uses a simplified attention pattern: - Uses causal attention for all tokens, including image tokens -- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": True}` +- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` - Will be updated in the future to support the correct behavior This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. @@ -1086,6 +1137,36 @@ This limitation exists because the model's mixed attention pattern (bidirectiona To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. ::: +:::{warning} +The output quality of `AllenAI/Molmo-7B-D-0924` (especially in object localization tasks) has deteriorated in recent updates. + +For the best results, we recommend using the following dependency versions (tested on A10 and L40): + +```text +# Core vLLM-compatible dependencies with Molmo accuracy setup (tested on L40) +torch==2.5.1 +torchvision==0.20.1 +transformers==4.48.1 +tokenizers==0.21.0 +tiktoken==0.7.0 +vllm==0.7.0 + +# Optional but recommended for improved performance and stability +triton==3.1.0 +xformers==0.0.28.post3 +uvloop==0.21.0 +protobuf==5.29.3 +openai==1.60.2 +opencv-python-headless==4.11.0.86 +pillow==10.4.0 + +# Installed FlashAttention (for float16 only) +flash-attn>=2.5.6 # Not used in float32, but should be documented +``` + +**Note:** Make sure you understand the security implications of using outdated packages. +::: + :::{note} The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. For more details, please see: @@ -1095,6 +1176,14 @@ For more details, please see: Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. ::: +:::{note} +To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via +`pip install git+https://github.com/huggingface/transformers.git`. + +Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. +`--mm-processor-kwargs '{"use_audio_in_video": true}'`. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. diff --git a/docs/source/serving/distributed_serving.md b/docs/source/serving/distributed_serving.md index 591acc2c9b75..c285ef3e8e1c 100644 --- a/docs/source/serving/distributed_serving.md +++ b/docs/source/serving/distributed_serving.md @@ -77,6 +77,10 @@ bash run_cluster.sh \ Then you get a ray cluster of **containers**. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses. +:::{warning} +It is considered best practice to set `VLLM_HOST_IP` to an address on a private network segment for the vLLM cluster. The traffic sent here is not encrypted. The endpoints are also exchanging data in a format that could be exploited to execute arbitrary code should a malicious party gain access to the network. Please ensure that this network is not reachable by any untrusted parties. +::: + :::{warning} Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`. ::: diff --git a/docs/source/serving/engine_args.md b/docs/source/serving/engine_args.md index e9943571a40a..97ea01cd3b2e 100644 --- a/docs/source/serving/engine_args.md +++ b/docs/source/serving/engine_args.md @@ -16,6 +16,7 @@ Below, you can find an explanation of every engine argument: :func: _engine_args_parser :prog: vllm serve :nodefaultconst: + :markdownhelp: ``` ## Async Engine Arguments @@ -29,4 +30,5 @@ Additional arguments are available to the asynchronous engine which is used for :func: _async_engine_args_parser :prog: vllm serve :nodefaultconst: + :markdownhelp: ``` diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index f45d36c3ccac..d9a093e8d145 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server: ```bash vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ - --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' ``` Then, you can use the OpenAI client as follows: diff --git a/docs/source/serving/offline_inference.md b/docs/source/serving/offline_inference.md index 85f2cafacdd3..894878ed14e7 100644 --- a/docs/source/serving/offline_inference.md +++ b/docs/source/serving/offline_inference.md @@ -28,6 +28,8 @@ Please refer to the above pages for more details about each API. [API Reference](/api/offline_inference/index) ::: +(configuration-options)= + ## Configuration Options This section lists the most common options for running the vLLM engine. @@ -59,6 +61,8 @@ model = LLM( Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM. +(reducing-memory-usage)= + ### Reducing memory usage Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem. @@ -81,6 +85,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable. ::: +:::{note} +With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). + +You can convert the model checkpoint to a sharded checkpoint using . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. +::: + #### Quantization Quantized models take less memory at the cost of lower precision. @@ -103,6 +113,39 @@ llm = LLM(model="adept/fuyu-8b", max_num_seqs=2) ``` +#### Reduce CUDA Graphs + +By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU. + +:::{important} +CUDA graph capture takes up more memory in V1 than in V0. +::: + +You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage: + +```python +from vllm import LLM +from vllm.config import CompilationConfig, CompilationLevel + +llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + # By default, it goes up to max_num_seqs + cudagraph_capture_sizes=[1, 2, 4, 8, 16], + ), +) +``` + +You can disable graph capturing completely via the `enforce_eager` flag: + +```python +from vllm import LLM + +llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True) +``` + #### Adjust cache size If you run out of CPU RAM, try the following options: @@ -110,16 +153,25 @@ If you run out of CPU RAM, try the following options: - (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). -#### Disable unused modalities +#### Multi-modal input limits -You can disable unused modalities (except for text) by setting its limit to zero. +You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model: + +```python +from vllm import LLM + +# Accept up to 3 images and 1 video per prompt +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"image": 3, "video": 1}) +``` +You can go a step further and disable unused modalities completely by setting its limit to zero. For example, if your application only accepts image input, there is no need to allocate any memory for videos. ```python from vllm import LLM -# Accept images but not videos +# Accept any number of images but no videos llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", limit_mm_per_prompt={"video": 0}) ``` @@ -134,6 +186,29 @@ llm = LLM(model="google/gemma-3-27b-it", limit_mm_per_prompt={"image": 0}) ``` +#### Multi-modal processor arguments + +For certain models, you can adjust the multi-modal processor arguments to +reduce the size of the processed multi-modal inputs, which in turn saves memory. + +Here are some examples: + +```python +from vllm import LLM + +# Available for Qwen2-VL series models +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={ + "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 + }) + +# Available for InternVL series models +llm = LLM(model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={ + "max_dynamic_patch": 4, # Default is 12 + }) +``` + ### Performance optimization and tuning You can potentially improve the performance of vLLM by finetuning various options. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 11ca571c684a..34382c87a484 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -33,11 +33,13 @@ print(completion.choices[0].message) vLLM supports some parameters that are not supported by OpenAI, `top_k` for example. You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`. ::: + :::{important} By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. To disable this behavior, please pass `--generation-config vllm` when launching the server. ::: + ## Supported APIs We currently support the following OpenAI APIs: @@ -172,6 +174,12 @@ print(completion._request_id) The `vllm serve` command is used to launch the OpenAI-compatible server. +:::{tip} +The vast majority of command-line arguments are based on those for offline inference. + +See [here](configuration-options) for some common options. +::: + :::{argparse} :module: vllm.entrypoints.openai.cli_args :func: create_parser_for_docs @@ -394,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. ::: +Code example: -Code example: +#### Extra Parameters + +The following [sampling parameters](#sampling-params) are supported. + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-transcription-sampling-params +:end-before: end-transcription-sampling-params +::: + +The following extra parameters are supported: + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-transcription-extra-params +:end-before: end-transcription-extra-params +::: (tokenizer-api)= diff --git a/examples/lmcache/README.md b/examples/lmcache/README.md new file mode 100644 index 000000000000..7d0c23f529bb --- /dev/null +++ b/examples/lmcache/README.md @@ -0,0 +1,56 @@ +# LMCache Examples + +This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing. + +## 1. Disaggregated Prefill in vLLM v1 + +This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node. + +### Prerequisites + +- Install [LMCache](https://github.com/LMCache/LMCache). You can simply run `pip install lmcache`. +- Install [NIXL](https://github.com/ai-dynamo/nixl). +- At least 2 GPUs +- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct. + +### Usage + +Run +`cd disagg_prefill_lmcache_v1` +to get into `disagg_prefill_lmcache_v1` folder, and then run + +```bash +bash disagg_example_nixl.sh +``` + +to run disaggregated prefill and benchmark the performance. + +### Components + +#### Server Scripts +- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server. +- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder +- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example + +#### Configuration +- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server +- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server + +#### Log Files +The main script generates several log files: +- `prefiller.log` - Logs from the prefill server +- `decoder.log` - Logs from the decode server +- `proxy.log` - Logs from the proxy server + +## 2. CPU Offload Examples + +- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0 +- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1 + +## 3. KV Cache Sharing + +The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances. + +## 4. Disaggregated Prefill in vLLM v0 + +The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0. diff --git a/examples/lmcache/cpu_offload_lmcache_v0.py b/examples/lmcache/cpu_offload_lmcache_v0.py new file mode 100644 index 000000000000..37aea281032f --- /dev/null +++ b/examples/lmcache/cpu_offload_lmcache_v0.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of cpu offloading +with LMCache. + +Note that `lmcache` is needed to run this example. +Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1 +Learn more about LMCache environment setup, please refer to: +https://docs.lmcache.ai/getting_started/installation.html +""" +import contextlib +import os +import time + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def setup_environment_variables(): + # LMCache-related environment variables + # Use experimental features in LMCache + os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" + # LMCache is set to use 256 tokens per chunk + os.environ["LMCACHE_CHUNK_SIZE"] = "256" + # Enable local CPU backend in LMCache + os.environ["LMCACHE_LOCAL_CPU"] = "True" + # Set local CPU memory limit to 5.0 GB + os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + + +@contextlib.contextmanager +def build_llm_with_lmcache(): + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392). + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + enable_chunked_prefill=True, + gpu_memory_utilization=0.8) + + try: + yield llm + finally: + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def print_output( + llm: LLM, + prompt: list[str], + sampling_params: SamplingParams, + req_str: str, +): + start = time.time() + outputs = llm.generate(prompt, sampling_params) + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print(f"Generation took {time.time() - start:.2f} seconds, " + f"{req_str} request done.") + print("-" * 50) + + +def main(): + setup_environment_variables() + + with build_llm_with_lmcache() as llm: + + # This example script runs two requests with a shared prefix. + # Define the shared prompt and specific prompts + shared_prompt = "Hello, how are you?" * 1000 + first_prompt = [ + shared_prompt + "Hello, my name is", + ] + second_prompt = [ + shared_prompt + "Tell me a very long story", + ] + + sampling_params = SamplingParams(temperature=0, + top_p=0.95, + max_tokens=10) + + # Print the first output + print_output(llm, first_prompt, sampling_params, "first") + + time.sleep(1) + + # print the second output + print_output(llm, second_prompt, sampling_params, "second") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/cpu_offload_lmcache.py b/examples/lmcache/cpu_offload_lmcache_v1.py similarity index 76% rename from examples/offline_inference/cpu_offload_lmcache.py rename to examples/lmcache/cpu_offload_lmcache_v1.py index 8211629b24ec..f44075a36965 100644 --- a/examples/offline_inference/cpu_offload_lmcache.py +++ b/examples/lmcache/cpu_offload_lmcache_v1.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """ This file demonstrates the example usage of cpu offloading -with LMCache. +with LMCache in vLLM v1. -Note that `pip install lmcache` is needed to run this example. +Note that lmcache needs to be installed to run this example. Learn more about LMCache in https://github.com/LMCache/LMCache. """ import os -import time from lmcache.experimental.cache_engine import LMCacheEngineBuilder from lmcache.integration.vllm.utils import ENGINE_NAME @@ -37,29 +36,22 @@ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. # Note that LMCache is not compatible with chunked prefill for now. -llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", +llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", kv_transfer_config=ktc, max_model_len=8000, - enable_chunked_prefill=False, gpu_memory_utilization=0.8) +# Should be able to see logs like the following: +# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0` +# This indicates that the KV cache has been stored in LMCache. outputs = llm.generate(first_prompt, sampling_params) for output in outputs: generated_text = output.outputs[0].text print(f"Generated text: {generated_text!r}") -print("First request done.") - -time.sleep(1) - -outputs = llm.generate(second_prompt, sampling_params) -for output in outputs: - generated_text = output.outputs[0].text - print(f"Generated text: {generated_text!r}") -print("Second request done.") # Clean up lmcache backend LMCacheEngineBuilder.destroy(ENGINE_NAME) diff --git a/examples/offline_inference/disaggregated_prefill_lmcache.py b/examples/lmcache/disagg_prefill_lmcache_v0.py similarity index 98% rename from examples/offline_inference/disaggregated_prefill_lmcache.py rename to examples/lmcache/disagg_prefill_lmcache_v0.py index 5c84bbfc92c5..7da6fb7aaa23 100644 --- a/examples/offline_inference/disaggregated_prefill_lmcache.py +++ b/examples/lmcache/disagg_prefill_lmcache_v0.py @@ -38,6 +38,10 @@ # `naive` indicates using raw bytes of the tensor without any compression os.environ["LMCACHE_REMOTE_SERDE"] = "naive" +prompts = [ + "Hello, how are you?" * 1000, +] + def run_prefill(prefill_done, prompts): # We use GPU 0 for prefill node. @@ -106,12 +110,7 @@ def run_lmcache_server(port): return server_proc -if __name__ == "__main__": - - prompts = [ - "Hello, how are you?" * 1000, - ] - +def main(): prefill_done = Event() prefill_process = Process(target=run_prefill, args=(prefill_done, prompts)) decode_process = Process(target=run_decode, args=(prefill_done, prompts)) @@ -128,3 +127,7 @@ def run_lmcache_server(port): prefill_process.terminate() lmcache_server_process.terminate() lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml new file mode 100644 index 000000000000..c3f5a0ae69c0 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "receiver" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml new file mode 100644 index 000000000000..8b0e82958a64 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "sender" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "cuda" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh new file mode 100644 index 000000000000..df8a41293504 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh @@ -0,0 +1,136 @@ +#!/bin/bash + +echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change." + + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # can you check if the number of GPUs are >=2 via nvidia-smi? + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + python -c "import $1" > /dev/null 2>&1 + if [ $? -ne 0 ]; then + if [ "$1" == "nixl" ]; then + echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation." + else + echo "$1 is not installed. Please install it via pip install $1." + fi + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == “this whole process-group” + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=1200 + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server" + return 1 + fi + + sleep 1 + done +} + + +main() { + check_hf_token + check_num_gpus + ensure_python_library_installed lmcache + ensure_python_library_installed nixl + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching prefiller, decoder and proxy..." + echo "Please check prefiller.log, decoder.log and proxy.log for logs." + + bash disagg_vllm_launcher.sh prefiller \ + > >(tee prefiller.log) 2>&1 & + prefiller_pid=$! + PIDS+=($prefiller_pid) + + bash disagg_vllm_launcher.sh decoder \ + > >(tee decoder.log) 2>&1 & + decoder_pid=$! + PIDS+=($decoder_pid) + + python3 disagg_proxy_server.py \ + --host localhost \ + --port 9000 \ + --prefiller-host localhost \ + --prefiller-port 8100 \ + --decoder-host localhost \ + --decoder-port 8200 \ + > >(tee proxy.log) 2>&1 & + proxy_pid=$! + PIDS+=($proxy_pid) + + wait_for_server 8100 + wait_for_server 8200 + wait_for_server 9000 + + echo "All servers are up. Starting benchmark..." + + # begin benchmark + cd ../../../benchmarks/ + python benchmark_serving.py --port 9000 --seed $(date +%s) \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup + +} + +main \ No newline at end of file diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py new file mode 100644 index 000000000000..8db93bc8931b --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['max_tokens'] = 1 + if 'max_completion_tokens' in req_data: + req_data['max_completion_tokens'] = 1 + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh new file mode 100644 index 000000000000..831ef0bb574b --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=0 \ + vllm serve $MODEL \ + --port 8100 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml + + UCX_TLS=cuda_ipc,cuda_copy,tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + LMCACHE_USE_EXPERIMENTAL=True \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + CUDA_VISIBLE_DEVICES=1 \ + vllm serve $MODEL \ + --port 8200 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode" + exit 1 +fi diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py new file mode 100644 index 000000000000..af1b4351dd54 --- /dev/null +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of remote KV cache sharing +with LMCache. +We will launch 2 vllm instances, and launch an additional LMCache server. +KV cache is transferred in the following manner: +(1) vLLM instance 1 -> LMCache server (KV cache store). +(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). + +Note that lmcache needs to be installed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.experimental.cache_engine import LMCacheEngineBuilder +from lmcache.integration.vllm.utils import ENGINE_NAME + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# The port to start LMCache server +port = 8100 +# Use experimental features in LMCache +os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server +os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" + +prompts = [ + "Hello, how are you?" * 1000, +] + + +def run_store(store_done, prompts): + # We use GPU 0 for KV cache store process. + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print("KV cache store is finished.") + store_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_retrieve(store_done, prompts, timeout=1): + # We use GPU 1 for KV cache retrieve process. + os.environ["CUDA_VISIBLE_DEVICES"] = "1" + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enforce_eager=True) + + print("Waiting for KV cache store to finish...") + store_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + server_proc = subprocess.Popen([ + "python", "-m", "lmcache.experimental.server", "localhost", + str(port) + ]) + return server_proc + + +def main(): + store_done = Event() + store_process = Process(target=run_store, args=(store_done, prompts)) + retrieve_process = Process(target=run_retrieve, args=(store_done, prompts)) + lmcache_server_process = run_lmcache_server(port) + + # Start KV cache store process + store_process.start() + + # Start KV cache retrieve process + retrieve_process.start() + + # Clean up the processes + store_process.join() + retrieve_process.terminate() + lmcache_server_process.terminate() + lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 248090474de6..bab41c915c32 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple): # Unless specified, these settings have been tested to work on a single L4. +# Granite Speech +def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: + # NOTE - the setting in this example are somehat different than what is + # optimal for granite speech, and it is generally recommended to use beam + # search. Check the model README for suggested settings. + # https://huggingface.co/ibm-granite/granite-speech-3.3-8b + model_name = "ibm-granite/granite-speech-3.3-8b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=2048, + max_num_seqs=2, + enable_lora=True, + max_lora_rank=64, + limit_mm_per_prompt={"audio": audio_count}, + ) + + # The model has an audio-specific lora directly in its model dir; + # it should be enabled whenever you pass audio inputs to the model. + speech_lora_path = model_name + audio_placeholder = "<|audio|>" * audio_count + prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501 + + return ModelRequestData( + engine_args=engine_args, + prompt=prompts, + lora_requests=[LoRARequest("speech", 1, speech_lora_path)], + ) + + # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" @@ -89,7 +120,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=12800, max_num_seqs=2, enable_lora=True, max_lora_rank=320, @@ -130,6 +161,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: ) +# Qwen2.5-Omni +def run_qwen2_5_omni(question: str, audio_count: int): + model_name = "Qwen/Qwen2.5-Omni-7B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + audio_in_prompt = "".join([ + "<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) + ]) + + default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech.") + + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # Ultravox 0.5-1B def run_ultravox(question: str, audio_count: int) -> ModelRequestData: model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" @@ -179,14 +240,43 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { + "granite_speech": run_granite_speech, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, "qwen2_audio": run_qwen2_audio, + "qwen2_5_omni": run_qwen2_5_omni, "ultravox": run_ultravox, "whisper": run_whisper, } +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'audio language models') + parser.add_argument('--model-type', + '-m', + type=str, + default="ultravox", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument('--num-prompts', + type=int, + default=1, + help='Number of prompts to run.') + parser.add_argument("--num-audios", + type=int, + default=1, + choices=[0, 1, 2], + help="Number of audio items per prompt.") + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + + return parser.parse_args() + + def main(args): model = args.model_type if model not in model_example_map: @@ -240,28 +330,5 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'audio language models') - parser.add_argument('--model-type', - '-m', - type=str, - default="ultravox", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument('--num-prompts', - type=int, - default=1, - help='Number of prompts to run.') - parser.add_argument("--num-audios", - type=int, - default=1, - choices=[0, 1, 2], - help="Number of audio items per prompt.") - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 2ba5ec1192b1..ae5ae7cb4834 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -12,16 +12,23 @@ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -# Create an LLM. -llm = LLM(model="facebook/opt-125m") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -print("\nGenerated Outputs:\n" + "-" * 60) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}") - print(f"Output: {generated_text!r}") - print("-" * 60) \ No newline at end of file + +def main(): + # Create an LLM. + llm = LLM(model="facebook/opt-125m") + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index 2dea45f843cf..6857c6e9e31d 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -4,6 +4,24 @@ from vllm.utils import FlexibleArgumentParser +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + # Add example params + parser.add_argument("--chat-template-path", type=str) + + return parser + + def main(args: dict): # Pop arguments not used by LLM max_tokens = args.pop("max_tokens") @@ -82,18 +100,6 @@ def print_outputs(outputs): if __name__ == "__main__": - parser = FlexibleArgumentParser() - # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) - # Add example params - parser.add_argument("--chat-template-path", type=str) + parser = create_parser() args: dict = vars(parser.parse_args()) main(args) diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py index 72c29e4c77c3..5b6dcb41eee1 100644 --- a/examples/offline_inference/basic/classify.py +++ b/examples/offline_inference/basic/classify.py @@ -6,6 +6,16 @@ from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", + task="classify", + enforce_eager=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -34,11 +44,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach", - task="classify", - enforce_eager=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 0283909a2a84..cb5f923ffb69 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -6,6 +6,16 @@ from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", + task="embed", + enforce_eager=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -34,11 +44,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="intfloat/e5-mistral-7b-instruct", - task="embed", - enforce_eager=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/basic/generate.py b/examples/offline_inference/basic/generate.py index 93f4f2a36fac..54b52b22a45a 100644 --- a/examples/offline_inference/basic/generate.py +++ b/examples/offline_inference/basic/generate.py @@ -4,6 +4,22 @@ from vllm.utils import FlexibleArgumentParser +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + engine_group = parser.add_argument_group("Engine arguments") + EngineArgs.add_cli_args(engine_group) + engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + + return parser + + def main(args: dict): # Pop arguments not used by LLM max_tokens = args.pop("max_tokens") @@ -35,23 +51,15 @@ def main(args: dict): ] outputs = llm.generate(prompts, sampling_params) # Print the outputs. + print("-" * 50) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) if __name__ == "__main__": - parser = FlexibleArgumentParser() - # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) + parser = create_parser() args: dict = vars(parser.parse_args()) main(args) diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index 83b8253f4e25..d2bda8b3180c 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -6,6 +6,16 @@ from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="BAAI/bge-reranker-v2-m3", + task="score", + enforce_eager=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. text_1 = "What is the capital of France?" @@ -30,11 +40,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="BAAI/bge-reranker-v2-m3", - task="score", - enforce_eager=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/batch_llm_inference.py b/examples/offline_inference/batch_llm_inference.py new file mode 100644 index 000000000000..6548857b6d11 --- /dev/null +++ b/examples/offline_inference/batch_llm_inference.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to use Ray Data for data parallel batch inference. + +Ray Data is a data processing framework that can handle large datasets +and integrates tightly with vLLM for data-parallel inference. + +As of Ray 2.44, Ray Data has a native integration with +vLLM (under ray.data.llm). + +Ray Data provides functionality for: +* Reading and writing to cloud storage (S3, GCS, etc.) +* Automatic sharding and load-balancing across a cluster +* Optimized configuration of vLLM using continuous batching +* Compatible with tensor/pipeline parallel inference as well. + +Learn more about Ray Data's LLM integration: +https://docs.ray.io/en/latest/data/working-with-llms.html +""" +import ray +from packaging.version import Version +from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig + +assert Version(ray.__version__) >= Version( + "2.44.1"), "Ray version must be at least 2.44.1" + +# Uncomment to reduce clutter in stdout +# ray.init(log_to_driver=False) +# ray.data.DataContext.get_current().enable_progress_bars = False + +# Read one text file from S3. Ray Data supports reading multiple files +# from cloud storage (such as JSONL, Parquet, CSV, binary format). +ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") +print(ds.schema()) + +size = ds.count() +print(f"Size of dataset: {size} prompts") + +# Configure vLLM engine. +config = vLLMEngineProcessorConfig( + model_source="unsloth/Llama-3.1-8B-Instruct", + engine_kwargs={ + "enable_chunked_prefill": True, + "max_num_batched_tokens": 4096, + "max_model_len": 16384, + }, + concurrency=1, # set the number of parallel vLLM replicas + batch_size=64, +) + +# Create a Processor object, which will be used to +# do batch inference on the dataset +vllm_processor = build_llm_processor( + config, + preprocess=lambda row: dict( + messages=[{ + "role": "system", + "content": "You are a bot that responds with haikus." + }, { + "role": "user", + "content": row["text"] + }], + sampling_params=dict( + temperature=0.3, + max_tokens=250, + )), + postprocess=lambda row: dict( + answer=row["generated_text"], + **row # This will return all the original columns in the dataset. + ), +) + +ds = vllm_processor(ds) + +# Peek first 10 results. +# NOTE: This is for local testing and debugging. For production use case, +# one should write full result out as shown below. +outputs = ds.take(limit=10) + +for output in outputs: + prompt = output["prompt"] + generated_text = output["generated_text"] + print(f"Prompt: {prompt!r}") + print(f"Generated text: {generated_text!r}") + +# Write inference output data out as Parquet files to S3. +# Multiple files would be written to the output destination, +# and each task would write one or more files separately. +# +# ds.write_parquet("s3://") diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 04a79e2f8ae6..965915beaf58 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -34,6 +34,40 @@ from vllm.utils import get_open_port +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="Data Parallel Inference") + parser.add_argument("--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path") + parser.add_argument("--dp-size", + type=int, + default=2, + help="Data parallel size") + parser.add_argument("--tp-size", + type=int, + default=2, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") + return parser.parse_args() + + def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="Data Parallel Inference") - parser.add_argument("--model", - type=str, - default="ibm-research/PowerMoE-3b", - help="Model name or path") - parser.add_argument("--dp-size", - type=int, - default=2, - help="Data parallel size") - parser.add_argument("--tp-size", - type=int, - default=2, - help="Tensor parallel size") - parser.add_argument("--node-size", - type=int, - default=1, - help="Total number of nodes") - parser.add_argument("--node-rank", - type=int, - default=0, - help="Rank of the current node") - parser.add_argument("--master-addr", - type=str, - default="", - help="Master node IP address") - parser.add_argument("--master-port", - type=int, - default=0, - help="Master node port") - args = parser.parse_args() + + args = parse_args() dp_size = args.dp_size tp_size = args.tp_size diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py new file mode 100644 index 000000000000..66efbc0c9dee --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# Read prompts from output.txt +prompts = [] +try: + with open("output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from output.txt") +except FileNotFoundError: + print("Error: output.txt file not found") + exit(-1) + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + +llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig.from_cli( + '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' + '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' + )) #, max_model_len=2048, max_num_batched_tokens=2048) + +# 1ST generation (prefill instance) +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py new file mode 100644 index 000000000000..f7cbf6557d54 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +context = "Hi " * 1000 +context2 = "Hey " * 500 +prompts = [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", +] + +sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + +llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig.from_cli( + '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' + '"kv_connector_extra_config": ' + '{"shared_storage_path": "local_storage"}}') + ) #, max_model_len=2048, max_num_batched_tokens=2048) + +# 1ST generation (prefill instance) +outputs = llm.generate( + prompts, + sampling_params, +) + +new_prompts = [] +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +# Write new_prompts to output.txt +with open("output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") +print(f"Saved {len(new_prompts)} prompts to output.txt") diff --git a/examples/offline_inference/disaggregated-prefill-v1/run.sh b/examples/offline_inference/disaggregated-prefill-v1/run.sh new file mode 100644 index 000000000000..0ebf45a1586a --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/run.sh @@ -0,0 +1,5 @@ +rm -rf local_storage/ +rm output.txt + +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 36ee24bf7f18..d60985146c5c 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -95,7 +95,7 @@ def run_decode(prefill_done): print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -if __name__ == "__main__": +def main(): prefill_done = Event() prefill_process = Process(target=run_prefill, args=(prefill_done, )) decode_process = Process(target=run_decode, args=(prefill_done, )) @@ -109,3 +109,7 @@ def run_decode(prefill_done): # Terminate the prefill node when decode is finished decode_process.join() prefill_process.terminate() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/distributed.py b/examples/offline_inference/distributed.py deleted file mode 100644 index e890c6dad8bd..000000000000 --- a/examples/offline_inference/distributed.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This example shows how to use Ray Data for running offline batch inference -distributively on a multi-nodes cluster. - -Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html -""" - -from typing import Any - -import numpy as np -import ray -from packaging.version import Version -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -from vllm import LLM, SamplingParams - -assert Version(ray.__version__) >= Version( - "2.22.0"), "Ray version must be at least 2.22.0" - -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Set tensor parallelism per instance. -tensor_parallel_size = 1 - -# Set number of instances. Each instance will use tensor_parallel_size GPUs. -num_instances = 1 - - -# Create a class to do batch inference. -class LLMPredictor: - - def __init__(self): - # Create an LLM. - self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", - tensor_parallel_size=tensor_parallel_size) - - def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]: - # Generate texts from the prompts. - # The output is a list of RequestOutput objects that contain the prompt, - # generated text, and other information. - outputs = self.llm.generate(batch["text"], sampling_params) - prompt: list[str] = [] - generated_text: list[str] = [] - for output in outputs: - prompt.append(output.prompt) - generated_text.append(' '.join([o.text for o in output.outputs])) - return { - "prompt": prompt, - "generated_text": generated_text, - } - - -# Read one text file from S3. Ray Data supports reading multiple files -# from cloud storage (such as JSONL, Parquet, CSV, binary format). -ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") - - -# For tensor_parallel_size > 1, we need to create placement groups for vLLM -# to use. Every actor has to have its own placement group. -def scheduling_strategy_fn(): - # One bundle per tensor parallel worker - pg = ray.util.placement_group( - [{ - "GPU": 1, - "CPU": 1 - }] * tensor_parallel_size, - strategy="STRICT_PACK", - ) - return dict(scheduling_strategy=PlacementGroupSchedulingStrategy( - pg, placement_group_capture_child_tasks=True)) - - -resources_kwarg: dict[str, Any] = {} -if tensor_parallel_size == 1: - # For tensor_parallel_size == 1, we simply set num_gpus=1. - resources_kwarg["num_gpus"] = 1 -else: - # Otherwise, we have to set num_gpus=0 and provide - # a function that will create a placement group for - # each instance. - resources_kwarg["num_gpus"] = 0 - resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn - -# Apply batch inference for all input data. -ds = ds.map_batches( - LLMPredictor, - # Set the concurrency to the number of LLM instances. - concurrency=num_instances, - # Specify the batch size for inference. - batch_size=32, - **resources_kwarg, -) - -# Peek first 10 results. -# NOTE: This is for local testing and debugging. For production use case, -# one should write full result out as shown below. -outputs = ds.take(limit=10) -for output in outputs: - prompt = output["prompt"] - generated_text = output["generated_text"] - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - -# Write inference output data out as Parquet files to S3. -# Multiple files would be written to the output destination, -# and each task would write one or more files separately. -# -# ds.write_parquet("s3://") diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56f..474b745a6106 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -27,7 +27,7 @@ def load_prompts(dataset_path, num_prompts): return prompts[:num_prompts] -def main(): +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--dataset", @@ -45,10 +45,15 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) - args = parser.parse_args() + return parser.parse_args() + + +def main(): - model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" - eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" + args = parse_args() + + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" max_model_len = 2048 @@ -76,7 +81,7 @@ def main(): max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ - "method": "eagle", + "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, @@ -90,6 +95,9 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not hasattr(outputs, "metrics") or outputs.metrics is None: + return + # calculate the average number of accepted tokens per forward pass, +1 is # to account for the token from the target model that's always going to be # accepted @@ -104,6 +112,11 @@ def main(): {sum(acceptance_counts) / acceptance_counts[0]:.2f}") print("-" * 50) + # print acceptance at each token position + for i in range(len(acceptance_counts)): + print(f"acceptance at token {i}:" + f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") + if __name__ == "__main__": main() diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/embed_jina_embeddings_v3.py index f7d9e47e7953..b347ddbf3197 100644 --- a/examples/offline_inference/embed_jina_embeddings_v3.py +++ b/examples/offline_inference/embed_jina_embeddings_v3.py @@ -6,6 +6,16 @@ from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -40,11 +50,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/embed_matryoshka_fy.py index ab71fbe73e6a..7a6cb02556d9 100644 --- a/examples/offline_inference/embed_matryoshka_fy.py +++ b/examples/offline_inference/embed_matryoshka_fy.py @@ -6,6 +6,16 @@ from vllm.utils import FlexibleArgumentParser +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) + return parser.parse_args() + + def main(args: Namespace): # Sample prompts. prompts = [ @@ -38,11 +48,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - # Set example specific arguments - parser.set_defaults(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index c6ccfd42ec85..c4916e00f473 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -8,94 +8,112 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt, zip_enc_dec_prompts) -dtype = "float" - -# Create a BART encoder/decoder model instance -llm = LLM( - model="facebook/bart-large-cnn", - dtype=dtype, -) - -# Get BART tokenizer -tokenizer = llm.llm_engine.get_tokenizer_group() - -# Test prompts -# -# This section shows all of the valid ways to prompt an -# encoder/decoder model. -# -# - Helpers for building prompts -text_prompt_raw = "Hello, my name is" -text_prompt = TextPrompt(prompt="The president of the United States is") -tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( - prompt="The capital of France is")) -# - Pass a single prompt to encoder/decoder model -# (implicitly encoder input prompt); -# decoder input prompt is assumed to be None - -single_text_prompt_raw = text_prompt_raw # Pass a string directly -single_text_prompt = text_prompt # Pass a TextPrompt -single_tokens_prompt = tokens_prompt # Pass a TokensPrompt - -# - Pass explicit encoder and decoder input prompts within one data structure. -# Encoder and decoder prompts can both independently be text or tokens, with -# no requirement that they be the same prompt type. Some example prompt-type -# combinations are shown below, note that these are not exhaustive. - -enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt string directly, & - # pass decoder prompt tokens - encoder_prompt=single_text_prompt_raw, - decoder_prompt=single_tokens_prompt, -) -enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( - # Pass TextPrompt to encoder, and - # pass decoder prompt string directly - encoder_prompt=single_text_prompt, - decoder_prompt=single_text_prompt_raw, -) -enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt tokens directly, and - # pass TextPrompt to decoder - encoder_prompt=single_tokens_prompt, - decoder_prompt=single_text_prompt, -) - -# - Finally, here's a useful helper function for zipping encoder and -# decoder prompts together into a list of ExplicitEncoderDecoderPrompt -# instances -zipped_prompt_list = zip_enc_dec_prompts( - ['An encoder prompt', 'Another encoder prompt'], - ['A decoder prompt', 'Another decoder prompt']) - -# - Let's put all of the above example prompts together into one list -# which we will pass to the encoder/decoder LLM. -prompts = [ - single_text_prompt_raw, single_text_prompt, single_tokens_prompt, - enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 -] + zipped_prompt_list + +def create_prompts(tokenizer): + # Test prompts + # + # This section shows all of the valid ways to prompt an + # encoder/decoder model. + # + # - Helpers for building prompts + text_prompt_raw = "Hello, my name is" + text_prompt = TextPrompt(prompt="The president of the United States is") + tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( + prompt="The capital of France is")) + # - Pass a single prompt to encoder/decoder model + # (implicitly encoder input prompt); + # decoder input prompt is assumed to be None + + single_text_prompt_raw = text_prompt_raw # Pass a string directly + single_text_prompt = text_prompt # Pass a TextPrompt + single_tokens_prompt = tokens_prompt # Pass a TokensPrompt + + # ruff: noqa: E501 + # - Pass explicit encoder and decoder input prompts within one data structure. + # Encoder and decoder prompts can both independently be text or tokens, with + # no requirement that they be the same prompt type. Some example prompt-type + # combinations are shown below, note that these are not exhaustive. + + enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt string directly, & + # pass decoder prompt tokens + encoder_prompt=single_text_prompt_raw, + decoder_prompt=single_tokens_prompt, + ) + enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( + # Pass TextPrompt to encoder, and + # pass decoder prompt string directly + encoder_prompt=single_text_prompt, + decoder_prompt=single_text_prompt_raw, + ) + enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt tokens directly, and + # pass TextPrompt to decoder + encoder_prompt=single_tokens_prompt, + decoder_prompt=single_text_prompt, + ) + + # - Finally, here's a useful helper function for zipping encoder and + # decoder prompts together into a list of ExplicitEncoderDecoderPrompt + # instances + zipped_prompt_list = zip_enc_dec_prompts( + ['An encoder prompt', 'Another encoder prompt'], + ['A decoder prompt', 'Another decoder prompt']) + + # - Let's put all of the above example prompts together into one list + # which we will pass to the encoder/decoder LLM. + return [ + single_text_prompt_raw, single_text_prompt, single_tokens_prompt, + enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 + ] + zipped_prompt_list + # Create a sampling params object. -sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=20, -) - -# Generate output tokens from the prompts. The output is a list of -# RequestOutput objects that contain the prompt, generated -# text, and other information. -outputs = llm.generate(prompts, sampling_params) +def create_sampling_params(): + return SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, + ) + # Print the outputs. -print("-" * 50) -for i, output in enumerate(outputs): - prompt = output.prompt - encoder_prompt = output.encoder_prompt - generated_text = output.outputs[0].text - print(f"Output {i+1}:") - print(f"Encoder prompt: {encoder_prompt!r}\n" - f"Decoder prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") +def print_outputs(outputs): print("-" * 50) + for i, output in enumerate(outputs): + prompt = output.prompt + encoder_prompt = output.encoder_prompt + generated_text = output.outputs[0].text + print(f"Output {i+1}:") + print(f"Encoder prompt: {encoder_prompt!r}\n" + f"Decoder prompt: {prompt!r}\n" + f"Generated text: {generated_text!r}") + print("-" * 50) + + +def main(): + dtype = "float" + + # Create a BART encoder/decoder model instance + llm = LLM( + model="facebook/bart-large-cnn", + dtype=dtype, + ) + + # Get BART tokenizer + tokenizer = llm.llm_engine.get_tokenizer_group() + + prompts = create_prompts(tokenizer) + sampling_params = create_sampling_params() + + # Generate output tokens from the prompts. The output is a list of + # RequestOutput objects that contain the prompt, generated + # text, and other information. + outputs = llm.generate(prompts, sampling_params) + + print_outputs(outputs) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 456ee60eaabf..2883c37ca236 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -22,7 +22,7 @@ class ModelRequestData(NamedTuple): def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", - tokenizer="facebook/bart-large", + tokenizer="Isotr0py/Florence-2-tokenizer", max_num_seqs=8, trust_remote_code=True, limit_mm_per_prompt={"image": 1}, @@ -126,6 +126,23 @@ def run_whisper(): } +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for text generation') + parser.add_argument('--model-type', + '-m', + type=str, + default="mllama", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + return parser.parse_args() + + def main(args): model = args.model_type if model not in model_example_map: @@ -148,6 +165,7 @@ def main(args): temperature=0, top_p=1.0, max_tokens=64, + skip_special_tokens=False, ) start = time.time() @@ -171,19 +189,5 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for text generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="mllama", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index abff90d1c0cb..d84cd9ee9f52 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -50,6 +50,13 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine: return LLMEngine.from_engine_args(engine_args) +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using the LLMEngine class directly') + parser = EngineArgs.add_cli_args(parser) + return parser.parse_args() + + def main(args: argparse.Namespace): """Main function that sets up and runs the prompt processing.""" engine = initialize_engine(args) @@ -58,8 +65,5 @@ def main(args: argparse.Namespace): if __name__ == '__main__': - parser = FlexibleArgumentParser( - description='Demo on using the LLMEngine class directly') - parser = EngineArgs.add_cli_args(parser) - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index efa1aa5b0369..37c3181dc5fa 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -16,11 +16,11 @@ # # Mistral format # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ # --tokenizer-mode mistral --config-format mistral --load-format mistral \ -# --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384 # # # HF format # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ -# --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384 # ``` # # - Client: @@ -62,6 +62,7 @@ def run_simple_demo(args: argparse.Namespace): tokenizer_mode="mistral" if args.format == "mistral" else "auto", config_format="mistral" if args.format == "mistral" else "auto", load_format="mistral" if args.format == "mistral" else "auto", + limit_mm_per_prompt={"image": 1}, max_model_len=4096, max_num_seqs=2, tensor_parallel_size=2, @@ -168,7 +169,7 @@ def run_advanced_demo(args: argparse.Namespace): print("-" * 50) -def main(): +def parse_args(): parser = argparse.ArgumentParser( description="Run a demo in simple or advanced mode.") @@ -187,8 +188,11 @@ def main(): '--disable-mm-preprocessor-cache', action='store_true', help='If True, disables caching of multi-modal preprocessor/mapper.') + return parser.parse_args() + - args = parser.parse_args() +def main(): + args = parse_args() if args.mode == "simple": print("Running simple demo...") diff --git a/examples/offline_inference/mlpspeculator.py b/examples/offline_inference/mlpspeculator.py index a2a984b04e00..53c58a76d9dc 100644 --- a/examples/offline_inference/mlpspeculator.py +++ b/examples/offline_inference/mlpspeculator.py @@ -34,8 +34,7 @@ def time_generation(llm: LLM, prompts: list[str], print("-" * 50) -if __name__ == "__main__": - +def main(): template = ( "Below is an instruction that describes a task. Write a response " "that appropriately completes the request.\n\n### Instruction:\n{}" @@ -66,3 +65,7 @@ def time_generation(llm: LLM, prompts: list[str], ) time_generation(llm, prompts, sampling_params, "With speculation") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 3ae507cac5ce..f97a1f32e621 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -417,6 +417,38 @@ def run_model(input_data, return pred_imgs +def parse_args(): + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help= + "0-based indices of the six Prithvi channels to be selected from the " + "input. By default selects [1,2,3,8,11,12] for S2L1C data.", + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + + def main( data_file: str, output_dir: str, @@ -496,35 +528,7 @@ def main( if __name__ == "__main__": - parser = argparse.ArgumentParser("MAE run inference", add_help=False) - parser.add_argument( - "--data_file", - type=str, - default="./India_900498_S2Hand.tif", - help="Path to the file.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Path to the directory where to save outputs.", - ) - parser.add_argument( - "--input_indices", - default=[1, 2, 3, 8, 11, 12], - type=int, - nargs="+", - help= - "0-based indices of the six Prithvi channels to be selected from the " - "input. By default selects [1,2,3,8,11,12] for S2L1C data.", - ) - parser.add_argument( - "--rgb_outputs", - action="store_true", - help="If present, output files will only contain RGB channels. " - "Otherwise, all bands will be saved.", - ) - args = parser.parse_args() + args = parse_args() main(**vars(args)) diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py index 6e1d4722440a..9c818d075734 100644 --- a/examples/offline_inference/profiling.py +++ b/examples/offline_inference/profiling.py @@ -359,7 +359,7 @@ def abort_requests(): f" in folder {context.save_chrome_traces_folder}") -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser(description=""" Profile a model @@ -449,7 +449,10 @@ def abort_requests(): EngineArgs.add_cli_args(parser) - args = parser.parse_args() + return parser.parse_args() + + +def main(args): context = ProfileContext( engine_args=EngineArgs.from_cli_args(args), **{ @@ -458,3 +461,8 @@ def abort_requests(): if k in inspect.signature(ProfileContext).parameters }) run_profile(context, csv_output=args.csv, json_output=args.json) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md new file mode 100644 index 000000000000..c30541a598ce --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -0,0 +1,32 @@ +# Qwen2.5-Omni Offline Inference Examples + +This folder provides several example scripts on how to inference Qwen2.5-Omni offline. + +## Thinker Only + +```bash +# Audio + image + video +python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities + +# Read vision and audio inputs from a single video file +# NOTE: V1 engine does not support interleaved modalities yet. +VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video + +# Multiple audios +VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios +``` + +This script will run the thinker part of Qwen2.5-Omni, and generate text response. + +You can also test Qwen2.5-Omni on a single modality: + +```bash +# Process audio inputs +python examples/offline_inference/audio_language.py --model-type qwen2_5_omni + +# Process image inputs +python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni + +# Process video inputs +python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni +``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py new file mode 100644 index 000000000000..c75a990120e0 --- /dev/null +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to use vLLM for running offline inference +with the correct prompt format on Qwen2.5-Omni (thinker only). +""" + +from typing import NamedTuple + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.utils import FlexibleArgumentParser + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + +default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech.") + + +def get_mixed_modalities_query() -> QueryResult: + question = ("What is recited in the audio? " + "What is the content of this image? Why is this video funny?") + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|vision_bos|><|IMAGE|><|vision_eos|>" + "<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n") + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": + AudioAsset("mary_had_lamb").audio_and_sample_rate, + "image": + ImageAsset("cherry_blossom").pil_image.convert("RGB"), + "video": + VideoAsset(name="sample_demo_1.mp4", + num_frames=16).np_ndarrays, + }, + }, + limit_mm_per_prompt={ + "audio": 1, + "image": 1, + "video": 1 + }, + ) + + +def get_use_audio_in_video_query() -> QueryResult: + question = ("Describe the content of the video, " + "then convert what the baby say into text.") + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n") + asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16) + audio = asset.get_audio(sampling_rate=16000) + assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " + "Please launch this example with " + "`VLLM_USE_V1=0`.") + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "video": asset.np_ndarrays, + "audio": audio, + }, + "mm_processor_kwargs": { + "use_audio_in_video": True, + }, + }, + limit_mm_per_prompt={ + "audio": 1, + "video": 1 + }, + ) + + +def get_multi_audios_query() -> QueryResult: + question = "Are these two audio clips the same?" + prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + "<|audio_bos|><|AUDIO|><|audio_eos|>" + f"{question}<|im_end|>\n" + f"<|im_start|>assistant\n") + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": { + "audio": [ + AudioAsset("winning_call").audio_and_sample_rate, + AudioAsset("mary_had_lamb").audio_and_sample_rate, + ], + }, + }, + limit_mm_per_prompt={ + "audio": 2, + }, + ) + + +query_map = { + "mixed_modalities": get_mixed_modalities_query, + "use_audio_in_video": get_use_audio_in_video_query, + "multi_audios": get_multi_audios_query, +} + + +def main(args): + model_name = "Qwen/Qwen2.5-Omni-7B" + query_result = query_map[args.query_type]() + + llm = LLM(model=model_name, + max_model_len=5632, + max_num_seqs=5, + limit_mm_per_prompt=query_result.limit_mm_per_prompt, + seed=args.seed) + + # We set temperature to 0.2 so that outputs can be different + # even when all prompts are identical when running batch inference. + sampling_params = SamplingParams(temperature=0.2, max_tokens=64) + + outputs = llm.generate(query_result.inputs, + sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'audio language models') + parser.add_argument('--query-type', + '-q', + type=str, + default="mixed_modalities", + choices=query_map.keys(), + help='Query type.') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py index 6aac9b75c59c..338380cc9684 100644 --- a/examples/offline_inference/save_sharded_state.py +++ b/examples/offline_inference/save_sharded_state.py @@ -29,20 +29,23 @@ from vllm import LLM, EngineArgs from vllm.utils import FlexibleArgumentParser -parser = FlexibleArgumentParser() -EngineArgs.add_cli_args(parser) -parser.add_argument("--output", - "-o", - required=True, - type=str, - help="path to output checkpoint") -parser.add_argument("--file-pattern", - type=str, - help="string pattern of saved filenames") -parser.add_argument("--max-file-size", - type=str, - default=5 * 1024**3, - help="max size (in bytes) of each safetensors file") + +def parse_args(): + parser = FlexibleArgumentParser() + EngineArgs.add_cli_args(parser) + parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") + parser.add_argument("--file-pattern", + type=str, + help="string pattern of saved filenames") + parser.add_argument("--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file") + return parser.parse_args() def main(args): @@ -87,5 +90,5 @@ def main(args): if __name__ == "__main__": - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/simple_profiling.py b/examples/offline_inference/simple_profiling.py index 6a8e3a5a3e75..d583110c8e69 100644 --- a/examples/offline_inference/simple_profiling.py +++ b/examples/offline_inference/simple_profiling.py @@ -18,8 +18,8 @@ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -if __name__ == "__main__": +def main(): # Create an LLM. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) @@ -42,3 +42,7 @@ # Add a buffer to wait for profiler in the background process # (in case MP is on) to finish writing profiling output. time.sleep(10) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 0a4b9098da65..a19452561b5f 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -150,7 +150,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="microsoft/Florence-2-large", - tokenizer="facebook/bart-large", + tokenizer="Isotr0py/Florence-2-tokenizer", max_model_len=4096, max_num_seqs=2, trust_remote_code=True, @@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: ) +# Kimi-VL +def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [ + "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>" + f"<|media_pad|><|media_end|>{question}<|im_end|>" + "<|im_assistant|>assistant<|im_middle|>" for question in questions + ] + + engine_args = EngineArgs( + model="moonshotai/Kimi-VL-A3B-Instruct", + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # LLaVA-1.5 def run_llava(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -791,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=5120, max_num_seqs=2, + max_num_batched_tokens=12800, enable_lora=True, max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 16}, limit_mm_per_prompt={"image": 1}, ) @@ -918,6 +944,42 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Qwen2.5-Omni +def run_qwen2_5_omni(questions: list[str], modality: str): + model_name = "Qwen/Qwen2.5-Omni-7B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": [1], + }, + limit_mm_per_prompt={"image": 1}, + ) + + if modality == "image": + placeholder = "<|IMAGE|>" + elif modality == "video": + placeholder = "<|VIDEO|>" + + default_system = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba " + "Group, capable of perceiving auditory and visual inputs, as well as " + "generating text and speech.") + + prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n" + f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n") for question in questions] + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -966,6 +1028,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, + "kimi_vl": run_kimi_vl, "llava": run_llava, "llava-next": run_llava_next, "llava-next-video": run_llava_next_video, @@ -986,6 +1049,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, + "qwen2_5_omni": run_qwen2_5_omni, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, } @@ -1073,6 +1137,59 @@ def time_counter(enable: bool): yield +def parse_args(): + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for text generation') + parser.add_argument('--model-type', + '-m', + type=str, + default="llava", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument('--num-prompts', + type=int, + default=4, + help='Number of prompts to run.') + parser.add_argument('--modality', + type=str, + default="image", + choices=['image', 'video'], + help='Modality of the input.') + parser.add_argument('--num-frames', + type=int, + default=16, + help='Number of frames to extract from the video.') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") + + parser.add_argument( + '--image-repeat-prob', + type=float, + default=None, + help='Simulates the hit-ratio for multi-modal preprocessor cache' + ' (if enabled)') + + parser.add_argument( + '--disable-mm-preprocessor-cache', + action='store_true', + help='If True, disables caching of multi-modal preprocessor/mapper.') + + parser.add_argument( + '--time-generate', + action='store_true', + help='If True, then print the total generate() call time') + + parser.add_argument( + '--use-different-prompt-per-request', + action='store_true', + help='If True, then use different prompt (with the same multi-modal ' + 'data) for each request.') + return parser.parse_args() + + def main(args): model = args.model_type if model not in model_example_map: @@ -1151,55 +1268,5 @@ def main(args): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models for text generation') - parser.add_argument('--model-type', - '-m', - type=str, - default="llava", - choices=model_example_map.keys(), - help='Huggingface "model_type".') - parser.add_argument('--num-prompts', - type=int, - default=4, - help='Number of prompts to run.') - parser.add_argument('--modality', - type=str, - default="image", - choices=['image', 'video'], - help='Modality of the input.') - parser.add_argument('--num-frames', - type=int, - default=16, - help='Number of frames to extract from the video.') - parser.add_argument("--seed", - type=int, - default=None, - help="Set the seed when initializing `vllm.LLM`.") - - parser.add_argument( - '--image-repeat-prob', - type=float, - default=None, - help='Simulates the hit-ratio for multi-modal preprocessor cache' - ' (if enabled)') - - parser.add_argument( - '--disable-mm-preprocessor-cache', - action='store_true', - help='If True, disables caching of multi-modal preprocessor/mapper.') - - parser.add_argument( - '--time-generate', - action='store_true', - help='If True, then print the total generate() call time') - - parser.add_argument( - '--use-different-prompt-per-request', - action='store_true', - help='If True, then use different prompt (with the same multi-modal ' - 'data) for each request.') - - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/vision_language_embedding.py index ad3c5ae0627b..2637949551a1 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/vision_language_embedding.py @@ -156,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): print("-" * 50) -def main(args: Namespace): - run_encode(args.model_name, args.modality, args.seed) - - model_example_map = { "e5_v": run_e5_v, "vlm2vec": run_vlm2vec, } -if __name__ == "__main__": + +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' 'vision language models for multimodal embedding') @@ -184,6 +181,13 @@ def main(args: Namespace): type=int, default=None, help="Set the seed when initializing `vllm.LLM`.") + return parser.parse_args() - args = parser.parse_args() + +def main(args: Namespace): + run_encode(args.model_name, args.modality, args.seed) + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 89818f8b33ee..7f6608559f9c 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -326,6 +326,44 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "moonshotai/Kimi-VL-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=4, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name, + trust_remote_code=True) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -465,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=10000, + max_model_len=4096, max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, enable_lora=True, max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 4}, ) placeholders = "".join(f"<|image_{i}|>" @@ -640,6 +680,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: "h2ovl_chat": load_h2ovl, "idefics3": load_idefics3, "internvl_chat": load_internvl, + "kimi_vl": load_kimi_vl, "llama4": load_llama4, "mistral3": load_mistral3, "mllama": load_mllama, @@ -727,22 +768,7 @@ def run_chat(model: str, question: str, image_urls: list[str], print("-" * 50) -def main(args: Namespace): - model = args.model_type - method = args.method - seed = args.seed - - image_urls = IMAGE_URLS[:args.num_images] - - if method == "generate": - run_generate(model, QUESTION, image_urls, seed) - elif method == "chat": - run_chat(model, QUESTION, image_urls, seed) - else: - raise ValueError(f"Invalid method: {method}") - - -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' 'vision language models that support multi-image input for text ' @@ -765,9 +791,29 @@ def main(args: Namespace): parser.add_argument( "--num-images", "-n", - choices=list(range(1, 13)), # 12 is the max number of images + type=int, + choices=list(range(1, + len(IMAGE_URLS) + 1)), # the max number of images default=2, help="Number of images to use for the demo.") + return parser.parse_args() - args = parser.parse_args() + +def main(args: Namespace): + model = args.model_type + method = args.method + seed = args.seed + + image_urls = IMAGE_URLS[:args.num_images] + + if method == "generate": + run_generate(model, QUESTION, image_urls, seed) + elif method == "chat": + run_chat(model, QUESTION, image_urls, seed) + else: + raise ValueError(f"Invalid method: {method}") + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/online_serving/api_client.py b/examples/online_serving/api_client.py index 60e4bccb7517..36079ff11d07 100644 --- a/examples/online_serving/api_client.py +++ b/examples/online_serving/api_client.py @@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]: return output +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--n", type=int, default=1) + parser.add_argument("--prompt", type=str, default="San Francisco is a") + parser.add_argument("--stream", action="store_true") + return parser.parse_args() + + def main(args: Namespace): prompt = args.prompt api_url = f"http://{args.host}:{args.port}/generate" @@ -82,11 +92,5 @@ def main(args: Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--n", type=int, default=1) - parser.add_argument("--prompt", type=str, default="San Francisco is a") - parser.add_argument("--stream", action="store_true") - args = parser.parse_args() + args = parse_args() main(args) diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py index fc434ada1d15..c2d4ef08ddbb 100644 --- a/examples/online_serving/cohere_rerank_client.py +++ b/examples/online_serving/cohere_rerank_client.py @@ -2,32 +2,46 @@ """ Example of using the OpenAI entrypoint's rerank API which is compatible with the Cohere SDK: https://github.com/cohere-ai/cohere-python +Note that `pip install cohere` is needed to run this example. run: vllm serve BAAI/bge-reranker-base """ +from typing import Union + import cohere +from cohere import Client, ClientV2 + +model = "BAAI/bge-reranker-base" + +query = "What is the capital of France?" + +documents = [ + "The capital of France is Paris", "Reranking is fun!", + "vLLM is an open-source framework for fast AI serving" +] + + +def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str, + documents: list[str]) -> dict: + return client.rerank(model=model, query=query, documents=documents) + + +def main(): + # cohere v1 client + cohere_v1 = cohere.Client(base_url="http://localhost:8000", + api_key="sk-fake-key") + rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents) + print("-" * 50) + print("rerank_v1_result:\n", rerank_v1_result) + print("-" * 50) + + # or the v2 + cohere_v2 = cohere.ClientV2("sk-fake-key", + base_url="http://localhost:8000") + rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents) + print("rerank_v2_result:\n", rerank_v2_result) + print("-" * 50) + -# cohere v1 client -co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") -rerank_v1_result = co.rerank( - model="BAAI/bge-reranker-base", - query="What is the capital of France?", - documents=[ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" - ]) - -print(rerank_v1_result) - -# or the v2 -co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") - -v2_rerank_result = co2.rerank( - model="BAAI/bge-reranker-base", - query="What is the capital of France?", - documents=[ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" - ]) - -print(v2_rerank_result) +if __name__ == "__main__": + main() diff --git a/examples/online_serving/gradio_openai_chatbot_webserver.py b/examples/online_serving/gradio_openai_chatbot_webserver.py index ee01e1eae628..314f1c5b7395 100644 --- a/examples/online_serving/gradio_openai_chatbot_webserver.py +++ b/examples/online_serving/gradio_openai_chatbot_webserver.py @@ -1,52 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 +"""Example for starting a Gradio OpenAI Chatbot Webserver +Start vLLM API server: + vllm serve meta-llama/Llama-2-7b-chat-hf +Start Gradio OpenAI Chatbot Webserver: + python examples/online_serving/gradio_openai_chatbot_webserver.py \ + -m meta-llama/Llama-2-7b-chat-hf + +Note that `pip install --upgrade gradio` is needed to run this example. +More details: https://github.com/gradio-app/gradio + +If your antivirus software blocks the download of frpc for gradio, +you can install it manually by following these steps: + +1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc +""" import argparse import gradio as gr from openai import OpenAI -# Argument parser setup -parser = argparse.ArgumentParser( - description='Chatbot Interface with Customizable Parameters') -parser.add_argument('--model-url', - type=str, - default='http://localhost:8000/v1', - help='Model URL') -parser.add_argument('-m', - '--model', - type=str, - required=True, - help='Model name for the chatbot') -parser.add_argument('--temp', - type=float, - default=0.8, - help='Temperature for text generation') -parser.add_argument('--stop-token-ids', - type=str, - default='', - help='Comma-separated stop token IDs') -parser.add_argument("--host", type=str, default=None) -parser.add_argument("--port", type=int, default=8001) - -# Parse the arguments -args = parser.parse_args() - -# Set OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = args.model_url - -# Create an OpenAI client to interact with the API server -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - - -def predict(message, history): - # Convert chat history to OpenAI format + +def format_history_to_openai(history): history_openai_format = [{ "role": "system", - "content": "You are a great ai assistant." + "content": "You are a great AI assistant." }] for human, assistant in history: history_openai_format.append({"role": "user", "content": human}) @@ -54,31 +34,92 @@ def predict(message, history): "role": "assistant", "content": assistant }) + return history_openai_format + + +def predict(message, history, client, model_name, temp, stop_token_ids): + # Format history to OpenAI chat format + history_openai_format = format_history_to_openai(history) history_openai_format.append({"role": "user", "content": message}) - # Create a chat completion request and send it to the API server + # Send request to OpenAI API (vLLM server) stream = client.chat.completions.create( - model=args.model, # Model name to use - messages=history_openai_format, # Chat history - temperature=args.temp, # Temperature for text generation - stream=True, # Stream response + model=model_name, + messages=history_openai_format, + temperature=temp, + stream=True, extra_body={ 'repetition_penalty': 1, - 'stop_token_ids': [ - int(id.strip()) for id in args.stop_token_ids.split(',') - if id.strip() - ] if args.stop_token_ids else [] + 'stop_token_ids': + [int(id.strip()) + for id in stop_token_ids.split(',')] if stop_token_ids else [] }) - # Read and return generated text from response stream - partial_message = "" + # Collect all chunks and concatenate them into a full message + full_message = "" for chunk in stream: - partial_message += (chunk.choices[0].delta.content or "") - yield partial_message + full_message += (chunk.choices[0].delta.content or "") + + # Return the full message as a single response + return full_message + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Chatbot Interface with Customizable Parameters') + parser.add_argument('--model-url', + type=str, + default='http://localhost:8000/v1', + help='Model URL') + parser.add_argument('-m', + '--model', + type=str, + required=True, + help='Model name for the chatbot') + parser.add_argument('--temp', + type=float, + default=0.8, + help='Temperature for text generation') + parser.add_argument('--stop-token-ids', + type=str, + default='', + help='Comma-separated stop token IDs') + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8001) + return parser.parse_args() + + +def build_gradio_interface(client, model_name, temp, stop_token_ids): + + def chat_predict(message, history): + return predict(message, history, client, model_name, temp, + stop_token_ids) + + return gr.ChatInterface(fn=chat_predict, + title="Chatbot Interface", + description="A simple chatbot powered by vLLM") + + +def main(): + # Parse the arguments + args = parse_args() + + # Set OpenAI's API key and API base to use vLLM's API server + openai_api_key = "EMPTY" + openai_api_base = args.model_url + + # Create an OpenAI client + client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) + + # Define the Gradio chatbot interface using the predict function + gradio_interface = build_gradio_interface(client, args.model, args.temp, + args.stop_token_ids) + + gradio_interface.queue().launch(server_name=args.host, + server_port=args.port, + share=True) -# Create and launch a chat interface with Gradio -gr.ChatInterface(predict).queue().launch(server_name=args.host, - server_port=args.port, - share=True) +if __name__ == "__main__": + main() diff --git a/examples/online_serving/gradio_webserver.py b/examples/online_serving/gradio_webserver.py index 85a9119c6aa2..2e7c2a0c5838 100644 --- a/examples/online_serving/gradio_webserver.py +++ b/examples/online_serving/gradio_webserver.py @@ -1,5 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 +"""Example for starting a Gradio Webserver +Start vLLM API server: + python -m vllm.entrypoints.api_server \ + --model meta-llama/Llama-2-7b-chat-hf +Start Webserver: + python examples/online_serving/gradio_webserver.py + +Note that `pip install --upgrade gradio` is needed to run this example. +More details: https://github.com/gradio-app/gradio + +If your antivirus software blocks the download of frpc for gradio, +you can install it manually by following these steps: + +1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64 +2. Rename the downloaded file to: frpc_linux_amd64_v0.3 +3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc +""" import argparse import json @@ -39,16 +56,23 @@ def build_demo(): return demo -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8001) parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") - args = parser.parse_args() + return parser.parse_args() + +def main(args): demo = build_demo() demo.queue().launch(server_name=args.host, server_port=args.port, share=True) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/jinaai_rerank_client.py index 3e760e171788..3076bba765ce 100644 --- a/examples/online_serving/jinaai_rerank_client.py +++ b/examples/online_serving/jinaai_rerank_client.py @@ -23,12 +23,19 @@ "The capital of France is Paris.", "Horses and cows are both animals" ] } -response = requests.post(url, headers=headers, json=data) - -# Check the response -if response.status_code == 200: - print("Request successful!") - print(json.dumps(response.json(), indent=2)) -else: - print(f"Request failed with status code: {response.status_code}") - print(response.text) + + +def main(): + response = requests.post(url, headers=headers, json=data) + + # Check the response + if response.status_code == 200: + print("Request successful!") + print(json.dumps(response.json(), indent=2)) + else: + print(f"Request failed with status code: {response.status_code}") + print(response.text) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index a81562041130..74e0c045d621 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -1,38 +1,49 @@ # SPDX-License-Identifier: Apache-2.0 - +"""Example Python client for OpenAI Chat Completion using vLLM API server +NOTE: start a supported chat completion model server with `vllm serve`, e.g. + vllm serve meta-llama/Llama-2-7b-chat-hf +""" from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -chat_completion = client.chat.completions.create( - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - model=model, -) - -print("Chat completion results:") -print(chat_completion) +messages = [{ + "role": "system", + "content": "You are a helpful assistant." +}, { + "role": "user", + "content": "Who won the world series in 2020?" +}, { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020." +}, { + "role": "user", + "content": "Where was it played?" +}] + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + ) + + print("-" * 50) + print("Chat completion results:") + print(chat_completion) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index ecfcf05a90d1..70db4d95e649 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -9,7 +9,7 @@ (multi-image inference with Phi-3.5-vision-instruct) vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ - --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' (audio inference with Ultravox) vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 @@ -303,12 +303,7 @@ def run_audio() -> None: } -def main(args) -> None: - chat_type = args.chat_type - example_function_map[chat_type]() - - -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using OpenAI client for online serving with ' 'multimodal language models served with vLLM.') @@ -318,5 +313,14 @@ def main(args) -> None: default="single-image", choices=list(example_function_map.keys()), help='Conversation type with multimodal data.') - args = parser.parse_args() + return parser.parse_args() + + +def main(args) -> None: + chat_type = args.chat_type + example_function_map[chat_type]() + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/online_serving/openai_chat_completion_client_with_tools.py b/examples/online_serving/openai_chat_completion_client_with_tools.py index 416fb61ca8bb..c25203860ff3 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools.py @@ -17,6 +17,7 @@ --enable-auto-tool-choice --tool-call-parser hermes """ import json +from typing import Any from openai import OpenAI @@ -24,15 +25,6 @@ openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - tools = [{ "type": "function", "function": { @@ -78,86 +70,123 @@ "Can you tell me what the temperate will be in Dallas, in fahrenheit?" }] -chat_completion = client.chat.completions.create(messages=messages, - model=model, - tools=tools) - -print("Chat completion results:") -print(chat_completion) -print("\n\n") - -tool_calls_stream = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=True) - -chunks = [] -for chunk in tool_calls_stream: - chunks.append(chunk) - if chunk.choices[0].delta.tool_calls: - print(chunk.choices[0].delta.tool_calls[0]) - else: - print(chunk.choices[0].delta) - -arguments = [] -tool_call_idx = -1 -for chunk in chunks: - - if chunk.choices[0].delta.tool_calls: - tool_call = chunk.choices[0].delta.tool_calls[0] - - if tool_call.index != tool_call_idx: - if tool_call_idx >= 0: - print( - f"streamed tool call arguments: {arguments[tool_call_idx]}" - ) - tool_call_idx = chunk.choices[0].delta.tool_calls[0].index - arguments.append("") - if tool_call.id: - print(f"streamed tool call id: {tool_call.id} ") - - if tool_call.function: - if tool_call.function.name: - print(f"streamed tool call name: {tool_call.function.name}") - - if tool_call.function.arguments: - arguments[tool_call_idx] += tool_call.function.arguments - -if len(arguments): - print(f"streamed tool call arguments: {arguments[-1]}") - -print("\n\n") - -messages.append({ - "role": "assistant", - "tool_calls": chat_completion.choices[0].message.tool_calls -}) - -# Now, simulate a tool call def get_current_weather(city: str, state: str, unit: 'str'): return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " "partly cloudly, with highs in the 90's.") -available_tools = {"get_current_weather": get_current_weather} - -completion_tool_calls = chat_completion.choices[0].message.tool_calls -for call in completion_tool_calls: - tool_to_call = available_tools[call.function.name] - args = json.loads(call.function.arguments) - result = tool_to_call(**args) - print(result) +def handle_tool_calls_stream( + client: OpenAI, + messages: list[dict[str, str]], + model: str, + tools: list[dict[str, Any]], +) -> list[Any]: + tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) + chunks = [] + print("chunks: ") + for chunk in tool_calls_stream: + chunks.append(chunk) + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + else: + print(chunk.choices[0].delta) + return chunks + + +def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]: + arguments = [] + tool_call_idx = -1 + print("arguments: ") + for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != tool_call_idx: + if tool_call_idx >= 0: + print(f"streamed tool call arguments: " + f"{arguments[tool_call_idx]}") + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + if tool_call.id: + print(f"streamed tool call id: {tool_call.id} ") + + if tool_call.function: + if tool_call.function.name: + print( + f"streamed tool call name: {tool_call.function.name}") + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + + return arguments + + +def main(): + # Initialize OpenAI client + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # Get available models and select one + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools) + + print("-" * 70) + print("Chat completion results:") + print(chat_completion) + print("-" * 70) + + # Stream tool calls + chunks = handle_tool_calls_stream(client, messages, model, tools) + print("-" * 70) + + # Handle arguments from streamed tool calls + arguments = handle_tool_calls_arguments(chunks) + + if len(arguments): + print(f"streamed tool call arguments: {arguments[-1]}\n") + + print("-" * 70) + + # Add tool call results to the conversation messages.append({ - "role": "tool", - "content": result, - "tool_call_id": call.id, - "name": call.function.name + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls }) -chat_completion_2 = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - stream=False) -print("\n\n") -print(chat_completion_2) + # Now, simulate a tool call + available_tools = {"get_current_weather": get_current_weather} + + completion_tool_calls = chat_completion.choices[0].message.tool_calls + for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + print("tool_to_call result: ", result) + messages.append({ + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name + }) + + chat_completion_2 = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=False) + print("Chat completion2 results:") + print(chat_completion_2) + print("-" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client_with_tools_required.py b/examples/online_serving/openai_chat_completion_client_with_tools_required.py index 779369d16344..97d900bb75f1 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools_required.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ -To run this example, you can start the vLLM server +To run this example, you can start the vLLM server without any specific flags: ```bash @@ -8,7 +8,7 @@ --guided-decoding-backend outlines ``` -This example demonstrates how to generate chat completions +This example demonstrates how to generate chat completions using the OpenAI Python client library. """ @@ -18,15 +18,6 @@ openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - tools = [ { "type": "function", @@ -116,21 +107,36 @@ }, ] -chat_completion = client.chat.completions.create( - messages=messages, - model=model, - tools=tools, - tool_choice="required", - stream=True # Enable streaming response -) -for chunk in chat_completion: - if chunk.choices and chunk.choices[0].delta.tool_calls: - print(chunk.choices[0].delta.tool_calls) +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + tools=tools, + tool_choice="required", + stream=True # Enable streaming response + ) + + for chunk in chat_completion: + if chunk.choices and chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls) + + chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + tool_choice="required") + + print(chat_completion.choices[0].message.tool_calls) -chat_completion = client.chat.completions.create(messages=messages, - model=model, - tools=tools, - tool_choice="required") -print(chat_completion.choices[0].message.tool_calls) +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index 986ff500e586..f71162e36efd 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -1,43 +1,49 @@ # SPDX-License-Identifier: Apache-2.0 +""" +To run this example, you need to start the vLLM server: + +```bash +vllm serve Qwen/Qwen2.5-3B-Instruct +``` +""" from enum import Enum from openai import BadRequestError, OpenAI from pydantic import BaseModel -client = OpenAI( - base_url="http://localhost:8000/v1", - api_key="-", -) # Guided decoding by Choice (list of possible options) -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": "Classify this sentiment: vLLM is wonderful!" - }], - extra_body={"guided_choice": ["positive", "negative"]}, -) -print(completion.choices[0].message.content) +def guided_choice_completion(client: OpenAI, model: str): + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": "Classify this sentiment: vLLM is wonderful!" + }], + extra_body={"guided_choice": ["positive", "negative"]}, + ) + return completion.choices[0].message.content + # Guided decoding by Regex -prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") - -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={ - "guided_regex": "\w+@\w+\.com\n", - "stop": ["\n"] - }, -) -print(completion.choices[0].message.content) +def guided_regex_completion(client: OpenAI, model: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": r"\w+@\w+\.com\n", + "stop": ["\n"] + }, + ) + return completion.choices[0].message.content # Guided decoding by JSON using Pydantic schema @@ -54,66 +60,100 @@ class CarDescription(BaseModel): car_type: CarType -json_schema = CarDescription.model_json_schema() - -prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"guided_json": json_schema}, -) -print(completion.choices[0].message.content) +def guided_json_completion(client: OpenAI, model: str): + json_schema = CarDescription.model_json_schema() -# Guided decoding by Grammar -simplified_sql_grammar = """ - ?start: select_statement + prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, + ) + return completion.choices[0].message.content - ?select_statement: "SELECT " column_list " FROM " table_name - ?column_list: column_name ("," column_name)* +# Guided decoding by Grammar +def guided_grammar_completion(client: OpenAI, model: str): + simplified_sql_grammar = """ + root ::= select_statement - ?table_name: identifier + select_statement ::= "SELECT " column " from " table " where " condition - ?column_name: identifier + column ::= "col_1 " | "col_2 " - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ -""" + table ::= "table_1 " | "table_2 " -prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") -completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"guided_grammar": simplified_sql_grammar}, -) -print(completion.choices[0].message.content) + condition ::= column "= " number -# Extra backend options -prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + number ::= "1 " | "2 " + """ -try: - # The no-fallback option forces vLLM to use xgrammar, so when it fails - # you get a 400 with the reason why + prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") completion = client.chat.completions.create( - model="Qwen/Qwen2.5-3B-Instruct", + model=model, messages=[{ "role": "user", "content": prompt, }], - extra_body={ - "guided_regex": "\w+@\w+\.com\n", - "stop": ["\n"], - "guided_decoding_backend": "xgrammar:no-fallback" - }, + extra_body={"guided_grammar": simplified_sql_grammar}, ) -except BadRequestError as e: - print("This error is expected:", e) + return completion.choices[0].message.content + + +# Extra backend options +def extra_backend_options_completion(client: OpenAI, model: str): + prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + + try: + # The no-fallback option forces vLLM to use xgrammar, so when it fails + # you get a 400 with the reason why + completion = client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": r"\w+@\w+\.com\n", + "stop": ["\n"], + "guided_decoding_backend": "xgrammar:no-fallback" + }, + ) + return completion.choices[0].message.content + except BadRequestError as e: + print("This error is expected:", e) + + +def main(): + client: OpenAI = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + model = "Qwen/Qwen2.5-3B-Instruct" + + print("Guided Choice Completion:") + print(guided_choice_completion(client, model)) + + print("\nGuided Regex Completion:") + print(guided_regex_completion(client, model)) + + print("\nGuided JSON Completion:") + print(guided_json_completion(client, model)) + + print("\nGuided Grammar Completion:") + print(guided_grammar_completion(client, model)) + + print("\nExtra Backend Options Completion:") + print(extra_backend_options_completion(client, model)) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py new file mode 100644 index 000000000000..b807bc540526 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +# This example demonstrates the `structural_tag` response format. +# It can be used to specify a structured output format that occurs between +# specific tags in the response. This example shows how it could be used +# to enforce the format of a tool call response, but it could be used for +# any structured output within a subset of the response. + + +def main(): + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + messages = [{ + "role": + "user", + "content": + """ +You have access to the following function to retrieve the weather in a city: + + { + "name": "get_weather", + "parameters": { + "city": { + "param_type": "string", + "description": "The city to get the weather for", + "required": True + } + } + } + +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function + argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + +You are a helpful assistant. + +Given the previous instructions, what is the weather in New York City, Boston, +and San Francisco? +""" + }] + + response = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=messages, + response_format={ + "type": + "structural_tag", + "structures": [{ + "begin": "", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + } + }, + "end": "" + }], + "triggers": [" requests.Response: return response -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3") + return parser.parse_args() + - args = parser.parse_args() +def main(args): api_url = f"http://{args.host}:{args.port}/score" model_name = args.model @@ -30,9 +32,9 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: text_2 = "The capital of Brazil is Brasilia." prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) - print("Prompt when text_1 and text_2 are both strings:") + print("\nPrompt when text_1 and text_2 are both strings:") pprint.pprint(prompt) - print("Score Response:") + print("\nScore Response:") pprint.pprint(score_response.json()) text_1 = "What is the capital of France?" @@ -41,9 +43,9 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: ] prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) - print("Prompt when text_1 is string and text_2 is a list:") + print("\nPrompt when text_1 is string and text_2 is a list:") pprint.pprint(prompt) - print("Score Response:") + print("\nScore Response:") pprint.pprint(score_response.json()) text_1 = [ @@ -54,7 +56,12 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: ] prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} score_response = post_http_request(prompt=prompt, api_url=api_url) - print("Prompt when text_1 and text_2 are both lists:") + print("\nPrompt when text_1 and text_2 are both lists:") pprint.pprint(prompt) - print("Score Response:") + print("\nScore Response:") pprint.pprint(score_response.json()) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_embedding_client.py b/examples/online_serving/openai_embedding_client.py index b7c5651e3bab..bc217f7ca7a0 100644 --- a/examples/online_serving/openai_embedding_client.py +++ b/examples/online_serving/openai_embedding_client.py @@ -6,22 +6,29 @@ openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -responses = client.embeddings.create( - input=[ - "Hello my name is", - "The best thing about vLLM is that it supports many different models" - ], - model=model, -) - -for data in responses.data: - print(data.embedding) # List of float of len 4096 + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + responses = client.embeddings.create( + # ruff: noqa: E501 + input=[ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" + ], + model=model, + ) + + for data in responses.data: + print(data.embedding) # List of float of len 4096 + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_embedding_matryoshka_fy.py b/examples/online_serving/openai_embedding_matryoshka_fy.py new file mode 100644 index 000000000000..4544dcfb5ab0 --- /dev/null +++ b/examples/online_serving/openai_embedding_matryoshka_fy.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Example Python client for embedding API dimensions using vLLM API server +NOTE: + start a supported Matryoshka Embeddings model server with `vllm serve`, e.g. + vllm serve jinaai/jina-embeddings-v3 --trust-remote-code +""" + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + responses = client.embeddings.create( + input=["Follow the white rabbit."], + model=model, + dimensions=32, + ) + + for data in responses.data: + print(data.embedding) # List of float of len 32 + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_pooling_client.py b/examples/online_serving/openai_pooling_client.py index e17f9c5efd65..abcfe27c2769 100644 --- a/examples/online_serving/openai_pooling_client.py +++ b/examples/online_serving/openai_pooling_client.py @@ -17,7 +17,7 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: return response -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) @@ -25,15 +25,20 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: type=str, default="jason9693/Qwen2.5-1.5B-apeach") - args = parser.parse_args() + return parser.parse_args() + + +def main(args): api_url = f"http://{args.host}:{args.port}/pooling" model_name = args.model # Input like Completions API prompt = {"model": model_name, "input": "vLLM is great!"} pooling_response = post_http_request(prompt=prompt, api_url=api_url) + print("-" * 50) print("Pooling Response:") pprint.pprint(pooling_response.json()) + print("-" * 50) # Input like Chat API prompt = { @@ -50,3 +55,9 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response: pooling_response = post_http_request(prompt=prompt, api_url=api_url) print("Pooling Response:") pprint.pprint(pooling_response.json()) + print("-" * 50) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 062868dd8adf..5fcb7c526416 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -26,7 +26,12 @@ def sync_openai(): model="openai/whisper-large-v3", language="en", response_format="json", - temperature=0.0) + temperature=0.0, + # Additional sampling params not provided by OpenAI API. + extra_body=dict( + seed=4419, + repetition_penalty=1.3, + )) print("transcription result:", transcription.text) diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py new file mode 100644 index 000000000000..f9ef3e2da1a1 --- /dev/null +++ b/examples/online_serving/ray_serve_deepseek.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Example to deploy DeepSeek R1 or V3 with Ray Serve LLM. +See Ray Serve LLM documentation at: +https://docs.ray.io/en/latest/serve/llm/serving-llms.html + +Run `python3 ray_serve_deepseek.py` to deploy the model. +""" + +from ray import serve +from ray.serve.llm import LLMConfig, build_openai_app + +llm_config = LLMConfig( + model_loading_config={ + "model_id": "deepseek", + # Since DeepSeek model is huge, it is recommended to pre-download + # the model to local disk, say /path/to/the/model and specify: + # model_source="/path/to/the/model" + "model_source": "deepseek-ai/DeepSeek-R1", + }, + deployment_config={ + "autoscaling_config": { + "min_replicas": 1, + "max_replicas": 1, + } + }, + # Change to the accelerator type of the node + accelerator_type="H100", + runtime_env={"env_vars": { + "VLLM_USE_V1": "1" + }}, + # Customize engine arguments as needed (e.g. vLLM engine kwargs) + engine_kwargs={ + "tensor_parallel_size": 8, + "pipeline_parallel_size": 2, + "gpu_memory_utilization": 0.92, + "dtype": "auto", + "max_num_seqs": 40, + "max_model_len": 16384, + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "trust_remote_code": True, + }, +) + +# Deploy the application +llm_app = build_openai_app({"llm_configs": [llm_config]}) +serve.run(llm_app) diff --git a/examples/tool_chat_template_llama4_json.jinja b/examples/tool_chat_template_llama4_json.jinja new file mode 100644 index 000000000000..759f16554436 --- /dev/null +++ b/examples/tool_chat_template_llama4_json.jinja @@ -0,0 +1,116 @@ +{%- macro is_array_of_type_objects(var) -%} + {%- if var is iterable and var is not string -%} + {%- set valid = true -%} + {%- for item in var -%} + {%- if 'type' not in item -%} + {%- set valid = false -%} + {%- break -%} + {%- endif -%} + {%- endfor -%} + {{ valid }} + {%- else -%} + {{ false }} + {%- endif -%} +{%- endmacro %} + +{%- macro render_message(message) %} + {%- if message['content'] is string %} + {{- message['content']|trim }} + {%- elif is_array_of_type_objects(data) == 'True' %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text']|trim }} + {%- endif %} + {%- endfor %} + {%- else %} + {{- message['content']|tojson }} + {%- endif %} +{%- endmacro %} + +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0] %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = ({ "content": "You are a helpful assistant with tool calling " + "capabilities. Only reply with a tool call if the function exists in the " + "library provided by the user. If it doesn't exist, just reply directly in " + "natural language. When you receive a tool call response, use the output to " + "format an answer to the original user question."}) %} +{%- endif %} + +{%- set tool_lib_preamble = 'Tools: You have access to the following tools. You might need to use one ' + 'or more function/tool calls to fulfill the task. \n' + 'If none are needed, then proceed to the response.\n\n' + 'Tool Call Syntax: You can call tools using the following syntax:\n' + '{"name": function name, "parameters": dictionary of argument name and its value}.\n' + 'Separate multiple function calls by "; ". Do not use variables.\n' + 'Do not include anything else when calling the tools with the syntax above.\n\n' + 'Here is a list of functions in JSON format that you can invoke.\n' %} + +{{- "<|header_start|>system<|header_end|>\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- tool_lib_preamble }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- render_message(system_message) }} +{{ "<|eot|>\n" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0] %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|header_start|>user<|header_end|>\n\n' }} + {{- tool_lib_preamble }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- render_message(first_user_message) + "\n<|eot|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {{- render_message(message) }} + {{- "\n<|eot|>" }} + {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} + {{- '\n<|header_start|>assistant<|header_end|>\n\n' -}} + {{- render_message(message) }} + {%- for tool_call in message.tool_calls %} + {{- '{"name": "' + tool_call.function.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.function.arguments | tojson }} + {{- "}" }} + {%- endfor %} + {{- "\n<|eot|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "\n<|header_start|>ipython<|header_end|>\n\n" }} + {{- render_message(message) }} + {{- "\n<|eom|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '\n<|header_start|>assistant<|header_end|>\n\n' }} +{%- endif %} diff --git a/pyproject.toml b/pyproject.toml index 167e975c70fd..b5f1039b44da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ build-backend = "setuptools.build_meta" [project] name = "vllm" authors = [{name = "vLLM Team"}] -license = { "file"= "LICENSE" } +license = "Apache-2.0" +license-files = ["LICENSE"] readme = "README.md" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" classifiers = [ @@ -23,7 +24,6 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: Apache Software License", "Intended Audience :: Developers", "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", @@ -46,8 +46,7 @@ vllm = "vllm.entrypoints.cli.main:main" [tool.setuptools.packages.find] where = ["."] -exclude = ["benchmarks", "csrc", "docs", "examples", "tests*"] -namespaces = false +include = ["vllm*"] [tool.yapfignore] ignore_patterns = [ @@ -59,7 +58,8 @@ ignore_patterns = [ line-length = 80 exclude = [ # External file, leaving license intact - "examples/other/fp8/quantizer/quantize.py" + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" ] [tool.ruff.lint.per-file-ignores] diff --git a/requirements/common.txt b/requirements/common.txt index bb1bb2dd994e..c4905b28c05e 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -8,7 +8,7 @@ blake3 py-cpuinfo transformers >= 4.51.1 huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads. -tokenizers >= 0.19.1 # Required for Llama 3. +tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp @@ -26,7 +26,7 @@ xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64 typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs -pyzmq +pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 importlib_metadata @@ -36,7 +36,7 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.9.3 # required for compressed-tensors +compressed-tensors == 0.9.4 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files diff --git a/requirements/docs.txt b/requirements/docs.txt index 416ca503b36c..d84fd633ce10 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,6 +7,7 @@ sphinx-togglebutton==0.3.2 myst-parser==3.0.1 msgspec cloudpickle +commonmark # Required by sphinx-argparse when using :markdownhelp: # packages to install to build the documentation cachetools @@ -18,6 +19,7 @@ transformers mistral_common >= 1.5.4 aiohttp starlette +scipy openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 830f6ef3f50c..5ac58bc02892 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -9,4 +9,4 @@ numpy==1.26.4 tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624 diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt new file mode 100644 index 000000000000..20372a9b2ef1 --- /dev/null +++ b/requirements/nightly_torch_test.txt @@ -0,0 +1,28 @@ +# Dependency that able to run entrypoints test +# pytest and its extensions +pytest +pytest-asyncio +pytest-forked +pytest-mock +pytest-rerunfailures +pytest-shard +pytest-timeout + + +librosa # required by audio tests in entrypoints/openai +sentence-transformers +numba == 0.61.2; python_version > '3.9' +# testing utils +awscli +boto3 +botocore +datasets +ray >= 2.10.0 +peft +runai-model-streamer==0.11.0 +runai-model-streamer-s3==0.11.0 +tensorizer>=2.9.0 +lm-eval==0.4.8 +buildkite-test-collector==0.1.9 + +lm-eval[api]==0.4.8 # required for model evaluation test diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 29d5647807bb..05de4ff16845 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -6,6 +6,7 @@ torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 +triton==3.2 cmake>=3.26,<4 packaging setuptools>=61 diff --git a/requirements/test.in b/requirements/test.in index 95c94dcdbe99..c5d2c4cd4c30 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -10,6 +10,7 @@ pytest-timeout # testing utils awscli backoff # required for phi4mm test +blobfile # required for kimi-vl test einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests @@ -26,14 +27,17 @@ torch==2.6.0 torchaudio==2.6.0 torchvision==0.21.0 transformers_stream_generator # required for qwen-vl test +mamba_ssm # required for plamo2 test matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.4 # required for pixtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test -transformers==4.51.1 +transformers==4.51.3 +tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. +schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes>=0.45.3 buildkite-test-collector==0.1.9 diff --git a/requirements/test.txt b/requirements/test.txt index 8fd36339b06d..5c27becc3e29 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -20,25 +20,35 @@ aiosignal==1.3.1 annotated-types==0.7.0 # via pydantic anyio==4.6.2.post1 - # via httpx + # via + # httpx + # starlette argcomplete==3.5.1 # via datamodel-code-generator +arrow==1.3.0 + # via isoduration attrs==24.2.0 # via # aiohttp + # hypothesis # jsonlines # jsonschema + # pytest-subtests # referencing audioread==3.0.1 # via librosa awscli==1.35.23 # via -r requirements/test.in backoff==2.2.1 - # via -r requirements/test.in + # via + # -r requirements/test.in + # schemathesis bitsandbytes==0.45.3 # via -r requirements/test.in black==24.10.0 # via datamodel-code-generator +blobfile==3.0.0 + # via -r requirements/test.in boto3==1.35.57 # via tensorizer botocore==1.35.57 @@ -67,11 +77,13 @@ click==8.1.7 # jiwer # nltk # ray + # schemathesis # typer colorama==0.4.6 # via # awscli # sacrebleu + # schemathesis # tqdm-multiprocess contourpy==1.3.0 # via matplotlib @@ -109,6 +121,7 @@ einops==0.8.0 # via # -r requirements/test.in # encodec + # mamba-ssm # vector-quantize-pytorch # vocos einx==0.3.0 @@ -127,6 +140,7 @@ fastsafetensors==0.1.10 # via -r requirements/test.in filelock==3.16.1 # via + # blobfile # datasets # huggingface-hub # ray @@ -134,6 +148,8 @@ filelock==3.16.1 # transformers fonttools==4.54.1 # via matplotlib +fqdn==1.5.1 + # via jsonschema frozendict==2.4.6 # via einx frozenlist==1.5.0 @@ -152,8 +168,12 @@ genai-perf==0.0.8 # via -r requirements/test.in genson==1.3.0 # via datamodel-code-generator +graphql-core==3.2.6 + # via hypothesis-graphql h11==0.14.0 # via httpcore +harfile==0.3.0 + # via schemathesis hf-xet==0.1.4 # via huggingface-hub hiredis==3.0.0 @@ -161,7 +181,9 @@ hiredis==3.0.0 httpcore==1.0.6 # via httpx httpx==0.27.2 - # via -r requirements/test.in + # via + # -r requirements/test.in + # schemathesis huggingface-hub==0.30.1 # via # -r requirements/test.in @@ -176,17 +198,29 @@ huggingface-hub==0.30.1 # vocos humanize==4.11.0 # via runai-model-streamer +hypothesis==6.131.0 + # via + # hypothesis-graphql + # hypothesis-jsonschema + # schemathesis +hypothesis-graphql==0.11.1 + # via schemathesis +hypothesis-jsonschema==0.23.1 + # via schemathesis idna==3.10 # via # anyio # email-validator # httpx + # jsonschema # requests # yarl inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 # via pytest +isoduration==20.11.0 + # via jsonschema isort==5.13.2 # via datamodel-code-generator jinja2==3.1.6 @@ -206,12 +240,18 @@ joblib==1.4.2 # scikit-learn jsonlines==4.0.0 # via lm-eval +jsonpointer==3.0.0 + # via jsonschema jsonschema==4.23.0 # via + # hypothesis-jsonschema # mistral-common # ray + # schemathesis jsonschema-specifications==2024.10.1 # via jsonschema +junit-xml==1.9 + # via schemathesis kaleido==0.2.1 # via genai-perf kiwisolver==1.4.7 @@ -227,11 +267,17 @@ llvmlite==0.44.0 lm-eval==0.4.8 # via -r requirements/test.in lxml==5.3.0 - # via sacrebleu + # via + # blobfile + # sacrebleu +mamba-ssm==2.2.4 + # via -r requirements/test.in markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 - # via jinja2 + # via + # jinja2 + # werkzeug matplotlib==3.9.2 # via -r requirements/test.in mbstrdecoder==1.1.3 @@ -263,6 +309,8 @@ mypy-extensions==1.0.0 # via black networkx==3.2.1 # via torch +ninja==1.11.1.3 + # via mamba-ssm nltk==3.9.1 # via rouge-score num2words==0.5.14 @@ -355,6 +403,7 @@ packaging==24.1 # fastparquet # huggingface-hub # lazy-loader + # mamba-ssm # matplotlib # peft # plotly @@ -426,6 +475,8 @@ pybind11==2.13.6 # via lm-eval pycparser==2.22 # via cffi +pycryptodomex==3.22.0 + # via blobfile pydantic==2.9.2 # via # datamodel-code-generator @@ -436,6 +487,8 @@ pygments==2.18.0 # via rich pyparsing==3.2.0 # via matplotlib +pyrate-limiter==3.7.0 + # via schemathesis pytablewriter==1.2.0 # via lm-eval pytest==8.3.3 @@ -448,7 +501,9 @@ pytest==8.3.3 # pytest-mock # pytest-rerunfailures # pytest-shard + # pytest-subtests # pytest-timeout + # schemathesis pytest-asyncio==0.24.0 # via -r requirements/test.in pytest-forked==1.6.0 @@ -459,10 +514,13 @@ pytest-rerunfailures==14.0 # via -r requirements/test.in pytest-shard==0.1.2 # via -r requirements/test.in +pytest-subtests==0.14.1 + # via schemathesis pytest-timeout==2.3.1 # via -r requirements/test.in python-dateutil==2.9.0.post0 # via + # arrow # botocore # matplotlib # pandas @@ -484,6 +542,7 @@ pyyaml==6.0.2 # peft # ray # responses + # schemathesis # timm # transformers # vocos @@ -514,10 +573,16 @@ requests==2.32.3 # pooch # ray # responses + # schemathesis + # starlette-testclient # tiktoken # transformers responses==0.25.3 # via genai-perf +rfc3339-validator==0.1.4 + # via jsonschema +rfc3987==1.3.8 + # via jsonschema rich==13.9.4 # via # genai-perf @@ -546,6 +611,8 @@ safetensors==0.4.5 # peft # timm # transformers +schemathesis==3.39.15 + # via -r requirements/test.in scikit-learn==1.5.2 # via # librosa @@ -564,18 +631,23 @@ sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 # via + # mamba-ssm # pytablewriter # torch shellingham==1.5.4 # via typer six==1.16.0 # via + # junit-xml # python-dateutil + # rfc3339-validator # rouge-score sniffio==1.3.1 # via # anyio # httpx +sortedcontainers==2.4.0 + # via hypothesis soundfile==0.12.1 # via # -r requirements/test.in @@ -584,6 +656,12 @@ soxr==0.5.0.post1 # via librosa sqlitedict==2.1.0 # via lm-eval +starlette==0.46.2 + # via + # schemathesis + # starlette-testclient +starlette-testclient==0.4.1 + # via schemathesis statsmodels==0.14.4 # via genai-perf sympy==1.13.1 @@ -610,8 +688,14 @@ tiktoken==0.7.0 # mistral-common timm==1.0.11 # via -r requirements/test.in -tokenizers==0.21.0 - # via transformers +tokenizers==0.21.1 + # via + # -r requirements/test.in + # transformers +tomli==2.2.1 + # via schemathesis +tomli-w==1.2.0 + # via schemathesis torch==2.6.0 # via # -r requirements/test.in @@ -620,6 +704,7 @@ torch==2.6.0 # encodec # fastsafetensors # lm-eval + # mamba-ssm # peft # runai-model-streamer # sentence-transformers @@ -652,11 +737,12 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.51.1 +transformers==4.51.3 # via # -r requirements/test.in # genai-perf # lm-eval + # mamba-ssm # peft # sentence-transformers # transformers-stream-generator @@ -675,6 +761,8 @@ typepy==1.3.2 # tabledata typer==0.15.2 # via fastsafetensors +types-python-dateutil==2.9.0.20241206 + # via arrow typing-extensions==4.12.2 # via # huggingface-hub @@ -687,8 +775,11 @@ typing-extensions==4.12.2 # typer tzdata==2024.2 # via pandas +uri-template==1.3.0 + # via jsonschema urllib3==2.2.3 # via + # blobfile # botocore # requests # responses @@ -697,6 +788,10 @@ vector-quantize-pytorch==1.21.2 # via -r requirements/test.in vocos==0.1.0 # via -r requirements/test.in +webcolors==24.11.1 + # via jsonschema +werkzeug==3.1.3 + # via schemathesis word2number==1.1 # via lm-eval xxhash==3.5.0 @@ -704,6 +799,8 @@ xxhash==3.5.0 # datasets # evaluate yarl==1.17.1 - # via aiohttp + # via + # aiohttp + # schemathesis zstandard==0.23.0 # via lm-eval diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 75ebbc4ed940..b63993ba1ee4 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -17,9 +17,8 @@ ray[data] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250408 +torchvision==0.22.0.dev20250408 torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/setup.py b/setup.py index 19eea7211738..92caf3d93639 100755 --- a/setup.py +++ b/setup.py @@ -269,15 +269,17 @@ def run(self): # First, run the standard build_ext command to compile the extensions super().run() - # copy vllm/vllm_flash_attn/*.py from self.build_lib to current + # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # directory so that they can be included in the editable build import glob - files = glob.glob( - os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py")) + files = glob.glob(os.path.join(self.build_lib, "vllm", + "vllm_flash_attn", "**", "*.py"), + recursive=True) for file in files: dst_file = os.path.join("vllm/vllm_flash_attn", - os.path.basename(file)) + file.split("vllm/vllm_flash_attn/")[-1]) print(f"Copying {file} to {dst_file}") + os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) @@ -377,13 +379,22 @@ def run(self) -> None: "vllm/_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", - "vllm/vllm_flash_attn/flash_attn_interface.py", - "vllm/vllm_flash_attn/__init__.py", "vllm/cumem_allocator.abi3.so", # "vllm/_version.py", # not available in nightly wheels yet ] - file_members = filter(lambda x: x.filename in files_to_copy, - wheel.filelist) + + file_members = list( + filter(lambda x: x.filename in files_to_copy, wheel.filelist)) + + # vllm_flash_attn python code: + # Regex from + # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` + import re + compiled_regex = re.compile( + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + file_members += list( + filter(lambda x: compiled_regex.match(x.filename), + wheel.filelist)) for file in file_members: print(f"Extracting and including {file.filename} " diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/benchmarks/test_latency_cli.py b/tests/benchmarks/test_latency_cli.py new file mode 100644 index 000000000000..8537459b9f94 --- /dev/null +++ b/tests/benchmarks/test_latency_cli.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.mark.benchmark +def test_bench_latency(): + command = [ + "vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", + "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py new file mode 100644 index 000000000000..b746d6b7853c --- /dev/null +++ b/tests/benchmarks/test_serve_cli.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +from ..utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.benchmark +def test_bench_serve(server): + command = [ + "vllm", + "bench", + "serve", + "--model", + MODEL_NAME, + "--host", + server.host, + "--port", + str(server.port), + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/benchmarks/test_throughput_cli.py b/tests/benchmarks/test_throughput_cli.py new file mode 100644 index 000000000000..2045b3629356 --- /dev/null +++ b/tests/benchmarks/test_throughput_cli.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +import subprocess + +import pytest + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +@pytest.mark.benchmark +def test_bench_throughput(): + command = [ + "vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", + "32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 579133ec0c3f..c09406385987 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -20,15 +20,11 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): ("facebook/opt-125m", {}), ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { "dtype": torch.float16, - "quantization": "compressed-tensors" }), ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { "dtype": torch.float16, - "quantization": "compressed-tensors" - }), - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", { - "quantization": "compressed-tensors" }), + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 9f9b2d06b227..27cd10b77491 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -10,7 +10,7 @@ kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig from .backend import TestBackend @@ -49,13 +49,15 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, do_fusion: bool): torch.set_default_device("cuda") - config = CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True) - noop_pass = NoOpEliminationPass(config) - fusion_pass = FusionPass.instance(config) + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig(pass_config= \ + CompilationConfig.PassConfig(enable_fusion=do_fusion, + enable_noop=True)) + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass] - func_pass = FixFunctionalizationPass(config) + func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index efebf05b6b04..6a696fe0226b 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -77,12 +77,13 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) + vllm_config.compilation_config.pass_config = \ + CompilationConfig.PassConfig(enable_fusion=True, + enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) - noop_pass = NoOpEliminationPass(config) - fusion_pass = FusionPass.instance(config) + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 2c1ee4dc7480..673ebe8b6fdc 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -6,7 +6,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import CompilationConfig +from vllm.config import VllmConfig # dummy custom pass that doesn't inherit @@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph): # Should fail to add directly to the pass manager def test_bad_callable(): - config = CompilationConfig().pass_config + config = VllmConfig() pass_manager = PostGradPassManager() pass_manager.configure(config) @@ -43,7 +43,7 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ], ) def test_pass_manager_uuid(callable): - config = CompilationConfig().pass_config + config = VllmConfig() pass_manager = PostGradPassManager() pass_manager.configure(config) @@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.enable_fusion = not config2.enable_fusion + config2.compilation_config.pass_config.enable_fusion = not \ + config2.compilation_config.pass_config.enable_fusion pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) pass_manager3.add(callable) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py new file mode 100644 index 000000000000..79f5486dadcd --- /dev/null +++ b/tests/compile/test_sequence_parallelism.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, + find_specified_fn, + find_specified_fn_maybe, is_func) +from vllm.compilation.sequence_parallelism import SequenceParallelismPass +from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, + VllmConfig) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +from ..utils import multi_gpu_test +from .backend import TestBackend + +OPS_IN_MODEL_BEFORE = [ + torch.ops.vllm.all_reduce.default, +] + +OPS_IN_MODEL_AFTER = [ + torch.ops.vllm.reduce_scatter.default, + torch.ops.vllm.all_gather.default, +] + +OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default] + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size=16, intermediate_size=32): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size))) + self.norm = RMSNorm(hidden_size, 1e-05) + # Initialize weights + torch.nn.init.normal_(self.gate_proj, std=0.02) + + def forward(self, hidden_states, residual): + """ + Forward pass implementing the operations in the FX graph + + Args: + hidden_states: Input tensor + residual: Residual tensor from previous layer + + Returns: + Tuple containing the output tensor + """ + # Reshape input + view = hidden_states.reshape(-1, self.hidden_size) + + #matrix multiplication + permute = self.gate_proj.permute(1, 0) + mm = torch.mm(view, permute) + + # Tensor parallel all-reduce + all_reduce = tensor_model_parallel_all_reduce(mm) + + # layer normalization + norm_output, residual_output = self.norm(all_reduce, residual) + + return norm_output, residual_output + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [16]) +@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_sequence_parallelism_pass(batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn(fn, + args=(num_processes, batch_size, seq_len, + hidden_size, dtype), + nprocs=nprocs) + + run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) + + +def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, + batch_size: int, seq_len: int, + hidden_size: int, + dtype: torch.dtype): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + # initialize distributed + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # configure vllm config for SequenceParallelismPass + vllm_config = VllmConfig() + vllm_config.compilation_config = CompilationConfig( + pass_config=CompilationConfig.PassConfig( + enable_sequence_parallelism=True, ), ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + + # this is a fake model name to construct the model config + # in the vllm_config, it's not really used. + model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42) + + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + backend_no_func = TestBackend(sequence_parallelism_pass) + func_pass = FixFunctionalizationPass(vllm_config) + backend_func = TestBackend(sequence_parallelism_pass, func_pass) + + model = TestModel(hidden_size, hidden_size * 2) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) + + # Check substitution worked + pre_nodes = backend_no_func.graph_pre_pass.nodes + post_nodes = backend_no_func.graph_post_pass.nodes + + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + for op in OPS_IN_MODEL_BEFORE: + find_specified_fn(pre_nodes, op) + for op in OPS_IN_MODEL_AFTER: + assert find_specified_fn_maybe(pre_nodes, op) is None + + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + for op in OPS_IN_MODEL_AFTER: + find_specified_fn(post_nodes, op) + for op in OPS_IN_MODEL_BEFORE: + assert find_specified_fn_maybe(post_nodes, op) is None + + # check if the functionalization pass is applied + for op in OPS_IN_MODEL: + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, + op) is None # noqa: E501 + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in OPS_IN_MODEL: + if is_func(node, op): + found[op] = True + assert all(found[op] for op in OPS_IN_MODEL) diff --git a/tests/conftest.py b/tests/conftest.py index 69447d3c474d..e62b56cb5825 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,20 +21,20 @@ from tests.models.utils import (TokensTextLogprobs, TokensTextLogprobsPromptLogprobs) from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype +from vllm.config import TaskOption, _get_and_verify_dtype from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - TokensPrompt, to_enc_dec_tuple_list, - zip_enc_dec_prompts) + to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.utils import cuda_device_count_stateless, is_list_of +from vllm.utils import cuda_device_count_stateless logger = init_logger(__name__) @@ -104,10 +104,25 @@ def prompts(self, prompts: _VideoAssetPrompts) -> list[str]: return [prompts["sample_demo_1"]] +class _AudioAssetsBase(UserList[AudioAsset]): + pass + + +class _AudioAssets(_AudioAssetsBase): + + def __init__(self) -> None: + super().__init__([ + AudioAsset("mary_had_lamb"), + AudioAsset("winning_call"), + ]) + + IMAGE_ASSETS = _ImageAssets() """Singleton instance of :class:`_ImageAssets`.""" VIDEO_ASSETS = _VideoAssets() """Singleton instance of :class:`_VideoAssets`.""" +AUDIO_ASSETS = _AudioAssets() +"""Singleton instance of :class:`_AudioAssets`.""" @pytest.fixture(scope="function", autouse=True) @@ -264,6 +279,11 @@ def video_assets() -> _VideoAssets: return VIDEO_ASSETS +@pytest.fixture(scope="session") +def audio_assets() -> _AudioAssets: + return AUDIO_ASSETS + + _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) _R = TypeVar("_R") @@ -391,10 +411,15 @@ def get_inputs( processor_kwargs["images"] = image if videos is not None and (video := videos[i]) is not None: processor_kwargs["videos"] = video - if audios is not None and (audio_tuple := audios[i]) is not None: - audio, sr = audio_tuple - processor_kwargs["audio"] = audio - processor_kwargs["sampling_rate"] = sr + if audios is not None and (audio_inputs := audios[i]) is not None: + # HACK - not all processors take sampling_rate; we should + # clean this up in the future. + if len(audio_inputs) == 2: + audio, sr = audio_inputs + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + else: + processor_kwargs["audio"] = audio_inputs inputs = self.processor(**processor_kwargs) if isinstance(inputs, BatchFeature): @@ -469,12 +494,19 @@ def generate_beam_search( prompts: list[str], beam_width: int, max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, num_beams=beam_width, - num_return_sequences=beam_width) + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios) + for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): @@ -525,7 +557,10 @@ def _hidden_states_to_seq_logprobs( for _, hidden_state in enumerate(hidden_states): last_hidden_states = hidden_state[-1][0] logits = torch.matmul( - last_hidden_states.to(output_embeddings.weight.device), + last_hidden_states.to( + device=output_embeddings.weight.device, + dtype=output_embeddings.weight.dtype, + ), output_embeddings.weight.t(), ) if getattr(output_embeddings, "bias", None) is not None: @@ -919,6 +954,7 @@ def generate_encoder_decoder_greedy_logprobs( max_tokens: int, num_logprobs: int, num_prompt_logprobs: Optional[int] = None, + skip_special_tokens: bool = True, ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( @@ -926,6 +962,7 @@ def generate_encoder_decoder_greedy_logprobs( max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=(num_prompt_logprobs), + skip_special_tokens=skip_special_tokens, ) ''' Greedy logprobs generation for vLLM encoder/decoder models @@ -936,18 +973,20 @@ def generate_encoder_decoder_greedy_logprobs( def generate_beam_search( self, - prompts: Union[list[str], list[list[int]]], + prompts: list[str], beam_width: int, max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - if is_list_of(prompts, str, check="all"): - prompts = [TextPrompt(prompt=prompt) for prompt in prompts] - else: - prompts = [ - TokensPrompt(prompt_token_ids=tokens) for tokens in prompts - ] + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + outputs = self.model.beam_search( - prompts, + inputs, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) returned_outputs = [] for output in outputs: @@ -1000,20 +1039,6 @@ def vllm_runner(): return VllmRunner -def get_tokenizer_pool_config(tokenizer_group_type): - if tokenizer_group_type is None: - return None - if tokenizer_group_type == "ray": - return TokenizerPoolConfig(pool_size=1, - pool_type="ray", - extra_config={}) - if isinstance(tokenizer_group_type, type): - return TokenizerPoolConfig(pool_size=1, - pool_type=tokenizer_group_type, - extra_config={}) - raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") - - @pytest.fixture() def temporary_enable_log_propagate(): import logging diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index e9b537ed5150..9e8e315d87b1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -195,15 +195,15 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{ - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 2, "max_num_seqs": 2, }, { - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 3, "max_num_seqs": 2, }, { - "block_size": 8, + "block_size": 16, "max_num_batched_tokens": 256, "max_num_seqs": 10, }]) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index ac6d6aae3006..8f4c3537e158 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -47,6 +48,34 @@ def all_reduce_test_worker( torch.testing.assert_close(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, + pp_size: int, rank: int, + distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tp_size) + ] + + index = rank % tp_size + partition_size = num_elements // tp_size + all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + expected = all_reduce[index * partition_size:(index + 1) * partition_size] + t = all_tensors[index] + t = tensor_model_parallel_reduce_scatter(t, 0) + torch.testing.assert_close(t, expected) + + @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 05e30f855ced..03de8d9b92bf 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -161,12 +161,12 @@ def iter_params(self, model_id: str): "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(), "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(), "tiiuae/falcon-7b": PPTestSettings.fast(), - "google/gemma-2b": PPTestSettings.fast(), + "google/gemma-1.1-2b-it": PPTestSettings.fast(), "google/gemma-2-9b": PPTestSettings.fast(), "gpt2": PPTestSettings.fast(), "bigcode/starcoder": PPTestSettings.fast(), "EleutherAI/gpt-j-6b": PPTestSettings.fast(), - "EleutherAI/pythia-12b": PPTestSettings.fast(), + "EleutherAI/pythia-1.4b": PPTestSettings.fast(), "ibm/PowerLM-3b": PPTestSettings.fast(), "ibm/PowerMoE-3b": PPTestSettings.fast(), # Uses Llama @@ -195,7 +195,7 @@ def iter_params(self, model_id: str): "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(), "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"), # noqa: E501 "Qwen/Qwen-7B-Chat": PPTestSettings.fast(), - "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(), + "Qwen/Qwen2.5-0.5B-Instruct": PPTestSettings.fast(), "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py new file mode 100644 index 000000000000..19497ad9c140 --- /dev/null +++ b/tests/distributed/test_sequence_parallel.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_sequence_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + sp_enabled: bool + eager_mode: bool + chunked_prefill: bool + + +class SPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class SPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + task: TaskOption + test_options: SPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 2, + multi_node_only: bool = False, + task: TaskOption = "auto", + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True) + ], + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fast( + *, + tp_base: int = 2, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + vllm_major_versions=["1", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.task, opts) + + +def _compare_sp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate", "encode"], + is_multimodal: bool, +): + ( + tp_size, + sp_enabled, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + compilation_config = { + 'level': 3, + 'custom_ops': ["+rms_norm"], + 'compile_sizes': [4, 8], + 'splitting_ops': [], + 'pass_config': { + 'enable_sequence_parallism': sp_enabled, + 'enable_noop': True, + 'enable_fusion': True, + }, + } + + tp_sp_env = tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + + tp_sp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + str(compilation_config), + ] + + tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_id, + tp_sp_args, + tp_args, + tp_sp_env, + tp_env, + method=method) + except Exception: + testing_ray_compiled_graph = tp_sp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "meta-llama/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available, +): + _compare_sp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 92387b46425e..052d5793c1b3 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,16 +1,120 @@ # SPDX-License-Identifier: Apache-2.0 +import json from argparse import ArgumentError, ArgumentTypeError +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Literal, Optional import pytest -from vllm.config import PoolerConfig -from vllm.engine.arg_utils import EngineArgs, nullable_kvs +from vllm.config import PoolerConfig, config +from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, + get_type, is_not_builtin, is_type, + nullable_kvs, optional_type) from vllm.utils import FlexibleArgumentParser +@pytest.mark.parametrize(("type", "value", "expected"), [ + (int, "42", 42), + (int, "None", None), + (float, "3.14", 3.14), + (float, "None", None), + (str, "Hello World!", "Hello World!"), + (str, "None", None), + (json.loads, '{"foo":1,"bar":2}', { + "foo": 1, + "bar": 2 + }), + (json.loads, "foo=1,bar=2", { + "foo": 1, + "bar": 2 + }), + (json.loads, "None", None), +]) +def test_optional_type(type, value, expected): + optional_type_func = optional_type(type) + context = nullcontext() + if value == "foo=1,bar=2": + context = pytest.warns(DeprecationWarning) + with context: + assert optional_type_func(value) == expected + + +@pytest.mark.parametrize(("type_hint", "type", "expected"), [ + (int, int, True), + (int, float, False), + (list[int], list, True), + (list[int], tuple, False), + (Literal[0, 1], Literal, True), +]) +def test_is_type(type_hint, type, expected): + assert is_type(type_hint, type) == expected + + +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ + ({float, int}, int, True), + ({int, tuple[int]}, int, True), + ({int, tuple[int]}, float, False), + ({str, Literal["x", "y"]}, Literal, True), +]) +def test_contains_type(type_hints, type, expected): + assert contains_type(type_hints, type) == expected + + +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ + ({int, float}, int, int), + ({int, float}, str, None), + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), +]) +def test_get_type(type_hints, type, expected): + assert get_type(type_hints, type) == expected + + +@config +@dataclass +class DummyConfigClass: + regular_bool: bool = True + """Regular bool with default True""" + optional_bool: Optional[bool] = None + """Optional bool with default None""" + optional_literal: Optional[Literal["x", "y"]] = None + """Optional literal with default None""" + tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) + """Tuple with default (1, 2, 3)""" + tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) + """Tuple with default (1, 2)""" + list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) + """List with default [1, 2, 3]""" + + +@pytest.mark.parametrize(("type_hint", "expected"), [ + (int, False), + (DummyConfigClass, True), +]) +def test_is_not_builtin(type_hint, expected): + assert is_not_builtin(type_hint) == expected + + +def test_get_kwargs(): + kwargs = get_kwargs(DummyConfigClass) + print(kwargs) + + # bools should not have their type set + assert kwargs["regular_bool"].get("type") is None + assert kwargs["optional_bool"].get("type") is None + # optional literals should have None as a choice + assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"] + # tuples should have the correct nargs + assert kwargs["tuple_n"]["nargs"] == "+" + assert kwargs["tuple_2"]["nargs"] == 2 + # lists should work + assert kwargs["list_n"]["type"] is int + assert kwargs["list_n"]["nargs"] == "+" + + @pytest.mark.parametrize(("arg", "expected"), [ - (None, None), + (None, dict()), ("image=16", { "image": 16 }), @@ -24,6 +128,10 @@ }), ]) def test_limit_mm_per_prompt_parser(arg, expected): + """This functionality is deprecated and will be removed in the future. + This argument should be passed as JSON string instead. + + TODO: Remove with nullable_kvs.""" parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) @@ -53,12 +161,20 @@ def test_compilation_config(): assert args.compilation_config.level == 3 # set to string form of a dict - args = parser.parse_args(["--compilation-config", "{'level': 3}"]) - assert args.compilation_config.level == 3 + args = parser.parse_args([ + "--compilation-config", + "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + ]) + assert (args.compilation_config.level == 3 and + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) # set to string form of a dict - args = parser.parse_args(["--compilation-config={'level': 3}"]) - assert args.compilation_config.level == 3 + args = parser.parse_args([ + "--compilation-config=" + "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + ]) + assert (args.compilation_config.level == 3 and + args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) def test_prefix_cache_default(): diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index e96081c167ed..6a4862123b51 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -89,3 +89,31 @@ def test_chat_multi_image(image_urls: list[str]): }] outputs = llm.chat(messages) assert len(outputs) >= 0 + + +def test_llm_chat_tokenization_no_double_bos(): + """ + LLM.chat() should not add special tokens when using chat templates. + Check we get a single BOS token for llama chat. + """ + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True) + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello!" + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 + prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None) + assert prompt_token_ids is not None + + bos_token = llm.get_tokenizer().bos_token_id + + # Ensure we have a single BOS + assert prompt_token_ids[0] == bos_token + assert prompt_token_ids[1] != bos_token, "Double BOS" diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index e43e9826e8f9..ad726fa8ce51 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -305,7 +305,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): with pytest.raises( ValueError, match="xgrammar does not support advanced JSON schema features " - "like enums, patterns or numeric ranges."): + "like string length, item limits, or property bounds."): llm.generate(prompts="This should fail", sampling_params=sampling_params, use_tqdm=True) @@ -383,4 +383,118 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, schema=json_schema) \ No newline at end of file + jsonschema.validate(instance=output_json, schema=json_schema) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +def test_guided_number_range_json_completion(llm, + guided_decoding_backend: str): + sample_output_schema = { + "type": "object", + "properties": { + "age": { + "type": "integer", + "minimum": 18, + "maximum": 99 + }, + "score": { + "type": "number", + "minimum": 0.0, + "maximum": 100.0 + }, + "zipcode": { + "type": "string", + "pattern": r"^\d{5}(-\d{4})?$" + }, + }, + "required": ["age", "score", "zipcode"], + } + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=sample_output_schema, + backend=guided_decoding_backend), + ) + outputs = llm.generate( + prompts=[ + "Create a JSON object for a user with age, score, and zipcode." + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_output_schema) + assert 18 <= output_json["age"] <= 99 + assert 0.0 <= output_json["score"] <= 100.0 + assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"]) + is not None) + + +@pytest.mark.skip_global_cleanup +def test_guidance_no_additional_properties(llm): + schema = { + 'type': 'object', + 'properties': { + 'a1': { + 'type': 'string' + }, + 'a2': { + 'type': 'string' + }, + 'a3': { + 'type': 'string' + } + }, + 'required': ['a1', 'a2', 'a3'], + } + + prompt = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " + "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "<|im_end|>\n<|im_start|>assistant\n") + + def generate_with_backend(backend): + guided_params = GuidedDecodingParams(json=schema, backend=backend) + sampling_params = SamplingParams(temperature=0, + max_tokens=256, + guided_decoding=guided_params) + + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + assert outputs is not None + generated_text = outputs[0].outputs[0].text + assert generated_text is not None + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) + return parsed_json + + base_generated = generate_with_backend('guidance:disable-any-whitespace') + assert "a1" in base_generated + assert "a2" in base_generated + assert "a3" in base_generated + # by default additional keys are generated + assert "a4" in base_generated + assert "a5" in base_generated + assert "a6" in base_generated + + generated = generate_with_backend( + 'guidance:no-additional-properties,disable-any-whitespace') + assert "a1" in generated + assert "a2" in generated + assert "a3" in generated + assert "a4" not in generated + assert "a5" not in generated + assert "a6" not in generated diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index eca5d184f5d6..642c204b9ff0 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -150,6 +150,7 @@ def test_wer_correctness(model_name, expected_wer, n_examples=-1, max_concurrent_request=None): + # TODO refactor to use `ASRDataset` with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: dataset = load_hf_dataset(dataset_repo) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index b13002a5b682..72e616656775 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import json + import openai import pytest import pytest_asyncio @@ -27,7 +29,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"audio={MAXIMUM_AUDIOS}", + json.dumps({"audio": MAXIMUM_AUDIOS}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -102,6 +104,35 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) +async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI, + model_name: str, + audio_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "audio_url", + "audio_url": audio_url + }, + { + "type": "text", + "text": "What's happening in this audio?" + }, + ], + }] + + # audio_url should be a dict {"url": "some url"}, not directly a string + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 2cdeb684f75d..50b20e78c4c4 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -11,11 +11,12 @@ from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.embedding.utils import check_embeddings_close +from ...models.embedding.utils import correctness_test from ...utils import RemoteOpenAIServer MODEL_NAME = "intfloat/multilingual-e5-small" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 +DTYPE = "bfloat16" @pytest.fixture(scope="module") @@ -25,7 +26,7 @@ def server(): "embed", # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + DTYPE, "--enforce-eager", "--max-model-len", "512", @@ -43,9 +44,17 @@ async def client(server): yield async_client +@pytest.fixture(scope="module") +def hf_model(hf_runner): + with hf_runner(MODEL_NAME, dtype=DTYPE, + is_sentence_transformer=True) as hf_model: + yield hf_model + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): +async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, + model_name: str): input_texts = [ "The chef prepared a delicious meal.", ] @@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.usage.prompt_tokens == 11 assert embeddings.usage.total_tokens == 11 + vllm_outputs = [d.embedding for d in embeddings.data] + correctness_test(hf_model, input_texts, vllm_outputs) + # test using token IDs input_tokens = [1, 1, 1, 1, 1] embedding_response = await client.embeddings.create( @@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): +async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, + model_name: str): # test list[str] input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", @@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): assert embeddings.usage.prompt_tokens == 33 assert embeddings.usage.total_tokens == 33 + vllm_outputs = [d.embedding for d in embeddings.data] + correctness_test(hf_model, input_texts, vllm_outputs) + # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], [25, 32, 64, 77]] @@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_embedding(client: openai.AsyncOpenAI, +async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): input_texts = [ "Hello my name is", @@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, model=model_name, encoding_format="float") float_data = [d.embedding for d in responses_float.data] + correctness_test(hf_model, input_texts, float_data) responses_base64 = await client.embeddings.create(input=input_texts, model=model_name, @@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI, np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist()) - check_embeddings_close( - embeddings_0_lst=float_data, - embeddings_1_lst=base64_data, - name_0="float", - name_1="base64", - ) + correctness_test(hf_model, input_texts, base64_data) # Default response is float32 decoded from base64 by OpenAI Client responses_default = await client.embeddings.create(input=input_texts, model=model_name) default_data = [d.embedding for d in responses_default.data] - - check_embeddings_close( - embeddings_0_lst=float_data, - embeddings_1_lst=default_data, - name_0="float", - name_1="default", - ) + correctness_test(hf_model, input_texts, default_data) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/openai/test_embedding_dimensions.py index 79d43a2231f8..9f5a8c6839bc 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/openai/test_embedding_dimensions.py @@ -3,80 +3,121 @@ Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. """ -from typing import NamedTuple +from typing import Optional import openai import pytest from vllm.entrypoints.openai.protocol import EmbeddingResponse +from ...conftest import HfRunner +from ...models.embedding.utils import EmbedModelInfo, correctness_test from ...utils import RemoteOpenAIServer - -class ModelInfo(NamedTuple): - name: str - is_matryoshka: bool - - MODELS = [ - ModelInfo(name="BAAI/bge-m3", is_matryoshka=False), - ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True), + EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256]), ] input_texts = [ "The chef prepared a delicious meal.", -] * 3 +] -@pytest.mark.asyncio -@pytest.mark.parametrize("model", MODELS) -async def test_validating_dimensions(model: ModelInfo): +@pytest.fixture(scope="module", params=MODELS) +def model_info(request): + return request.param + + +@pytest.fixture(scope="module", params=["bfloat16"]) +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def server(model_info, dtype: str): args = [ "--task", "embed", # use half precision for speed and memory savings in CI environment "--dtype", - "bfloat16", + dtype, "--enforce-eager", "--max-model-len", - "512", - "--trust_remote_code" + "512" ] - with RemoteOpenAIServer(model.name, args) as remote_server: - client = remote_server.get_async_client() - - async def make_request(dimensions): - embedding_response = await client.embeddings.create( - model=model.name, - input=input_texts, - dimensions=dimensions, - encoding_format="float", - ) - embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) - - assert embeddings.id is not None - assert len(embeddings.data) == 3 - assert len(embeddings.data[0].embedding) > 0 - assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens > 0 - assert embeddings.usage.total_tokens > 0 - - if dimensions is not None: - assert len(embeddings.data[0].embedding) == dimensions - - if model.is_matryoshka: - for dimensions in [None, 16]: - await make_request(dimensions) + if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5": + # Manually enable Matryoshka Embeddings + args.extend([ + "--trust_remote_code", "--hf_overrides", + '{"matryoshka_dimensions":[256]}' + ]) + + with RemoteOpenAIServer(model_info.name, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def hf_model(hf_runner, model_info, dtype: str): + with hf_runner(model_info.name, dtype=dtype, + is_sentence_transformer=True) as hf_model: + yield hf_model + + +@pytest.mark.asyncio +async def test_matryoshka(model_info: EmbedModelInfo, + server: RemoteOpenAIServer, hf_model: HfRunner): + client = server.get_async_client() + + async def make_request_and_correctness_test(dimensions): + prompts = input_texts * 3 + + embedding_response = await client.embeddings.create( + model=model_info.name, + input=prompts, + dimensions=dimensions, + encoding_format="float", + ) + embeddings = EmbeddingResponse.model_validate( + embedding_response.model_dump(mode="json")) + + assert embeddings.id is not None + assert len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) > 0 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens > 0 + assert embeddings.usage.total_tokens > 0 + + if dimensions is not None: + assert len(embeddings.data[0].embedding) == dimensions + + vllm_outputs = [d.embedding for d in embeddings.data] + correctness_test(hf_model, prompts, vllm_outputs, dimensions) + + if model_info.is_matryoshka: + valid_dimensions: list[Optional[int]] = [None] + if model_info.matryoshka_dimensions is not None: + valid_dimensions += model_info.matryoshka_dimensions[:2] + + for dimensions in valid_dimensions: + await make_request_and_correctness_test(dimensions) + + invalid_dimensions: list[Optional[int]] = [-1] + if model_info.matryoshka_dimensions is not None: + assert 5 not in model_info.matryoshka_dimensions + invalid_dimensions.append(5) + + for dimensions in invalid_dimensions: with pytest.raises(openai.BadRequestError): - for dimensions in [-1]: - await make_request(dimensions) + await make_request_and_correctness_test(dimensions) - else: - for dimensions in [None]: - await make_request(dimensions) + else: + for dimensions in [None]: + await make_request_and_correctness_test(dimensions) + for dimensions in [-1, 16]: with pytest.raises(openai.BadRequestError): - for dimensions in [-1, 16]: - await make_request(dimensions) + await make_request_and_correctness_test(dimensions) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py new file mode 100644 index 000000000000..c96151349eb3 --- /dev/null +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import suppress +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from vllm.config import MultiModalConfig +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL_NAME = "openai-community/gpt2" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + +MOCK_RESOLVER_NAME = "mock_test_resolver" + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + """Minimal mock ModelConfig for testing.""" + model: str = MODEL_NAME + tokenizer: str = MODEL_NAME + trust_remote_code: bool = False + tokenizer_mode: str = "auto" + max_model_len: int = 100 + tokenizer_revision: Optional[str] = None + multimodal_config: MultiModalConfig = field( + default_factory=MultiModalConfig) + hf_config: MockHFConfig = field(default_factory=MockHFConfig) + logits_processor_pattern: Optional[str] = None + diff_sampling_param: Optional[dict] = None + allowed_local_media_path: str = "" + encoder_config = None + generation_config: str = "auto" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +class MockLoRAResolver(LoRAResolver): + + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + if lora_name == "test-lora": + return LoRARequest(lora_name="test-lora", + lora_int_id=1, + lora_local_path="/fake/path/test-lora") + elif lora_name == "invalid-lora": + return LoRARequest(lora_name="invalid-lora", + lora_int_id=2, + lora_local_path="/fake/path/invalid-lora") + return None + + +@pytest.fixture(autouse=True) +def register_mock_resolver(): + """Fixture to register and unregister the mock LoRA resolver.""" + resolver = MockLoRAResolver() + LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver) + yield + # Cleanup: remove the resolver after the test runs + if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers: + del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME] + + +@pytest.fixture +def mock_serving_setup(): + """Provides a mocked engine and serving completion instance.""" + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + def mock_add_lora_side_effect(lora_request: LoRARequest): + """Simulate engine behavior when adding LoRAs.""" + if lora_request.lora_name == "test-lora": + # Simulate successful addition + return + elif lora_request.lora_name == "invalid-lora": + # Simulate failure during addition (e.g. invalid format) + raise ValueError(f"Simulated failure adding LoRA: " + f"{lora_request.lora_name}") + + mock_engine.add_lora.side_effect = mock_add_lora_side_effect + mock_engine.generate.reset_mock() + mock_engine.add_lora.reset_mock() + + mock_model_config = MockModelConfig() + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + + serving_completion = OpenAIServingCompletion(mock_engine, + mock_model_config, + models, + request_logger=None) + + return mock_engine, serving_completion + + +@pytest.mark.asyncio +async def test_serving_completion_with_lora_resolver(mock_serving_setup, + monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") + + mock_engine, serving_completion = mock_serving_setup + + lora_model_name = "test-lora" + req_found = CompletionRequest( + model=lora_model_name, + prompt="Generate with LoRA", + ) + + # Suppress potential errors during the mocked generate call, + # as we are primarily checking for add_lora and generate calls + with suppress(Exception): + await serving_completion.create_completion(req_found) + + mock_engine.add_lora.assert_called_once() + called_lora_request = mock_engine.add_lora.call_args[0][0] + assert isinstance(called_lora_request, LoRARequest) + assert called_lora_request.lora_name == lora_model_name + + mock_engine.generate.assert_called_once() + called_lora_request = mock_engine.generate.call_args[1]['lora_request'] + assert isinstance(called_lora_request, LoRARequest) + assert called_lora_request.lora_name == lora_model_name + + +@pytest.mark.asyncio +async def test_serving_completion_resolver_not_found(mock_serving_setup, + monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") + + mock_engine, serving_completion = mock_serving_setup + + non_existent_model = "non-existent-lora-adapter" + req = CompletionRequest( + model=non_existent_model, + prompt="what is 1+1?", + ) + + response = await serving_completion.create_completion(req) + + mock_engine.add_lora.assert_not_called() + mock_engine.generate.assert_not_called() + + assert isinstance(response, ErrorResponse) + assert response.code == HTTPStatus.NOT_FOUND.value + assert non_existent_model in response.message + + +@pytest.mark.asyncio +async def test_serving_completion_resolver_add_lora_fails( + mock_serving_setup, monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") + + mock_engine, serving_completion = mock_serving_setup + + invalid_model = "invalid-lora" + req = CompletionRequest( + model=invalid_model, + prompt="what is 1+1?", + ) + + response = await serving_completion.create_completion(req) + + # Assert add_lora was called before the failure + mock_engine.add_lora.assert_called_once() + called_lora_request = mock_engine.add_lora.call_args[0][0] + assert isinstance(called_lora_request, LoRARequest) + assert called_lora_request.lora_name == invalid_model + + # Assert generate was *not* called due to the failure + mock_engine.generate.assert_not_called() + + # Assert the correct error response + assert isinstance(response, ErrorResponse) + assert response.code == HTTPStatus.BAD_REQUEST.value + assert invalid_model in response.message + + +@pytest.mark.asyncio +async def test_serving_completion_flag_not_set(mock_serving_setup): + mock_engine, serving_completion = mock_serving_setup + + lora_model_name = "test-lora" + req_found = CompletionRequest( + model=lora_model_name, + prompt="Generate with LoRA", + ) + + await serving_completion.create_completion(req_found) + + mock_engine.add_lora.assert_not_called() + mock_engine.generate.assert_not_called() diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py new file mode 100644 index 000000000000..1ccb803a328d --- /dev/null +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import schemathesis +from schemathesis import GenerationConfig + +from ...utils import RemoteOpenAIServer + +schemathesis.experimental.OPEN_API_3_1.enable() + +MODEL_NAME = "HuggingFaceTB/SmolVLM-256M-Instruct" +MAXIMUM_IMAGES = 2 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "generate", + "--max-model-len", + "2048", + "--max-num-seqs", + "5", + "--enforce-eager", + "--trust-remote-code", + "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def get_schema(server): + # avoid generating null (\x00) bytes in strings during test case generation + return schemathesis.openapi.from_uri( + f"{server.url_root}/openapi.json", + generation_config=GenerationConfig(allow_x00=False), + ) + + +schema = schemathesis.from_pytest_fixture("get_schema") + + +@schema.parametrize() +@schema.override(headers={"Content-Type": "application/json"}) +async def test_openapi_stateless(case): + #No need to verify SSL certificate for localhost + await case.call_and_validate(verify=False) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 29571bcd7649..5c48df3cebbc 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -192,3 +192,36 @@ async def post_with_stream(*args, **kwargs): else: continuous = continuous and hasattr(chunk, 'usage') assert final and continuous + + +@pytest.mark.asyncio +async def test_sampling_params(mary_had_lamb): + """ + Compare sampling with params and greedy sampling to assert results + are different when extreme sampling parameters values are picked. + """ + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + temperature=0.8, + extra_body=dict(seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0)) + + greedy_transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + temperature=0.0, + extra_body=dict(seed=42)) + + assert greedy_transcription.text != transcription.text diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index f9ccce9c1c33..53f057a294c0 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import json + import openai import pytest import pytest_asyncio @@ -31,7 +33,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"video={MAXIMUM_VIDEOS}", + json.dumps({"video": MAXIMUM_VIDEOS}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -106,6 +108,35 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI, + model_name: str, + video_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "video_url", + "video_url": video_url + }, + { + "type": "text", + "text": "What's in this video?" + }, + ], + }] + + # video_url should be a dict {"url": "some url"}, not directly a string + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 4b9029ded41b..1ab50b41c7ec 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import json + import openai import pytest import pytest_asyncio @@ -35,7 +37,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + json.dumps({"image": MAXIMUM_IMAGES}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -135,6 +137,36 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, + model_name: str, + image_url: str): + content_text = "What's in this image?" + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": image_url + }, + { + "type": "text", + "text": content_text + }, + ], + }] + + # image_url should be a dict {"url": "some url"}, not directly a string + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 3e6f13e10ac2..26c68e06c199 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import json + import pytest import requests from PIL import Image @@ -37,7 +39,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + json.dumps({"image": MAXIMUM_IMAGES}), "--chat-template", str(vlm2vec_jinja_path), ] diff --git a/tests/kernels/conftest.py b/tests/kernels/attention/conftest.py similarity index 100% rename from tests/kernels/conftest.py rename to tests/kernels/attention/conftest.py diff --git a/tests/kernels/test_attention.py b/tests/kernels/attention/test_attention.py similarity index 99% rename from tests/kernels/test_attention.py rename to tests/kernels/attention/test_attention.py index 0d7898a900e4..e5650136f258 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -6,13 +6,12 @@ import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes -from .allclose_default import get_default_atol, get_default_rtol - if not current_platform.is_rocm(): from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py new file mode 100644 index 000000000000..b0414244c215 --- /dev/null +++ b/tests/kernels/attention/test_attention_selector.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.platforms.cpu import CpuPlatform +from vllm.platforms.cuda import CudaPlatform +from vllm.platforms.rocm import RocmPlatform +from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +# Define MLA and non-MLA backends separately +DEVICE_MLA_BACKENDS = { + "cuda": ["TRITON_MLA", "FLASHMLA"], + "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], + "cpu": [], +} + +DEVICE_REGULAR_ATTN_BACKENDS = { + "cuda": ["XFORMERS", "FLASHINFER"], + "hip": ["ROCM_FLASH"], + "cpu": ["TORCH_SDPA"], +} + +DEVICE_MLA_BLOCK_SIZES = { + "cuda": [16, 64], # CUDA supports both standard and extended block sizes + "hip": [16, 1], # HIP requires special handling for block_size=1 + "cpu": [16] # CPU uses fixed block size from test cases +} + + +def generate_params(): + params = [] + for use_mla in [True, False]: + for device in ["cuda", "hip", "cpu"]: + backends = DEVICE_MLA_BACKENDS[ + device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device] + for name in backends: + block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [ + 16 + ] + for block_size in block_sizes: + params.append( + pytest.param( + device, + name, + use_mla, + block_size, + id= + f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}" + )) + return params + + +@pytest.mark.parametrize("device, name, use_mla, block_size", + generate_params()) +@pytest.mark.parametrize("use_v1", [True, False]) +def test_env( + device: str, + name: str, + use_mla: bool, + block_size: int, + use_v1: bool, + monkeypatch: pytest.MonkeyPatch, +): + """Test attention backend selection with valid device-backend pairs.""" + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv(STR_BACKEND_ENV_VAR, name) + m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") + + if device == "cpu": + with patch("vllm.attention.selector.current_platform", + CpuPlatform()): + backend = get_attn_backend(16, torch.float16, torch.float16, + block_size, False) + assert backend.get_name() == "TORCH_SDPA" + + elif device == "hip": + with patch("vllm.attention.selector.current_platform", + RocmPlatform()): + if use_mla: + # Validate HIP MLA backend-block_size combinations + valid_combination = ( + (name == "TRITON_MLA" and block_size != 1) + or (name == "ROCM_AITER_MLA" and block_size == 1)) + + if valid_combination: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert backend.get_name() == name + else: + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" + assert backend.get_name() == expected + + elif device == "cuda": + with patch("vllm.attention.selector.current_platform", + CudaPlatform()): + if use_mla: + if name == "FLASHMLA" and block_size == 64: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + + # only on cuda platforms with specific capability. + is_supported, _ = is_flashmla_supported() + + if not is_supported: + # if platform is not supported then skip this case. + pytest.skip() + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = ("TRITON_MLA_VLLM_V1" + if use_v1 else "TRITON_MLA") + assert backend.get_name() == expected + elif name == "FLASHINFER": + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "FLASHINFER_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + + +def test_flash_attn(monkeypatch: pytest.MonkeyPatch): + """Test FlashAttn validation.""" + # TODO: When testing for v1, pipe in `use_v1` as an argument to + # get_attn_backend + + with monkeypatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) + + # Unsupported CUDA arch + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: + (7, 5)) + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Reset the monkeypatch for subsequent tests + monkeypatch.undo() + + # Unsupported data type + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Unsupported kv cache data type + backend = get_attn_backend(16, torch.float16, "fp8", 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Unsupported block size + backend = get_attn_backend(16, torch.float16, None, 8, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # flash-attn is not installed + import sys + original_module = sys.modules.get('vllm_flash_attn') + monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Restore the original module if it existed + if original_module is not None: + monkeypatch.setitem(sys.modules, 'vllm_flash_attn', + original_module) + else: + monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) + + # Unsupported head size + backend = get_attn_backend(17, torch.float16, None, 16, False) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + # Attention-free models should bypass env and use PlaceholderAttention + backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + assert backend.get_name() != STR_FLASH_ATTN_VAL + + +@pytest.mark.parametrize("use_v1", [True, False]) +def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): + + with monkeypatch.context() as m, patch( + "vllm.attention.selector.current_platform", CudaPlatform()): + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) + + # Test with head size 32 + backend = get_attn_backend(32, torch.float16, None, 16, False) + EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN" + assert backend.get_name() == EXPECTED + + # when block size == 16, backend will fall back to XFORMERS + # this behavior is not yet supported on V1. + if use_v1: + # TODO: support fallback on V1! + # https://github.com/vllm-project/vllm/issues/14524 + pass + else: + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() == "XFORMERS" diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/attention/test_blocksparse_attention.py similarity index 99% rename from tests/kernels/test_blocksparse_attention.py rename to tests/kernels/attention/test_blocksparse_attention.py index 3025ae0f921a..82d038257575 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/attention/test_blocksparse_attention.py @@ -6,14 +6,13 @@ import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) from vllm.platforms import current_platform from vllm.utils import get_max_shared_memory_bytes -from .allclose_default import get_default_atol, get_default_rtol - FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer diff --git a/tests/kernels/test_cache.py b/tests/kernels/attention/test_cache.py similarity index 93% rename from tests/kernels/test_cache.py rename to tests/kernels/attention/test_cache.py index 899122818e0e..2f2212dd2b0e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -16,6 +16,7 @@ NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 120, 256] BLOCK_SIZES = [8, 16, 32] +CACHE_LAYOUTS = ["NHD", "HND"] # Parameters for MLA tests. KV_LORA_RANKS = [512] @@ -220,6 +221,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -232,17 +234,21 @@ def test_reshape_and_cache_flash( seed: int, device: str, kv_cache_dtype: str, + kv_cache_layout: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + # fp8 conversion requires continugous memory buffer. Reduce the number of + # blocks and tokens to consume less memory. + num_tokens = num_tokens // 2 + num_blocks = num_blocks // 2 # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens) slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) - qkv = torch.randn(num_tokens, 3, num_heads, @@ -261,27 +267,35 @@ def test_reshape_and_cache_flash( kv_cache_dtype, dtype, device=device, + cache_layout=kv_cache_layout, ) - key_cache, value_cache = key_caches[0].contiguous( - ), value_caches[0].contiguous() + key_cache, value_cache = key_caches[0], value_caches[0] del key_caches del value_caches k_scale = (key.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32) + def permute_and_compact(x): + y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3) + return y.contiguous() + + key_cache_compact = permute_and_compact(key_cache) + value_cache_compact = permute_and_compact(value_cache) + # Clone the KV caches. if kv_cache_dtype == "fp8": - cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(), - kv_cache_dtype) - cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(), + cloned_key_cache = torch.empty_like(key_cache_compact, + dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype) + cloned_value_cache = torch.empty_like(value_cache_compact, + dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache_compact, + v_scale.item(), kv_cache_dtype) else: - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() - + cloned_key_cache = key_cache_compact.clone() + cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -289,16 +303,20 @@ def test_reshape_and_cache_flash( cond=(head_size == HEAD_SIZES[0])) ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale) + key_cache_compact = permute_and_compact(key_cache) + value_cache_compact = permute_and_compact(value_cache) if kv_cache_dtype == "fp8": - result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + result_key_cache = torch.empty_like(key_cache_compact, + dtype=torch.float16) ops.convert_fp8(result_key_cache, - key_cache, + key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype) - result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + result_value_cache = torch.empty_like(value_cache_compact, + dtype=torch.float16) ops.convert_fp8(result_value_cache, - value_cache, + value_cache_compact, v_scale.item(), kv_dtype=kv_cache_dtype) @@ -310,8 +328,12 @@ def test_reshape_and_cache_flash( for i in range(num_tokens): block_idx = block_indicies_lst[i] block_offset = block_offsets_lst[i] - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] + if kv_cache_layout == "NHD": + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + else: + cloned_key_cache[block_idx, :, block_offset, :] = key[i] + cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": torch.testing.assert_close(result_key_cache, @@ -323,8 +345,8 @@ def test_reshape_and_cache_flash( atol=0.001, rtol=0.1) else: - torch.testing.assert_close(key_cache, cloned_key_cache) - torch.testing.assert_close(value_cache, cloned_value_cache) + torch.testing.assert_close(key_cache_compact, cloned_key_cache) + torch.testing.assert_close(value_cache_compact, cloned_value_cache) @pytest.mark.parametrize("direction", COPYING_DIRECTION) diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py similarity index 100% rename from tests/kernels/test_cascade_flash_attn.py rename to tests/kernels/attention/test_cascade_flash_attn.py diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/attention/test_encoder_decoder_attn.py similarity index 100% rename from tests/kernels/test_encoder_decoder_attn.py rename to tests/kernels/attention/test_encoder_decoder_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py similarity index 99% rename from tests/kernels/test_flash_attn.py rename to tests/kernels/attention/test_flash_attn.py index 572563c0bd82..88516b75cde2 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -145,7 +145,7 @@ def test_flash_attn_with_paged_kv( v_descale = None if q_dtype is not None: # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = query.to(q_dtype) + maybe_quantized_query = q.to(q_dtype) maybe_quantized_key_cache = key_cache.to(q_dtype) maybe_quantized_value_cache = value_cache.to(q_dtype) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py similarity index 100% rename from tests/kernels/test_flashinfer.py rename to tests/kernels/attention/test_flashinfer.py diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/attention/test_flashmla.py similarity index 100% rename from tests/kernels/test_flashmla.py rename to tests/kernels/attention/test_flashmla.py diff --git a/tests/kernels/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py similarity index 100% rename from tests/kernels/test_lightning_attn.py rename to tests/kernels/attention/test_lightning_attn.py diff --git a/tests/kernels/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py similarity index 100% rename from tests/kernels/test_merge_attn_states.py rename to tests/kernels/attention/test_merge_attn_states.py diff --git a/tests/kernels/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py similarity index 100% rename from tests/kernels/test_mha_attn.py rename to tests/kernels/attention/test_mha_attn.py diff --git a/tests/kernels/test_mla_decode_cpu.py b/tests/kernels/attention/test_mla_decode_cpu.py similarity index 100% rename from tests/kernels/test_mla_decode_cpu.py rename to tests/kernels/attention/test_mla_decode_cpu.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py similarity index 100% rename from tests/kernels/test_prefix_prefill.py rename to tests/kernels/attention/test_prefix_prefill.py diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py new file mode 100644 index 000000000000..4cf7bcb01d4d --- /dev/null +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend +from vllm.platforms.rocm import RocmPlatform +from vllm.utils import STR_BACKEND_ENV_VAR + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +def test_selector(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") + + # Set the current platform to ROCm using monkeypatch + monkeypatch.setattr("vllm.attention.selector.current_platform", + RocmPlatform()) + + # Test standard ROCm attention + backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) + assert (backend.get_name() == "ROCM_FLASH" + or backend.get_name() == "TRITON_ATTN_VLLM_V1") + + # MLA test for deepseek related + + # change the attention backend to triton MLA + m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, + False, True) + assert backend.get_name() == "TRITON_MLA" + + # If attention backend is None + # If use_mla is true + # The selected backend is triton MLA + m.setenv(STR_BACKEND_ENV_VAR, None) + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, + False, True) + assert backend.get_name() == "TRITON_MLA" + + # change the attention backend to AITER MLA + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, + False, True) + assert backend.get_name() == "ROCM_AITER_MLA" + + # If attention backend is None + # If use_mla is true + # If VLLM_ROCM_USE_AITER is enabled + # The selected backend is ROCM_AITER_MLA + m.setenv(STR_BACKEND_ENV_VAR, None) + m.setenv("VLLM_ROCM_USE_AITER", "1") + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, + False, True) + assert backend.get_name() == "ROCM_AITER_MLA" diff --git a/tests/kernels/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py similarity index 100% rename from tests/kernels/test_triton_decode_attention.py rename to tests/kernels/attention/test_triton_decode_attention.py diff --git a/tests/kernels/test_activation.py b/tests/kernels/core/test_activation.py similarity index 97% rename from tests/kernels/test_activation.py rename to tests/kernels/core/test_activation.py index cf0f21ce0651..79f838a954e7 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -5,6 +5,7 @@ import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, GeluAndMul, MulAndSilu, @@ -12,8 +13,6 @@ SiluAndMul) from vllm.platforms import current_platform -from .allclose_default import get_default_atol, get_default_rtol - DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py similarity index 100% rename from tests/kernels/test_fused_quant_layernorm.py rename to tests/kernels/core/test_fused_quant_layernorm.py diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/core/test_layernorm.py similarity index 100% rename from tests/kernels/test_layernorm.py rename to tests/kernels/core/test_layernorm.py diff --git a/tests/kernels/core/test_opcheck.py b/tests/kernels/core/test_opcheck.py new file mode 100644 index 000000000000..c9a9679c5d80 --- /dev/null +++ b/tests/kernels/core/test_opcheck.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for miscellaneous utilities +""" + +import torch + +from tests.kernels.utils import opcheck + + +def test_convert_fp8_opcheck(): + data = torch.randn((256, 256), dtype=torch.float32, device="cuda") + result = torch.empty_like(data, dtype=torch.float8_e4m3fn) + opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) + + +# TODO: Add this back, currently fails with +# csrc/cuda_utils_kernels.cu:15 'invalid argument' +# @pytest.mark.skipif(not current_platform.is_cuda(), +# reason="Only supported for CUDA") +# def test_cuda_utils_opcheck(): +# opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) +# opcheck( +# torch.ops._C_cuda_utils. +# get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/test_permute_cols.py b/tests/kernels/core/test_permute_cols.py similarity index 100% rename from tests/kernels/test_permute_cols.py rename to tests/kernels/core/test_permute_cols.py diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py similarity index 99% rename from tests/kernels/test_pos_encoding.py rename to tests/kernels/core/test_pos_encoding.py index eb83b4d612c2..2b7bf755ec22 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -6,11 +6,10 @@ import pytest import torch +from tests.kernels.allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform -from .allclose_default import get_default_atol, get_default_rtol - IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 112, 120, 256] diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py similarity index 100% rename from tests/kernels/test_rotary_embedding.py rename to tests/kernels/core/test_rotary_embedding.py diff --git a/tests/kernels/test_uva.py b/tests/kernels/core/test_uva.py similarity index 100% rename from tests/kernels/test_uva.py rename to tests/kernels/core/test_uva.py diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py similarity index 100% rename from tests/kernels/test_causal_conv1d.py rename to tests/kernels/mamba/test_causal_conv1d.py diff --git a/tests/kernels/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py similarity index 100% rename from tests/kernels/test_mamba_mixer2.py rename to tests/kernels/mamba/test_mamba_mixer2.py diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py similarity index 100% rename from tests/kernels/test_mamba_ssm.py rename to tests/kernels/mamba/test_mamba_ssm.py diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py similarity index 95% rename from tests/kernels/test_mamba_ssm_ssd.py rename to tests/kernels/mamba/test_mamba_ssm_ssd.py index 8f23a9b216e9..ee908105f557 100644 --- a/tests/kernels/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from einops import rearrange, repeat +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + _seq_idx_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform @@ -160,14 +162,14 @@ def end_boundary(n: int): # get the metadata cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) - sed_idx = torch.zeros(cu_seqlens[-1], + seq_idx = torch.zeros(cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device) for i, (srt, end) in enumerate(zip( cu_seqlens, cu_seqlens[1:], )): - sed_idx[srt:end] = i + seq_idx[srt:end] = i # for cont batch if IND_E is None: @@ -177,7 +179,7 @@ def end_boundary(n: int): IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], - cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @pytest.mark.parametrize("itype", @@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, exhausted: dict = {} # map: eg -> boolean indicating example is exhausted states = None - for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + for Y_min, cu_seqlens, seq_idx, (A, dt, X, B, C) in generate_continous_batched_examples( cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype): + chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( + seq_idx, chunk_size) + Y, new_states = mamba_chunk_scan_combined( X, dt, @@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, chunk_size, D=None, cu_seqlens=cu_seqlens, - seq_idx=sed_idx, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, return_varlen_states=True, initial_states=states, ) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py new file mode 100644 index 000000000000..975cd418a171 --- /dev/null +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -0,0 +1,364 @@ +# SPDX-License-Identifier: Apache-2.0 +import dataclasses +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.platforms import current_platform + +NUM_EXPERTS = [40, 64] +TOP_KS = [6, 8] + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 3072, 1536), + (224, 1024, 1024), + (224, 1024, 1536), + (224, 3072, 1024), + (224, 3072, 1536), +] + + +@dataclasses.dataclass +class MOETensors: + a: torch.Tensor + w1: torch.Tensor + w2: torch.Tensor + ab_strides1: torch.Tensor + c_strides1: torch.Tensor + ab_strides2: torch.Tensor + c_strides2: torch.Tensor + + @staticmethod + def make_moe_tensors(m: int, k: int, n: int, e: int, + dtype: torch.dtype) -> "MOETensors": + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + return MOETensors(a=a, + w1=w1, + w2=w2, + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2) + + +@dataclasses.dataclass +class MOETensors8Bit(MOETensors): + # quantized + a_q: Optional[torch.Tensor] = None # a -> a_q + w1_q: Optional[torch.Tensor] = None # w1 -> w1_q + w2_q: Optional[torch.Tensor] = None # w2 -> w2_q + a_scale: Optional[torch.Tensor] = None + w1_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + # dequantized + a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d + w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d + w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + per_act_token: bool, + per_out_channel: bool) -> "MOETensors8Bit": + dtype = torch.half + q_dtype = torch.float8_e4m3fn + + moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype) + + # a -> a_q, w1 -> w1_q, w2 -> w2_q + n_b_scales = 2 * n if per_out_channel else 1 + k_b_scales = k if per_out_channel else 1 + # Get the right scale for tests. + _, a_scale = ops.scaled_fp8_quant( + moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a, + a_scale, + use_per_token_if_dynamic=per_act_token) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + moe_tensors_fp16.w1[expert], + use_per_token_if_dynamic=per_out_channel) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + moe_tensors_fp16.w2[expert], + use_per_token_if_dynamic=per_out_channel) + + # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d + a_d = a_q.float().mul(a_scale).to(dtype) + w1_d = torch.empty_like(moe_tensors_fp16.w1) + w2_d = torch.empty_like(moe_tensors_fp16.w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() + + return MOETensors8Bit(a=moe_tensors_fp16.a, + w1=moe_tensors_fp16.w1, + w2=moe_tensors_fp16.w2, + ab_strides1=moe_tensors_fp16.ab_strides1, + c_strides1=moe_tensors_fp16.c_strides1, + ab_strides2=moe_tensors_fp16.ab_strides2, + c_strides2=moe_tensors_fp16.c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a_d=a_d, + w1_d=w1_d, + w2_d=w2_d) + + +def run_with_expert_maps(num_experts: int, num_local_experts: int, + **cutlass_moe_kwargs): + + def slice_experts(): + slice_params = [ + "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", + "c_strides2", "w1_scale", "w2_scale" + ] + full_tensors = { + k: v + for k, v in cutlass_moe_kwargs.items() + if k in slice_params and k in cutlass_moe_kwargs + } + + for i in range(0, num_experts, num_local_experts): + s, e = i, i + num_local_experts + + # make expert map + expert_map = [-1] * num_experts + expert_map[s:e] = list(range(num_local_experts)) + expert_map = torch.tensor(expert_map, + dtype=torch.int32, + device="cuda") + + # update cutlass moe arg with expert_map + cutlass_moe_kwargs["expert_map"] = expert_map + # update cutlass moe arg tensors + for k, t in full_tensors.items(): + cutlass_moe_kwargs[k] = t[s:e] + + yield cutlass_moe_kwargs + + out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) + for kwargs in slice_experts(): + out_tensor = out_tensor + cutlass_moe_fp8(**kwargs) + + return out_tensor + + +def run_8_bit(moe_tensors: MOETensors8Bit, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_local_experts: Optional[int] = None) -> torch.Tensor: + assert not any([ + t is None for t in [ + moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, + moe_tensors.w2_scale, moe_tensors.a_scale + ] + ]) + + kwargs = { + 'a': moe_tensors.a, + 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] + 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] + 'topk_weights': topk_weights, + 'topk_ids_': topk_ids, + 'ab_strides1': moe_tensors.ab_strides1, + 'c_strides1': moe_tensors.c_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides2': moe_tensors.c_strides2, + 'w1_scale': moe_tensors.w1_scale, + 'w2_scale': moe_tensors.w2_scale, + 'a1_scale': moe_tensors.a_scale + } + + num_experts = moe_tensors.w1.size(0) + with_ep = num_local_experts is not None or num_local_experts == num_experts + if not with_ep: + return cutlass_moe_fp8(**kwargs) + + assert num_local_experts is not None + return run_with_expert_maps( + num_experts, + num_local_experts, # type: ignore[arg-type] + **kwargs) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_ch) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_ch) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=9e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [64]) +@pytest.mark.parametrize("n", [1024]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [16]) +@pytest.mark.parametrize("topk", [1, 8]) +@pytest.mark.parametrize("per_act_token", [True]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_EP( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: int, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_channel) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + assert e % ep_size == 0, "Cannot distribute experts evenly" + cutlass_output = run_8_bit(mt, + topk_weights, + topk_ids, + num_local_experts=e // ep_size) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) diff --git a/tests/kernels/test_moe.py b/tests/kernels/moe/test_moe.py similarity index 72% rename from tests/kernels/test_moe.py rename to tests/kernels/moe/test_moe.py index 3f4dd3cf0e5d..425f36984a33 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,16 +11,14 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) -from vllm import _custom_ops as ops +from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, + torch_moe_single) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) + awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE @@ -287,14 +285,17 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("n", [128, 1024]) +@pytest.mark.parametrize("k", [256, 2048]) +@pytest.mark.parametrize("e", [4, 12]) +@pytest.mark.parametrize("topk", [2, 3]) +@pytest.mark.parametrize("ep_size", [1, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( @@ -303,9 +304,12 @@ def test_fused_marlin_moe( k: int, e: int, topk: int, + ep_size: int, + dtype: torch.dtype, group_size: int, act_order: bool, num_bits: int, + has_zp: bool, is_k_full: bool, ): current_platform.seed_everything(7) @@ -316,75 +320,110 @@ def test_fused_marlin_moe( return if group_size in (k, n): return + if has_zp: + return else: if not is_k_full: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - dtype = torch.float16 + if has_zp: + # we don't build kernel for int8 with zero + if num_bits == 8: + return + quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + quant_type = scalar_types.uint4b8 \ + if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + w_ref1_l = [] qweight1_l = [] scales1_l = [] + zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - w_ref1_l.append(w_ref1) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) + if has_zp: + w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + zeros1_l.append(zeros1) + else: + test_perm = torch.randperm(k) + quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) - g_idx1 = stack_and_dev(g_idx1_l) - sort_indices1 = stack_and_dev(sort_indices1_l) + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None w_ref2_l = [] qweight2_l = [] scales2_l = [] + zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - w_ref2_l.append(w_ref2) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) + if has_zp: + w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + zeros2_l.append(zeros2) + else: + test_perm = torch.randperm(n) + quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) - g_idx2 = stack_and_dev(g_idx2_l) - sort_indices2 = stack_and_dev(sort_indices2_l) + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) - triton_output = fused_moe( - a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False, - ) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, @@ -394,111 +433,91 @@ def test_fused_marlin_moe( score, topk_weights, topk_ids, + global_num_experts=e, + expert_map=e_map, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, num_bits=num_bits, - is_k_full=is_k_full, - ) - - assert compute_max_diff(marlin_output, triton_output) < 4e-2 - - if ops.supports_moe_ops: - token_expert_indicies = torch.empty(m, - topk, - dtype=torch.int32, - device=a.device) - - opcheck(torch.ops._moe_C.topk_softmax, ( - topk_weights, - topk_ids, - token_expert_indicies, - score.float(), - )) - - block_size_m = 4 - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, - e) + is_k_full=is_k_full) - max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - zp = torch.empty((0, 0), - dtype=dtype, - device="cuda", - requires_grad=False) - opcheck(torch.ops._moe_C.marlin_gemm_moe, - (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id, - m, 2 * n, k, True, e, topk, block_size_m, True, False)) + torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") -@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("m", [1, 33, 123]) +@pytest.mark.parametrize("n", [128, 1024]) +@pytest.mark.parametrize("k", [256, 2048]) +@pytest.mark.parametrize("e", [4, 12]) +@pytest.mark.parametrize("topk", [2, 3]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("group_size", [-1, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("is_k_full", [True, False]) -@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_single_marlin_moe_multiply( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, - act_order: bool, - num_bits: int, - is_k_full: bool, -): - +def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype, group_size: int, + act_order: bool, num_bits: int, + has_zp: bool, is_k_full: bool): # Filter act_order if act_order: if group_size == -1: return - if group_size == k: + if group_size in (k, n): + return + if has_zp: return else: if not is_k_full: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - dtype = torch.float16 + if has_zp: + quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + quant_type = scalar_types.uint4b8 \ + if num_bits == 4 else scalar_types.uint8b128 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w_ref_l = [] - qweights_l = [] + qweight_l = [] scales_l = [] + zeros_l = [] g_idx_l = [] sort_indices_l = [] for i in range(w.shape[0]): - test_perm = torch.randperm(k) - w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) - w_ref_l.append(w_ref) - qweights_l.append(qweight) - scales_l.append(scales) - g_idx_l.append(g_idx) - sort_indices_l.append(sort_indices) + if has_zp: + w_ref, qweight, scales, zeros = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + zeros_l.append(zeros) + else: + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweights_l).contiguous() + qweight = stack_and_dev(qweight_l).contiguous() scales = stack_and_dev(scales_l) - g_idx = stack_and_dev(g_idx_l) - sort_indices = stack_and_dev(sort_indices_l) + g_idx = stack_and_dev(g_idx_l) if g_idx_l else None + zeros = stack_and_dev(zeros_l) if zeros_l else None + sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) marlin_output = torch.ops.vllm.single_marlin_moe( @@ -510,13 +529,14 @@ def test_single_marlin_moe_multiply( renormalize=False, g_idx=g_idx, sort_indices=sort_indices, + w_zeros=zeros, num_bits=num_bits, is_k_full=is_k_full, ) - torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + torch_output = torch_moe_single(a, w_ref, score, topk) - assert compute_max_diff(marlin_output, torch_output) < 1e-2 + torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) def test_moe_align_block_size_opcheck(): diff --git a/tests/kernels/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py similarity index 100% rename from tests/kernels/test_triton_moe_ptpc_fp8.py rename to tests/kernels/moe/test_triton_moe_ptpc_fp8.py diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 498da6001ae9..764924f26783 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ref_out = (as_float32_tensor(x) * ref_iscale).clamp( fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) return ref_out, ref_scale.view((1, )) + + +def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, + As: torch.Tensor, Bs: torch.Tensor, block_size, + output_dtype): + """This function performs matrix multiplication with block-wise + quantization using native torch. + It is agnostic to the input data type and can be used for both int8 and + fp8 data types. + + It takes two input tensors `A` and `B` (int8) with scales `As` and + `Bs` (float32). + The output is returned in the specified `output_dtype`. + """ + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C diff --git a/tests/kernels/test_allspark_gemm.py b/tests/kernels/quantization/test_allspark_gemm.py similarity index 100% rename from tests/kernels/test_allspark_gemm.py rename to tests/kernels/quantization/test_allspark_gemm.py diff --git a/tests/kernels/test_aqlm.py b/tests/kernels/quantization/test_aqlm.py similarity index 100% rename from tests/kernels/test_aqlm.py rename to tests/kernels/quantization/test_aqlm.py diff --git a/tests/kernels/test_awq.py b/tests/kernels/quantization/test_awq.py similarity index 100% rename from tests/kernels/test_awq.py rename to tests/kernels/quantization/test_awq.py diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/quantization/test_awq_marlin.py similarity index 100% rename from tests/kernels/test_awq_marlin.py rename to tests/kernels/quantization/test_awq_marlin.py diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py similarity index 100% rename from tests/kernels/test_awq_triton.py rename to tests/kernels/quantization/test_awq_triton.py diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py similarity index 99% rename from tests/kernels/test_block_fp8.py rename to tests/kernels/quantization/test_block_fp8.py index c450048bf665..c57e39f42506 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -6,6 +6,7 @@ import pytest import torch +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe @@ -18,8 +19,6 @@ per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform -from .utils_block import native_w8a8_block_matmul - dg_available = False try: import deep_gemm diff --git a/tests/kernels/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py similarity index 99% rename from tests/kernels/test_block_int8.py rename to tests/kernels/quantization/test_block_int8.py index 9447f9d69165..104f23fd7cd2 100644 --- a/tests/kernels/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -6,6 +6,7 @@ import pytest import torch +from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe @@ -13,8 +14,6 @@ w8a8_block_int8_matmul) from vllm.platforms import current_platform -from .utils_block import native_w8a8_block_matmul - if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) diff --git a/tests/kernels/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py similarity index 99% rename from tests/kernels/test_cutlass_2of4_sparse.py rename to tests/kernels/quantization/test_cutlass_2of4_sparse.py index 2890e15d6cba..d67d2dbb8998 100644 --- a/tests/kernels/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -7,13 +7,12 @@ import pytest import torch +from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8 from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform -from .utils import baseline_scaled_mm, to_fp8, to_int8 - CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py similarity index 99% rename from tests/kernels/test_cutlass.py rename to tests/kernels/quantization/test_cutlass_scaled_mm.py index f11ce6f45a98..8084d9bf2c2d 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -8,13 +8,11 @@ import pytest import torch -from tests.kernels.utils import opcheck +from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8 from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils import cdiv -from .utils import baseline_scaled_mm, to_fp8, to_int8 - MNK_FACTORS = [ (1, 256, 128), (1, 16384, 1024), diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py similarity index 100% rename from tests/kernels/test_fp8_quant.py rename to tests/kernels/quantization/test_fp8_quant.py diff --git a/tests/kernels/test_ggml.py b/tests/kernels/quantization/test_ggml.py similarity index 100% rename from tests/kernels/test_ggml.py rename to tests/kernels/quantization/test_ggml.py diff --git a/tests/kernels/test_gguf.py b/tests/kernels/quantization/test_gguf.py similarity index 100% rename from tests/kernels/test_gguf.py rename to tests/kernels/quantization/test_gguf.py diff --git a/tests/kernels/test_gptq.py b/tests/kernels/quantization/test_gptq.py similarity index 100% rename from tests/kernels/test_gptq.py rename to tests/kernels/quantization/test_gptq.py diff --git a/tests/kernels/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py similarity index 100% rename from tests/kernels/test_int8_kernel.py rename to tests/kernels/quantization/test_int8_kernel.py diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py similarity index 100% rename from tests/kernels/test_int8_quant.py rename to tests/kernels/quantization/test_int8_quant.py diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py similarity index 100% rename from tests/kernels/test_machete_mm.py rename to tests/kernels/quantization/test_machete_mm.py diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py similarity index 100% rename from tests/kernels/test_marlin_gemm.py rename to tests/kernels/quantization/test_marlin_gemm.py diff --git a/tests/kernels/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py similarity index 100% rename from tests/kernels/test_nvfp4_quant.py rename to tests/kernels/quantization/test_nvfp4_quant.py diff --git a/tests/kernels/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py similarity index 100% rename from tests/kernels/test_nvfp4_scaled_mm.py rename to tests/kernels/quantization/test_nvfp4_scaled_mm.py diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py new file mode 100644 index 000000000000..622079c39445 --- /dev/null +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +N = [1, 2, 3, 4] +SEEDS = [0] + + +@pytest.mark.parametrize("n", [1]) # only test for batch size 1 +@pytest.mark.parametrize("k", K) +@pytest.mark.parametrize("m", M) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +@torch.inference_mode() +def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): + torch.manual_seed(seed) + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + + ref_out = torch.matmul(A, B.t()) + out = ops.LLMM1(B, A, rows_per_block) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n", N) # only test for batch size <= 4 +@pytest.mark.parametrize("k", K + [9216, 10240, 16384]) +@pytest.mark.parametrize("m", [8] + M) # m >= 8 +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + + ref_out = torch.matmul(A, B.t()) + out = ops.wvSplitK(B, A, cu_count) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n", N) # only test for batch size <= 4 +@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0 +@pytest.mark.parametrize("m", M + [28672]) # m >= 16 +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + + A = torch.rand(n, k, device="cuda") + B = torch.rand(m, k, device="cuda") + + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + + ref_out = torch._scaled_mm(A, + B.t(), + out_dtype=dtype, + scale_a=scale_a, + scale_b=scale_b) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, + current_platform.get_cu_count()) + + assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py similarity index 100% rename from tests/kernels/test_triton_scaled_mm.py rename to tests/kernels/quantization/test_triton_scaled_mm.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py deleted file mode 100644 index a51e70d45ee0..000000000000 --- a/tests/kernels/test_attention_selector.py +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import patch - -import pytest -import torch - -from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend -from vllm.platforms.cpu import CpuPlatform -from vllm.platforms.cuda import CudaPlatform -from vllm.platforms.rocm import RocmPlatform -from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ - _cached_get_attn_backend.cache_clear() - - -@pytest.mark.parametrize( - "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) -@pytest.mark.parametrize("use_v1", [True, False]) -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) -def test_env( - name: str, - use_v1: bool, - device: str, - monkeypatch: pytest.MonkeyPatch, -): - """Test that the attention selector can be set via environment variable. - Note that we do not test FlashAttn because it is the default backend. - """ - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - m.setenv(STR_BACKEND_ENV_VAR, name) - - if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - 16, False) - assert backend.get_name() == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - 16, False) - EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" - assert backend.get_name() == EXPECTED - else: - if name in ["XFORMERS", "FLASHINFER"]: - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): - backend = get_attn_backend(16, torch.float16, - torch.float16, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name - assert backend.get_name() == EXPECTED - - -def test_flash_attn(monkeypatch: pytest.MonkeyPatch): - """Test FlashAttn validation.""" - # TODO: When testing for v1, pipe in `use_v1` as an argument to - # get_attn_backend - - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) - - # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: - (7, 5)) - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Reset the monkeypatch for subsequent tests - monkeypatch.undo() - - # Unsupported data type - backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Unsupported kv cache data type - backend = get_attn_backend(16, torch.float16, "fp8", 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Unsupported block size - backend = get_attn_backend(16, torch.float16, None, 8, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # flash-attn is not installed - import sys - original_module = sys.modules.get('vllm_flash_attn') - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Restore the original module if it existed - if original_module is not None: - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', - original_module) - else: - monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) - - # Unsupported head size - backend = get_attn_backend(17, torch.float16, None, 16, False) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) - assert backend.get_name() != STR_FLASH_ATTN_VAL - - -@pytest.mark.parametrize("use_v1", [True, False]) -def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): - - with monkeypatch.context() as m, patch( - "vllm.attention.selector.current_platform", CudaPlatform()): - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") - m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) - - # Test with head size 32 - backend = get_attn_backend(32, torch.float16, None, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN" - assert backend.get_name() == EXPECTED - - # when block size == 16, backend will fall back to XFORMERS - # this behavior is not yet supported on V1. - if use_v1: - # TODO: support fallback on V1! - # https://github.com/vllm-project/vllm/issues/14524 - pass - else: - backend = get_attn_backend(16, torch.float16, None, 16, False) - assert backend.get_name() == "XFORMERS" diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py new file mode 100644 index 000000000000..87e4bd4b096b --- /dev/null +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +import vllm._custom_ops as ops +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Cutlass MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("block_size", [16, 64, 128]) +def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, + varlen: bool, block_size: int): + torch.set_default_dtype(dtype) + torch.set_default_device('cuda') + torch.manual_seed(42) + + d = 576 + h_q = 128 + dv = 512 + + q_nope_dim = 128 + q_pe_dim = 64 + scale = (q_nope_dim + q_pe_dim)**(-0.5) + if varlen: + seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) + seq_lens = seq_lens.clip(2).to(torch.int32) + else: + seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + max_seq_len = seq_lens.max().item() + block_num = (max_seq_len + block_size - 1) // block_size + + # Pad block_num so that small blocks can be packed into full 128-sized + # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small + # blocks. + pack_factor = 128 // block_size + block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor + + q = torch.randn(bs, h_q, d) + block_table = torch.randint(0, + bs * block_num, (bs, block_num), + dtype=torch.int32) + + kv_cache = torch.randn(block_table.numel(), block_size, d) + + out_ref = q.new_zeros(bs, h_q, dv) + ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) + out_ans = torch.zeros_like(out_ref) + q_nope = q[:, :, :dv].clone() + q_pe = q[:, :, dv:].clone() + ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, + block_table, scale) + + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py deleted file mode 100644 index 3cfed6ae8538..000000000000 --- a/tests/kernels/test_cutlass_moe.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.platforms import current_platform - -NUM_EXPERTS = [40, 64] -TOP_KS = [6, 8] - - -def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_no_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - cutlass_output = cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1) - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_cuda_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) - torch.cuda.synchronize() - graph.replay() - torch.cuda.synchronize() - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) diff --git a/tests/kernels/test_rocm_attention_selector.py b/tests/kernels/test_rocm_attention_selector.py deleted file mode 100644 index 90b483b4a41a..000000000000 --- a/tests/kernels/test_rocm_attention_selector.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch - -from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend -from vllm.platforms.rocm import RocmPlatform -from vllm.utils import STR_BACKEND_ENV_VAR - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ - _cached_get_attn_backend.cache_clear() - - -def test_selector(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") - - # Set the current platform to ROCm using monkeypatch - monkeypatch.setattr("vllm.attention.selector.current_platform", - RocmPlatform()) - - # Test standard ROCm attention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) - assert (backend.get_name() == "ROCM_FLASH" - or backend.get_name() == "TRITON_ATTN_VLLM_V1") - - # mla test for deepseek related - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) - assert backend.get_name() == "TRITON_MLA" diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py new file mode 100644 index 000000000000..cf2bdc908e42 --- /dev/null +++ b/tests/kernels/test_triton_flash_attention.py @@ -0,0 +1,499 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the triton_flash_attention kernel + +Run `pytest tests/kernels/test_triton_flash_attention.py`. +""" +import pytest +import torch + +from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS, + MetaData, + compute_alibi_tensor, + scale_fp8, + triton_attention_rocm) +from vllm.platforms import current_platform + + +class ReferenceAttention: + + def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, + input_metadata): + self.Z = Z + self.HQ = HQ + self.HK = HK + self.N_CTX_Q = N_CTX_Q + self.N_CTX_K = N_CTX_K + self.D_HEAD = D_HEAD + self.use_alibi = use_alibi + self.dtype = dtype + self.input_metadata = input_metadata + + def fwd(self, q, k, v): + scores = torch.einsum('bhqd,bhkd->bhqk', q, + k).float() * self.input_metadata.sm_scale + if self.input_metadata.causal: + mask = torch.tril(torch.ones(self.N_CTX_Q, + self.N_CTX_K, + device="cuda"), + diagonal=self.N_CTX_K - self.N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + if self.input_metadata.bias is not None: + scores += self.input_metadata.bias + + if self.use_alibi: + scores += compute_alibi_tensor(self.input_metadata.alibi_slopes, + self.N_CTX_Q, self.N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if self.input_metadata.causal: + # If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going + # into softmax. This creates a row of NaNs as -inf - -inf == NaN. + # So we fix this by converting the NaNs to 0s, which is what they + # should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v) + # compare + if self.input_metadata.layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + return ref_out + + def fwd_fp8(self, q_quantized, k_quantized, v_quantized): + q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( + self.dtype) + k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( + self.dtype) + v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( + self.dtype) + result = self.fwd(q, k, v) + if self.input_metadata.o_scale is not None: + result, _ = scale_fp8(result, self.input_metadata.o_scale) + return result + + def fwd_fp8_kv(self, q, k_quantized, v_quantized): + k_descale, v_descale = (self.input_metadata.k_descale, + self.input_metadata.v_descale) + k_dequantized = (k_quantized.to(torch.float32) * + k_descale.to(torch.float32)).to(self.dtype) + v_dequantized = (v_quantized.to(torch.float32) * + v_descale.to(torch.float32)).to(self.dtype) + return self.fwd(q, k_dequantized, v_dequantized) + + def varlen_fwd(self, q, k, v, is_mqa=False): + ref_out = torch.empty_like(q) + if is_mqa: + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so + # the size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, + k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, + v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + else: + k_ref = k + v_ref = v + + for i in range(0, self.input_metadata.num_contexts): + start_q, start_k = self.input_metadata.cu_seqlens_q[ + i], self.input_metadata.cu_seqlens_k[i] + end_q, end_k = self.input_metadata.cu_seqlens_q[ + i + 1], self.input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + v_curr = v_ref[start_k:end_k] + if is_mqa: + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], + k_curr).float() + p = torch.softmax(scores * self.input_metadata.sm_scale, + dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + return ref_out + + +def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): + q_descale = None + if not fp8_kv: + q, q_descale = scale_fp8(q) + k, k_descale = scale_fp8(k) + v, v_descale = scale_fp8(v) + + # In real world use case, the p scale would be a parameter trained by the + # model. + p_scale = None + + o_scale = torch.rand(1, device="cuda", + requires_grad=False) if use_o_scale else None + + return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale + + +def input_helper( + Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout=None, + use_alibi=None, + causal=None, + is_fp8=False, + fp8_kv=False, + use_o_scale=False, + use_bias=False, +): + assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + current_platform.seed_everything(0) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts + # 2^(-8/n) + alibi_slopes = torch.tensor( + [2**(-8 / HQ * i) for i in range(1, HQ + 1)], + dtype=torch.float32, + device="cuda").repeat(Z, 1) + else: + alibi_slopes = None + + if use_bias: + bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K), + dtype=dtype, + device="cuda", + requires_grad=False) + else: + bias = None + + q = torch.randn(q_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + k = torch.randn(k_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + v = torch.randn(k_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + + if is_fp8: + (q, k, v, q_descale, k_descale, v_descale, p_scale, + o_scale) = quantize_input(q, + k, + v, + use_o_scale=use_o_scale, + fp8_kv=fp8_kv) + else: + q_descale = k_descale = v_descale = p_scale = o_scale = None + + input_metadata = MetaData(sm_scale=D_HEAD**-0.5, + max_seqlens_q=N_CTX_Q, + max_seqlens_k=N_CTX_K, + layout=layout, + alibi_slopes=alibi_slopes, + alibi_batch=Z, + alibi_nheads=HQ, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=o_scale, + bias=bias, + seqlen_q=N_CTX_Q, + seqlen_k=N_CTX_K) + return q, k, v, input_metadata + + +def varlen_input_helper(Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + equal_seqlens=False): + current_platform.seed_everything(0) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual + # seqs + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, + max_seqlens_q + 1, (Z, ), + dtype=torch.int32) + seqlens_k = torch.randint(1, + max_seqlens_k + 1, (Z, ), + dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([ + torch.tensor([0], dtype=torch.int32), + seqlens_q.cumsum(dim=0, dtype=torch.int32) + ]) + cu_seqlens_k = torch.cat([ + torch.tensor([0], dtype=torch.int32), + seqlens_k.cumsum(dim=0, dtype=torch.int32) + ]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (1, 48, 12, 1, 1, 64), + (4, 4, 4, 128, 128, 65), + (16, 48, 48, 1, 1, 128), + (64, 48, 24, 3, 3, 128), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd']) +def test_op_fwd(Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + use_alibi, + layout, + dtype=torch.float16): + current_platform.seed_everything(0) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, + dtype, layout, use_alibi, causal) + + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all + # layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, + -1).reshape(k.shape[0], -1, k.shape[2], + k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, + -1).reshape(v.shape[0], -1, v.shape[2], + v.shape[3]) + + ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, + use_alibi, dtype, input_metadata) + ref_out = ref_impl.fwd(q, k, v) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +@pytest.mark.parametrize('use_o_scale', [True, False]) +@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0), + reason="Triton FP8 requires CUDA 9.0 or higher") +def test_op_fwd_fp8(Z, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + layout, + use_o_scale, + dtype=torch.float32): + current_platform.seed_everything(0) + + # Disable grad to save memory it won't run into OOM on CI machine. + # q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, + # dtype, layout) + + q_quantized, k_quantized, v_quantized, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + use_o_scale=use_o_scale) + + o = torch.empty_like(q_quantized) if use_o_scale else None + + tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized, + o, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) + + # compare + torch.testing.assert_close(ref_out.to(torch.float32), + tri_out.to(torch.float32), + atol=7e-2, + rtol=2e-1) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_fp8_kv(Z, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + layout, + dtype=torch.float32): + current_platform.seed_everything(0) + + q, k_quantized, v_quantized, input_metadata = input_helper(Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + fp8_kv=True) + + o = torch.empty_like(q) + + tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, + input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) + + torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +@pytest.mark.parametrize('dtype', [torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): + current_platform.seed_everything(0) + q, k, v, input_metadata = input_helper(Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout='bhsd', + causal=causal, + use_bias=use_bias) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd(q, k, v) + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +# NOTE: Uses thd layout, so also tests thd. +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64), + (4, 48, 512, 64), + (16, 48, 512, 64), + (64, 48, 128, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, + D_HEAD, dtype) + + tri_out = torch.empty_like(q) + triton_attention_rocm(q, k, v, tri_out, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, + input_metadata) + ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +# NOTE: Uses thd layout, so also tests thd. +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 64, 16, 128, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, + HQ, + HK, + N_CTX, + D_HEAD, + causal, + dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, + D_HEAD, dtype) + + tri_out = torch.empty_like(q) + triton_attention_rocm(q, k, v, tri_out, input_metadata) + + ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/test_utils.py b/tests/kernels/test_utils.py deleted file mode 100644 index d3f032002651..000000000000 --- a/tests/kernels/test_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Tests for miscellaneous utilities -""" - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm.platforms import current_platform - - -def test_convert_fp8_opcheck(): - data = torch.randn((256, 256), dtype=torch.float32, device="cuda") - result = torch.empty_like(data, dtype=torch.float8_e4m3fn) - opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) - - -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="Only supported for CUDA") -def test_cuda_utils_opcheck(): - opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) - opcheck( - torch.ops._C_cuda_utils. - get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/utils_block.py b/tests/kernels/utils_block.py deleted file mode 100644 index c16cba50967e..000000000000 --- a/tests/kernels/utils_block.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, - As: torch.Tensor, Bs: torch.Tensor, block_size, - output_dtype): - """This function performs matrix multiplication with block-wise - quantization using native torch. - It is agnostic to the input data type and can be used for both int8 and - fp8 data types. - - It takes two input tensors `A` and `B` (int8) with scales `As` and - `Bs` (float32). - The output is returned in the specified `output_dtype`. - """ - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index cdb8c893b8bc..e3a054bd6206 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -47,6 +47,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256, + skip_special_tokens=False, stop=["[/assistant]"]) outputs = llm.generate( prompts, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 576d95a47154..52b0834cacb8 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -31,6 +31,8 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] if current_platform.is_cuda_alike() else ["cpu"]) +DEFAULT_DTYPE = torch.get_default_dtype() + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch: pytest.MonkeyPatch): @@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model): model = dummy_model manager = LoRAModelManager( model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device(DEVICES[0])) + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8, + lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) @@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) @@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=2, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) assert all(x is None for x in manager.lora_index_to_id) @@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): - lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + lora_config = LoRAConfig(max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, @@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. - lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + lora_config = LoRAConfig(max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE) worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, device, @@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=2, - max_loras=2), + max_loras=2, + lora_dtype=DEFAULT_DTYPE), device=device) model = manager.model diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 0b223e5011ff..24242b8a1759 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -66,8 +66,12 @@ def test_minicpmv_lora(minicpmv_lora_files): max_loras=2, max_lora_rank=8, enforce_eager=True, + max_model_len=2048, + limit_mm_per_prompt={ + "image": 2, + "video": 0 + }, trust_remote_code=True, - enable_chunked_prefill=True, ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): @@ -91,9 +95,11 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=4, + limit_mm_per_prompt={ + "image": 2, + "video": 0 + }, trust_remote_code=True, - enforce_eager=True, - enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): @@ -115,8 +121,11 @@ def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): max_lora_rank=8, tensor_parallel_size=4, trust_remote_code=True, + limit_mm_per_prompt={ + "image": 1, + "video": 0 + }, fully_sharded_loras=True, - enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): diff --git a/tests/lora/test_resolver.py b/tests/lora/test_resolver.py new file mode 100644 index 000000000000..8ebc2ae98fc4 --- /dev/null +++ b/tests/lora/test_resolver.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import pytest + +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry + + +class DummyLoRAResolver(LoRAResolver): + """A dummy LoRA resolver for testing.""" + + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + if lora_name == "test_lora": + return LoRARequest( + lora_name=lora_name, + lora_path=f"/dummy/path/{base_model_name}/{lora_name}", + lora_int_id=abs(hash(lora_name))) + return None + + +def test_resolver_registry_registration(): + """Test basic resolver registration functionality.""" + registry = LoRAResolverRegistry + resolver = DummyLoRAResolver() + + # Register a new resolver + registry.register_resolver("dummy", resolver) + assert "dummy" in registry.get_supported_resolvers() + + # Get registered resolver + retrieved_resolver = registry.get_resolver("dummy") + assert retrieved_resolver is resolver + + +def test_resolver_registry_duplicate_registration(): + """Test registering a resolver with an existing name.""" + registry = LoRAResolverRegistry + resolver1 = DummyLoRAResolver() + resolver2 = DummyLoRAResolver() + + registry.register_resolver("dummy", resolver1) + registry.register_resolver("dummy", resolver2) + + assert registry.get_resolver("dummy") is resolver2 + + +def test_resolver_registry_unknown_resolver(): + """Test getting a non-existent resolver.""" + registry = LoRAResolverRegistry + + with pytest.raises(KeyError, match="not found"): + registry.get_resolver("unknown_resolver") + + +@pytest.mark.asyncio +async def test_dummy_resolver_resolve(): + """Test the dummy resolver's resolve functionality.""" + dummy_resolver = DummyLoRAResolver() + base_model_name = "base_model_test" + lora_name = "test_lora" + + # Test successful resolution + result = await dummy_resolver.resolve_lora(base_model_name, lora_name) + assert isinstance(result, LoRARequest) + assert result.lora_name == lora_name + assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}" + + # Test failed resolution + result = await dummy_resolver.resolve_lora(base_model_name, + "nonexistent_lora") + assert result is None diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index d605ab734688..8845eb33d207 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -5,17 +5,14 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import get_lora_tokenizer -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group - -from ..conftest import get_tokenizer_pool_config +from vllm.transformers_utils.tokenizer_group import TokenizerGroup @pytest.mark.asyncio @pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=True, max_num_seqs=1, @@ -60,8 +57,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path): @pytest.mark.parametrize("max_num_seqs", [1, 2]) @pytest.mark.parametrize("max_loras", [1, 2]) def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(None), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=enable_lora, max_num_seqs=max_num_seqs, diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 34a26e9edf36..67f3866beff5 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -9,7 +9,6 @@ from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) -from vllm.utils import LRUCache def test_parse_fine_tuned_lora_name_valid(): @@ -40,6 +39,18 @@ def test_parse_fine_tuned_lora_name_valid(): False, False, ), + ( + "language_model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.layers.9.mlp.down_proj", + True, + False, + ), + ( + "language_model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.layers.9.mlp.down_proj", + False, + False, + ), } for name, module_name, is_lora_a, is_bias in fixture: assert (module_name, is_lora_a, @@ -85,114 +96,6 @@ def test_replace_submodule(): assert dict(model.named_modules())["seq1.dense2"] == dense2 -class TestLRUCache(LRUCache): - - def _on_remove(self, key, value): - if not hasattr(self, "_remove_counter"): - self._remove_counter = 0 - self._remove_counter += 1 - - -def test_lru_cache(): - cache = TestLRUCache(3) - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(2, 2) - assert len(cache) == 2 - - cache.put(3, 3) - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache.put(4, 4) - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache.get(2) == 2 - - cache.put(5, 5) - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - assert cache.pop(5) == 5 - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.get(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.put(6, 6) - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - cache.remove_oldest() - assert len(cache) == 2 - assert set(cache.cache) == {2, 6} - assert cache._remove_counter == 4 - - cache.clear() - assert len(cache) == 0 - assert cache._remove_counter == 6 - - cache._remove_counter = 0 - - cache[1] = 1 - assert len(cache) == 1 - - cache[1] = 1 - assert len(cache) == 1 - - cache[2] = 2 - assert len(cache) == 2 - - cache[3] = 3 - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache[4] = 4 - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache[2] == 2 - - cache[5] = 5 - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - del cache[5] - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache[6] = 6 - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - # Unit tests for get_adapter_absolute_path @patch('os.path.isabs') def test_get_adapter_absolute_path_absolute(mock_isabs): diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index ac2e0f3542e7..2d9cf1d48fd5 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -11,6 +11,8 @@ dispatch_fused_experts_func, dispatch_topk_func, torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, vllm_topk_softmax) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) @@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) topk_func = dispatch_topk_func() - + is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_topk_softmax) - assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + is_rocm_aiter_moe_enabled.cache_clear() fused_experts_func = dispatch_fused_experts_func(inplace) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts) - assert fused_experts_func == rocm_aiter_fused_experts elif inplace: assert fused_experts_func == torch_vllm_inplace_fused_experts diff --git a/tests/models/decoder_only/audio_language/test_granite_speech.py b/tests/models/decoder_only/audio_language/test_granite_speech.py new file mode 100644 index 000000000000..7c14845ec54d --- /dev/null +++ b/tests/models/decoder_only/audio_language/test_granite_speech.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from typing import Optional + +import pytest +from transformers import AutoModelForSpeechSeq2Seq + +from vllm.lora.request import LoRARequest +from vllm.sequence import SampleLogprobs + +from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets +from ...registry import HF_EXAMPLE_MODELS +from ...utils import check_logprobs_close + +HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501 + + +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], +) -> tuple[list[int], str, Optional[SampleLogprobs]]: + """Sanitize hf output to be comparable with vllm output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|end_of_text|>" + + return output_ids, hf_output_str, out_logprobs + + +MODEL_NAME = "ibm-granite/granite-speech-3.3-8b" +# Audio lora co-exists directly in the model directory, but +# currently still needs to be passed directly to vLLM. +audio_lora_path = MODEL_NAME +models = [MODEL_NAME] + + +def run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + inputs: Sequence[tuple[list[str], PromptAudioInput]], + model: str, + *, + max_model_len: int, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the audio fixtures for the test are from AUDIO_ASSETS. + For huggingface runner, we provide the audio as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + # max_model_len should be greater than image_feature_size + with vllm_runner( + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"audio": 1}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=64, + enforce_eager=True, + ) as vllm_model: + lora_request = LoRARequest("audio", 1, audio_lora_path) + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + lora_request=lora_request) + for prompts, audios in inputs + ] + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: + + hf_processor = hf_model.processor + eos_token_id = hf_processor.tokenizer.eos_token_id + + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=[audios], + eos_token_id=eos_token_id) + for prompts, audios in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(output) for output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_model_len", [2048]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets, + dtype: str, max_model_len: int, max_tokens: int, + num_logprobs: int) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + + audio, sr = audio_assets[0].audio_and_sample_rate + # This model expects 16k sample rate, which our test audio + # already is; if this changes, it may break this test, + # so we check it directly + assert sr == 16000 + run_test( + hf_runner, + vllm_runner, + [ + ([HF_AUDIO_PROMPT], [audio]), + ], + model, + dtype=dtype, + max_model_len=max_model_len, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index a843e41aa26e..1d7de946a3f8 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -1,16 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +import json +from typing import Any, Optional import numpy as np import pytest import pytest_asyncio from transformers import AutoModel, AutoTokenizer -from vllm.multimodal.audio import resample_audio +from vllm.multimodal.audio import resample_audio_librosa from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, VllmRunner +from ....conftest import HfRunner, VllmRunner, _AudioAssets from ....utils import RemoteOpenAIServer from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -30,31 +31,34 @@ } -@pytest.fixture(scope="session") -def audio_assets(): - from vllm.assets.audio import AudioAsset - return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] - - @pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call")) def audio(request): from vllm.assets.audio import AudioAsset return AudioAsset(request.param) +def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: + """Convert kwargs to CLI args.""" + args = [] + for key, value in params_kwargs.items(): + if isinstance(value, bool): + if value: + args.append(f"--{key.replace('_','-')}") + else: + args.append(f"--{key.replace('_','-')}={value}") + return args + + @pytest.fixture(params=[ pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def server(request, audio_assets): +def server(request, audio_assets: _AudioAssets): args = [ - "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", - f"--limit-mm-per-prompt=audio={len(audio_assets)}", - "--trust-remote-code" - ] + [ - f"--{key.replace('_','-')}={value}" - for key, value in request.param.items() - ] + "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" + ] + params_kwargs_to_cli_args(request.param) with RemoteOpenAIServer(MODEL_NAME, args, @@ -135,9 +139,9 @@ def run_test( [hf_prompt], max_tokens, num_logprobs=num_logprobs, - audios=[(resample_audio(audio[0], - orig_sr=audio[1], - target_sr=16000), 16000)]) + audios=[(resample_audio_librosa(audio[0], + orig_sr=audio[1], + target_sr=16000), 16000)]) for _, hf_prompt, audio in prompts_and_audios ] @@ -220,8 +224,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), ]) -def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, - max_tokens: int, num_logprobs: int, +def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets, + dtype: str, max_tokens: int, + num_logprobs: int, vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(len(audio_assets), @@ -240,7 +245,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, @pytest.mark.asyncio -async def test_online_serving(client, audio_assets): +async def test_online_serving(client, audio_assets: _AudioAssets): """Exercises online serving with/without chunked prefill enabled.""" messages = [{ diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 60eb3830c6d8..5931c25b8d80 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -6,75 +6,84 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams -from ...utils import check_outputs_equal - -# This test is for the hybrid models -MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"] -# Bamba at Fp32 is too big for the CI (L4 GPU). -# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) +from ...utils import check_logprobs_close, check_outputs_equal + +# NOTE: The first model in each list is taken as the primary model, +# meaning that it will be used in all tests in this file +# The rest of the models will only be tested by test_models + +SSM_MODELS = [ + "state-spaces/mamba-130m-hf", + "tiiuae/falcon-mamba-tiny-dev", + # TODO: Compare to a Mamba2 model. The HF transformers implementation of + # Mamba2 is buggy for Codestral as it doesn't handle n_groups. + # See https://github.com/huggingface/transformers/pull/35943 + # "mistralai/Mamba-Codestral-7B-v0.1", +] + +HYBRID_MODELS = [ + "ai21labs/Jamba-tiny-dev", + # NOTE: Running Plamo2 in transformers implementation requires to install + # causal-conv1d package, which is not listed as a test dependency as it's + # not compatible with pip-compile. + "pfnet/plamo-2-1b", + "Zyphra/Zamba2-1.2B-instruct", + "ibm-ai-platform/Bamba-9B", +] + +# Avoid OOM +MAX_NUM_SEQS = 4 + + +@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: + with hf_runner(model) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) - # numeric error produces different generation - if "Bamba" in model: - example_prompts.pop(3) - - model_kwargs = { - "use_mamba_kernels": False, # mamba kernels are not installed so HF - # don't use them - } - if "Zamba2" in model: - # Zamba2 HF implementation automatically checks if mamba kernels are - # installed - model_kwargs = {} + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_batching( vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: - # To pass the small model tests, we need full precision. for_loop_outputs = [] - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) + single_output, = vllm_model.generate_greedy_logprobs([prompt], + max_tokens, + num_logprobs) + for_loop_outputs.append(single_output) - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + batched_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - check_outputs_equal( + check_logprobs_close( outputs_0_lst=for_loop_outputs, outputs_1_lst=batched_outputs, name_0="for_loop_vllm", @@ -82,74 +91,35 @@ def test_batching( ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_mamba_prefill_chunking_with_parallel_sampling( - hf_runner, vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int) -> None: - # Tests prefill chunking in conjunction with n>1, in this case, - # prefill is populated with decoding tokens and we test that it - # doesn't fail This test might fail if cache is not allocated - # correctly for n > 1 decoding steps inside a - # chunked prefill forward pass (where we have both prefills - # and decoding together ) - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) - with vllm_runner( - model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=30, - max_num_seqs=10 # forces prefill chunks with decoding - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [7]) -def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int) -> None: - # numeric error during prefill chunking produces different generation - # compared to w/o prefill chunking for those examples, removed them for now - if "Jamba" in model: - example_prompts.pop(7) - example_prompts.pop(2) - example_prompts.pop(1) - elif "Bamba" in model: - example_prompts.pop(6) - example_prompts.pop(3) - example_prompts.pop(2) - dtype = "half" # use a different dtype for Bamba - elif "Zamba2" in model: - example_prompts.pop(7) - dtype = "half" - - model_kwargs = { - "use_mamba_kernels": False, # mamba kernels are not installed so HF - # don't use them - } - if "Zamba2" in model: - # Zamba2 HF implementation automatically checks if mamba kernels are - # installed - model_kwargs = {} - - with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: - non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +def test_chunked_prefill( + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + num_logprobs: int, + chunked_prefill_token_size: int, +) -> None: + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size with vllm_runner(model, - dtype=dtype, enable_chunked_prefill=True, - max_num_batched_tokens=5, - max_num_seqs=2) as vllm_model: - chunked = vllm_model.generate_greedy(example_prompts, - max_tokens=max_tokens) + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + chunked = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, num_logprobs) - check_outputs_equal( + with vllm_runner(model, + enable_chunked_prefill=False, + max_num_seqs=max_num_seqs) as vllm_model: + non_chunked = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( outputs_0_lst=chunked, outputs_1_lst=non_chunked, name_0="chunked", @@ -157,64 +127,59 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [15]) -def test_parallel_sampling( +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_chunked_prefill_with_parallel_sampling( vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, ) -> None: - - with vllm_runner(model, dtype=dtype) as vllm_model: - for_loop_outputs = [] - for _ in range(10): - for_loop_outputs.append( - # using example_prompts index 1 instead of 0 since with 0 the - # logprobs get really close and the test doesn't pass - vllm_model.generate_greedy([example_prompts[1]], max_tokens) - [0]) - sampling_params = SamplingParams(n=10, - temperature=0.001, - seed=0, - max_tokens=max_tokens) - n_lt_1_outputs = vllm_model.generate([example_prompts[1]], - sampling_params) - token_ids, texts = n_lt_1_outputs[0] - n_lt_1_outputs = [(token_id, text) - for token_id, text in zip(token_ids, texts)] - - check_outputs_equal( - outputs_0_lst=n_lt_1_outputs, - outputs_1_lst=for_loop_outputs, - name_0="vllm_n_lt_1_outputs", - name_1="vllm", - ) + """ + Tests chunked prefill in conjunction with n > 1. + + In this case, prefill is populated with decoding tokens and + we test that it doesn't fail. + + This test might fail if cache is not allocated correctly for n > 1 + decoding steps inside a chunked prefill forward pass + (where we have both prefill and decode together) + """ + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + enable_chunked_prefill=True, + # forces prefill chunks with decoding + max_num_batched_tokens=MAX_NUM_SEQS * 3, + max_num_seqs=MAX_NUM_SEQS, + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) -@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_mamba_cache_cg_padding( vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, ) -> None: - # This test is for verifying that mamba cache is padded to CG captured - # batch size. If it's not, a torch RuntimeError will be raised because - # tensor dimensions aren't compatible - vllm_config = EngineArgs(model=model).create_engine_config() + """ + This test is for verifying that mamba cache is padded to CG captured + batch size. If it's not, a torch RuntimeError will be raised because + tensor dimensions aren't compatible. + """ + vllm_config = EngineArgs(model=model, + trust_remote_code=True).create_engine_config() while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) try: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) except RuntimeError: pytest.fail( @@ -223,28 +188,24 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_models_preemption_recompute( - hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, ) -> None: - # Tests that outputs are identical with and w/o preemtions (recompute) - assert dtype == "float" - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = True + """ + Tests that outputs are identical with and w/o preemptions (recompute). + """ + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + scheduler = vllm_model.model.llm_engine.scheduler[0] + scheduler.ENABLE_ARTIFICIAL_PREEMPT = True preempt_vllm_outputs = vllm_model.generate_greedy( example_prompts, max_tokens) - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = False + scheduler.ENABLE_ARTIFICIAL_PREEMPT = False vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( @@ -255,40 +216,43 @@ def test_models_preemption_recompute( ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( vllm_runner, - model: str, - dtype: str, example_prompts, + model: str, ) -> None: - # This test is for verifying that the hybrid inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum mamba block capacity. - # This could generally happen due to the fact that hybrid does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. + """ + This test is for verifying that the hybrid inner state management doesn't + collapse in case where the number of incoming requests and + finished_requests_ids is larger than the maximum mamba block capacity. + + This could generally happen due to the fact that hybrid does support + statelessness mechanism where it can cleanup new incoming requests in + a single step. + """ try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: pytest.fail("Hybrid inner state wasn't cleaned up properly between" "steps finished requests registered unnecessarily ") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_state_cleanup( vllm_runner, - model: str, - dtype: str, example_prompts, + model: str, ) -> None: - # This test is for verifying that the Hybrid state is cleaned up between - # steps, If its not cleaned, an error would be expected. + """ + This test is for verifying that the Hybrid state is cleaned up between + steps. + + If its not cleaned, an error would be expected. + """ try: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: @@ -296,28 +260,14 @@ def test_state_cleanup( "could be related to finished_requests_ids") -@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_multistep( +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness( vllm_runner, - model: str, - dtype: str, example_prompts, + model: str, + max_tokens: int, ) -> None: - # This test is verifying that multistep works correctly - #on mamba-like models - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 10, 1) - - -@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_multistep_correctness(vllm_runner, model: str, dtype: str, - max_tokens: int, example_prompts) -> None: with vllm_runner(model, num_scheduler_steps=8, max_num_seqs=2) as vllm_model: vllm_outputs_multistep = vllm_model.generate_greedy( @@ -337,18 +287,21 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) def test_hybrid_distributed_produces_identical_generation( - vllm_runner, model: str, dtype: str, max_tokens: int, - example_prompts) -> None: - - with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_runner, + example_prompts, + model: str, + max_tokens: int, +) -> None: + with vllm_runner(model, tensor_parallel_size=2, + max_num_seqs=2) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + with vllm_runner(model, tensor_parallel_size=1, + max_num_seqs=2) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py deleted file mode 100644 index 47b9c0f69c36..000000000000 --- a/tests/models/decoder_only/language/test_mamba.py +++ /dev/null @@ -1,337 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. - -Run `pytest tests/models/test_mamba.py`. -""" -import pytest -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from vllm.engine.arg_utils import EngineArgs -from vllm.sampling_params import SamplingParams - -from ...utils import check_outputs_equal - -MODELS = [ - "state-spaces/mamba-130m-hf", - "tiiuae/falcon-mamba-tiny-dev", - # TODO: Compare to a Mamba2 model. The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups. - # See https://github.com/huggingface/transformers/pull/35943 - # "mistralai/Mamba-Codestral-7B-v0.1", -] - - -# Use lower-level interfaces to create this greedy generator, as mamba will -# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. -def generate_greedy(model_name, example_prompts, max_tokens): - # Create a text generation pipeline - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name) - - # Set the device (GPU if available, else CPU) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - # Generate texts from the prompts - outputs = [] - for prompt in example_prompts: - # Tokenize the input prompt with truncation - inputs = tokenizer(prompt, return_tensors="pt", truncation=True) - input_ids = inputs["input_ids"].to(model.device) - - # Generate text using the model's generate method directly - generated_ids = model.generate(input_ids, - max_new_tokens=max_tokens, - do_sample=False) - generated_text = tokenizer.decode(generated_ids[0], - skip_special_tokens=True) - - outputs.append((generated_ids[0].tolist(), generated_text)) - - return outputs - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_models( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - hf_outputs = generate_greedy(model, example_prompts, max_tokens) - - # Set max_num_seqs to keep Codestral from going OOM at fp32 - with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_batching( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # To pass the small model tests, we need full precision. - for_loop_outputs = [] - with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: - for prompt in example_prompts: - for_loop_outputs.append( - vllm_model.generate_greedy([prompt], max_tokens)[0]) - - batched_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=for_loop_outputs, - outputs_1_lst=batched_outputs, - name_0="for_loop_vllm", - name_1="batched_vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int) -> None: - # Tests chunked prefill in conjunction with n>1. In this case, prefill is - # populated with decoding tokens and we test that it doesn't fail. - # This test might fail if cache is not allocated correctly for n > 1 - # decoding steps inside a chunked prefill forward pass (where we have both - # prefill and decode together ) - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) - with vllm_runner( - model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=30, - max_num_seqs=10 # forces prefill chunks with decoding - ) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, - chunked_prefill_token_size: int) -> None: - """ - Checks exact match decode between huggingface model and vllm runner with - chunked prefill. - """ - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - non_chunked = generate_greedy(model, example_prompts, max_tokens) - - with vllm_runner(model, - dtype=dtype, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy(example_prompts, - max_tokens=max_tokens) - - check_outputs_equal( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [15]) -def test_parallel_sampling( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - - # Numerical differences produce slightly different output for these - if 'state-spaces' in model: - example_prompts.pop(0) - example_prompts.pop(0) - example_prompts.pop(0) - - with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: - for_loop_outputs = [] - for _ in range(10): - for_loop_outputs.append( - vllm_model.generate_greedy(example_prompts, max_tokens)[0]) - sampling_params = SamplingParams(n=10, - temperature=0.001, - seed=0, - max_tokens=max_tokens) - n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params) - token_ids, texts = n_lt_1_outputs[0] - n_lt_1_outputs = [(token_id, text) - for token_id, text in zip(token_ids, texts)] - - check_outputs_equal( - outputs_0_lst=n_lt_1_outputs, - outputs_1_lst=for_loop_outputs, - name_0="vllm_n_lt_1_outputs", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_mamba_cache_cg_padding( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # This test is for verifying that mamba cache is padded to CG captured - # batch size. If it's not, a torch RuntimeError will be raised because - # tensor dimensions aren't compatible - vllm_config = EngineArgs(model=model).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph( - len(example_prompts)): - example_prompts.append(example_prompts[0]) - - try: - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) - except RuntimeError: - pytest.fail( - "Couldn't run batch size which is not equal to a Cuda Graph " - "captured batch size. " - "Could be related to mamba cache not padded correctly") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_models_preemption_recompute( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # Tests that outputs are identical with and w/o preemtions (recompute) - assert dtype == "float" - - with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - vllm_model.model.llm_engine.scheduler[ - 0].ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Mamba inner state management doesn't - # collapse in case where the number of incoming requests and - # finished_requests_ids is larger than the maximum Mamba block capacity. - # This could generally happen due to the fact that Mamba does support - # statelessness mechanism where it can cleanup new incoming requests in - # a single step. - try: - with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 100, 10) - except ValueError: - pytest.fail("Mamba inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_state_cleanup( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - # This test is for verifying that the Mamba state is cleaned up between - # steps, If its not cleaned, an error would be expected. - try: - with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model: - for _ in range(10): - vllm_model.generate_greedy([example_prompts[0]] * 100, 1) - except ValueError: - pytest.fail("Mamba inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -def test_multistep( - vllm_runner, - model: str, - dtype: str, - example_prompts, -) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_model.generate_greedy([example_prompts[0]] * 10, 1) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -def test_multistep_correctness(vllm_runner, model: str, dtype: str, - max_tokens: int, example_prompts) -> None: - with vllm_runner(model, num_scheduler_steps=8, - max_num_seqs=2) as vllm_model: - vllm_outputs_multistep = vllm_model.generate_greedy( - example_prompts, max_tokens) - - with vllm_runner(model, num_scheduler_steps=1, - max_num_seqs=2) as vllm_model: - vllm_outputs_single_step = vllm_model.generate_greedy( - example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs_multistep, - outputs_1_lst=vllm_outputs_single_step, - name_0="vllm_outputs_multistep", - name_1="vllm_outputs_single_step", - ) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index ec885386dd94..79778072cc8b 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -10,8 +10,8 @@ import jsonschema.exceptions import pytest -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa - MistralToolParser) +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall, MistralToolParser) from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ) -@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, assert "�" not in outputs[0].outputs[0].text.strip() -@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") +@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("model", - MISTRAL_FORMAT_MODELS) # v1 can't do func calling def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, @@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: parsed_message = tool_parser.extract_tool_calls(model_output, None) assert parsed_message.tools_called - assert parsed_message.tool_calls[0].id == "0UAqFzWsD" + + assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id) assert parsed_message.tool_calls[ 0].function.name == "get_current_weather" assert parsed_message.tool_calls[ @@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding(vllm_runner, model: str, - guided_backend: str) -> None: - with vllm_runner(model, dtype='bfloat16', - tokenizer_mode="mistral") as vllm_model: +def test_mistral_guided_decoding( + monkeypatch: pytest.MonkeyPatch, + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with monkeypatch.context() as m: + # Guided JSON not supported in xgrammar + V1 yet + m.setenv("VLLM_USE_V1", "0") - guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, - backend=guided_backend) - params = SamplingParams(max_tokens=512, - temperature=0.7, - guided_decoding=guided_decoding) - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {SAMPLE_JSON_SCHEMA}" - }] - outputs = vllm_model.model.chat(messages, sampling_params=params) + with vllm_runner( + model, + dtype='bfloat16', + tokenizer_mode="mistral", + guided_decoding_backend=guided_backend, + ) as vllm_model: + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) generated_text = outputs[0].outputs[0].text json_response = json.loads(generated_text) diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 79fa3fa99773..d35d87459cd9 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -9,6 +9,8 @@ from vllm.platforms import current_platform +from ....utils import large_gpu_mark +from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not @@ -25,7 +27,7 @@ AITER_MODEL_LIST = [ "meta-llama/Llama-3.2-1B-Instruct", "openbmb/MiniCPM3-4B", - "Qwen/Qwen-7B", + "Qwen/Qwen-7B-Chat", "Qwen/Qwen2.5-0.5B-Instruct", "ehristoforu/Falcon3-MoE-2x7B-Insruct", ] @@ -60,7 +62,8 @@ pytest.param( "openbmb/MiniCPM3-4B", # fused_moe not supported on CPU - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, + large_gpu_mark(min_gb=32)], ), pytest.param( "facebook/opt-125m", # opt @@ -71,7 +74,7 @@ marks=[pytest.mark.core_model], ), pytest.param( - "Qwen/Qwen-7B", # qwen (text-only) + "Qwen/Qwen-7B-Chat", # qwen (text-only) ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 @@ -81,17 +84,21 @@ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "ehristoforu/Falcon3-MoE-2x7B-Insruct", # mixtral - marks=[pytest.mark.cpu_model], + marks=[pytest.mark.cpu_model, + large_gpu_mark(min_gb=48)], ) ]) -@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: + max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, + monkeypatch) -> None: + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") if model in REQUIRES_V0: monkeypatch.setenv("VLLM_USE_V1", "0") @@ -105,15 +112,17 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - with hf_runner(model, dtype=dtype) as hf_model: - if model.startswith("THUDM/chatglm3"): - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.transformer.output_layer - + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner( + model, + tokenizer_name=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + max_num_seqs=2, + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 5bd10544d81b..9985cb579e10 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -139,6 +139,23 @@ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "qwen2_5_omni": VLMTestInfo( + models=["Qwen/Qwen2.5-Omni-7B"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -318,6 +335,18 @@ use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "kimi_vl": VLMTestInfo( + models=["moonshotai/Kimi-VL-A3B-Instruct"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 + img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501 + max_model_len=8192, + max_num_seqs=2, + dtype="bfloat16", + tensor_parallel_size=1, + vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output, + marks=[large_gpu_mark(min_gb=48)], + ), "llama4": VLMTestInfo( models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 diff --git a/tests/models/decoder_only/vision_language/test_phi4mm.py b/tests/models/decoder_only/vision_language/test_phi4mm.py index 3cd830015076..11460a1a8d2b 100644 --- a/tests/models/decoder_only/vision_language/test_phi4mm.py +++ b/tests/models/decoder_only/vision_language/test_phi4mm.py @@ -181,7 +181,7 @@ def patch_hf_processor(*args, ], ) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [4096]) +@pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ], ) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3520345c9679..49305332726e 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs +def kimiv_vl_vllm_to_hf_output( + vllm_output: RunnerOutput, + model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|im_end|>[EOS]" + + return output_ids, hf_output_str, out_logprobs + + def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: config = AutoConfig.from_pretrained(model) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index d6bf7d270639..87a1dde9381f 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch): def server_embedding(): # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") def server_generate(): args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest_asyncio.fixture -async def client_embedding(monkeypatch: pytest.MonkeyPatch, - server_embedding: RemoteOpenAIServer): - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") - async with server_embedding.get_async_client() as async_client: - yield async_client +async def client_embedding(server_embedding: RemoteOpenAIServer): + async with server_embedding.get_async_client() as async_client: + yield async_client @pytest_asyncio.fixture diff --git a/tests/models/embedding/language/test_jina.py b/tests/models/embedding/language/test_jina.py index 881d0a75b158..1e234368f3b3 100644 --- a/tests/models/embedding/language/test_jina.py +++ b/tests/models/embedding/language/test_jina.py @@ -153,14 +153,24 @@ def test_matryoshka( with vllm_runner(model, task="embed", dtype=dtype, max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.encode( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) - - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) + matryoshka_dimensions = ( + vllm_model.model.llm_engine.model_config.matryoshka_dimensions) + assert matryoshka_dimensions is not None + + if dimensions not in matryoshka_dimensions: + with pytest.raises(ValueError): + vllm_model.encode( + example_prompts, + pooling_params=PoolingParams(dimensions=dimensions)) + else: + vllm_outputs = vllm_model.encode( + example_prompts, + pooling_params=PoolingParams(dimensions=dimensions)) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/language/test_snowflake_arctic_embed.py b/tests/models/embedding/language/test_snowflake_arctic_embed.py new file mode 100644 index 000000000000..2b884fceec80 --- /dev/null +++ b/tests/models/embedding/language/test_snowflake_arctic_embed.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the embedding outputs of HF and vLLM models. + +Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`. +""" +import pytest + +from tests.models.embedding.utils import EmbedModelInfo + +from ..utils import check_embeddings_close + +EMBEDDING_PROMPTS = [ + 'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!', + 'Mexico City of Course!' +] + +MODELS = [ + EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + enable_test=True), + EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + enable_test=True), +] + + +@pytest.mark.parametrize("model_info", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model_info: EmbedModelInfo, + dtype: str, + monkeypatch, +) -> None: + if not model_info.enable_test: + # A model family has many models with the same architecture, + # and we don't need to test each one. + pytest.skip("Skipping test.") + + example_prompts = example_prompts + EMBEDDING_PROMPTS + + vllm_extra_kwargs = { + "hf_overrides": { + "is_matryoshka": model_info.is_matryoshka + } + } + + with hf_runner(model_info.name, dtype=dtype, + is_sentence_transformer=True) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model_info.name, + task="embed", + dtype=dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + + assert (vllm_model.model.llm_engine.model_config.is_matryoshka == + model_info.is_matryoshka) + + if model_info.architecture: + assert (model_info.architecture + in vllm_model.model.llm_engine.model_config.architectures) + + vllm_outputs = vllm_model.encode(example_prompts) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index 5aeeb5178540..6d4df2c265c4 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Sequence +from typing import NamedTuple, Optional import torch import torch.nn.functional as F @@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions): tensor = tensor[..., :dimensions] tensor = F.normalize(tensor, p=2, dim=1) return tensor + + +class EmbedModelInfo(NamedTuple): + name: str + is_matryoshka: bool + matryoshka_dimensions: Optional[list[int]] = None + architecture: str = "" + enable_test: bool = True + + +def correctness_test(hf_model, + inputs, + vllm_outputs: Sequence[list[float]], + dimensions: Optional[int] = None): + + hf_outputs = hf_model.encode(inputs) + if dimensions: + hf_outputs = matryoshka_fy(hf_outputs, dimensions) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index a6ec333e2e9b..14b64393bf52 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -13,12 +13,12 @@ from ...utils import check_logprobs_close MODELS = ["microsoft/Florence-2-base"] -# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer -# Therefore, we borrow the BartTokenizer from the original Bart model -TOKENIZER = "facebook/bart-base" +# Florence-2 model repo's tokenizer config is missing some special tokens. +# Therefore, we use a converted tokenizer from a forked repo +TOKENIZER = "Isotr0py/Florence-2-tokenizer" HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - "", # special task token + "", # special task token which will output special tokens "cherry_blossom": "Describe in detail what is shown in the image.", }) @@ -45,7 +45,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str, output_ids, output_str, out_logprobs = hf_output output_str = output_str.replace("", "").replace("", "") - output_ids = [ids for ids in output_ids if ids not in [0, 2]] return output_ids, output_str, out_logprobs @@ -71,8 +70,11 @@ def run_test( enforce_eager=True) as vllm_model: vllm_outputs_per_case = [ vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs=num_logprobs) - for prompts in inputs + prompts, + max_tokens, + num_logprobs=num_logprobs, + skip_special_tokens=False, + ) for prompts in inputs ] hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] @@ -93,6 +95,7 @@ def run_test( outputs_1_lst=vllm_outputs, name_0="hf", name_1="vllm", + num_outputs_0_skip_tokens=1, ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 8ec7d0887bd4..5a4215a70d24 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -254,10 +254,12 @@ def _test_processing_correctness_mistral( "adept/fuyu-8b", "google/gemma-3-4b-it", "THUDM/glm-4v-9b", + "ibm-granite/granite-speech-3.3-8b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "moonshotai/Kimi-VL-A3B-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", @@ -273,12 +275,14 @@ def _test_processing_correctness_mistral( "nvidia/NVLM-D-72B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", + "microsoft/Phi-4-multimodal-instruct", "mistralai/Pixtral-12B-2409", "mistral-community/pixtral-12b", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2.5-Omni-7B", "Skywork/Skywork-R1V-38B", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py new file mode 100644 index 000000000000..797986adba4a --- /dev/null +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for phi4mm's multimodal preprocessing kwargs.""" +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ....conftest import _ImageAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"]) +# yapf: disable +@pytest.mark.parametrize( + ("mm_processor_kwargs", "expected_toks_per_img"), + [ + ({"dynamic_hd": 4}, 1329), + ({"dynamic_hd": 16}, 4433), + # the default num_crops of phi-4-multimodal is 36 + ({}, 9585), + ]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) +def test_processor_override( + image_assets: _ImageAssets, + model_id: str, + mm_processor_kwargs: dict[str, int], + expected_toks_per_img: int, + num_imgs: int, + kwargs_on_init: bool, +): + """Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly.""" + # Avoid initializing CUDA early + from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID + + ctx = build_model_context( + model_id, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, + limit_mm_per_prompt={"image": num_imgs}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs + + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + + image_size = ctx.get_hf_config( + ).embd_layer["image_embd_layer"]["crop_size"] + dummy_image_size = (image_size * 7, image_size * 7) + dummy_image = image_assets[0].pil_image.resize(dummy_image_size) + mm_data = {"image": [dummy_image] * num_imgs} + + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = processed_inputs["prompt_token_ids"].count( + _IMAGE_PLACEHOLDER_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/registry.py b/tests/models/registry.py index 896b6c3bf47b..a08924639b17 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -121,9 +121,11 @@ def check_available_online( "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), - "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), + "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", + {"1b": "bigscience/bloomz-1b1"}), "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", - trust_remote_code=True), + trust_remote_code=True, + max_transformers_version="4.48"), "ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501 trust_remote_code=True), "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", @@ -141,24 +143,26 @@ def check_available_online( "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), + "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), - "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", - min_transformers_version="4.50"), + "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo( - "THUDM/GLM-4-32B-Chat-0414", + "THUDM/GLM-4-32B-0414", is_available_online=False, min_transformers_version="4.52.dev0" ), - "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), - "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), - "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), - "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), + "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", + {"alias": "gpt2"}), + "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", + {"tiny": "bigcode/tiny_starcoder_py"}), # noqa: E501 + "GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M", + {"6b": "EleutherAI/gpt-j-6b"}), + "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", + {"1b": "EleutherAI/pythia-1.4b"}), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts", # noqa: E501 - min_transformers_version="4.49"), # noqa: E501 + "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", @@ -186,7 +190,8 @@ def check_available_online( "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), - "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 + "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 + {"falcon3": "ehristoforu/Falcon3-MoE-2x7B-Insruct"}), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), @@ -194,7 +199,8 @@ def check_available_online( "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), - "OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"), + "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", + {"1b": "facebook/opt-iml-max-1.3b"}), "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", trust_remote_code=True), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), @@ -204,10 +210,12 @@ def check_available_online( trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), + "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", + trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), - "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", - extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501 + "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct", + extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501 "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo( "Qwen/Qwen3-8B", @@ -233,8 +241,7 @@ def check_available_online( "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), - "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct", - min_transformers_version="4.49"), + "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), @@ -245,11 +252,15 @@ def check_available_online( "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), + "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", + trust_remote_code=True), "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", trust_remote_code=True), "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), + "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 + trust_remote_code=True), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), @@ -273,6 +284,7 @@ def check_available_online( "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 } _MULTIMODAL_EXAMPLE_MODELS = { @@ -286,10 +298,11 @@ def check_available_online( extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 - hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), - "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it", - min_transformers_version="4.50"), + "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), + "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501 + min_transformers_version="4.52.0"), # noqa: E501 "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 @@ -302,6 +315,9 @@ def check_available_online( trust_remote_code=True), "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 + "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 + extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 + trust_remote_code=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 min_transformers_version="4.51"), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", @@ -322,7 +338,6 @@ def check_available_online( extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 trust_remote_code=True), "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 - min_transformers_version="4.50", # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", max_transformers_version="4.48", @@ -348,8 +363,9 @@ def check_available_online( hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 - min_transformers_version="4.49"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 + "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501 + min_transformers_version="4.52"), # noqa: E501 "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 @@ -358,7 +374,7 @@ def check_available_online( # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="facebook/bart-base", + tokenizer="Isotr0py/Florence-2-tokenizer", trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 @@ -378,6 +394,10 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="meta-llama/Llama-3.1-8B-Instruct"), } _TRANSFORMERS_MODELS = { diff --git a/tests/models/test_bitblas.py b/tests/models/test_bitblas.py new file mode 100644 index 000000000000..ae4a52214ad0 --- /dev/null +++ b/tests/models/test_bitblas.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_bitblas: str + model_gptq: str + + +model_pairs = [ + ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", + model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_bitblas, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="bitblas", + ) diff --git a/tests/models/test_gptq_bitblas.py b/tests/models/test_gptq_bitblas.py new file mode 100644 index 000000000000..d28442120ea6 --- /dev/null +++ b/tests/models/test_gptq_bitblas.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_gptq: str + + +model_pairs = [ + ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="gptq_bitblas", + ) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index cd2b8f00d521..446c4efbf6af 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -24,10 +24,7 @@ def test_can_initialize(model_arch): def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) - if hasattr(hf_config, "text_config"): - text_config: PretrainedConfig = hf_config.text_config - else: - text_config = hf_config + text_config = hf_config.get_text_config() text_config.update({ "num_layers": 1, diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index f1ed8a04cfa0..b45a87d94b86 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -18,10 +18,9 @@ def test_plugin( m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_PLUGINS", "") - with pytest.raises(Exception) as excinfo: + match = "Cannot find model module" + with pytest.raises(ValueError, match=match): LLM(model=dummy_opt_path, load_format="dummy") - error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501 - assert (error_msg in str(excinfo.value)) @create_new_process_for_each_test() diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 5c928f27c10d..70f716f95e89 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -261,16 +261,23 @@ def check_model(model): @pytest.mark.parametrize( "wNa16_args", - [ - ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), - ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), - ], + [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8, + True, False), + ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True, + False), + ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4, + True, False), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128, + 8, False, False), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", + "channel", None, 8, False, False), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", + "group", 128, 8, False, True)], ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): - model, strategy, group, pack_factor = wNa16_args + model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args with vllm_runner(model) as llm: def check_model(model): @@ -286,6 +293,8 @@ def check_model(model): if group is None else group) assert qkv_proj.scheme.pack_factor == pack_factor + assert qkv_proj.scheme.symmetric == symmetric + assert qkv_proj.scheme.has_g_idx == has_g_idx llm.apply_model(check_model) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index a1a81b3891f6..5de1137eaf68 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -5,6 +5,9 @@ """ import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.assets.audio import AudioAsset @pytest.fixture(autouse=True) @@ -19,6 +22,7 @@ def v1(run_with_both_engines): # 3. Use the model "huggyllama/llama-7b". MAX_TOKENS = [64] BEAM_WIDTHS = [4] +MM_BEAM_WIDTHS = [2] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @@ -48,15 +52,90 @@ def test_beam_search_single_input( for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for i, (hf_text, + for j, (hf_text, vllm_text) in enumerate(zip(hf_output_texts, vllm_output_texts)): - print(f">>>{i}-th hf output:") + print(f">>>{j}-th hf output:") print(hf_text) - print(f">>>{i}-th vllm output:") + print(f">>>{j}-th vllm output:") print(vllm_text) assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], ( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) +def test_beam_search_passes_multimodal_data( + hf_runner, + vllm_runner, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Ensure that beam search passes multimodal data through correctly.""" + # NOTE - this test is primarily to check that mm data is passed to beams + # correctly. As such, we just need to check one extra modality to make + # sure things pass through properly. + audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate] + model = "Qwen/Qwen2-Audio-7B-Instruct" + audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>" + prompts = [ + f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501 + ] + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + audio_token_id = hf_model.config.audio_token_index + eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|> + hf_outputs = hf_model.generate_beam_search( + prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=audios, + ) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_beam_search( + prompts, + beam_width=beam_width, + max_tokens=max_tokens, + audios=audios, + ) + + seq_with_no_audio_toks = lambda seq: [ + tok for tok in seq if tok != audio_token_id + ] + + for i in range(len(prompts)): + hf_output_ids, hf_output_texts = hf_outputs[i] + vllm_output_ids, vllm_output_texts = vllm_outputs[i] + + for j, (hf_text, + vllm_text) in enumerate(zip(hf_output_texts, + vllm_output_texts)): + print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:") + print(hf_text) + print(f">>>{j}-th vllm output:") + print(vllm_text) + assert len(hf_output_ids) == len(vllm_output_ids) + + for j in range(len(hf_output_ids)): + # Compare everything except for the audio tokens; we do this since + # the IDs returned from the transformers helper expands the audio + # token to match features, while the vLLM helper maintains the + # single audio token in the input text + filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j]) + filtered_vllm_output_ids = seq_with_no_audio_toks( + vllm_output_ids[j]) + + # HF output IDs may contain the end of sequence + if len(filtered_hf_output_ids + ) == len(filtered_vllm_output_ids) + 1: + assert filtered_hf_output_ids[-1] == eos_token_id + filtered_hf_output_ids = filtered_hf_output_ids[:-1] + + assert filtered_hf_output_ids == filtered_vllm_output_ids diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py index 161cc9fbf556..f73cf4b345fb 100644 --- a/tests/spec_decode/test_scorer.py +++ b/tests/spec_decode/test_scorer.py @@ -62,9 +62,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int, scorer_worker = create_worker(Worker, model_name, block_size, num_gpu_blocks, seed) scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer - scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True - scorer_worker.model_runner.model.sampler.\ - should_modify_greedy_probs_inplace = True + scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True + scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True vocab_size = scorer_worker.vocab_size diff --git a/tests/test_config.py b/tests/test_config.py index 06264c5b99b9..53db91e81c41 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,14 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import asdict +from dataclasses import MISSING, Field, asdict, dataclass, field import pytest -from vllm.config import ModelConfig, PoolerConfig +from vllm.config import ModelConfig, PoolerConfig, get_field from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform +def test_get_field(): + + @dataclass + class TestConfig: + a: int + b: dict = field(default_factory=dict) + c: str = "default" + + with pytest.raises(ValueError): + get_field(TestConfig, "a") + + b = get_field(TestConfig, "b") + assert isinstance(b, Field) + assert b.default is MISSING + assert b.default_factory is dict + + c = get_field(TestConfig, "c") + assert isinstance(c, Field) + assert c.default == "default" + assert c.default_factory is MISSING + + @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ diff --git a/tests/test_utils.py b/tests/test_utils.py index b6129a102085..580e65f1f833 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,11 +13,11 @@ from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, - PlaceholderModule, StoreBoolean, bind_kv_cache, - deprecate_kwargs, get_open_port, memory_profiling, - merge_async_iterators, sha256, supports_kw, - swap_dict_values) +from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, + MemorySnapshot, PlaceholderModule, StoreBoolean, + bind_kv_cache, deprecate_kwargs, get_open_port, + memory_profiling, merge_async_iterators, sha256, + supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @@ -417,6 +417,129 @@ def test_bind_kv_cache_pp(): assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + + assert cache.get(2) == 2 + assert cache.stat() == CacheInfo(hits=1, total=1) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + assert cache[2] == 2 + assert cache.stat() == CacheInfo(hits=2, total=2) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + assert cache.get(-1) is None + assert cache.stat() == CacheInfo(hits=2, total=3) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index cd60cefd7ccd..c740fde42636 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -1,24 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 - +import pickle from copy import deepcopy +import pytest from transformers import AutoTokenizer -from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_cached_tokenizer) -def test_cached_tokenizer(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") +@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"]) +def test_cached_tokenizer(model_id: str): + reference_tokenizer = AutoTokenizer.from_pretrained(model_id, + trust_remote_code=True) reference_tokenizer.add_special_tokens({"cls_token": ""}) reference_tokenizer.add_special_tokens( {"additional_special_tokens": [""]}) + cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) + _check_consistency(cached_tokenizer, reference_tokenizer) + + pickled_tokenizer = pickle.dumps(cached_tokenizer) + unpickled_tokenizer = pickle.loads(pickled_tokenizer) + _check_consistency(unpickled_tokenizer, reference_tokenizer) + + +def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer): + assert isinstance(target, type(expected)) + + # Cached attributes + assert target.all_special_ids == expected.all_special_ids + assert target.all_special_tokens == expected.all_special_tokens + assert (target.all_special_tokens_extended == + expected.all_special_tokens_extended) + assert target.get_vocab() == expected.get_vocab() + assert len(target) == len(expected) + + # Other attributes + assert getattr(target, "padding_side", + None) == getattr(expected, "padding_side", None) - assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( - "prompt") - assert set(reference_tokenizer.all_special_ids) == set( - cached_tokenizer.all_special_ids) - assert set(reference_tokenizer.all_special_tokens) == set( - cached_tokenizer.all_special_tokens) - assert set(reference_tokenizer.all_special_tokens_extended) == set( - cached_tokenizer.all_special_tokens_extended) + assert target.encode("prompt") == expected.encode("prompt") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index b1860e0bb708..f8e213b9ca48 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -4,14 +4,22 @@ from typing import Any, Optional import pytest -from transformers import AutoTokenizer +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import (Detokenizer, - detokenize_incrementally) -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, + IncrementalDetokenizer, + SlowIncrementalDetokenizer) + +SPECIAL_TOKS_TRUTH = [ + "Some text with adjacent special tokens <|padding|><|padding|>other text", # noqa +] TRUTH = [ "Hello here, this is a simple test", @@ -22,7 +30,8 @@ # incomplete UTF-8 characters # see https://github.com/vllm-project/vllm/pull/9625 "ပုံပြင်လေးပြောပြပါ်", -] +] + SPECIAL_TOKS_TRUTH + TOKENIZERS = [ "facebook/opt-125m", "gpt2", @@ -38,26 +47,37 @@ ] -def _run_incremental_decode(tokenizer, all_input_ids, - skip_special_tokens: bool, starting_index: int): - decoded_text = "" - offset = 0 - token_offset = 0 - prev_tokens = None - for i in range(starting_index, len(all_input_ids)): - new_tokens, text, offset, token_offset = detokenize_incrementally( - tokenizer, - all_input_ids[:i + 1], - prev_tokens, - offset, - token_offset, - skip_special_tokens=skip_special_tokens) - decoded_text += text - if prev_tokens is None: - prev_tokens = new_tokens - else: - prev_tokens += new_tokens - return decoded_text +def _run_incremental_decode(tokenizer, + all_input_ids, + skip_special_tokens: bool, + starting_index: int, + spaces_between_special_tokens: bool = True, + fast: Optional[bool] = None): + + prompt_token_ids = all_input_ids[:starting_index] + + params = SamplingParams( + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + request = EngineCoreRequest("", prompt_token_ids, None, None, None, params, + None, 0.0, None) + + if fast is None: + detokenizer = IncrementalDetokenizer.from_new_request( + tokenizer, request) + elif fast: + detokenizer = FastIncrementalDetokenizer(tokenizer, request) + else: + detokenizer = SlowIncrementalDetokenizer(tokenizer, request) + + output_text = "" + for i, token_id in enumerate(all_input_ids[starting_index:]): + detokenizer.update([token_id], False) + finished = i == len(all_input_ids) - 1 + output_text += detokenizer.get_next_output_text(finished, delta=True) + + return output_text, detokenizer.output_token_ids @pytest.fixture @@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth): starting_index = 0 all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids - decoded_text = _run_incremental_decode(tokenizer, - all_input_ids, - skip_special_tokens=True, - starting_index=starting_index) + decoded_text, out_ids = _run_incremental_decode( + tokenizer, + all_input_ids, + skip_special_tokens=True, + starting_index=starting_index) assert decoded_text == truth + assert out_ids == all_input_ids[starting_index:] @pytest.fixture @@ -106,45 +128,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: @pytest.mark.parametrize("with_prompt", [True, False]) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) -def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): +@pytest.mark.parametrize("spaces_between_special_tokens", (True, False)) +@pytest.mark.parametrize("fast", (True, False)) +def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, + spaces_between_special_tokens, fast): + if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): + pytest.skip() + + if skip_special_tokens and not spaces_between_special_tokens: + pytest.skip() + + if not fast and isinstance(tokenizer, PreTrainedTokenizerFast): + # Fix up inconsistency in fast/slow tokenizer behaviour. + tokenizer.add_special_tokens({ + "additional_special_tokens": [ + at for at in + tokenizer._tokenizer.get_added_tokens_decoder().values() + if at.special + ] + }) + + extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \ + else {"spaces_between_special_tokens": spaces_between_special_tokens} + + truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids + if tokenizer.bos_token_id is not None: + truth_tokens.insert(0, tokenizer.bos_token_id) + truth_tokens.append(tokenizer.eos_token_id) + + new_truth = tokenizer.decode(truth_tokens, + skip_special_tokens=skip_special_tokens, + **extra_decode_args) + if with_prompt: - truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids - prompt_input_ids = truth_tokens[:len(truth) // 2] - generated_input_ids = truth_tokens[len(truth) // 2:] + num_prompt_tokens = len( + tokenizer(truth[:len(truth) // 2], + add_special_tokens=False).input_ids) + if tokenizer.bos_token_id is not None: + num_prompt_tokens += 1 + + prompt_input_ids = truth_tokens[:num_prompt_tokens] + generated_input_ids = truth_tokens[num_prompt_tokens:] all_input_ids = prompt_input_ids + generated_input_ids starting_index = len(prompt_input_ids) prompt = tokenizer.decode(prompt_input_ids, - skip_special_tokens=skip_special_tokens) - generated = truth[len(prompt):] + skip_special_tokens=skip_special_tokens, + **extra_decode_args) + + generated = new_truth[len(prompt):] else: - generated = truth + generated = new_truth starting_index = 0 - all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids - if skip_special_tokens: - if tokenizer.bos_token_id is not None: - all_input_ids = [tokenizer.bos_token_id] + all_input_ids - starting_index += 1 - all_input_ids = all_input_ids + [tokenizer.eos_token_id] + all_input_ids = truth_tokens - decoded_text = _run_incremental_decode( + decoded_text, out_ids = _run_incremental_decode( tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens, - starting_index=starting_index) + starting_index=starting_index, + spaces_between_special_tokens=spaces_between_special_tokens, + fast=fast) assert decoded_text == generated + assert out_ids == all_input_ids[starting_index:] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("fast", (True, False)) +def test_oov_decode(tokenizer, fast): + if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): + pytest.skip() - decoded_text = _run_incremental_decode( + decoded_text, out_ids = _run_incremental_decode( tokenizer, [len(tokenizer)], - skip_special_tokens=skip_special_tokens, - starting_index=starting_index) + skip_special_tokens=True, + starting_index=0, + spaces_between_special_tokens=True, + fast=fast) assert decoded_text == '' + assert out_ids == [len(tokenizer)] @pytest.fixture def detokenizer(tokenizer_name: str) -> Detokenizer: - init_kwargs = dict( + tokenizer_group = TokenizerGroup( tokenizer_id=tokenizer_name, enable_lora=False, max_num_seqs=100, @@ -154,26 +222,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: revision=None, ) - tokenizer_group = get_tokenizer_group( - None, - **init_kwargs, - ) - return Detokenizer(tokenizer_group) @pytest.fixture(name="complete_sequence_token_ids") def create_complete_sequence_token_ids(complete_sequence: str, tokenizer) -> list[int]: - complete_sequence_token_ids = tokenizer(complete_sequence).input_ids - return complete_sequence_token_ids + return tokenizer(complete_sequence, add_special_tokens=False).input_ids def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [1] + prompt_token_ids = prompt_token_ids or [] return Sequence( seq_id=0, - inputs=token_inputs(prompt_token_ids, prompt=""), + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -224,7 +286,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, assert sequential_result == "".join(sequential_logprobs_text_chosen_token) assert sequential_result != "".join(sequential_logprobs_text_other_token) - if skip_special_tokens: + if not skip_special_tokens: # Text for logprobs for the chosen token should be the same as the # generated text. Note that this will only be true if we skip # special tokens. @@ -233,10 +295,23 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], +def test_decode_prompt_logprobs(complete_sequence: str, + complete_sequence_token_ids: list[int], detokenizer: Detokenizer): + + # We want to use skip_special_tokens=False here but Mistral tokenizers + # don't support that. + if complete_sequence not in SPECIAL_TOKS_TRUTH: + skip_special_tokens = True + elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None), + MistralTokenizer): + skip_special_tokens = False + else: + pytest.skip("MistralTokenizers don't support " + "skip_special_tokens=False") + return """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=True, + sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, prompt_logprobs=1) # Run sequentially. @@ -256,8 +331,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids tokenizer = detokenizer.get_tokenizer_for_seq(seq) - text_full = tokenizer.decode(token_ids, skip_special_tokens=True) - text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) + text_full = tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) + text_first = tokenizer.decode(token_ids[0], + skip_special_tokens=skip_special_tokens) text = text_full[len(text_first):] # Text for logprobs for the chosen token should be the same as the diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index 5b62f992c1be..bcfa78ed41cf 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -1,40 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -import asyncio -import os -import sys -from typing import Optional -from unittest.mock import patch - import pytest from transformers import AutoTokenizer, PreTrainedTokenizerBase -from vllm.transformers_utils.tokenizer_group import (TokenizerGroup, - get_tokenizer_group) -from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( - RayTokenizerGroupPool) - -from ..conftest import get_tokenizer_pool_config - - -class CustomTokenizerGroup(TokenizerGroup): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._i = 0 - - def encode(self, *args, **kwargs): - self._i += 1 - return super().encode(*args, **kwargs) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup @pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", - [None, "ray", CustomTokenizerGroup]) -async def test_tokenizer_group(tokenizer_group_type): +async def test_tokenizer_group(): reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer_group = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), + tokenizer_group = TokenizerGroup( tokenizer_id="gpt2", enable_lora=False, max_num_seqs=1, @@ -49,159 +24,3 @@ async def test_tokenizer_group(tokenizer_group_type): PreTrainedTokenizerBase) assert tokenizer_group.get_lora_tokenizer( None) == await tokenizer_group.get_lora_tokenizer_async(None) - if tokenizer_group_type is CustomTokenizerGroup: - assert tokenizer_group._i > 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_pool(tokenizer_group_type): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer_group_pool = get_tokenizer_group( - get_tokenizer_pool_config(tokenizer_group_type), - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - # Send multiple requests to the tokenizer group pool - # (more than the pool size) - # and check that all requests are processed correctly. - num_requests = tokenizer_group_pool.pool_size * 5 - requests = [ - tokenizer_group_pool.encode_async(prompt=f"prompt {i}", - lora_request=None) - for i in range(num_requests) - ] - results = await asyncio.gather(*requests) - expected_results = [ - reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests) - ] - assert results == expected_results - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_ray_pool_env_var_propagation( - tokenizer_group_type): - """Test that env vars from caller process are propagated to - tokenizer Ray actors.""" - env_var = "MY_ENV_VAR" - - class EnvVarCheckerTokenizerGroup(TokenizerGroup): - - def ping(self): - assert os.environ.get(env_var) == "1" - return super().ping() - - class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool): - _worker_cls = EnvVarCheckerTokenizerGroup - - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None) - with pytest.raises(AssertionError): - tokenizer_pool.ping() - - with patch.dict(os.environ, {env_var: "1"}): - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None) - tokenizer_pool.ping() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", ["ray"]) -async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type): - """Test that Ray tokenizer pool group can recover from failures and - if that's not possible, mark itself as unhealthy.""" - - class FailingTokenizerGroup(TokenizerGroup): - - def __init__(self, - *args, - fail_at: Optional[list[int]] = None, - **kwargs): - super().__init__(*args, **kwargs) - self.i = 0 - self.fail_at = fail_at or [] - - def encode(self, *args, **kwargs): - self.i += 1 - if self.i in self.fail_at: - sys.exit(1) - return super().encode(*args, **kwargs) - - class FailingRayTokenizerGroupPool(RayTokenizerGroupPool): - _worker_cls = FailingTokenizerGroup - - # Fail at first iteration - fail_at = [1] - tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type) - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - fail_at=fail_at) - tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() - - # Modify fail at to not fail at all (will be re-read when actor is - # re-initialized). - fail_at[0] = 1000 - - # We should recover successfully. - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - - # Check that we have a new actor - assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) - assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors - - # Fail at first iteration - fail_at = [1] - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - fail_at=fail_at) - - # We should fail after re-initialization. - with pytest.raises(RuntimeError): - await tokenizer_group_pool.encode_async(prompt="prompt", - lora_request=None) - - # check_health should raise the same thing - with pytest.raises(RuntimeError): - tokenizer_group_pool.check_health() - - # Ensure that non-ActorDiedErrors are still propagated correctly and do not - # cause a re-initialization. - fail_at = [] - tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config( - tokenizer_pool_config, - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=2, - fail_at=fail_at) - tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy() - - # Prompt too long error - with pytest.raises(ValueError): - await tokenizer_group_pool.encode_async(prompt="prompt" * 100, - lora_request=None) - await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None) - # Actors should stay the same. - assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 7c87c73f04da..c14eaf71e978 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -98,6 +98,20 @@ def ensure_system_prompt(messages: list[dict[str, Any]], "extended": True }, + "llama4_json": { + "model": + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "arguments": [ + "--enforce-eager", "--no-enable-prefix-caching", "-tp", "4", + "--distributed-executor-backend", "mp", "--tool-call-parser", + "llama4_json", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja") + ], + "supports_parallel": + True, + "extended": + True + }, "mistral": { "model": "mistralai/Mistral-7B-Instruct-v0.3", diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a4a571b180c6..e73e08e74b0d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -37,7 +37,6 @@ def make_request(request_id, return Request( request_id=request_id, - prompt=None, prompt_token_ids=prompt_token_ids, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, @@ -311,7 +310,7 @@ def test_metrics(): def stats(requests, queries, hits): return PrefixCacheStats(requests=requests, queries=queries, hits=hits) - metrics = PrefixCachingMetrics(interval=5) + metrics = PrefixCachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 metrics.observe(stats(1, 20, 9)) @@ -496,8 +495,7 @@ def test_allocate_with_lookahead(): # Test case 1: Requires additional lookahead tokens kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=0) + max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, @@ -507,25 +505,19 @@ def test_allocate_with_lookahead(): # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=4) - # num_preallocate_blocks = 4 // 4 - 2 // 4 = 1 + max_model_len=100) # required_blocks = ceil((3 + 2) /4) = 2 - # total_blocks = 1 + 2 = 3 blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks) == 3 + assert len(blocks) == 2 # Test case 3: With precomputed blocks - # num_preallocate_blocks = 4 // 4 - 4 // 4 = 0 # required_blocks = ceil((3 + 4) / 4) = 2 - # total_blocks = 0 + 2 = 2 kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=4) + max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 80dd275a90b8..b2e8ff61450c 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,7 +8,7 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, @@ -29,7 +29,6 @@ def make_request(request_id, return Request( request_id=request_id, - prompt=None, prompt_token_ids=prompt_token_ids, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, @@ -61,7 +60,6 @@ def test_prefill(hash_algo): max_model_len=8192, enable_caching=True, caching_hash_algo=hash_algo, - num_preallocate_tokens=16, ) # choose the hash function according to the parameter @@ -80,7 +78,7 @@ def test_prefill(hash_algo): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -92,8 +90,8 @@ def test_prefill(hash_algo): assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value - # Check partial/preallocated block metadata - for block_id in (4, 5): + # Check partial block metadata + for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -107,12 +105,12 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6, 7] + assert [b.block_id for b in blocks] == [5] for block in computed_blocks: assert block.ref_cnt == 2 - # At this point, we should have 3 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 3 + # At this point, we should have 5 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) @@ -120,14 +118,14 @@ def test_prefill(hash_algo): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (8, 9, 10)] - # [unique_req0 (5, 4)] - # [unique_req1 (7, 6)] + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) @@ -139,29 +137,29 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [8, 9] + assert [b.block_id for b in blocks] == [6] - # Although we only have 5 free blocks, we have 8 blocks in + # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. - assert manager.block_pool.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 6 assert all([ b.ref_cnt == 0 for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) assert len([ b for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ]) == 5 + ]) == 6 manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 9)) + req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1] + assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -178,7 +176,6 @@ def test_prefill_plp(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # the default hash function is hash hash_fn = hash @@ -197,7 +194,7 @@ def test_prefill_plp(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] req0_block_hashes = [b.block_hash for b in blocks] # Check full block metadata @@ -210,8 +207,8 @@ def test_prefill_plp(): assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value - # Check partial/preallocated block metadata - for block_id in (4, 5): + # Check partial block metadata + for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -226,12 +223,12 @@ def test_prefill_plp(): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6, 7] + assert [b.block_id for b in blocks] == [5] for block in computed_blocks: assert block.ref_cnt == 2 - # At this point, we should have 3 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 3 + # At this point, we should have 5 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) @@ -239,14 +236,14 @@ def test_prefill_plp(): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (8, 9, 10)] - # [unique_req0 (5, 4)] - # [unique_req1 (7, 6)] + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks @@ -262,7 +259,7 @@ def test_prefill_plp(): block_ids = [b.block_id for b in blocks] # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks] == req0_block_hashes - assert block_ids != [1, 2, 3, 4, 5] + assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -277,7 +274,6 @@ def test_decode(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # Complete 3 blocks (48 tokens) @@ -291,7 +287,7 @@ def test_decode(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -299,28 +295,18 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is None + assert manager.req_to_blocks[req0.request_id][-1].block_hash is None - # Append slots without allocating a new block, but start using the - # preallocated block. + # Append slots with allocating a new block. req0.num_computed_tokens = 59 - # 6 tokens to fill the previous block, and 10 tokens to fill + # 9 tokens to fill the previous block, and 10 tokens to fill # the preallocated block. - for _ in range(5 + 10): + for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 15) - assert new_blocks is not None and len(new_blocks) == 0 + new_blocks = manager.allocate_slots(req0, 19) + assert new_blocks is not None and len(new_blocks) == 1 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None - - # Append slots with allocating a new block. - req0.num_computed_tokens = 74 - # 6 tokens to fill the previous block, and 10 tokens to fill - # the preallocated block. - for _ in range(6 + 11): - req0.append_output_token_ids(12) - new_blocks = manager.allocate_slots(req0, 17) - # Plus one preallocated block. - assert new_blocks is not None and len(new_blocks) == 2 + assert manager.req_to_blocks[req0.request_id][-1].block_hash is None def test_evict(): @@ -328,7 +314,6 @@ def test_evict(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) last_token_id = 5 * 16 + 7 @@ -337,7 +322,7 @@ def test_evict(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated + assert len(blocks) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, @@ -349,7 +334,8 @@ def test_evict(): assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 - assert manager.block_pool.free_block_queue.num_free_blocks == 0 + # 10 - (6 + 3) == 1 + assert manager.block_pool.free_block_queue.num_free_blocks == 1 manager.free(req0) manager.free(req1) @@ -357,7 +343,7 @@ def test_evict(): assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8] + ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) @@ -365,8 +351,8 @@ def test_evict(): assert [b.block_id for b in computed_blocks] == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [7, 6] - assert manager.block_pool.free_block_queue.num_free_blocks == 6 + assert [b.block_id for b in blocks] == [10] + assert manager.block_pool.free_block_queue.num_free_blocks == 7 def test_hash_block_correct_reuse(): @@ -379,7 +365,6 @@ def test_hash_block_correct_reuse(): make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Allocate 1 block and cache it. @@ -416,7 +401,6 @@ def test_computed_blocks_not_evicted(): make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Allocate a block and cache it. @@ -465,7 +449,6 @@ def test_basic_prefix_caching_disabled(): make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, - num_preallocate_tokens=0, ) req1 = make_request("1", list(range(10))) # 2 blocks and some more @@ -496,40 +479,6 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8))) -@pytest.mark.parametrize("block_size", [4]) -def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): - """ - This tests that the preallocated blocks are correctly added. - """ - manager = KVCacheManager( - make_kv_cache_config(block_size, 11), - max_model_len=8192, - enable_caching=True, - num_preallocate_tokens=num_preallocate_tokens, - ) - num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) - - req = make_request("0", list(range(block_size * 30))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks - assert num_computed_tokens == 0 - # Just ask for 1 block. - blocks = manager.allocate_slots(req, block_size, computed_blocks) - req.num_computed_tokens = block_size - assert len(blocks) == 1 + num_preallocated_blocks - - # Assume all computed, only when num_preallocate_tokens > 0, we need to - # consume the previously preallocated blocks. - if num_preallocated_blocks > 0: - manager.allocate_slots(req, block_size * (len(blocks) - 1)) - req.num_computed_tokens = block_size * len(blocks) - - # Append 1 block. - blocks = manager.allocate_slots(req, block_size) - assert len(blocks) == 1 + num_preallocated_blocks - - @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_cache_blocks(hash_fn): """ @@ -588,7 +537,6 @@ def test_mm_prefix_caching(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) @@ -626,7 +574,7 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -667,7 +615,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | @@ -721,7 +668,6 @@ def test_reset_prefix_cache(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] @@ -751,3 +697,82 @@ def test_reset_prefix_cache(): assert manager.reset_prefix_cache() assert not manager.block_pool.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) + + +def test_prefix_cache_stats_disabled(): + """Test that prefix_cache_stats is None when log_stats is False.""" + manager = KVCacheManager( + make_kv_cache_config(16, 11), + max_model_len=8192, + enable_caching=True, + log_stats=False, # Disable logging stats + ) + assert manager.prefix_cache_stats is None + + # Call all functions that check whether log_stats is disabled. + req = make_request("0", list(range(16))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert not computed_blocks + assert num_computed_tokens == 0 + manager.allocate_slots(req, 16, computed_blocks) + manager.reset_prefix_cache() + + # Ensure prefix_cache_stats remains None + assert manager.prefix_cache_stats is None + + +def test_eagle_enabled_removes_last_block(): + """Verify Eagle does NOT remove blocks when request + length is divisible by block size.""" + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks=10), + max_model_len=8192, + enable_caching=True, + use_eagle=True, + ) + + # Request with 3 full blocks (48 tokens) + token_ids = [0] * (3 * block_size) + req = make_request("divisible_request", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with same tokens + Eagle enabled + req_eagle = make_request("eagle_divisible", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + + # Should retain 2 blocks: + # 1. Original 3 blocks → pop last hash → 2 matched blocks + # 2. last_block_hash is not None → Eagle pop is not SKIPPED + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size # 32 tokens + + +def test_eagle_with_partial_blocks(): + """Test Eagle behavior with requests containing partial blocks.""" + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks=10), + max_model_len=8192, + enable_caching=True, + use_eagle=True, + ) + # 2 full blocks + 5 tokens (non-divisible length) + token_ids = [0] * (2 * block_size + 5) + req = make_request("partial_block_test", token_ids) + + # Prime the cache + computed_blocks, _ = manager.get_computed_blocks(req) + manager.allocate_slots(req, len(token_ids), computed_blocks) + manager.free(req) + + # New request with Eagle enabled + req_eagle = make_request("partial_eagle", token_ids) + computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) + # Original match: 2 full blocks → Eagle removes 1 → 1 remaining + assert len(computed_blocks) == 1 + assert num_tokens == 1 * block_size diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bc17ca32e5b6..9987688b02fa 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional +from unittest.mock import Mock import pytest import torch -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput @@ -25,6 +27,11 @@ def create_scheduler( enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ) -> Scheduler: '''Create scheduler under test. @@ -39,12 +46,15 @@ def create_scheduler( Returns: :class:`Scheduler` instance ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, + max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, ) model_config = ModelConfig( model=model, @@ -60,31 +70,42 @@ def create_scheduler( 'enable_prefix_caching': enable_prefix_caching }) cache_config = CacheConfig( - block_size=16, + block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, ) kv_cache_config = KVCacheConfig( - num_blocks=10000, # A large number of blocks to hold all requests + num_blocks=num_blocks, # A large number of blocks to hold all requests tensors={}, kv_cache_groups=[ KVCacheGroupSpec(['layer'], - FullAttentionSpec(16, 1, 1, torch.float32, False)) + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) ], ) - cache_config.num_gpu_blocks = 10000 + cache_config.num_gpu_blocks = num_blocks return Scheduler( - scheduler_config, - model_config, - cache_config, - lora_config=None, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), @@ -111,7 +132,6 @@ def create_requests(num_requests: int, mm_inputs = None request = Request( request_id=f"{i}", - prompt=None, prompt_token_ids=[i] * num_tokens, sampling_params=sampling_params, multi_modal_inputs=mm_inputs, @@ -286,6 +306,7 @@ def test_no_mm_input_chunking(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, + max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests(num_requests=1, @@ -414,7 +435,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): def test_stop_via_update_from_output(): """Test stopping behavior through update_from_output""" - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=1) # Test case 1: Stop on EOS token requests = create_requests(num_requests=2, max_tokens=10) @@ -422,7 +443,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -466,7 +486,7 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [10, 11] # Test case 2: Stop on custom stop token - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) @@ -474,7 +494,6 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -518,13 +537,12 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [13, 14] # Test case 3: Stop on max tokens - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=2, max_tokens=2) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) - scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], @@ -568,13 +586,12 @@ def test_stop_via_update_from_output(): assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag - scheduler = create_scheduler() + scheduler = create_scheduler(num_speculative_tokens=2) requests = create_requests(num_requests=1, max_tokens=10) requests[0].sampling_params.ignore_eos = True requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - scheduler.scheduled_req_ids.add(requests[0].request_id) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -671,13 +688,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match - ([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences - ([[1]], [[1, 2]], (1, 1)), # single token sequence - ([[]], [[5]], (0, 0)), # empty sequence + ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match + ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], + (2, 3, 3, [2, 1])), # multiple sequences + ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence + ([[]], [[5]], (0, 0, 0, [0])), # empty sequence ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], - (6, 3)), # multiple mismatches + (2, 6, 3, [2, 1, 0])), # multiple mismatches ]) def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): """Test scheduling behavior with speculative decoding. @@ -686,7 +704,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): 1. Speculated tokens get scheduled correctly 2. Spec decoding stats properly count number of draft and accepted tokens """ - scheduler = create_scheduler() + num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) + scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) req_ids = [] req_to_index = {} @@ -759,5 +778,467 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): else: assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats - assert stats.num_draft_tokens == expected[0] - assert stats.num_accepted_tokens == expected[1] + assert stats.num_drafts == expected[0] + assert stats.num_draft_tokens == expected[1] + assert stats.num_accepted_tokens == expected[2] + assert stats.num_accepted_tokens_per_pos == expected[3] + + +def _assert_right_scheduler_output( + output: SchedulerOutput, + num_requests: int, + expected_num_scheduled_tokens: int, +): + """Check if SchedulerOutput is correct after remote KV cache hit.""" + + # We should inject the kv_connector_metadata. + assert len(output.kv_connector_metadata.requests) == num_requests + + # Only num_tokens - matched_num_new_tokens should be scheduled. + for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): + assert num_scheduled_tokens == expected_num_scheduled_tokens + + +def _assert_right_kv_cache_manager( + scheduler: Scheduler, + req_ids: list[str], + num_tokens: int, + block_size: int, + num_requests: int, + num_total_blocks: int, +): + """Check whether KVCacheManager is correct after allocate.""" + + # Make sure the request stats are right. + EXPECTED_TOTAL_BLOCKS = num_tokens // block_size + for req_id in req_ids: + blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] + hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + assert (scheduler.kv_cache_manager.num_cached_block[req_id] == + EXPECTED_TOTAL_BLOCKS) + assert len(blocks) == EXPECTED_TOTAL_BLOCKS + assert len(hashes) == EXPECTED_TOTAL_BLOCKS + + # Make sure we actually touched all the blocks. + BLOCKS_PER_REQ = num_tokens / block_size + assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == + num_total_blocks - num_requests * BLOCKS_PER_REQ) + + +def _step_until_done( + scheduler: Scheduler, + output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, +): + """Loop over schedule(), update_from_output() until finished.""" + + all_finished = False + _ = scheduler.update_from_output(output, model_runner_output) + while not all_finished: + # Schedule + a few iterations until stopping. + output = scheduler.schedule() + assert len(scheduler.running) + for _, num_scheduled_tokens in output.num_scheduled_tokens.items(): + # We should be in the decode phase now. + assert num_scheduled_tokens == 1 + assert len(output.kv_connector_metadata.requests) == 0 + ecos = scheduler.update_from_output(output, model_runner_output) + all_done = True + for eco in ecos.outputs: + if eco.finish_reason is None: + all_done = False + all_finished = all_done + + +def test_kv_connector_basic(): + """ + Test whether Scheduler with KVConnector schedules tokens, allocates + memory, and cleans up requests as expected under normal operation. + """ + + # Setup Scheduler. + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + ) + NUM_TOTAL_BLOCKS = ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks()) + BLOCK_SIZE = scheduler.cache_config.block_size + + # Mock External Cache Hit. + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + ###################################################### + # FIRST SET OF REQUESTS - External Hit Only + NUM_REQUESTS = 2 + NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2 + MAX_TOKENS = 3 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # Ensure ScheduleOutput is correct. + output = scheduler.schedule() + _assert_right_scheduler_output( + output=output, + num_requests=NUM_REQUESTS, + # Just the incremental tokens should be scheduled. + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) + + # Ensure KVCacheManager is correct. + _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + NUM_REQUESTS, NUM_TOTAL_BLOCKS) + + # Continue Generation until done. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + # Confirm we clean up the memory properly. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_TOTAL_BLOCKS + + ###################################################### + # SECOND SET OF REQUESTS - Local And External Hit + NUM_TOKENS_PREFIX = NUM_TOKENS + # We will get a local prefix cache hit for the first + # NUM_TOKENS_PREFIX tokens since they are used above. + NUM_TOKENS = NUM_TOKENS_PREFIX * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # We should get a local cache hit of NUM_TOKENS_PREFIX and + # a remote KV cache hit of NUM_MATCHED_NEW_TOKENS. + output = scheduler.schedule() + _assert_right_scheduler_output( + output=output, + num_requests=NUM_REQUESTS, + # Just the incremental tokens after local + remote cache hit. + expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX - + NUM_MATCHED_NEW_TOKENS)) + + # Ensure KVCacheManager is correct. + _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + NUM_REQUESTS, NUM_TOTAL_BLOCKS) + + # Continue Generation until done. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + # Confirm we clean up the memory properly. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_TOTAL_BLOCKS + + +def test_kv_connector_unable_to_allocate(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 4 + NUM_BLOCKS = 10 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + # Create two requests. The second request will not be able to + # allocate slots because it will not have enough blocks. + NUM_REQUESTS = 2 + NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE + MAX_TOKENS = 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # Just one request should be running. + output = scheduler.schedule() + _assert_right_scheduler_output(output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - + NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # All memory should be freed, with one request waiting. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Just one request should be running. + output = scheduler.schedule() + _assert_right_scheduler_output(output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - + NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # All memory should be freed, with no requests waiting / running. + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 0 + + +def test_kv_connector_handles_preemption(): + """ + Test whether scheduler with KVConnector is able to handle + unable to allocate (run out of blocks in allocate_slots(). + """ + + # Setup Scheduler With Mock External Cache Hit. + BLOCK_SIZE = 2 + # NOTE: there is 1 null block, so this is 6 blocks. + NUM_BLOCKS = 7 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + ) + + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS) + + # Create two requests. + # Both can be scheduled at first, but the second request + # will be preempted and re-scheduled. + NUM_REQUESTS = 2 + NUM_TOKENS = BLOCK_SIZE * 2 + 1 + MAX_TOKENS = BLOCK_SIZE * 2 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + # All can be scheduled - 1st token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 2 remote kv cache hits. + num_requests=2, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # All can be scheduled - 2nd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 2 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # This will generate a new block and cause a preemption - 3rd token. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.waiting) == 1 + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + + # Restarts the preempted request - generate 3rd token. + # This will have a local and remote cache hit. + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # 1 remote kv_cache hit! + num_requests=1, + # Only 1 block was preempted and there is a single + # remote hit. So only single new token is scheduled. + expected_num_scheduled_tokens=1, + ) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Only 1 can be scheduled - 4th (and last token). + output = scheduler.schedule() + _assert_right_scheduler_output( + output, + # no connector_metadata + num_requests=0, + expected_num_scheduled_tokens=1) + assert len(scheduler.running) == 1 + _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + assert len(scheduler.running) == 0 + # All memory should be freed since nothing is running. + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ + == NUM_BLOCKS - 1 + + +def make_output(scheduler: Scheduler): + return ModelRunnerOutput( + req_ids=[req.request_id for req in scheduler.running], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(scheduler.running) + }, + sampled_token_ids=[[1000]] * len(scheduler.running), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + + +def assert_scheduler_empty(scheduler: Scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + assert len(scheduler.requests) == 0 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 0 + assert len(scheduler.finished_req_ids) == 0 + assert len(scheduler._cached_reqs_data) == 0 + + # EncoderCacheManager. + assert len(scheduler.encoder_cache_manager.freed) == 0 + assert len(scheduler.encoder_cache_manager.cached) == 0 + + # KVCache Manager. + assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 + assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + num_free_blocks = ( + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + assert num_free_blocks == ( + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 + # assert block._block_hash is None + # assert ( + # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block + # ) == 0) + + +def test_memory_leak(): + """Test that we do not have a memory leak.""" + + scheduler = create_scheduler(enable_prefix_caching=True) + + NUM_REQUESTS = 5 + NUM_TOKENS = 10 + MAX_TOKENS = 10 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + + # Add each request. + for request in requests: + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Iterate until done. + while True: + scheduler_output = scheduler.schedule() + if len(scheduler.running) == 0: + break + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm no memory leak. + assert_scheduler_empty(scheduler) diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index a8079dcce5e2..48c265560348 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -1,13 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + from vllm import LLM, SamplingParams +from ...utils import fork_new_process_for_each_test + -def test_cascade_attention(example_system_message, monkeypatch): +@fork_new_process_for_each_test +@pytest.mark.parametrize("attn_backend", + ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) +def test_cascade_attention(example_system_message, monkeypatch, attn_backend): prompt = "\n: Implement fibonacci sequence in Python.\n:" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") sampling_params = SamplingParams(temperature=0.0, max_tokens=100) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 673714980592..2fad37d6801b 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -44,18 +44,20 @@ def test_prompts(): @pytest.fixture def sampling_config(): - # Only support greedy for now return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) @pytest.fixture def model_name(): - return "meta-llama/Meta-Llama-3-8B-Instruct" + return "meta-llama/Llama-3.1-8B-Instruct" -@pytest.fixture def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3-Instruct-8B" + return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + +def eagle3_model_name(): + return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" def test_ngram_correctness( @@ -102,12 +104,13 @@ def test_ngram_correctness( del spec_llm +@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, - eagle_model_name: str, + use_eagle3: bool, ): ''' Compare the outputs of a original LLM and a speculative LLM @@ -116,18 +119,22 @@ def test_eagle_correctness( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name, max_model_len=1024) + ref_llm = LLM(model=model_name, max_model_len=2048) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + spec_model_name = eagle3_model_name( + ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, + trust_remote_code=True, speculative_config={ - "method": "eagle", - "model": eagle_model_name, + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, "num_speculative_tokens": 3, + "max_model_len": 2048, }, - max_model_len=1024, + max_model_len=2048, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 @@ -140,7 +147,7 @@ def test_eagle_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 70% of the prompts to match exactly + # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + assert matches > int(0.66 * len(ref_outputs)) del spec_llm diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index 8872f0388dd2..f8addd920d57 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -47,7 +47,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer=tokenizer, tokenizer_group=init_tokenizer_from_configs( vllm_config.model_config, vllm_config.scheduler_config, - vllm_config.parallel_config, vllm_config.lora_config), + vllm_config.lora_config), vllm_config=vllm_config, full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], prompt_tokens=prompt_tokens, diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index da0639678af8..5d52ad5f5328 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -3,16 +3,19 @@ import asyncio from contextlib import ExitStack from typing import Optional +from unittest.mock import MagicMock import pytest from vllm import SamplingParams from vllm.assets.image import ImageAsset +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import LoggingStatLogger if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, # Assert only the last output has the finished flag set assert all(not out.finished for out in outputs[:-1]) assert outputs[-1].finished + + +class MockLoggingStatLogger(LoggingStatLogger): + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + super().__init__(vllm_config, engine_index) + self.log = MagicMock() + + +@pytest.mark.asyncio +async def test_customize_loggers(monkeypatch): + """Test that we can customize the loggers. + If a customized logger is provided at the init, it should + be used directly. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger], + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + assert len(engine.stat_loggers) == 1 + assert len(engine.stat_loggers[0]) == 1 + engine.stat_loggers[0][0].log.assert_called_once() diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 3f3109c1484c..30fa9e371ad1 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import threading import time import uuid -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor import pytest from transformers import AutoTokenizer @@ -32,8 +31,7 @@ def make_request() -> EngineCoreRequest: return EngineCoreRequest( - request_id=uuid.uuid4(), - prompt=PROMPT, + request_id=str(uuid.uuid4()), prompt_token_ids=PROMPT_TOKENS, mm_inputs=None, mm_hashes=None, @@ -244,33 +242,33 @@ def initialize_from_config( self, kv_cache_configs: list[KVCacheConfig]) -> None: super().initialize_from_config(kv_cache_configs) - # This executor actually can only run 1 batch at a time - self.semaphore = threading.Semaphore(1) + # Create a thread pool with a single worker + self.thread_pool = ThreadPoolExecutor(max_workers=1) def execute_model( self, scheduler_output, ) -> Future[ModelRunnerOutput]: """Make execute_model non-blocking.""" - future: Future[ModelRunnerOutput] = Future() - def _thread_wrapper(scheduler_output, future): - with self.semaphore: - output = self.collective_rpc("execute_model", - args=(scheduler_output, )) - # Make a copy because output[0] may be reused - # by the next batch. - output = copy.deepcopy(output[0]) - future.set_result(output) + def _execute(): + output = self.collective_rpc("execute_model", + args=(scheduler_output, )) + # Make a copy because output[0] may be reused + # by the next batch. + return copy.deepcopy(output[0]) - threading.Thread(target=_thread_wrapper, - args=(scheduler_output, future)).start() - return future + # Use the thread pool instead of creating a new thread + return self.thread_pool.submit(_execute) @property def max_concurrent_batches(self) -> int: return 2 + def shutdown(self): + if hasattr(self, 'thread_pool'): + self.thread_pool.shutdown(wait=False) + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -299,14 +297,77 @@ def max_concurrent_batches(self) -> int: # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 1 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 10 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[ + req0.request_id].num_computed_tokens == 10 + + # Schedule Batch 2: (2, req0), (8, req1) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 2 + assert scheduler_output.num_scheduled_tokens[1] == 8 + # num_computed_tokens should have been updated immediately. + assert engine_core.scheduler.requests[0].num_computed_tokens == 12 + assert engine_core.scheduler.requests[1].num_computed_tokens == 8 + assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Loop through both requests. - while engine_core.scheduler.get_num_unfinished_requests() == 2: - engine_core.step_with_batch_queue() + # Batch queue is full. Finish Batch 1. + engine_core.step_with_batch_queue() + + # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled + # because it is in the decoding stage now. + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[1] == 4 - # Reaching here when got the result of the first request. - while engine_core.scheduler.get_num_unfinished_requests() == 1: - engine_core.step_with_batch_queue() + # Batch queue is full. Finish Batch 2. Get first token of req0. + output = engine_core.step_with_batch_queue() + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 + + # Schedule Batch 4: (1, req0). + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[0] == 1 + + # Batch queue is full. Finish Batch 3. Get first token of req1. + output = engine_core.step_with_batch_queue() + assert output is not None + assert len(output.outputs) == 1 + assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 + + # Schedule Batch 5: (1, req1). + engine_core.step_with_batch_queue() + assert engine_core.batch_queue.qsize() == 2 + scheduler_output = engine_core.batch_queue.queue[-1][1] + assert scheduler_output.num_scheduled_tokens[1] == 1 + + # Loop until req0 is finished. + step = 0 + req_id = 0 + expected_num_tokens = [ + engine_core.scheduler.requests[0].num_tokens + 1, + engine_core.scheduler.requests[1].num_tokens + 1, + ] + while engine_core.scheduler.get_num_unfinished_requests() == 2: + output = engine_core.step_with_batch_queue() + if step % 2 == 0: + # Even steps consumes an output. + assert output is not None + assert len(output.outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert engine_core.scheduler.requests[ + req_id].num_tokens == expected_num_tokens[req_id] + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 + else: + # Odd steps schedules a new batch. + assert output is None + step += 1 diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 8ebdaf63b484..8cc36fa163f7 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -35,7 +35,6 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: return EngineCoreRequest( request_id=str(uuid.uuid4()), - prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, mm_inputs=None, mm_hashes=None, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 9ac42dbc34a4..d2bb7d88fef2 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -50,7 +50,6 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, # Make N requests. requests = [ EngineCoreRequest(request_id=f"request-{idx}", - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -64,14 +63,13 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, output_kind=request_output_kind, stop=[], include_stop_str_in_output=False, - )) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + )) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add requests to the detokenizer. - for request in requests: - output_processor.add_request(request) + for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): + output_processor.add_request(request, prompt) gen_strings = {} gen_tokens = {} @@ -398,7 +396,6 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ] requests = [ EngineCoreRequest(request_id=request_id_list[idx], - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -414,14 +411,13 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, include_stop_str_in_output=False, logprobs=num_sample_logprobs, prompt_logprobs=num_prompt_logprobs, - )) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + )) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add requests to the detokenizer. - for request in requests: - output_processor.add_request(request) + for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): + output_processor.add_request(request, prompt) gen_tokens = {} gen_logprobs = {} @@ -562,7 +558,6 @@ def test_stop_token(include_stop_str_in_output: bool, request_id = "request-0" request = EngineCoreRequest( request_id=request_id, - prompt=prompt_string, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -583,7 +578,7 @@ def test_stop_token(include_stop_str_in_output: bool, )) # Add request to the detokenizer. - output_processor.add_request(request) + output_processor.add_request(request, prompt_string) # Loop over engine core steps; run output processor gen_string = "" @@ -659,7 +654,6 @@ def test_stop_string(include_stop_str_in_output: bool, requests = [ EngineCoreRequest( request_id=request_id_list[idx], - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -675,14 +669,13 @@ def test_stop_string(include_stop_str_in_output: bool, include_stop_str_in_output=include_stop_str_in_output, logprobs=num_sample_logprobs, prompt_logprobs=None, - )) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + )) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add requests to the detokenizer. - for request in requests: - output_processor.add_request(request) + for request, prompt in zip(requests, dummy_test_vectors.prompt_strings): + output_processor.add_request(request, prompt) gen_strings = {} gen_tokens = {} @@ -774,7 +767,6 @@ def test_iteration_stats(dummy_test_vectors): requests = [ EngineCoreRequest( request_id=f"request-{idx}", - prompt=prompt, prompt_token_ids=prompt_tokens, arrival_time=0, mm_inputs=None, @@ -783,15 +775,13 @@ def test_iteration_stats(dummy_test_vectors): eos_token_id=None, lora_request=None, sampling_params=SamplingParams(), - ) for idx, (prompt, prompt_tokens) in enumerate( - zip(dummy_test_vectors.prompt_strings, - dummy_test_vectors.prompt_tokens)) + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add all requests except one to the OutputProcessor. num_active = len(dummy_test_vectors.generation_tokens) - 1 for request in requests[:num_active]: - output_processor.add_request(request) + output_processor.add_request(request, None) inactive_request = requests[num_active] # First iteration has 2 prefills. @@ -817,7 +807,7 @@ def test_iteration_stats(dummy_test_vectors): assert iteration_stats.num_generation_tokens == num_active # Add a new request - prefill and 2 decodes in this step. - output_processor.add_request(inactive_request) + output_processor.add_request(inactive_request, None) num_active += 1 outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() @@ -921,3 +911,84 @@ def make_outputs() -> list[RequestOutput]: # Cumulative logprobs should be the last one. cumulative_logprob_expected = 1.0 * num_to_put assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected + + +@pytest.mark.asyncio +async def test_cumulative_output_collector_n(): + """Test collector correctly handles multiple outputs by index.""" + collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE) + outputs = [ + RequestOutput( + request_id="my-request-id", + prompt=None, + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="a", + token_ids=[0], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + CompletionOutput( + index=1, + text="b", + token_ids=[1], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + ], + finished=False, + ), + RequestOutput( + request_id="my-request-id", + prompt=None, + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text="ab", + token_ids=[0, 1], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + CompletionOutput( + index=2, + text="c", + token_ids=[2], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ), + ], + finished=False, + ), + ] + for output in outputs: + collector.put(output) + + # Get the output and check that the text and token_ids are correct. + result = await collector.get() + # We are expecting + # [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}] + assert len(result.outputs) == 3 + # First is the one where index is 0 + first = [k for k in result.outputs if k.index == 0] + assert len(first) == 1 + assert first[0].text == "ab" + + # Second is the one where index is 1 + second = [k for k in result.outputs if k.index == 1] + assert len(second) == 1 + assert second[0].text == "b" + assert second[0].token_ids == [1] + + # Third is the one where index is 2 + third = [k for k in result.outputs if k.index == 2] + assert len(third) == 1 + assert third[0].text == "c" diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 1ee93c72cd26..4a23e0c1b212 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -8,8 +8,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -296,7 +295,7 @@ def generate_dummy_prompt_logprobs_tensors( class DummyOutputProcessorTestVectors: """Dummy test vectors for output processor tests""" tokenizer: GeneralTokenizerType - tokenizer_group: BaseTokenizerGroup + tokenizer_group: TokenizerGroup vllm_config: EngineArgs full_tokens: list[list[int]] # Prompt + generated tokens prompt_tokens: list[list[int]] diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 6d4278b4c871..d84b2b22db12 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -47,6 +47,14 @@ def sample_json_schema(): "type": "string", } }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, "work_history": { "type": "array", "items": { @@ -56,17 +64,20 @@ def sample_json_schema(): "type": "string" }, "duration": { - "type": "number" + "type": "number", + "minimum": 0.0, + "maximum": 100.0, # Numeric range }, "position": { "type": "string" } }, - "required": ["company", "position"] + "required": ["company", "duration", "position"] } } }, - "required": ["name", "age", "skills", "work_history"] + "required": + ["name", "age", "skills", "grade", "email", "work_history"] } @@ -78,27 +89,18 @@ def unsupported_json_schema(): "properties": { "score": { "type": "integer", - "minimum": 0, - "maximum": 100 # Numeric range - }, - "grade": { - "type": "string", - "pattern": "^[A-D]$" # Regex pattern - }, - "email": { - "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "multipleOf": 5 # Numeric multiple }, "tags": { "type": "array", "items": { "type": "string", - "pattern": - "^[a-z]{1,10}$" # Combining length and pattern restrictions + "minLength": 10, + "maxLength": 20 } } }, - "required": ["score", "grade", "email", "tags"] + "required": ["score", "tags"] } diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index b179dc3b4747..19960c13c856 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -13,6 +13,7 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ @@ -63,10 +64,13 @@ def test_structured_output( ): monkeypatch.setenv("VLLM_USE_V1", "1") + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager = bool(not current_platform.is_tpu()) # Use a single LLM instance for several scenarios to # speed up the test suite. llm = LLM(model=model_name, - enforce_eager=True, + enforce_eager=enforce_eager, max_model_len=1024, guided_decoding_backend=guided_decoding_backend, tokenizer_mode=tokenizer_mode) @@ -346,6 +350,7 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) + outputs = llm.generate( prompts="Generate a description of a frog using 50 characters.", sampling_params=sampling_params, @@ -364,6 +369,106 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) + # + # Test 11: Generate structured output using structural_tag format + # + structural_tag_config = { + "type": + "structural_tag", + "structures": [{ + "begin": "", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + } + }, + "end": "" + }], + "triggers": ["{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name + as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query + +You are a helpful assistant. + +Given the previous instructions, what is the weather in New York City? +""" + + # Change this once other backends support structural_tag + if guided_decoding_backend.startswith("xgrammar"): + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + assert outputs is not None + else: + outputs = [] + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + generated_text = output.outputs[0].text + assert generated_text is not None + + # Search for function call pattern in the response + function_call_pattern = r'(.*?)' + matches = re.findall(function_call_pattern, generated_text) + + if not matches: + print(f"Warning: No function calls found in response: " + f"{generated_text!r}") + continue + + # Take the first function call if multiple are found + json_str = matches[0] + try: + json_content = json.loads(json_str) + assert "city" in json_content + assert isinstance(json_content["city"], str) + print(f"Found valid function call: {generated_text!r}") + except (json.JSONDecodeError, AssertionError) as e: + pytest.fail("Invalid function call format: " + f"{generated_text!r}\nError: {str(e)}") + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("model_name, tokenizer_mode", @@ -386,13 +491,21 @@ def test_structured_output_auto_mode( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) + prompts = ("Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts=("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}"), + outputs = llm.generate(prompts=prompts, sampling_params=sampling_params, use_tqdm=True) + # Make sure `auto` backend handling doesn't mess up sampling_params + # and that we can reuse it without error. + outputs.extend( + llm.generate(prompts=prompts, + sampling_params=sampling_params, + use_tqdm=True)) + assert outputs is not None for output in outputs: assert output is not None @@ -404,3 +517,59 @@ def test_structured_output_auto_mode( # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_V1", "1") + + backend = 'guidance:no-additional-properties,disable-any-whitespace' + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", + max_model_len=1024, + guided_decoding_backend=backend) + + schema = { + 'type': 'object', + 'properties': { + 'a1': { + 'type': 'string' + }, + 'a2': { + 'type': 'string' + }, + 'a3': { + 'type': 'string' + } + }, + 'required': ['a1', 'a2', 'a3'], + } + + prompt = ( + "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " + "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "<|im_end|>\n<|im_start|>assistant\n") + + def generate_with_backend(backend): + guided_params = GuidedDecodingParams(json=schema, backend=backend) + sampling_params = SamplingParams(temperature=0, + max_tokens=256, + guided_decoding=guided_params) + + outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + assert outputs is not None + generated_text = outputs[0].outputs[0].text + assert generated_text is not None + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) + return parsed_json + + generated = generate_with_backend( + 'guidance:no-additional-properties,disable-any-whitespace') + assert "a1" in generated + assert "a2" in generated + assert "a3" in generated + assert "a4" not in generated + assert "a5" not in generated + assert "a6" not in generated diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py new file mode 100644 index 000000000000..ed368fe828d0 --- /dev/null +++ b/tests/v1/shutdown/test_delete.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +@pytest.mark.asyncio +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("send_one_request", [False, True]) +async def test_async_llm_delete(model: str, tensor_parallel_size: int, + send_one_request: bool) -> None: + """Test that AsyncLLM frees GPU memory upon deletion. + AsyncLLM always uses an MP client. + + Args: + model: model under test + tensor_parallel_size: degree of tensor parallelism + send_one_request: send one request to engine before deleting + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Instantiate AsyncLLM; make request to complete any deferred + # initialization; then delete instance + async_llm = AsyncLLM.from_engine_args(engine_args) + if send_one_request: + async for _ in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=1, output_kind=RequestOutputKind.DELTA)): + pass + del async_llm + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("send_one_request", [False, True]) +def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, + enable_multiprocessing: bool, + send_one_request: bool) -> None: + """Test that LLM frees GPU memory upon deletion. + TODO(andy) - LLM without multiprocessing. + + Args: + model: model under test + tensor_parallel_size: degree of tensor parallelism + enable_multiprocessing: enable workers in separate process(es) + send_one_request: send one request to engine before deleting + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Instantiate LLM; make request to complete any deferred + # initialization; then delete instance + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + if send_one_request: + llm.generate("Hello my name is", + sampling_params=SamplingParams(max_tokens=1)) + del llm + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py new file mode 100644 index 000000000000..9fedbe4f9a01 --- /dev/null +++ b/tests/v1/shutdown/test_forward_error.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle an Error in model forward and shutdown.""" + +import asyncio + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM, AsyncEngineArgs, SamplingParams +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineDeadError + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +def evil_forward(self, *args, **kwargs): + """Evil forward method that raise an exception after 10 calls.""" + NUMBER_OF_GOOD_PASSES = 10 + + if not hasattr(self, "num_calls"): + self.num_calls = 0 + + if (self.num_calls == NUMBER_OF_GOOD_PASSES + and get_tensor_model_parallel_rank() == 0): + raise Exception("Simulated illegal memory access on Rank 0!") + self.num_calls += 1 + + return self.model(*args, **kwargs) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("model", MODELS) +async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, + model: str) -> None: + """Test that AsyncLLM propagates a forward pass error and frees memory. + + AsyncLLM always uses an MP client. + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + generator = async_llm.generate("Hello my name is", + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should get an EngineDeadError. + for output in outputs: + assert isinstance(output, EngineDeadError) + + # AsyncLLM should be errored. + assert async_llm.errored + + # We should not be able to make another request. + with pytest.raises(EngineDeadError): + async for _ in async_llm.generate("Hello my name is", + request_id="abc", + sampling_params=SamplingParams()): + raise Exception("We should not get here.") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + # NOTE: shutdown is handled by the API Server if an exception + # occurs, so it is expected that we would need to call this. + async_llm.shutdown() + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("model", MODELS) +def test_llm_model_error(monkeypatch, tensor_parallel_size: int, + enable_multiprocessing: bool, model: str) -> None: + """Test that LLM propagates a forward pass error and frees memory. + TODO(andy) - LLM without multiprocessing; LLM with multiprocessing + and >1 rank + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + m.setattr(LlamaForCausalLM, "forward", evil_forward) + + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + with pytest.raises( + EngineDeadError if enable_multiprocessing else Exception): + llm.generate("Hello my name is Robert and I") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py new file mode 100644 index 000000000000..0fe48da475c6 --- /dev/null +++ b/tests/v1/shutdown/test_processor_error.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test error handling in Processor. Should not impact other reqs.""" + +import asyncio + +import pytest + +from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs.data import TokensPrompt +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineGenerateError + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +@pytest.mark.asyncio +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +async def test_async_llm_processor_error(model: str) -> None: + """Test that AsyncLLM propagates a processor error. + Test empty tokens prompt (failure) and non-empty prompt (no failure.) + AsyncLLM always uses an MP client. + """ + engine_args = AsyncEngineArgs(model=model, enforce_eager=True) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + # [] is not allowed and will raise a ValueError in Processor. + generator = async_llm.generate(TokensPrompt([]), + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should have get an EngineGenerateError. + for output in outputs: + with pytest.raises(EngineGenerateError): + raise output + + # AsyncLLM should be errored. + assert not async_llm.errored + + # This should be no problem. + EXPECTED_TOKENS = 5 + outputs = [] + async for out in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=EXPECTED_TOKENS, + output_kind=RequestOutputKind.DELTA)): + outputs.append(out) + + generated_tokens = [] + for out in outputs: + generated_tokens.extend(out.outputs[0].token_ids) + assert len(generated_tokens) == EXPECTED_TOKENS + + async_llm.shutdown() diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py new file mode 100644 index 000000000000..1bba19102ec6 --- /dev/null +++ b/tests/v1/shutdown/test_startup_error.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +def evil_method(self, *args, **kwargs): + """Evil method that raises an exception.""" + + if get_tensor_model_parallel_rank() == 0: + raise Exception("Simulated Error in startup!") + + return self.model(*args, **kwargs, intermediate_tensors=None) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) +def test_async_llm_startup_error(monkeypatch, model: str, + tensor_parallel_size: int, + failing_method: str) -> None: + """Test that AsyncLLM propagates an __init__ error & frees memory. + Test profiling (forward()) and load weights failures. + AsyncLLM always uses an MP client. + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm we get an exception. + with pytest.raises(Exception, match="initialization failed"): + _ = AsyncLLM.from_engine_args(engine_args) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) +def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, + enable_multiprocessing: bool, + failing_method: str) -> None: + """Test that LLM propagates an __init__ error and frees memory. + Test profiling (forward()) and load weights failures. + TODO(andy) - LLM without multiprocessing. + """ + if model != "meta-llama/Llama-3.2-1B": + pytest.skip(reason="Only test meta-llama/Llama-3.2-1B") + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) + + with pytest.raises( + Exception, + match="initialization failed" + if enable_multiprocessing else "Simulated Error in startup!"): + _ = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/utils.py b/tests/v1/shutdown/utils.py new file mode 100644 index 000000000000..8f7c0380d407 --- /dev/null +++ b/tests/v1/shutdown/utils.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shutdown test utils""" + +SHUTDOWN_TEST_TIMEOUT_SEC = 120 +SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30 diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py new file mode 100644 index 000000000000..f577fb4ab329 --- /dev/null +++ b/tests/v1/spec_decode/test_max_len.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test whether spec decoding handles the max model length properly.""" + +import pytest + +from vllm import LLM, SamplingParams + +_PROMPTS = [ + "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", + "Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501 + "Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501 +] + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_ngram_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="facebook/opt-125m", + max_model_len=100, + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": num_speculative_tokens, + }, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_eagle_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "eagle", + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": num_speculative_tokens, + }, + max_model_len=100, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index a81b4897e5d6..50548219fff0 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,6 +2,7 @@ import numpy as np +from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, _find_subarray_kmp, _kmp_lps_array) @@ -39,50 +40,50 @@ def test_find_subarray_kmp(): def test_ngram_proposer(): - proposer = NgramProposer() + + def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: + # Dummy model config. Just to set max_model_len. + model_config = ModelConfig(model="facebook/opt-125m", + task="generate", + max_model_len=100, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False) + return NgramProposer( + vllm_config=VllmConfig(model_config=model_config, + speculative_config=SpeculativeConfig. + from_dict({ + "prompt_lookup_min": min_n, + "prompt_lookup_max": max_n, + "num_speculative_tokens": k, + "method": "ngram", + }))) # No match. - result = proposer.propose( - context_token_ids=np.array([1, 2, 3, 4, 5]), - min_n=2, - max_n=2, - k=2, - ) + result = ngram_proposer( + 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) assert result is None # No match for 4-gram. - result = proposer.propose( - context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), - min_n=4, - max_n=4, - k=2, - ) + result = ngram_proposer( + 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert result is None # No match for 4-gram but match for 3-gram. - result = proposer.propose( - context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), - min_n=3, - max_n=4, - k=2, - ) + result = ngram_proposer( + 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) assert np.array_equal(result, np.array([4, 1])) # Match for both 4-gram and 3-gram. # In this case, the proposer should return the 4-gram match. - result = proposer.propose( - context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), - min_n=3, - max_n=4, - k=2, - ) + result = ngram_proposer(3, 4, 2).propose( + context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] # Match for 2-gram and 3-gram, but not 4-gram. - result = proposer.propose( - context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), - min_n=2, - max_n=4, - k=2, - ) + result = ngram_proposer( + 2, 4, + 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 0929f9901628..1cefe8726df7 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -2,17 +2,13 @@ import pytest -from vllm.v1.structured_output.utils import ( +from vllm.v1.structured_output.backend_xgrammar import ( has_xgrammar_unsupported_json_features) @pytest.fixture def unsupported_string_schemas(): return [ - { - "type": "string", - "pattern": "^[a-zA-Z]+$" - }, { "type": "string", "format": "email" @@ -23,22 +19,6 @@ def unsupported_string_schemas(): @pytest.fixture def unsupported_integer_schemas(): return [ - { - "type": "integer", - "minimum": 0 - }, - { - "type": "integer", - "maximum": 120 - }, - { - "type": "integer", - "exclusiveMinimum": 120 - }, - { - "type": "integer", - "exclusiveMaximum": 120 - }, { "type": "integer", "multipleOf": 120 @@ -49,22 +29,6 @@ def unsupported_integer_schemas(): @pytest.fixture def unsupported_number_schemas(): return [ - { - "type": "number", - "minimum": 0 - }, - { - "type": "number", - "maximum": 120 - }, - { - "type": "number", - "exclusiveMinimum": 120 - }, - { - "type": "number", - "exclusiveMaximum": 120 - }, { "type": "number", "multipleOf": 120 @@ -156,13 +120,28 @@ def supported_schema(): "type": "string", "enum": ["sedan", "suv", "truck"] }, + "car_brand": { + "type": "string", + "pattern": "^[a-zA-Z]+$" + }, "short_description": { "type": "string", "maxLength": 50 }, + "mileage": { + "type": "number", + "minimum": 0, + "maximum": 1000000 + }, + "model_year": { + "type": "integer", + "exclusiveMinimum": 1900, + "exclusiveMaximum": 2100 + }, "long_description": { "type": "string", - "minLength": 50 + "minLength": 50, + "maxLength": 2000 }, "address": { "type": "object", diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index f0e031969e73..ce4c4d198db5 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind): # the engines only synchronize stopping every N steps so # allow a small amount of time here. for _ in range(10): - if core_client.num_engines_running == 0: + if not core_client.engines_running: break await asyncio.sleep(0.5) - assert core_client.num_engines_running == 0 + assert not core_client.engines_running assert not core_client.reqs_in_flight diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index bc0e0cbd85e1..b55018ae8ef0 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from collections import UserDict from dataclasses import dataclass +from typing import Optional +import msgspec import numpy as np import torch +from vllm.multimodal.inputs import (MultiModalBatchedField, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -26,6 +32,7 @@ class MyType: large_f_contig_tensor: torch.Tensor small_non_contig_tensor: torch.Tensor large_non_contig_tensor: torch.Tensor + empty_tensor: torch.Tensor def test_encode_decode(): @@ -41,6 +48,10 @@ def test_encode_decode(): torch.rand((1, 10), dtype=torch.float32), torch.rand((3, 5, 4000), dtype=torch.float64), torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), @@ -48,9 +59,10 @@ def test_encode_decode(): large_f_contig_tensor=torch.rand(1024, 4).t(), small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], + empty_tensor=torch.empty(0), ) - encoder = MsgpackEncoder() + encoder = MsgpackEncoder(size_threshold=256) decoder = MsgpackDecoder(MyType) encoded = encoder.encode(obj) @@ -58,7 +70,7 @@ def test_encode_decode(): # There should be the main buffer + 4 large tensor buffers # + 1 large numpy array. "large" is <= 512 bytes. # The two small tensors are encoded inline. - assert len(encoded) == 6 + assert len(encoded) == 8 decoded: MyType = decoder.decode(encoded) @@ -70,7 +82,7 @@ def test_encode_decode(): encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 6 + assert len(encoded2) == 8 assert encoded2[0] is preallocated decoded2: MyType = decoder.decode(encoded2) @@ -78,6 +90,97 @@ def test_encode_decode(): assert_equal(decoded2, obj) +class MyRequest(msgspec.Struct): + mm: Optional[list[MultiModalKwargs]] + + +def test_multimodal_kwargs(): + d = { + "foo": + torch.zeros(20000, dtype=torch.float16), + "bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)], + "baz": [ + torch.rand((256), dtype=torch.float16), + [ + torch.rand((1, 12), dtype=torch.float32), + torch.rand((3, 5, 7), dtype=torch.float64), + ], [torch.rand((4, 4), dtype=torch.float16)] + ], + } + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest(mm=[MultiModalKwargs(d)]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + assert len(encoded) == 6 + + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) + + # expected total encoding length, should be 44559, +-20 for minor changes + assert total_len >= 44539 and total_len <= 44579 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + assert all(nested_equal(d[k], decoded[k]) for k in d) + + +def test_multimodal_items_by_modality(): + e1 = MultiModalFieldElem("audio", "a0", + torch.zeros(1000, dtype=torch.bfloat16), + MultiModalBatchedField()) + e2 = MultiModalFieldElem( + "video", + "v0", + [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], + MultiModalBatchedField(), + ) + e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, + dtype=torch.int32), + MultiModalSharedField(4)) + e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, + dtype=torch.int32), + MultiModalBatchedField()) + audio = MultiModalKwargsItem.from_elems([e1]) + video = MultiModalKwargsItem.from_elems([e2]) + image = MultiModalKwargsItem.from_elems([e3, e4]) + mm = MultiModalKwargs.from_items([audio, video, image]) + + # pack mm kwargs into a mock request so that it can be decoded properly + req = MyRequest([mm]) + + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(MyRequest) + + encoded = encoder.encode(req) + + assert len(encoded) == 8 + + total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) + + # expected total encoding length, should be 14255, +-20 for minor changes + assert total_len >= 14235 and total_len <= 14275 + decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] + + # check all modalities were recovered and do some basic sanity checks + assert len(decoded.modalities) == 3 + images = decoded.get_items("image") + assert len(images) == 1 + assert len(images[0].items()) == 2 + assert list(images[0].keys()) == ["i0", "i1"] + + # check the tensor contents and layout in the main dict + assert all(nested_equal(mm[k], decoded[k]) for k in mm) + + +def nested_equal(a: NestedTensors, b: NestedTensors): + if isinstance(a, torch.Tensor): + return torch.equal(a, b) + else: + return all(nested_equal(x, y) for x, y in zip(a, b)) + + def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.tensor1, obj2.tensor1) assert obj1.a_string == obj2.a_string @@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType): obj2.small_non_contig_tensor) assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor) + assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 8164952fe382..a4571a554572 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -22,6 +22,7 @@ ] TENSOR_PARALLEL_SIZES = [1] +MAX_NUM_REQS = [16, 1024] # TODO: Enable when CI/CD will have a multi-tpu instance # TENSOR_PARALLEL_SIZES = [1, 4] @@ -32,12 +33,14 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) +@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS) def test_basic( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, model: str, max_tokens: int, tensor_parallel_size: int, + max_num_seqs: int, ) -> None: prompt = "The next numbers of the sequence " + ", ".join( str(i) for i in range(1024)) + " are:" @@ -51,9 +54,9 @@ def test_basic( # Note: max_num_batched_tokens == 1024 is needed here to # actually test chunked prompt max_num_batched_tokens=1024, - max_model_len=8196, + max_model_len=8192, gpu_memory_utilization=0.7, - max_num_seqs=16, + max_num_seqs=max_num_seqs, tensor_parallel_size=tensor_parallel_size) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py new file mode 100644 index 000000000000..eb62e0e4b201 --- /dev/null +++ b/tests/v1/tpu/test_multimodal.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai +import pytest + +from vllm import envs +from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.platforms import current_platform + +from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS +from ...utils import RemoteOpenAIServer + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.fixture(scope="session") +def base64_encoded_image() -> dict[str, str]: + return { + image_url: encode_image_base64(fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) +async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, + str]): + + def whats_in_this_image_msg(b64): + return [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{b64}" + }, + }, + ], + }] + + server_args = [ + "--max-model-len", + "1024", + "--max-num-seqs", + "16", + "--gpu-memory-utilization", + "0.95", + "--trust-remote-code", + "--max-num-batched-tokens", + "576", + # NOTE: max-num-batched-tokens>=mm_item_size + "--disable_chunked_mm_input", + "--chat-template", + "examples/template_llava.jinja" + ] + + # Server will pre-compile on first startup (takes a long time). + with RemoteOpenAIServer(model_name, server_args, + max_wait_seconds=600) as remote_server: + client: openai.AsyncOpenAI = remote_server.get_async_client() + + # Other requests now should be much faster + for image_url in TEST_IMAGE_URLS: + image_base64 = base64_encoded_image[image_url] + chat_completion_from_base64 = await client.chat.completions\ + .create( + model=model_name, + messages=whats_in_this_image_msg(image_base64), + max_completion_tokens=24, + temperature=0.0) + result = chat_completion_from_base64 + assert result + choice = result.choices[0] + assert choice.finish_reason == "length" + + message = choice.message + message = result.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 0147da533517..c6b492b5a3cc 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import random + import pytest from vllm import LLM, envs @@ -39,3 +41,23 @@ def test_sampler_different(model_name: str): # Unsupported `seed` param. sampling_params = SamplingParams(temperature=0.3, seed=42) output2 = llm.generate(prompts, sampling_params) + + # Batch-case with TopK/P + for B in [4, 16]: + p = prompts * B + sampling_params = [ + SamplingParams( + temperature=0.1, + min_p=0.8, + max_tokens=64, + # Vary number of ks + top_k=random.randint(4, 12), + top_p=random.random()) for _ in range(B) + ] + # Make sure first two reqs have the same K/P + sampling_params[0] = sampling_params[1] + output = llm.generate(p, sampling_params) + # There are natural numerical instabilities that make it difficult + # to have deterministic results over many tokens, tests the first ~20 + # tokens match. + assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20] diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index dce0303e68d5..ff9217f8f3ca 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -5,7 +5,8 @@ import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + apply_top_k_top_p_tpu) if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) @@ -16,6 +17,25 @@ TOLERANCE = 1e-6 +def test_topk_equivalence_to_native_impl(): + with torch.device(xm.xla_device()): + xm.set_rng_state(seed=33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + + # Random top-k values between 1 and 10. + k = torch.randint(1, 10, (BATCH_SIZE, )) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), + VOCAB_SIZE) + + result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) + + result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + assert torch.allclose(result_native, result_tpu) + + def test_topp_result_sums_past_p(): with torch.device(xm.xla_device()): xm.set_rng_state(seed=33) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 8ea8c890613a..319b38b4ca09 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - prompt="test", mm_inputs=[], mm_hashes=[], mm_positions=[], @@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner): def test_get_paddings(): + # Bucketed padding min_token_size, max_token_size, padding_gap = 16, 512, 64 expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + + # Bucketed padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 192, 256, 320] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + + # Exponential padding. + max_token_size, padding_gap = 1024, 0 + expected_paddings = [16, 32, 64, 128, 256, 512, 1024] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + # Exponential padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 256, 512] actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 2486c26c6071..915ec2914a82 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int): return CachedRequestState( req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, - prompt=None, sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index dd95a7f53064..68e34cfacc58 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - prompt="test", mm_inputs=[], mm_hashes=[], mm_positions=[], diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7a4c93ad6f7f..4c577c1c47e7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1196,6 +1196,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, ssm_states, pad_slot_id) +# ROCm skinny gemms +def LLMM1(a: torch.Tensor, b: torch.Tensor, + rows_per_block: int) -> torch.Tensor: + return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) + + +def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, cu_count) + + +def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cu_count: int) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), + dtype=out_dtype, + device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + return out + + # moe def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) @@ -1245,6 +1265,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies, gating_output) +def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, moe_block_size: int, + top_k: int, mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.ops._moe_C.moe_wna16_marlin_gemm( + input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace, + sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, + moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m, + size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, + is_zp_float) + + if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") @@ -1263,6 +1306,29 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, dtype=a.dtype, device=a.device) + @register_fake("_moe_C::moe_wna16_marlin_gemm") + def moe_wna16_marlin_gemm_fake(input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, top_k: int, + mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, + size_n: int, size_k: int, is_k_full: bool, + use_atomic_add: bool, use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.empty((size_m * top_k, size_n), + dtype=input.dtype, + device=input.device) + def reshape_and_cache( key: torch.Tensor, @@ -1459,3 +1525,12 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + + +def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, page_table: torch.Tensor, + scale: float) -> torch.Tensor: + torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, + seq_lens, page_table, scale) + return out diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 32b0b86ba36f..133e18b68e25 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Literal +from typing import Literal, Optional import cv2 import numpy as np @@ -10,8 +10,15 @@ from huggingface_hub import hf_hub_download from PIL import Image +from vllm.utils import PlaceholderModule + from .base import get_cache_dir +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + @lru_cache def download_video_asset(filename: str) -> str: @@ -85,3 +92,12 @@ def np_ndarrays(self) -> npt.NDArray: video_path = download_video_asset(self.name) ret = video_to_ndarrays(video_path, self.num_frames) return ret + + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: + """ + Read audio data from the video asset, used in Qwen2.5-Omni examples. + + See also: examples/offline_inference/qwen2_5_omni/only_thinker.py + """ + video_path = download_video_asset(self.name) + return librosa.load(video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 82d60f9da7da..f3d6ffaeb8f4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -77,6 +77,10 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: raise NotImplementedError + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + raise NotImplementedError + @staticmethod @abstractmethod def swap_blocks( @@ -237,6 +241,7 @@ class AttentionLayer(Protocol): _v_scale: torch.Tensor _k_scale_float: float _v_scale_float: float + _prob_scale: torch.Tensor def forward( self, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f9c5ad4df54e..7f8f720eee0a 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -22,13 +22,13 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -689,7 +689,7 @@ def forward( assert output is not None, "Output tensor must be provided." # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: + if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: assert ( layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( "key/v_scale is only supported in FlashAttention 3 with " diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 09717a1121d0..d92177d58a48 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import dataclasses +import os from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -37,7 +38,7 @@ is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -48,6 +49,9 @@ from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) +FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", + "NHD").upper() + class FlashInferBackend(AttentionBackend): @@ -80,6 +84,14 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + cache_layout = FLASHINFER_KV_CACHE_LAYOUT + assert (cache_layout in ("NHD", "HND")) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, + 2, 4) + return stride_order + @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, @@ -128,12 +140,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) @@ -187,7 +197,8 @@ def __init__(self, runner): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = self.runner.vllm_config + self._kv_cache_layout = None def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -197,10 +208,15 @@ def _get_workspace_buffer(self): device=self.runner.device) return self._workspace_buffer + def get_kv_cache_layout(self): + if self._kv_cache_layout is None: + self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT + return self._kv_cache_layout + def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") + self._get_workspace_buffer(), self.get_kv_cache_layout()) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -213,7 +229,7 @@ def _get_decode_wrapper(self): num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), - "NHD", + self.get_kv_cache_layout(), use_tensor_cores=use_tensor_cores) return self._decode_wrapper @@ -274,7 +290,8 @@ def graph_capture_get_metadata_for_batch( self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, "NHD", + self._graph_indices_buffer, _last_page_len_buffer, + self.get_kv_cache_layout(), use_tensor_cores) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -613,7 +630,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = self.runner.vllm_config def prepare(self): self.slot_mapping: List[int] = [] @@ -1005,6 +1022,7 @@ def forward( prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None + stride_order = FlashInferBackend.get_kv_cache_stride_order() if prefill_meta := attn_metadata.prefill_metadata: # We will use flash attention for prefill # when kv_cache is not provided. @@ -1036,7 +1054,7 @@ def forward( prefill_output = prefill_meta.prefill_wrapper.run( query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, ) @@ -1051,7 +1069,7 @@ def forward( decode_output = decode_meta.decode_wrapper.run( decode_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, ) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 15625612e08e..55a63a81677f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -4,14 +4,14 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### -import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch +import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, - VLLMKVCache) +from vllm_hpu_extension.flags import enabled_flags +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, @@ -126,7 +126,15 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() - ops.pa_impl = ops.pa + self.fused_scaled_dot_product_attention = kernels.fsdpa() + + self.prefill_impl = 'naive' + if "flex_attention" in enabled_flags(): + self.prefill_impl = 'flex' + if "fsdpa" in enabled_flags(): + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + self.prefill_impl = 'fsdpa' self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window @@ -138,19 +146,9 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '0').lower() in ['1', 'true'] - self.fused_scaled_dot_product_attention = None - if self.prefill_usefusedsdpa: + if self.prefill_impl == 'fsdpa': assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' - try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - self.fused_scaled_dot_product_attention = ModuleFusedSDPA( - FusedSDPA) - except ImportError: - logger.warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in supported_head_sizes: @@ -158,7 +156,8 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if attn_type != AttentionType.DECODER: + self.attn_type = attn_type + if self.attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " @@ -192,15 +191,18 @@ def forward( batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape - query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) block_indices = attn_metadata.block_indices block_offsets = attn_metadata.block_offsets - if attn_metadata.is_prompt: + key_cache = None + value_cache = None + if attn_metadata.is_prompt and self.attn_type \ + is not AttentionType.ENCODER_ONLY \ + and attn_metadata.block_list is None: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: + if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -214,36 +216,28 @@ def forward( if attn_metadata.is_prompt: # Prompt run. - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ - 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None: - position_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, - attn_bias.dtype, - attn_bias.shape[-1]) - attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) - else: - attn_bias = None - query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) + + attn_bias = attn_metadata.attn_bias + if attn_bias is not None and self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), + impl=self.prefill_impl, + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + is_causal=True, attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - fsdpa_op=self.fused_scaled_dot_product_attention, - ) + valid_seq_lengths=attn_metadata.seq_lens_tensor, + **self.common_attention_args()) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. @@ -254,18 +248,26 @@ def forward( block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, - block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - matmul_av_op=self.matmul_av, - batch2block_matmul_op=self.batch2block_matmul, - block2batch_matmul_op=self.block2batch_matmul, - keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + **self.common_attention_args()) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) + def common_attention_args(self): + fsdpa_op = self.fused_scaled_dot_product_attention.apply \ + if self.fused_scaled_dot_product_attention is not None else None + return { + 'scale': self.scale, + 'matmul_qk_op': self.matmul_qk, + 'matmul_av_op': self.matmul_av, + 'batch2block_matmul_op': self.batch2block_matmul, + 'block2batch_matmul_op': self.block2batch_matmul, + 'fsdpa_op': fsdpa_op, + 'keys_fetch_func': self.k_cache.fetch_from_cache, + 'values_fetch_func': self.v_cache.fetch_from_cache, + 'softmax_op': self.softmax, + } + def _make_alibi_bias( alibi_slopes: torch.Tensor, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 99917a92af5f..27959caa651a 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -220,8 +220,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer._k_scale_float, + layer._v_scale_float, ) if attn_metadata.is_prompt: @@ -306,8 +306,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer._k_scale_float, + layer._v_scale_float, ) else: # Run PagedAttention V2. @@ -339,8 +339,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + layer._k_scale_float, + layer._v_scale_float, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 54278f5f608e..382a9a6d44d8 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -205,6 +205,7 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -214,7 +215,6 @@ from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version if HAS_TRITON: from vllm.attention.ops.triton_flash_attention import triton_attention @@ -711,12 +711,24 @@ def advance_step(self, self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) + self._ops_advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions) + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + # here we use advance_step_flashinfo to update the paged_kv_* tensors ops.advance_step_flashattn(num_seqs=num_seqs, num_queries=num_queries, block_size=block_size, - input_tokens=model_input.input_tokens, + input_tokens=input_tokens, sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, + input_positions=input_positions, seq_lens=self.seq_lens_tensor, slot_mapping=self.slot_mapping, block_tables=self.block_tables) @@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + BLOCK_TABLE_EXTENDER: list[list[int]] = [] def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder @@ -877,8 +890,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) + self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * + cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: @@ -1043,8 +1058,8 @@ def __init__( self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj - self.triton_fa_func = triton_attention + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 @@ -1055,6 +1070,77 @@ def __init__( functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, + return_softmax_lse, **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ + and not return_softmax_lse: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + if is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + assert rest is not None + return attn_out, rest[0] + return attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -1176,40 +1262,19 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_vllm_fa: - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) if output is None: output = attn_output @@ -1252,58 +1317,22 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: - output = self.triton_fa_func( - q, - k, - v_padded, - None, - prefill_metadata.query_start_loc, - prefill_metadata.query_start_loc, - prefill_metadata.max_prefill_seq_len, - prefill_metadata.max_prefill_seq_len, - True, # causal - self.scale, - None, # attn_mask is None unless applying ALiBi mask - ) - ## triton flash attention always return 2 objects - if not has_context: - output = output[0] - elif is_vllm_fa: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - else: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_attn_probs=has_context, - ) + output = self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) if has_context: # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse, *rest = output + suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) @@ -1316,12 +1345,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py new file mode 100644 index 000000000000..6e695b78e0e1 --- /dev/null +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type, Union + +import torch + +import vllm._custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import (compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, + get_aiter_mla_metadata) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA" + + @staticmethod + def get_impl_cls() -> Type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AiterMLAState"]: + return AiterMLAState + + +@dataclass +class AiterMLAMetadata(MLACommonMetadata): + # The following 4 tensors are for current version of AITER MLA + block_table_bound: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self): + prefill_metadata = super().prefill_metadata + self._cached_prefill_metadata = prefill_metadata + + if prefill_metadata is not None: + prefill_metadata.paged_kv_indptr = self.paged_kv_indptr + prefill_metadata.paged_kv_indices = self.paged_kv_indices + prefill_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + prefill_metadata.block_table_bound = self.block_table_bound + + # update the cache + self._cached_prefill_metadata = self.__class__( + **prefill_metadata.__dict__) + + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + + self._cached_decode_metadata = decode_metadata + + if decode_metadata is not None: + decode_metadata.paged_kv_indptr = self.paged_kv_indptr + decode_metadata.paged_kv_indices = self.paged_kv_indices + decode_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + decode_metadata.block_table_bound = self.block_table_bound + + # update the cache + self._cached_decode_metadata = self.__class__( + **decode_metadata.__dict__) + + return self._cached_decode_metadata + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_lens=self.paged_kv_last_page_lens, + block_table_bound=self.block_table_bound) + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + super().__init__(input_builder) + assert self.runner.model_config.max_model_len == 32768,\ + "AITER MLA requires max model len to be set to 32768" + assert self.block_size == 1, "AITER MLA requires only block size 1." + + def prepare(self): + super().prepare() + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + if is_profile_run: + return + + # Update paged_kv_* tensors only for non-profile run + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def build(self, seq_lens: list[int], query_lens: list[int], + cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: + metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + if use_captured_graph: + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + + # For current version of AITER MLA + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_kv_last_page_lens_tensor = torch.tensor( + self.paged_kv_last_page_lens, device=device, dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device=device, + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_lens_tensor = None + block_table_bound_tensor = None + + metadata.paged_kv_indptr = paged_kv_indptr_tensor + metadata.paged_kv_indices = paged_kv_indices_tensor + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor + metadata.block_table_bound = block_table_bound_tensor + + return metadata + + +class AiterMLAState(MLACommonState[AiterMLAMetadata]): + + @contextmanager + def graph_capture(self, max_batch_size: int): + kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=self.runner.get_max_block_per_batch(), + device=self.runner.device) + self._paged_kv_indices_tensor = kv_indices + self._paged_kv_indptr_tensor = kv_indptr + self._paged_kv_last_page_lens_tensor = last_page_lens + + with super().graph_capture(max_batch_size): + yield + + del self._paged_kv_indices_tensor + del self._paged_kv_indptr_tensor + del self._paged_kv_last_page_lens_tensor + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: + + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + + paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] + paged_kv_indices = self._paged_kv_indices_tensor + paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: + batch_size] + + metadata.paged_kv_indptr = paged_kv_indptr + metadata.paged_kv_indices = paged_kv_indices + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + + return metadata + + def get_graph_input_buffers(self, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers[ + 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr + input_buffers[ + "paged_kv_indices"] = attn_metadata.\ + decode_metadata.paged_kv_indices + input_buffers[ + "paged_kv_last_page_lens"] = attn_metadata.\ + decode_metadata.paged_kv_last_page_lens + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ + 0] + input_buffers["paged_kv_indptr"].copy_( + attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) + input_buffers["paged_kv_indices"][:num_total_blocks].copy_( + attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) + input_buffers["paged_kv_last_page_lens"].copy_( + attn_metadata.decode_metadata.paged_kv_last_page_lens, + non_blocking=True) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + softmax_scale: float, return_softmax_lse: bool, + **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7376f9303788..8076c4791d3c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -2,6 +2,7 @@ """Attention layer ROCm GPUs.""" import itertools from dataclasses import dataclass +from functools import cache from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -26,7 +27,34 @@ _PARTITION_SIZE_ROCM = 256 +@cache +def is_rocm_aiter_paged_attn_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ + and envs.VLLM_ROCM_USE_AITER \ + + +@cache +def _get_paged_attn_module() -> PagedAttention: + """ + Initializes the appropriate PagedAttention module from `attention/ops`, + which is used as helper function + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. + + The choice of attention module depends on whether + AITER paged attention is enabled: + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. + - Otherwise, it defaults to using the original `PagedAttention`. + """ + if is_rocm_aiter_paged_attn_enabled(): + # Import AITERPagedAttention only when the flag is enabled + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AITERPagedAttention) + return AITERPagedAttention() + return PagedAttention() + + class ROCmFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True @staticmethod def get_name() -> str: @@ -55,8 +83,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + paged_attn = _get_paged_attn_module() + return paged_attn.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -64,14 +93,16 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + paged_attn = _get_paged_attn_module() + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + paged_attn = _get_paged_attn_module() + paged_attn.copy_blocks(kv_caches, src_to_dists) @dataclass @@ -495,7 +526,10 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - supported_head_sizes = PagedAttention.get_supported_head_sizes() + self.paged_attn_module = _get_paged_attn_module() + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( + ) + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " @@ -515,7 +549,7 @@ def __init__( from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) - self.attn_func = triton_attention + self.triton_attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") if self.sliding_window != (-1, -1): logger.warning("ROCm Triton FA does not currently support " @@ -531,7 +565,7 @@ def __init__( else: try: from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func + self.fa_attn_func = flash_attn_varlen_func logger.debug("Using CK FA in ROCmBackend") except ModuleNotFoundError: self.use_naive_attn = True @@ -542,9 +576,11 @@ def __init__( "ROCm Naive FlashAttention does not support " "attention logits soft capping.") - self.attn_func = _sdpa_attention + self.sdpa_attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") + self.aiter_kv_scales_initialized = False + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape @@ -613,6 +649,8 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert output is not None, "Output tensor must be provided." + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -621,12 +659,37 @@ def forward( else: assert value is None + paged_attn = self.paged_attn_module + + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 + and not self.aiter_kv_scales_initialized + and kv_cache.shape != torch.Size([0])): + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + # Only update KV cache for decoder self-attention # and encoder-decoder cross-attention if self.attn_type not in [ AttentionType.ENCODER, AttentionType.ENCODER_ONLY ] and kv_cache.numel() > 0: - key_cache, value_cache = PagedAttention.split_kv_cache( + key_cache, value_cache = paged_attn.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) if key is not None and value is not None: @@ -634,7 +697,7 @@ def forward( # cache. If kv_cache is not provided, the new key and value # tensors are not cached. This happens during the initial # memory profiling run. - PagedAttention.write_to_paged_cache( + paged_attn.write_to_paged_cache( key, value, key_cache, @@ -656,7 +719,6 @@ def forward( assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens - output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -704,11 +766,17 @@ def forward( query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore - out, _ = self.attn_func( + use_fp8_scales = (layer._q_scale and layer._k_scale + and layer._v_scale and layer._prob_scale + and self.kv_cache_dtype == "fp8") + full_scales = ( + layer._q_scale, layer._k_scale, layer._v_scale, + layer._prob_scale) if use_fp8_scales else None + self.triton_attn_func( query, key, value, - None, + output[:num_prefill_tokens], query_seq_start_loc, key_seq_start_loc, query_max_seq_len, @@ -717,6 +785,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: @@ -733,10 +802,11 @@ def forward( key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) # sdpa math backend attention - out = self.attn_func( + self.sdpa_attn_func( query, key, value, + output[:num_prefill_tokens], query_seq_start_loc, num_prefill_tokens, self.num_heads, @@ -745,7 +815,8 @@ def forward( attn_masks, ) else: - out = self.attn_func( + # upstream FA does not support an output arg, copy + output[:num_prefill_tokens] = self.fa_attn_func( q=query, k=key, v=value, @@ -760,33 +831,26 @@ def forward( softcap=self.logits_soft_cap, ) - # common code for prefill - assert output[:num_prefill_tokens].shape == out.shape - if output.shape[0] > num_prefill_tokens: - output[:num_prefill_tokens] = out - else: - output = out else: # prefix-enabled attention - # not applicable for encoder-only models if self.attn_type != AttentionType.ENCODER_ONLY: - output[: - num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) + output[:num_prefill_tokens] = paged_attn.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) # Skip decode phase for encoder-only models if (decode_meta := attn_metadata.decode_metadata) and ( self.attn_type != AttentionType.ENCODER_ONLY): @@ -818,14 +882,10 @@ def forward( device=output.device, ) max_logits = torch.empty_like(exp_sums) - if num_prefill_tokens > 0: - out = output[num_prefill_tokens:] - else: - out = output query_start_loc = None ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -849,7 +909,7 @@ def forward( layer._v_scale, ) else: - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_tokens:] = paged_attn.forward_decode( decode_query, key_cache, value_cache, @@ -878,7 +938,8 @@ def _sdpa_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - seq_lens: List[int], + output: torch.Tensor, + seq_lens: torch.Tensor, num_tokens: int, num_heads: int, head_size: int, @@ -886,9 +947,9 @@ def _sdpa_attention( attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 - output = torch.empty((num_tokens, num_heads, head_size), - dtype=query.dtype, - device=query.device) + assert output.shape == (num_tokens, num_heads, head_size) + assert output.dtype == query.dtype + assert output.device == query.device for i, seq_len in enumerate(seq_lens): end = start + seq_len diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index b4413c36b64a..89f1ea9b8a57 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -2,8 +2,10 @@ """Attention backend utils""" from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -11,6 +13,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) + + +@dataclass +class MLADims: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dbf4723ee1bd..aa218cc37af9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -10,6 +10,9 @@ from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( @@ -87,6 +90,7 @@ def __init__( # FlashAttn doesn't support quantizing the kv-cache only # but requires q to be quantized as well. self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) @@ -329,17 +333,54 @@ def forward( return out.reshape(bsz, q_len, -1) +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + + def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, layer_name: str, ) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output def unified_attention_fake( @@ -367,6 +408,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] @@ -379,6 +421,8 @@ def unified_attention_with_output( attn_metadata, output=output) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + def unified_attention_with_output_fake( query: torch.Tensor, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 49ea420d092c..1dedd2ffc5fa 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata: block_usage: Optional[torch.Tensor] block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] - block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index e0478c2aebda..a8c8d8409620 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -16,831 +16,778 @@ # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) -if triton.__version__ >= "2.1.0": - - @triton.jit - def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], - dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) # [N] - # [D,N] - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - # [N,D] - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] - - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) - < SLIDING_WINDOW, qk, -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len)) + +# Here's an example autotuner config for this kernel. This config does provide +# a performance improvement, but dramatically increases first call latency in +# triton 3.2. Because of this tradeoff, it's currently commented out. +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ +# "num_unroll_cache": 4, \ +# "num_unroll_request": 1 } | \ +# ({"kpack": 2, "waves_per_eu": 2} \ +# if current_platform.is_rocm() else {}), \ +# num_warps=4, \ +# num_stages=1) +# ], +# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +# ) +@triton.jit +def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - @triton.jit - def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] - < cur_batch_seq_len - cur_batch_ctx_len, + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - @triton.jit - def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - k_scale, - v_scale, - B_Start_Loc, - B_Seqlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SKIP_DECODE: tl.constexpr, - ): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) - cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len - - if SKIP_DECODE and cur_batch_query_len == 1: - return - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) - else: - k = k_load - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) - else: - v = v_load - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - p = p.to(v.dtype) - - acc = tl.dot(p, v, acc=acc, input_precision='ieee') - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: return - @torch.inference_mode() - def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False): - - q_dtype_is_f32 = q.dtype is torch.float32 + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK - - # Turing does have tensor core for float32 multiplication - # use ieee as fallback for triton kernels work. There is also - # warning on vllm/config.py to inform users this fallback - # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None - - # Conversion of FP8 Tensor from uint8 storage to - # appropriate torch.dtype for interpretation by Triton - if "fp8" in kv_cache_dtype: - assert (k_cache.dtype == torch.uint8) - assert (v_cache.dtype == torch.uint8) - - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = current_platform.fp8_dtype() - elif kv_cache_dtype == "fp8_e5m2": - target_dtype = torch.float8_e5m2 - else: - raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) - - k_cache = k_cache.view(target_dtype) - v_cache = v_cache.view(target_dtype) - - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - if sm_scale is None: - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - assert batch + 1 == len(b_start_loc) - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - if alibi_slopes is not None: - _fwd_kernel_alibi[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - k_scale, - v_scale, - b_start_loc, - b_seq_len, - alibi_slopes, - v_cache.shape[3], - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - IN_PRECISION=IN_PRECISION, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, - SKIP_DECODE=skip_decode, - num_warps=NUM_WARPS, - num_stages=1, - ) - return - - _fwd_kernel[grid]( + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + _fwd_kernel_alibi[grid]( q, k, v, @@ -852,6 +799,7 @@ def context_attention_fwd(q, v_scale, b_start_loc, b_seq_len, + alibi_slopes, v_cache.shape[3], k_cache.shape[4], o, @@ -886,9 +834,69 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + extra_kargs = {} + if current_platform.is_rocm(): + extra_kargs = {"kpack": 2, "waves_per_eu": 2} + + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + BLOCK_M=128, + BLOCK_N=64, + num_unroll_cache=4, + num_unroll_request=1, + num_warps=4, + num_stages=1, + **extra_kargs) + return diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py new file mode 100644 index 000000000000..1c90f8c19b09 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + + +def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + max_block_per_batch: int, + device: torch.device) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, + dtype=torch.int32, + device=device) + paged_kv_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens + + +def aiter_mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + logit_cap: float = 0.0, +): + from aiter.mla import mla_decode_fwd + + mla_decode_fwd(q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 000000000000..0f3cf1842c80 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.platforms import current_platform +from vllm.utils import cdiv + +FP8_DTYPE = current_platform.fp8_dtype() + + +class AITERPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + kv_cache_torch_dtype = (FP8_DTYPE + if "fp8" in kv_cache_dtype else torch.int8) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = cdiv(max_seq_len, block_size) + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 40daec3ec124..35ee0835f42a 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -39,11 +39,12 @@ logger = logging.getLogger(__name__) -# TODO: Remove this when triton>=3.2.0. This issue will not affect performance -# and accuracy. -logger.warning( - "The following error message 'operation scheduled before its operands' " - "can be ignored.") +# Only print the following warnings when triton version < 3.2.0. +# The issue won't affect performance or accuracy. +if triton.__version__ < '3.2.0': + logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") @triton.jit diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 745818eb6cff..e98b5254541b 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -1,31 +1,237 @@ -#!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 """ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf -Features supported: +Credits: +AMD Triton kernels team +OpenAI kernel team -1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: +Currently only the forward kernel is supported, and contains these features: -1) Non power of two head dims +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias +7) FP8 support """ +from typing import Optional + import torch import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] + +default_eight_bit_dtype_triton = tl.float8e4b8 +default_eight_bit_dtype_torch = current_platform.fp8_dtype() +default_float8_info = torch.finfo(default_eight_bit_dtype_torch) + +FP8_MIN = triton.language.constexpr(default_float8_info.min) + +# According to https://github.com/vllm-project/vllm/blob/main +# /csrc/quantization/utils.cuh#L31, +# need to make the max for the uz datatype be 224.0 for accuracy reasons. +FP8_MAX = triton.language.constexpr( + default_float8_info.max if default_eight_bit_dtype_torch != + torch.float8_e4m3fnuz else 224.0) + + +class MetaData: + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + eight_bit = False + layout = None + return_encoded_softmax = False + eight_bit_dtype_triton = default_eight_bit_dtype_triton + eight_bit_dtype_torch = default_eight_bit_dtype_torch + output_dtype = None + + # Note about layouts: + # + # thd - [num_tokens, num_heads, head_size] + # bshd - [batch_size, seq_len, num_heads, head_size] + # bhsd - [batch_size, num_heads, seq_len, head_size] + # + # This is for each tensor, all tensors must have same layout. + # Q can have num_heads and seq_len differ from from K and V, + # however K and V must agree on this. + # + # Notes about varlen and bias: + # Only one or the other is implemented, meaning can't combine + # both varlen and bias right now. + # + # Note about quantization: + # Only 8-bit quantization supported (for now) and specifically fp8. + # Scales must be tensors. + # o_scale: This is 'output scaling', but comes from parameter called + # 'input_scale', this is applied to the output from the kernel. + # o_scale should be None if none of the other quantization parameters + # are used. + # + # NOTE: Object is in a tentatively good state after initialized, however, + # to verify, call check_args(q,k,v,o) where o is the output tensor. + def __init__( + self, + sm_scale=1.0, + layout=None, # layout can be 'bshd', 'bhsd', or 'thd' + output_dtype=None, + max_seqlens_q=0, + max_seqlens_k=0, + # varlen params + cu_seqlens_q=None, # only 'thd' layout supported for varlen + cu_seqlens_k=None, + # quant params + q_descale=None, + k_descale=None, + v_descale=None, + p_scale=None, + o_scale=None, + # bias params + bias=None, # varlen not implemented for bias + seqlen_q=None, + seqlen_k=None, + # alibi params + alibi_slopes=None, + alibi_batch=None, + alibi_nheads=None, + # causal + causal=None, + ): + self.sm_scale = sm_scale + self.output_dtype = output_dtype + self.max_seqlens_q = max_seqlens_q + self.max_seqlens_k = max_seqlens_k + self.layout = layout + if cu_seqlens_q is not None or cu_seqlens_k is not None: + assert cu_seqlens_q is not None and cu_seqlens_k is not None + assert layout is None or layout not in [ + 'bshd', 'bhsd' + ], "Varlen only implemented for thd layout" + self.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale] + if any(x is not None for x in quant_params): + p_descale = 1.0 / p_scale if p_scale is not None else None + self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale, + p_descale, o_scale) + if bias is not None: + self.need_bias(bias, seqlen_q, seqlen_k) + if alibi_slopes is not None: + self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads) + if causal is not None and causal: + self.need_causal() + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max( + cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), + self.max_seqlens_q) + self.max_seqlens_k = max( + cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), + self.max_seqlens_k) + + def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale, + p_descale, o_scale): + self.eight_bit = True + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + self.p_scale = p_scale + self.p_descale = p_descale + self.o_scale = o_scale + self.use_p_scale = (p_scale is not None) and ( + p_descale is not None) and (v_descale is not None) + self.eight_bit_kv = ((q_descale is None) and (k_descale is not None) + and (v_descale is not None)) + self.eight_bit_dtype_torch = default_eight_bit_dtype_torch + + def need_bias(self, bias, seqlen_q, seqlen_k): + assert bias is not None + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout( + q, k, self) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + if self.eight_bit: + if self.eight_bit_kv: + assert (v.dtype == k.dtype + and k.dtype == self.eight_bit_dtype_torch) + assert q.dtype != k.dtype + assert (self.v_descale is not None) and (self.k_descale + is not None) + else: + assert (q.dtype == k.dtype and q.dtype == v.dtype + and q.dtype == self.eight_bit_dtype_torch) + assert (self.q_descale + is not None) and (self.k_descale + is not None) and (self.v_descale + is not None) + if self.use_p_scale: + assert (self.p_scale is not None) and (self.p_descale + is not None) + else: + assert (q.dtype == k.dtype) and (q.dtype == v.dtype) + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen @triton.jit @@ -38,40 +244,85 @@ def max_fn(x, y): return tl.math.max(x, y) +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] +def masked_load(ptrs, offset_first, offset_second, boundary_first, + boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor @triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) +def compute_alibi_block(alibi_slope, + seqlen_q, + seqlen_k, + offs_m, + offs_n, + transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to + # the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is + # masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that + # spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, + # offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = + # [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q - + offs_n[None, :]) + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) - rng_keep = rng_output > dropout_p - return rng_keep +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, + device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, + device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - + k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze( + -1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) - else: - tensor = tl.load(block_ptr) - return tensor +def quant_fp8(x, scale): + x *= scale + x = tl.clamp(x, FP8_MIN, FP8_MAX) + return x @triton.jit @@ -80,58 +331,68 @@ def _attn_fwd_inner( l_i, m_i, q, - K_block_ptr, - V_block_ptr, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, start_m, actual_seqlen_k, - dropout_p, + actual_seqlen_q, philox_seed, batch_philox_offset, - encoded_softmax_block_ptr, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, - bias_ptr, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, + SHOULD_PRE_LOAD_V: tl.constexpr, + SHOULD_MASK_STEPS: tl.constexpr, + SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_PADDED_HEAD: tl.constexpr, + IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, + QK_SCALE: tl.constexpr, + IS_EIGHT_BIT_GEMM: tl.constexpr, + USE_P_SCALE: tl.constexpr, + IS_EIGHT_BIT_KV: tl.constexpr, + QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, ): + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + k_offs_n = start_n + tl.arange(0, + BLOCK_N) if SHOULD_MASK_STEPS else None + k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, + actual_seqlen_k) + if SHOULD_PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + IS_ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 + if SHOULD_MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not + # is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], @@ -144,167 +405,276 @@ def _attn_fwd_inner( causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 + if IS_EIGHT_BIT_GEMM: + qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * + QK_SCALE) + else: + if IS_EIGHT_BIT_KV: + k = (k * k_descale).to(q.type.element_ty) + qk += (tl.dot(q, k) * QK_SCALE) + + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange( + 0, BLOCK_N) if SHOULD_MASK_STEPS else None + bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, + actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an + # additional scale factor of log2(e) which we must also multiply + # the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, + actual_seqlen_k, + global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) + if SHOULD_RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + if not SHOULD_PRE_LOAD_V: + v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + IS_ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + + if IS_EIGHT_BIT_GEMM: + if USE_P_SCALE: + p = quant_fp8(p, p_scale).to(QUANT_DTYPE) + acc += tl.dot(p, v) + else: + # v is in eight_bit but p is not, we want the gemm in p's type + acc += tl.dot(p, v.to(p.type.element_ty)) + else: + if IS_EIGHT_BIT_KV: + v = (v * v_descale).to(p.type.element_ty) + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_cdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=4), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def get_rdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), + num_warps=2), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + num_warps=2), + # Fall-back config. triton.Config( { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 1, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 }, num_stages=1, - num_warps=4, - ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + num_warps=2), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def get_general_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def has_cdna_target(): + ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"] + return triton.runtime.driver.active.get_current_target( + ).arch in ROCM_CDNA_TARGETS + + +def is_rocm_cdna(): + return current_platform.is_rocm() and has_cdna_target() + + +def get_autotune_configs(): + if is_rocm_cdna(): + return get_cdna_autotune_configs() + elif current_platform.is_rocm(): + return get_rdna_autotune_configs() + else: + return get_general_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -312,38 +682,53 @@ def attn_fwd( K, V, bias, - sm_scale, + SM_SCALE: tl.constexpr, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, + stride_az: tl.int64, + stride_ah: tl.int64, + q_descale_ptr, + k_descale_ptr, + p_scale_ptr, + p_descale_ptr, + o_descale_ptr, + v_descale_ptr, + q_descale_has_singleton: tl.constexpr, + k_descale_has_singleton: tl.constexpr, + p_descale_has_singleton: tl.constexpr, + v_descale_has_singleton: tl.constexpr, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, + NUM_CU: tl.constexpr, + GRID_CU_MULTIP: tl.constexpr, + B: tl.constexpr, philox_offset_base, encoded_softmax, + alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, + IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, @@ -351,24 +736,39 @@ def attn_fwd( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, + SHOULD_PRE_LOAD_V: tl.constexpr, + USE_BIAS: tl.constexpr, + SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_ALIBI: tl.constexpr, + IS_EIGHT_BIT: tl.constexpr, + USE_P_SCALE: tl.constexpr, + IS_EIGHT_BIT_KV: tl.constexpr, + QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, ): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + + if o_descale_ptr is not None: + o_descale = tl.load(o_descale_ptr) + + start_m: tl.int64 = tl.program_id(0) + off_h_q: tl.int64 = tl.program_id(1) + off_z: tl.int64 = tl.program_id(2) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64) + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64) + + # as we can't have return statements inside while loop in Triton + continue_condition = True + if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. + # We have a one-size-fits-all grid in id(0). Some seqlens might be + # too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: - return + continue_condition = False + # return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start @@ -378,444 +778,598 @@ def attn_fwd( seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL - - # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) - - # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) - - -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - assert head_size <= 256 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 + if continue_condition: + # Now we compute whether we need to exit early due to causal + # masking. This is because for seqlen_q > seqlen_k, M rows of the + # attn scores are completely masked, resulting in 0s written to the + # output, and inf written to LSE. We don't need to do any GEMMs in + # this case. This block of code determines what N is, and if this + # WG is operating on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which + # means the causal mask boundary is bottom right aligned, and + # ends at either the top edge (seqlen_q < seqlen_k) or left + # edge. This captures the decrease in n_blocks if we have a + # rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all + # n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this + # WG is part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( + [BLOCK_M, BLOCK_DMODEL]) + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as + # that is statically known. + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this from qk which makes it -inf, such that + # exp(qk - inf) = 0 for these masked blocks. + l_value = tl.full([BLOCK_M], + value=float("inf"), + dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l_value, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be + # handled here too? + continue_condition = False + # return + + if continue_condition: + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL + != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + q_ptrs = (q_offset + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk) + k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + k_ptrs = (k_offset + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn) + v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + v_ptrs = (v_offset + offs_n[:, None] * stride_vk + + offs_d[None, :] * stride_vn) + # Compute pointers for all scale tensors used in this kernel. + + IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( + not IS_EIGHT_BIT_KV) + if IS_EIGHT_BIT: + if k_descale_has_singleton: + k_descale_ptrs = k_descale_ptr + else: + k_descale_ptrs = k_descale_ptr + off_h_k + + if v_descale_has_singleton: + v_descale_ptrs = v_descale_ptr + else: + v_descale_ptrs = v_descale_ptr + off_h_k + + if not IS_EIGHT_BIT_KV: + if q_descale_has_singleton: + q_descale_ptrs = q_descale_ptr + else: + q_descale_ptrs = q_descale_ptr + off_h_q + if USE_P_SCALE: + if p_descale_has_singleton: + p_scale_ptrs = p_scale_ptr + p_descale_ptrs = p_descale_ptr + else: + p_scale_ptrs = p_scale_ptr + off_h_q + p_descale_ptrs = p_descale_ptr + off_h_q + + if USE_BIAS: + bias_offset = off_h_q * stride_bh + bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm + + offs_n[None, :] * stride_bn) + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + batch_philox_offset = 0 + # We can ask to return the dropout mask without doing any + # dropout. In this case, we return an invalid pointer so + # indicate the mask is not valid. + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_base = (encoded_softmax + + off_h_q * seqlen_q * seqlen_k) + encoded_sm_ptrs = (encoded_sm_base + + offs_m[:, None] * seqlen_k + + offs_n[None, :]) + else: + encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do + # not have native e^x support in HW. + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if USE_PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] + < IS_ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + if IS_EIGHT_BIT: + k_descale = tl.load(k_descale_ptrs) + v_descale = tl.load(v_descale_ptrs) + q_descale = None if IS_EIGHT_BIT_KV else tl.load( + q_descale_ptrs) + if USE_P_SCALE: + p_scale = tl.load(p_scale_ptrs) + p_descale = tl.load(p_descale_ptrs) + else: + p_scale = None + p_descale = None + else: + q_descale = None + k_descale = None + v_descale = None + p_scale = None + p_descale = None + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked + # blocks. Additionally there might be one more due to + # dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an + # additional block. In this case we might exceed n_blocks so + # pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false + # regardless of its actual value because there is no masking. + # Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, SHOULD_MASK_STEPS, ... + SHOULD_PRE_LOAD_V, + False, + SHOULD_RETURN_ENCODED_SOFTMAX, + USE_PADDED_HEAD, + IS_ACTUAL_BLOCK_DMODEL, + QK_SCALE, + IS_EIGHT_BIT_GEMM, + USE_P_SCALE, + IS_EIGHT_BIT_KV, + QUANT_DTYPE) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, SHOULD_MASK_STEPS, ... + SHOULD_PRE_LOAD_V, + True, + SHOULD_RETURN_ENCODED_SOFTMAX, + USE_PADDED_HEAD, + IS_ACTUAL_BLOCK_DMODEL, + QK_SCALE, + IS_EIGHT_BIT_GEMM, + USE_P_SCALE, + IS_EIGHT_BIT_KV, + QUANT_DTYPE) + + if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: + if USE_P_SCALE: + acc *= p_descale + acc *= v_descale + + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc + # which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + # If seqlen_q > seqlen_k but the delta is not a multiple of + # BLOCK_M, then we have one block with a row of all NaNs which + # come from computing softmax over a row of all + # -infs (-inf - inf = NaN). We check for that here and store 0s + # where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 + if o_descale_ptr is not None: + acc = quant_fp8(acc, o_descale) + + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if (causal_start_idx > start_m_idx + and causal_start_idx < end_m_idx): + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, + z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # If seqlen_q not multiple of BLOCK_M, we need to mask out the + # last few rows. This is only true for the last M block. + # For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), + BLOCK_M - overflow_size, + dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if USE_PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] + < IS_ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +def get_shape_from_layout(q, k, metadata): + assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + return batch, nheads_q, nheads_k, head_size + + +def get_strides_from_layout(q, k, v, o, metadata): + assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + STRIDE_PERMUTATIONS = { + 'thd': (None, 1, 0, 2), + 'bhsd': (0, 1, 2, 3), + 'bshd': (0, 2, 1, 3), + } + + perm = STRIDE_PERMUTATIONS[metadata.layout] + stride = lambda x, p: (0 if p is None else x.stride(p)) + strides = lambda x: (stride(x, p) for p in perm) + + return tuple(strides(x) for x in [q, k, v, o]) class _attention(torch.autograd.Function): @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): + def forward(ctx, q, k, v, o, metadata: MetaData): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + if o is None: - o = torch.empty_like(q, dtype=v.dtype) + if metadata.eight_bit: + o = torch.empty_like( + q, + dtype=metadata.output_dtype if metadata.output_dtype + is not None else metadata.eight_bit_dtype_torch) + else: + o = torch.empty_like(q, dtype=q.dtype) - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + metadata.check_args(q, k, v, o) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout( + q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( + q, k, v, o, metadata) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128, 256} - if head_size not in unpadded_head_dims: - padded_d_model = None - for i in unpadded_head_dims: - if i > head_size: - padded_d_model = i - break - assert padded_d_model is not None - else: - padded_d_model = head_size + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) + # encoded_softmax is used to validate dropout behavior vs the + # PyTorch SDPA math backend reference. We zero this out to give a + # consistent starting point and then populate it with the output of + # softmax with the sign bit set according to the dropout mask. + # The resulting return allows this mask to be fed into the reference + # implementation for testing only. This return holds no useful output + # aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2], k.shape[2]), + device=q.device, + dtype=torch.float32) + else: + encoded_softmax = None - encoded_softmax = None + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), + device=q.device, + dtype=torch.float32) # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), + metadata.bias.stride(2), metadata.bias.stride(3)) else: bias_strides = (0, 0, 0, 0) + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), + metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + if metadata.eight_bit: + q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( + metadata.q_descale, metadata.k_descale, metadata.p_scale, + metadata.p_descale, metadata.v_descale, metadata.o_scale) + o_descale = 1.0 / o_scale if o_scale is not None else None + else: + q_descale = k_descale = p_scale = None + p_descale = v_descale = o_descale = None + + # number of compute units available + NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count + + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[ + 'BLOCK_M']), nheads_q, batch) + attn_fwd[grid]( q, k, v, - bias, - sm_scale, - None, + metadata.bias, + metadata.sm_scale, + M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, + *alibi_strides, + q_descale, + k_descale, + p_scale, + p_descale, + o_descale, + v_descale, + q_descale.numel() == 1 if q_descale is not None else False, + k_descale.numel() == 1 if k_descale is not None else False, + p_descale.numel() == 1 if p_descale is not None else False, + v_descale.numel() == 1 if v_descale is not None else False, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, + IS_ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, + VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) + USE_BIAS=metadata.bias is not None, + USE_ALIBI=metadata.alibi_slopes is not None, + SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, + IS_EIGHT_BIT=metadata.eight_bit, + USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, + IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, + NUM_CU=NUM_CU, + B=batch, + QUANT_DTYPE=metadata.eight_bit_dtype_triton) ctx.grid = grid - ctx.sm_scale = sm_scale + ctx.sm_scale = metadata.sm_scale ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False + ctx.return_encoded_softmax = metadata.return_encoded_softmax return o, encoded_softmax -triton_attention = _attention.apply +triton_attention_rocm = _attention.apply + + +def scale_fp8(t, scale=None): + t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]), + scale) + return t_scaled.reshape(t.shape), scale_out + + +def maybe_quantize_fp8(t, scale): + eight_bit_dtype = current_platform.fp8_dtype() + if t.dtype != eight_bit_dtype: + t, _ = scale_fp8(t, scale) + return t + + +def check_and_maybe_quantize_qkv(q, k, v, fp8_scales): + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + + q = maybe_quantize_fp8(q, q_scale) + k = maybe_quantize_fp8(k, k_scale) + v = maybe_quantize_fp8(v, v_scale) + + return q, k, v + + +# query - [num_tokens, num_heads, head_size] +# key - [num_tokens, num_kv_heads, head_size] +# value - [num_tokens, num_kv_heads, head_size +# output - [num_tokens, num_heads, head_size] +def triton_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlens_q: int, + max_seqlens_k: int, + causal: bool = False, + sm_scale: float = 1.0, + bias: Optional[torch.Tensor] = None, + fp8_scales: Optional[tuple[float, ...]] = None, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if fp8_scales is not None: + q_descale, k_descale, v_descale, p_scale = fp8_scales + else: + q_descale = k_descale = v_descale = p_scale = None + + attn_metadata = MetaData(sm_scale=sm_scale, + max_seqlens_q=max_seqlens_q, + max_seqlens_k=max_seqlens_k, + causal=causal, + bias=bias, + output_dtype=q.dtype, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=input_scale) + + if fp8_scales is not None: + q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales) + + return triton_attention_rocm(q, k, v, o, attn_metadata) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 9671b933f47b..250426d9faa5 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -66,7 +66,10 @@ def merge_attn_states_kernel( max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse s_lse = s_lse - max_lse - out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + # Will reuse precomputed Exp values for scale factor computation. + p_se = tl.exp(p_lse) + s_se = tl.exp(s_lse) + out_se = (p_se + s_se) if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse @@ -84,8 +87,8 @@ def merge_attn_states_kernel( # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / out_se - s_scale = tl.exp(s_lse) / out_se + p_scale = p_se / out_se + s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale tl.store(output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, diff --git a/vllm/vllm_flash_attn/fa_utils.py b/vllm/attention/utils/fa_utils.py similarity index 100% rename from vllm/vllm_flash_attn/fa_utils.py rename to vllm/attention/utils/fa_utils.py diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 5d4ebdb7acbc..967510abaeb9 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -38,9 +38,18 @@ class BeamSearchOutput: class BeamSearchInstance: - def __init__(self, prompt_tokens: list[int]): + def __init__( + self, + prompt_tokens: list[int], + logprobs: Optional[list[dict[int, Logprob]]] = None, + **kwargs, + ): self.beams: list[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) + BeamSearchSequence( + tokens=prompt_tokens, + logprobs=[] if logprobs is None else list(logprobs), + **kwargs, + ) ] self.completed: list[BeamSearchSequence] = [] diff --git a/vllm/benchmarks/__init__.py b/vllm/benchmarks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py new file mode 100644 index 000000000000..299c888c2e7b --- /dev/null +++ b/vllm/benchmarks/datasets.py @@ -0,0 +1,831 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena + +TODO: Implement CustomDataset to parse a JSON file and convert its contents into +SampleRequest instances, similar to the approach used in ShareGPT. +""" + +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from io import BytesIO +from typing import Any, Callable, Optional, Union + +import numpy as np +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. + + Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. max_loras (Optional[int]): The maximum number of + LoRAs available. If None, LoRA is not used. lora_path + (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA + is not used. + + Returns: + tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first + element is a LoRARequest (or None if not applicable) and the second + element is the tokenizer associated with the LoRA request (or the + base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = image.convert("RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + assert range_ratio < 1.0, ( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + # New sampling logic: [X * (1 - b), X * (1 + b)] + input_low = int(input_len * (1 - range_ratio)) + input_high = int(input_len * (1 + range_ratio)) + output_low = int(output_len * (1 - range_ratio)) + output_high = int(output_len * (1 + range_ratio)) + + # Add logging for debugging + logger.info("Sampling input_len from [%s, %s]", input_low, input_high) + logger.info("Sampling output_len from [%s, %s]", output_low, + output_high) + + input_lens = np.random.randint(input_low, + input_high + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_high + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_lens[i]) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + while len(samples) < num_requests: + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + try: + import pandas as pd + except ImportError as e: + raise ImportError( + "Pandas is required for BurstGPTDataset. Please install it " + "using `pip install pandas`.") from e + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Base Implementation +# ----------------------------------------------------------------------------- +class HuggingFaceDataset(BenchmarkDataset): + """Base class for datasets hosted on HuggingFace.""" + + SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + + def __init__( + self, + dataset_path: str, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(dataset_path=dataset_path, **kwargs) + + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + self.load_data() + + def load_data(self) -> None: + """Load data from HuggingFace datasets.""" + try: + from datasets import load_dataset + except ImportError as e: + raise ImportError( + "Hugging Face datasets library is required for this dataset. " + "Please install it using `pip install datasets`.") from e + + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = self.data.shuffle(seed=self.random_seed) + + +# ----------------------------------------------------------------------------- +# Conversation Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ConversationDataset(HuggingFaceDataset): + """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { + 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "lmarena-ai/VisionArena-Chat": + lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": + lambda x: x["turns"][0][0]["content"] + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError( + f"Unsupported dataset path: {self.dataset_path}") + prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Instruct Coder Dataset Implementation +# ----------------------------------------------------------------------------- + + +class InstructCoderDataset(HuggingFaceDataset): + """ + InstructCoder Dataset. + https://huggingface.co/datasets/likaixin/InstructCoder + + InstructCoder is the dataset designed for general code editing. It consists + of 114,239 instruction-input-output triplets, and covers multiple distinct + code editing scenario. + """ + + DEFAULT_OUTPUT_LEN = 200 # this is the average default output length + SUPPORTED_DATASET_PATHS = { + "likaixin/InstructCoder", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = f"{item['instruction']}:\n{item['input']}" + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py new file mode 100644 index 000000000000..06f6848f50cb --- /dev/null +++ b/vllm/benchmarks/latency.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark the latency of processing a single batch of requests.""" + +import argparse +import dataclasses +import json +import os +import time +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from vllm import LLM, SamplingParams +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={"latency": results["latencies"]}, + extra_info={k: results[k] + for k in ["avg_latency", "percentiles"]}) + if pt_records: + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument( + "--n", + type=int, + default=1, + help="Number of generated sequences per prompt.", + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-iters-warmup", + type=int, + default=10, + help="Number of iterations to run for warmup.", + ) + parser.add_argument("--num-iters", + type=int, + default=30, + help="Number of iterations to run.") + parser.add_argument( + "--profile", + action="store_true", + help="profile the generation process of a single batch", + ) + parser.add_argument( + "--profile-result-dir", + type=str, + default=None, + help=("path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard."), + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the latency results in JSON format.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + print(args) + + engine_args = EngineArgs.from_cli_args(args) + + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. + llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") + + sampling_params = SamplingParams( + n=args.n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) + print(sampling_params) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompts: list[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + ), + ) + + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir)), + ) as p: + llm_generate() + print(p.key_averages().table(sort_by="self_cuda_time_total")) + else: + start_time = time.perf_counter() + llm_generate() + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + print("Warming up...") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_dir=None) + + if args.profile: + profile_dir = args.profile_result_dir + if not profile_dir: + profile_dir = (Path(".") / "vllm_benchmark_result" / + f"latency_result_{time.time()}") + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=profile_dir) + return + + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f"Avg latency: {np.mean(latencies)} seconds") + for percentage, percentile in zip(percentages, percentiles): + print(f"{percentage}% percentile latency: {percentile} seconds") + + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py new file mode 100644 index 000000000000..b3e24911cc98 --- /dev/null +++ b/vllm/benchmarks/throughput.py @@ -0,0 +1,608 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +import warnings +from typing import Any, Optional, Union + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.utils import merge_async_iterators + + +def run_vllm( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, Optional[list[RequestOutput]]]: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests: Optional[list[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) + end = time.perf_counter() + else: + assert lora_requests is None, "BeamSearch API does not support LoRA" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() + return end - start, outputs + + +def run_vllm_chat( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + """ + Run vLLM chat benchmark. This function is recommended ONLY for benchmarking + multimodal models as it properly handles multimodal inputs and chat + formatting. For non-multimodal models, use run_vllm() instead. + """ + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests.") + + prompts = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + start = time.perf_counter() + outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + return end - start, outputs + + +async def run_vllm_async( + requests: list[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + disable_detokenize: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + assert all( + llm.model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: list[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + max_batch_size: int, + trust_remote_code: bool, + disable_detokenize: bool = False, +) -> float: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: list[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt = requests[i].prompt + prompt_len = requests[i].prompt_len + output_len = requests[i].expected_output_len + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + next_prompt_len = requests[i + 1].prompt_len + next_output_len = requests[i + 1].expected_output_len + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=True, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + if not disable_detokenize: + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "requests_per_second": [results["requests_per_second"]], + "tokens_per_second": [results["tokens_per_second"]], + }, + extra_info={ + k: results[k] + for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = VisionArenaDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = InstructCoderDataset + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ConversationDataset + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_cls = AIMODataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + +def validate_args(args): + """ + Validate command-line arguments. + """ + + # === Deprecation and Defaulting === + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next release. " + "Please use '--dataset-name' and '--dataset-path' instead.", + stacklevel=2) + args.dataset_path = args.dataset + + if not getattr(args, "tokenizer", None): + args.tokenizer = args.model + + # === Backend Validation === + valid_backends = {"vllm", "hf", "mii", "vllm-chat"} + if args.backend not in valid_backends: + raise ValueError(f"Unsupported backend: {args.backend}") + + # === Dataset Configuration === + if not args.dataset and not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf": + if args.dataset_path in ( + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 + elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + else: + raise ValueError( + f"{args.dataset_path} is not supported by hf dataset.") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2) + + # === LoRA Settings === + if getattr(args, "enable_lora", False) and args.backend != "vllm": + raise ValueError( + "LoRA benchmarking is only supported for vLLM backend") + if getattr(args, "enable_lora", False) and args.lora_path is None: + raise ValueError("LoRA path must be provided when enable_lora is True") + + # === Backend-specific Validations === + if args.backend == "hf" and args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend") + if args.backend != "hf" and args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + + if args.backend in {"hf", "mii"} and getattr(args, "quantization", + None) is not None: + raise ValueError("Quantization is only for vLLM backend.") + + if args.backend == "mii" and args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.backend == "mii" and args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError( + "Tokenizer must be the same as the model for MII backend.") + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm") + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)")) + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before the random " + "context in a request (default: 0).", + ) + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for RandomDataset. Must be in the range [0, 1) to define " + "a symmetric sampling range " + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + + # hf dtaset + parser.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + parser.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + if args.tokenizer is None: + args.tokenizer = args.model + validate_args(args) + if args.seed is None: + args.seed = 0 + print(args) + random.seed(args.seed) + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = get_requests(args, tokenizer) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + request_outputs: Optional[list[RequestOutput]] = None + if args.backend == "vllm": + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + args.disable_detokenize, + )) + else: + elapsed_time, request_outputs = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.hf_max_batch_size, args.trust_remote_code, + args.disable_detokenize) + elif args.backend == "vllm-chat": + elapsed_time, request_outputs = run_vllm_chat( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + if request_outputs: + # Note: with the vllm and vllm-chat backends, + # we have request_outputs, which we use to count tokens. + total_prompt_tokens = 0 + total_output_tokens = 0 + for ro in request_outputs: + if not isinstance(ro, RequestOutput): + continue + total_prompt_tokens += len( + ro.prompt_token_ids) if ro.prompt_token_ids else 0 + total_output_tokens += sum( + len(o.token_ids) for o in ro.outputs if o) + total_num_tokens = total_prompt_tokens + total_output_tokens + else: + total_num_tokens = sum(r.prompt_len + r.expected_output_len + for r in requests) + total_output_tokens = sum(r.expected_output_len for r in requests) + total_prompt_tokens = total_num_tokens - total_output_tokens + + if is_multi_modal and args.backend != "vllm-chat": + print("\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. + # vllm-chat backend counts the image tokens now + + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print(f"Total num prompt tokens: {total_prompt_tokens}") + print(f"Total num output tokens: {total_output_tokens}") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/collect_env.py b/vllm/collect_env.py similarity index 96% rename from collect_env.py rename to vllm/collect_env.py index e11271a13640..9cfceb7c45cc 100644 --- a/collect_env.py +++ b/vllm/collect_env.py @@ -282,13 +282,21 @@ def get_vllm_version(): if __version__ == "dev": return "N/A (dev)" - - if len(__version_tuple__) == 4: # dev build - git_sha = __version_tuple__[-1][1:] # type: ignore - return f"{__version__} (git sha: {git_sha}" - + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith('g'): + # it's a dev build + if '.' in version_str: + # it's a dev build containing local changes + git_sha = version_str.split('.')[0][1:] + date = version_str.split('.')[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + # it's a dev build without local changes + git_sha = version_str[1:] # type: ignore + return f"{__version__} (git sha: {git_sha})" return __version__ + def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( @@ -502,7 +510,9 @@ def run_with_pip(): print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] else: - raise RuntimeError("Could not collect pip list output (pip or uv module not available)") + raise RuntimeError( + "Could not collect pip list output (pip or uv module not available)" + ) out = run_and_read_all(run_lambda, cmd) return "\n".join(line for line in out.splitlines() @@ -535,13 +545,12 @@ def is_xnnpack_available(): else: return "N/A" + def get_env_vars(): env_vars = '' - secret_terms=('secret', 'token', 'api', 'access', 'password') - report_prefix = ("TORCH", "NCCL", "PYTORCH", - "CUDA", "CUBLAS", "CUDNN", - "OMP_", "MKL_", - "NVIDIA") + secret_terms = ('secret', 'token', 'api', 'access', 'password') + report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", + "OMP_", "MKL_", "NVIDIA") for k, v in os.environ.items(): if any(term in k.lower() for term in secret_terms): continue @@ -552,6 +561,7 @@ def get_env_vars(): return env_vars + def get_env_info(): run_lambda = run pip_version, pip_list_output = get_pip_packages(run_lambda) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45988c2e9b0d..a1d12b517550 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -110,10 +110,14 @@ def compile(self, compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) if compiled_graph is not None: - if graph_index == 0: - # adds some info logging for the first graph - logger.info("Directly load the compiled graph for shape %s " - "from the cache", str(runtime_shape)) # noqa + if graph_index == num_graphs - 1: + # after loading the last graph for this shape, record the time. + # there can be multiple graphs due to piecewise compilation. + now = time.time() + elapsed = now - compilation_start_time + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", str(runtime_shape), elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -335,7 +339,7 @@ def __init__( def configure_post_pass(self): config = self.compilation_config - self.post_grad_pass_manager.configure(config.pass_config) + self.post_grad_pass_manager.configure(self.vllm_config) # Post-grad custom passes are run using the post_grad_custom_post_pass # hook. If a pass for that hook exists, add it to the pass manager. diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 6c8875916efc..c5454ccdcbf7 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -11,9 +11,12 @@ import torch._inductor.compile_fx import torch.fx as fx +import vllm.envs as envs from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer +from .inductor_pass import pass_context + class CompilerInterface: """ @@ -167,8 +170,7 @@ def compile( compiler_config: Dict[str, Any], runtime_shape: Optional[int] = None ) -> Tuple[Optional[Callable], Optional[Any]]: - from torch._inductor import config - current_config = config.get_config_copy() + current_config = {} from torch._inductor.compile_fx import compile_fx # disable remote cache @@ -196,7 +198,6 @@ def compile( hash_str, file_path = None, None from torch._inductor.codecache import (FxGraphCache, compiled_fx_graph_hash) - if torch.__version__.startswith("2.5"): original_load = FxGraphCache.load original_load_name = "torch._inductor.codecache.FxGraphCache.load" @@ -281,6 +282,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv: patch("torch._inductor.codecache.FxGraphCache._get_shape_env", _get_shape_env)) + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env)) + # for forcing the graph to be cached stack.enter_context( patch( @@ -290,16 +301,34 @@ def _get_shape_env() -> AlwaysHitShapeEnv: # Dynamo metrics context, see method for more details. stack.enter_context(self.metrics_context()) - compiled_graph = compile_fx( - graph, - example_inputs, - inner_compile=hijacked_compile_fx_inner, - config_patches=current_config) - - assert hash_str is not None, ( - "failed to get the hash of the compiled graph") - assert file_path is not None, ( - "failed to get the file path of the compiled graph") + # Disable remote caching. When these are on, on remote cache-hit, + # the monkey-patched functions never actually get called. + # vLLM today assumes and requires the monkey-patched functions to + # get hit. + # TODO(zou3519): we're going to replace this all with + # standalone_compile sometime. + if is_torch_equal_or_newer("2.6"): + stack.enter_context( + torch._inductor.config.patch(fx_graph_remote_cache=False)) + stack.enter_context( + torch._functorch.config.patch( + enable_remote_autograd_cache=False)) + + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch + # compilation cache. So turn off the checks if we disable the + # compilation cache. + if not envs.VLLM_DISABLE_COMPILE_CACHE: + assert hash_str is not None, ( + "failed to get the hash of the compiled graph") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") return compiled_graph, (hash_str, file_path) def load(self, @@ -313,11 +342,19 @@ def load(self, assert isinstance(handle[1], str) hash_str = handle[0] + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) from torch._inductor.codecache import FxGraphCache with ExitStack() as exit_stack: exit_stack.enter_context( patch("torch._inductor.codecache.FxGraphCache._get_shape_env", lambda *args, **kwargs: AlwaysHitShapeEnv())) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) # Dynamo metrics context, see method for more details. exit_stack.enter_context(self.metrics_context()) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index b46f5f52244f..8f32fdb03f8b 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import CompilationConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass): _instance: 'Optional[FusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig.PassConfig): + def instance(cls, config: VllmConfig): """ Get the singleton instance of the FusionPass. If the instance exists, the config is updated but @@ -540,10 +540,10 @@ def instance(cls, config: CompilationConfig.PassConfig): if cls._instance is None: cls._instance = FusionPass(config) else: - cls._instance.config = config + cls._instance.pass_config = config.compilation_config.pass_config return cls._instance - def __init__(self, config: CompilationConfig.PassConfig): + def __init__(self, config: VllmConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index b9a8d3112e77..f9427e48ac31 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -12,6 +12,22 @@ def is_func(node: fx.Node, target) -> bool: return node.op == "call_function" and node.target == target +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + # Returns the first auto_functionalized node with the given op (if it exists) def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 00a2e89f21ae..6cd7720fca2f 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -4,6 +4,7 @@ import inspect import json import types +from contextlib import contextmanager from typing import Any, Callable, Dict, Optional, Union import torch @@ -18,6 +19,34 @@ from .torch25_custom_graph_pass import ( # noqa: yapf Torch25CustomGraphPass as CustomGraphPass) +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + class InductorPass(CustomGraphPass): """ @@ -62,6 +91,9 @@ def hash_dict(dict_: Dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() + def is_applicable_for_shape(self, shape: Optional[int]): + return True + class CallableInductorPass(InductorPass): """ diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 530a88b2b09a..f8e8c4971cbb 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -4,13 +4,15 @@ from torch import fx as fx -from vllm.config import CompilationConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass -from .inductor_pass import CustomGraphPass, InductorPass +from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass +from .sequence_parallelism import SequenceParallelismPass +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass): """ def __init__(self): - self.passes: List[InductorPass] = [] + self.passes: List[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape for pass_ in self.passes: - pass_(graph) + if pass_.is_applicable_for_shape(shape): + pass_(graph) # always run fix_functionalization last self.fix_functionalization(graph) - def configure(self, pass_config: CompilationConfig.PassConfig): - self.pass_config = pass_config - if pass_config.enable_noop: - self.passes += [NoOpEliminationPass(pass_config)] + def configure(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] - if pass_config.enable_fusion: - self.passes += [FusionPass.instance(pass_config)] + if self.pass_config.enable_fusion: + self.passes += [FusionPass.instance(config)] - self.fix_functionalization = FixFunctionalizationPass(pass_config) + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py new file mode 100644 index 000000000000..95db63d34f7e --- /dev/null +++ b/vllm/compilation/sequence_parallelism.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class AllReduceRMSNormPattern: + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + self.epsilon = epsilon + self.dtype = dtype + self.device = device + + +class EmbeddingAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + arg2_1 = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mul_6 = torch.tensor([[3, 7, 1, 4, 9, 2, 5, 0]], + device=self.device, + dtype=torch.long) + unsqueeze = torch.rand([1, 8, 1], device=self.device, \ + dtype=self.dtype) > 0.5 + full_default = torch.zeros([1, 8, 4], device=self.device, \ + dtype=self.dtype) + permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) + + return [arg2_1, mul_6, unsqueeze, full_default, permute, arg3_1] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + all_reduce = tensor_model_parallel_all_reduce(where) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=permute, + input=all_reduce, + weight=arg3_1, + epsilon=self.epsilon, + ) + + return rmsnorm[1], all_reduce + + def replacement( + arg2_1: torch.Tensor, + mul_6: torch.Tensor, + unsqueeze: torch.Tensor, + full_default: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + embedding = torch.ops.aten.embedding.default(arg2_1, mul_6) + where = torch.ops.aten.where.self(unsqueeze, full_default, + embedding) + + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + where, dim=0, world_size=tp_size, group_name=tp.unique_name) + + rmsnorm_result = torch.empty_like(reduce_scatter) + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=rmsnorm_result, + input=reduce_scatter, + weight=arg3_1, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1], rmsnorm[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + all_gather = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + return rmsnorm[1] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tp = get_tp_group() + tp_size = get_tensor_model_parallel_world_size() + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm_1, dim=0, world_size=tp_size, group_name=tp.unique_name) + + # TODO is it possible to extract epsilon from somewhere + rmsnorm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduce_scatter, + residual=residual, + weight=rms_norm_weights, + epsilon=self.epsilon, + ) + + normalized = torch.ops.vllm.all_gather.default( + rmsnorm[1], + dim=0, + world_size=tp_size, + group_name=tp.unique_name) + + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class SequenceParallelismPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="sequence_parallelism_pass") + for epsilon in [1e-5, 1e-6]: + EmbeddingAllReduceRMSNormPattern( + epsilon, self.dtype, self.device).register(self.patterns) + + MiddleAllReduceRMSNormPattern(epsilon, self.dtype, + self.device).register(self.patterns) + + LastAllReduceRMSNormPattern(epsilon, self.dtype, + self.device).register(self.patterns) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.dump_graph(graph, "before_sequence_parallelism_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_sequence_parallelism_pass") diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 98ed6f1472a4..e8bffb406f14 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,7 +4,7 @@ import torch -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -24,16 +24,19 @@ class VllmInductorPass(InductorPass): It provides timing, logging, and dumping utilities. """ - def __init__(self, config: CompilationConfig.PassConfig): - self.config = config + def __init__(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + self.dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config \ + else None self.pass_name = self.__class__.__name__ def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): - if stage in self.config.dump_graph_stages or always: + if stage in self.pass_config.dump_graph_stages or always: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" logger.info("%s printing graph to %s", self.pass_name, filepath) with open(filepath, "w") as f: diff --git a/vllm/config.py b/vllm/config.py index 2912361ee35e..e645103557c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,18 +6,18 @@ import hashlib import inspect import json +import re import sys import textwrap import warnings from collections import Counter -from collections.abc import Mapping from contextlib import contextmanager from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, replace) from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, TypeVar, Union) + Optional, Protocol, TypeVar, Union, get_args) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -28,6 +28,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import CpuArchEnum, current_platform @@ -52,16 +53,16 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.loader import BaseModelLoader - from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) - Config = TypeVar("Config", bound=DataclassInstance) + ConfigType = type[DataclassInstance] else: QuantizationConfig = None - Config = TypeVar("Config") + ConfigType = type logger = init_logger(__name__) +ConfigT = TypeVar("ConfigT", bound=ConfigType) + # This value is chosen to have a balance between ITL and TTFT. Note it is # not optimized for throughput. _DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 @@ -121,7 +122,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: def pairwise(iterable): """ Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - + Can be removed when Python 3.9 support is dropped. """ iterator = iter(iterable) @@ -163,7 +164,7 @@ def pairwise(iterable): return out -def config(cls: type[Config]) -> type[Config]: +def config(cls: ConfigT) -> ConfigT: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. @@ -182,6 +183,23 @@ def config(cls: type[Config]) -> type[Config]: return cls +def get_field(cls: ConfigType, name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields.get(name) + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + return field(default=default) + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory.") + + class ModelConfig: """Configuration for the model. @@ -250,7 +268,7 @@ class ModelConfig: config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running + . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the @@ -298,12 +316,20 @@ def compute_hash(self) -> str: factors.append(self.quantization) factors.append(self.revision) factors.append(self.code_revision) + factors.append(self.max_model_len) + factors.append(self.max_logprobs) + factors.append(self.disable_sliding_window) factors.append(self.trust_remote_code) + factors.append(self.mm_processor_kwargs) + factors.append(self.generation_config) + factors.append(self.model_impl) + factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) - # rope cos/sin cache depends on the max_position_embeddings - factors.append( - getattr(self.hf_config, "max_position_embeddings", "None")) + # hf_config can control how the model looks! + factors.append(self.hf_config.to_json_string()) + str_factors = str(factors) + assert_hashable(str_factors) return hashlib.sha256(str(factors).encode()).hexdigest() def __init__( @@ -332,7 +358,7 @@ def __init__( disable_cascade_attn: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, list[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + limit_mm_per_prompt: Optional[dict[str, int]] = None, use_async_output_proc: bool = True, config_format: ConfigFormat = ConfigFormat.AUTO, hf_token: Optional[Union[bool, str]] = None, @@ -417,8 +443,10 @@ def __init__( from vllm.platforms import current_platform - if self.enable_sleep_mode and not current_platform.is_cuda(): - raise ValueError("Sleep mode is only supported on CUDA devices.") + if (self.enable_sleep_mode + and not current_platform.is_sleep_mode_available()): + raise ValueError( + "Sleep mode is not supported on current platform.") hf_config = get_config(self.hf_config_path or self.model, trust_remote_code, revision, code_revision, @@ -553,7 +581,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, self.tokenizer = s3_tokenizer.dir def _init_multimodal_config( - self, limit_mm_per_prompt: Optional[Mapping[str, int]] + self, limit_mm_per_prompt: Optional[dict[str, int]] ) -> Optional["MultiModalConfig"]: if self.registry.is_multimodal_model(self.architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) @@ -725,8 +753,8 @@ def _verify_quantization(self) -> None: supported_quantization = QUANTIZATION_METHODS optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed_tensors", - "compressed-tensors", "experts_int8", "quark", "nvfp4" + "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", + "quark", "nvfp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -736,13 +764,47 @@ def _verify_quantization(self) -> None: if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() + quant_method = quant_method.replace("compressed_tensors", + "compressed-tensors") + quant_cfg["quant_method"] = quant_method + + # Quantization methods which are overrides (i.e. they have a + # `override_quantization_method` method) must be checked in order + # of preference (this is particularly important for GPTQ). + overrides = [ + "marlin", + "bitblas", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "ipex", + "moe_wna16", + ] + quantization_methods = [ + q for q in supported_quantization if q not in overrides + ] + # Any custom overrides will be in quantization_methods so we place + # them at the start of the list so custom overrides have preference + # over the built in ones. + quantization_methods = quantization_methods + overrides # Detect which checkpoint is it - for name in QUANTIZATION_METHODS: + for name in quantization_methods: method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) - if quantization_override: + if quantization_override is not None: + # Raise error if the override is not custom (custom would + # be in QUANTIZATION_METHODS but not QuantizationMethods) + # and hasn't been added to the overrides list. + if (name in get_args(QuantizationMethods) + and name not in overrides): + raise ValueError( + f"Quantization method {name} is an override but " + "is has not been added to the `overrides` list " + "above. This is necessary to ensure that the " + "overrides are checked in order of preference.") quant_method = quantization_override self.quantization = quantization_override break @@ -1220,23 +1282,78 @@ def is_matryoshka(self) -> bool: return (hasattr(self.hf_config, "matryoshka_dimensions") or getattr(self.hf_config, "is_matryoshka", False)) + @property + def matryoshka_dimensions(self): + return getattr(self.hf_config, "matryoshka_dimensions", None) -class CacheConfig: - """Configuration for the KV cache. - Args: - block_size: Size of a cache block in number of tokens. - gpu_memory_utilization: Fraction of GPU memory to use for the - vLLM execution. - swap_space: Size of the CPU swap space per GPU (in GiB). - cache_dtype: Data type for kv cache storage. - is_attention_free: Whether the model is attention-free. - num_gpu_blocks_override: Number of GPU blocks to use. This overrides the - profiled num_gpu_blocks if specified. Does nothing if None. - sliding_window: Sliding window size for the KV cache. - enable_prefix_caching: Whether to enable prefix caching. - cpu_offload_gb: Size of the CPU offload buffer in GiB. +BlockSize = Literal[1, 8, 16, 32, 64, 128] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] +PrefixCachingHashAlgo = Literal["builtin", "sha256"] + + +@config +@dataclass +class CacheConfig: + """Configuration for the KV cache.""" + + block_size: BlockSize = None # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_configs()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" def compute_hash(self) -> str: """ @@ -1257,43 +1374,13 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - def __init__( - self, - block_size: int, - gpu_memory_utilization: float, - swap_space: float, - cache_dtype: str, - is_attention_free: bool = False, - num_gpu_blocks_override: Optional[int] = None, - sliding_window: Optional[int] = None, - enable_prefix_caching: bool = False, - prefix_caching_hash_algo: str = "builtin", - cpu_offload_gb: float = 0, - calculate_kv_scales: Optional[bool] = None, - ) -> None: - self.block_size = block_size - self.gpu_memory_utilization = gpu_memory_utilization - self.swap_space_bytes = swap_space * GiB_bytes - self.num_gpu_blocks_override = num_gpu_blocks_override - self.cache_dtype = cache_dtype - self.is_attention_free = is_attention_free - self.sliding_window = sliding_window - self.enable_prefix_caching = enable_prefix_caching - self.prefix_caching_hash_algo = prefix_caching_hash_algo - self.cpu_offload_gb = cpu_offload_gb - self.calculate_kv_scales = calculate_kv_scales + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() - # Will be set after profiling. - self.num_gpu_blocks: Optional[int] = None - self.num_cpu_blocks: Optional[int] = None - - # Set calculate_kv_scales to False if the value is unset. - if self.calculate_kv_scales is None: - self.calculate_kv_scales = False - def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info @@ -1312,7 +1399,7 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in get_args(CacheDType): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " @@ -1330,12 +1417,12 @@ def _verify_prefix_caching(self) -> None: "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") - if self.enable_prefix_caching and self.prefix_caching_hash_algo not in ( - "builtin", "sha256"): + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): raise ValueError( "Unknown prefix caching hash algorithm: " - f"{self.prefix_caching_hash_algo}. Must be either " - "'builtin' or 'sha256'.") + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") def verify_with_parallel_config( self, @@ -1356,77 +1443,33 @@ def verify_with_parallel_config( logger.warning("Possibly too large swap space. %s", msg) +@config @dataclass class TokenizerPoolConfig: - """Configuration for the tokenizer pool. + """This config is deprecated and will be removed in a future release. - Args: - pool_size: Number of tokenizer workers in the pool. - pool_type: Type of the pool. - extra_config: Additional config for the pool. - The way the config will be used depends on the - pool type. + Passing these parameters will have no effect. Please remove them from your + configurations. """ - pool_size: int - pool_type: Union[str, type["BaseTokenizerGroup"]] - extra_config: dict - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. + pool_size: int = 0 + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + pool_type: str = "ray" + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + extra_config: dict = field(default_factory=dict) + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if self.pool_type not in ("ray", ) and not isinstance( - self.pool_type, type): - raise ValueError(f"Unknown pool type: {self.pool_type}") - if not isinstance(self.extra_config, dict): - raise ValueError("extra_config must be a dictionary.") - - @classmethod - def create_config( - cls, tokenizer_pool_size: int, - tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]], - tokenizer_pool_extra_config: Optional[Union[str, dict]] - ) -> Optional["TokenizerPoolConfig"]: - """Create a TokenizerPoolConfig from the given parameters. - - If tokenizer_pool_size is 0, return None. - - Args: - tokenizer_pool_size: Number of tokenizer workers in the pool. - tokenizer_pool_type: Type of the pool. - tokenizer_pool_extra_config: Additional config for the pool. - The way the config will be used depends on the - pool type. This can be a JSON string (will be parsed). - """ - if tokenizer_pool_size: - if isinstance(tokenizer_pool_extra_config, str): - tokenizer_pool_extra_config_parsed = json.loads( - tokenizer_pool_extra_config) - else: - tokenizer_pool_extra_config_parsed = ( - tokenizer_pool_extra_config or {}) - tokenizer_pool_config = cls(tokenizer_pool_size, - tokenizer_pool_type, - tokenizer_pool_extra_config_parsed) - else: - tokenizer_pool_config = None - return tokenizer_pool_config + def __post_init__(self) -> None: + logger.warning_once( + "TokenizerPoolConfig is deprecated and will be removed in a " + "future release. Passing this parameter will have no effect. " + "Please remove it from your configurations.") class LoadFormat(str, enum.Enum): @@ -1441,6 +1484,7 @@ class LoadFormat(str, enum.Enum): BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" RUNAI_STREAMER = "runai_streamer" + RUNAI_STREAMER_SHARDED = "runai_streamer_sharded" FASTSAFETENSORS = "fastsafetensors" @@ -1475,7 +1519,7 @@ class LoadConfig: download_dir: Optional[str] = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" - model_loader_extra_config: Optional[Union[str, dict]] = None + model_loader_extra_config: dict = field(default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. This should be a JSON string that will be parsed into a dictionary.""" @@ -1506,10 +1550,6 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - model_loader_extra_config = self.model_loader_extra_config or {} - if isinstance(model_loader_extra_config, str): - self.model_loader_extra_config = json.loads( - model_loader_extra_config) if isinstance(self.load_format, str): load_format = self.load_format.lower() self.load_format = LoadFormat(load_format) @@ -1522,6 +1562,9 @@ def __post_init__(self): self.ignore_patterns = ["original/**/*"] +DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] + + @config @dataclass class ParallelConfig: @@ -1536,8 +1579,21 @@ class ParallelConfig: the product of the tensor parallel size and data parallel size.""" data_parallel_rank: int = 0 """Rank of the data parallel group.""" - data_parallel_rank_local: Optional[int] = None - """Local rank of the data parallel group, defaults to global rank.""" + _data_parallel_rank_local: Optional[int] = field(default=None, init=False) + """Private field to store the local rank of the data parallel group.""" + + @property + def data_parallel_rank_local(self) -> int: + """Local rank of the data parallel group, defaults to global rank.""" + if self._data_parallel_rank_local is None: + return self.data_parallel_rank + return self._data_parallel_rank_local + + @data_parallel_rank_local.setter + def data_parallel_rank_local(self, value: int) -> None: + """Set the local rank of the data parallel group.""" + self._data_parallel_rank_local = value + data_parallel_master_ip: str = "127.0.0.1" """IP of the data parallel master.""" data_parallel_master_port: int = 29500 @@ -1554,8 +1610,8 @@ class ParallelConfig: """Disable the custom all-reduce kernel and fall back to NCCL.""" tokenizer_pool_config: Optional[TokenizerPoolConfig] = None - """Config for the tokenizer pool. If None, will use synchronous - tokenization.""" + """This parameter is deprecated and will be removed in a future release. + Please remove it from your configs""" ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -1563,7 +1619,7 @@ class ParallelConfig: placement_group: Optional["PlacementGroup"] = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[Union[str, + distributed_executor_backend: Optional[Union[DistributedExecutorBackend, type["ExecutorBase"]]] = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product @@ -1577,7 +1633,7 @@ class ParallelConfig: """The full name of the worker class to use. If "auto", the worker class will be determined based on the platform.""" sd_worker_cls: str = "auto" - """The full name of the worker class to use for speculative decofing. + """The full name of the worker class to use for speculative decofing. If "auto", the worker class will be determined based on the platform.""" worker_extension_cls: str = "" """The full name of the worker extension class to use. The worker extension @@ -1646,6 +1702,7 @@ def compute_hash(self): factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) + factors.append(self.enable_expert_parallel) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: @@ -1687,7 +1744,7 @@ def __post_init__(self) -> None: # current node and we aren't in a ray placement group. from vllm.executor import ray_utils - backend = "mp" + backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() if current_platform.is_neuron(): # neuron uses single process to control multiple devices @@ -1755,92 +1812,125 @@ def _verify_args(self) -> None: "worker_extension_cls must be a string (qualified class name).") +PreemptionMode = Literal["swap", "recompute"] +SchedulerPolicy = Literal["fcfs", "priority"] + + +@config @dataclass class SchedulerConfig: """Scheduler configuration.""" - runner_type: str = "generate" # The runner type to launch for the model. + runner_type: RunnerType = "generate" + """The runner type to launch for the model.""" + + max_num_batched_tokens: int = None # type: ignore + """Maximum number of tokens to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" - # Maximum number of tokens to be processed in a single iteration. - max_num_batched_tokens: int = field(default=None) # type: ignore + max_num_seqs: int = None # type: ignore + """Maximum number of sequences to be processed in a single iteration. - # Maximum number of sequences to be processed in a single iteration. - max_num_seqs: int = 128 + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" - # Maximum length of a sequence (including prompt and generated text). - max_model_len: int = 8192 + max_model_len: int = None # type: ignore + """Maximum length of a sequence (including prompt and generated text). This + is primarily set in `ModelConfig` and that value should be manually + duplicated here.""" - # Maximum number of sequences that can be partially prefilled concurrently max_num_partial_prefills: int = 1 + """For chunked prefill, the maximum number of sequences that can be + partially prefilled concurrently.""" - # Maximum number of "very long prompt" sequences that can be prefilled - # concurrently (long is defined by long_prefill_threshold) max_long_partial_prefills: int = 1 + """For chunked prefill, the maximum number of prompts longer than + long_prefill_token_threshold that will be prefilled concurrently. Setting + this less than max_num_partial_prefills will allow shorter prompts to jump + the queue in front of longer prompts in some cases, improving latency.""" - # calculate context length that determines which sequences are - # considered "long" long_prefill_token_threshold: int = 0 + """For chunked prefill, a request is considered long if the prompt is + longer than this number of tokens.""" - # The number of slots to allocate per sequence per - # step, beyond the known token ids. This is used in speculative - # decoding to store KV activations of tokens which may or may not be - # accepted. num_lookahead_slots: int = 0 + """The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. + + NOTE: This will be replaced by speculative config in the future; it is + present to enable correctness tests until then.""" - # Apply a delay (of delay factor multiplied by previous - # prompt latency) before scheduling next prompt. delay_factor: float = 0.0 + """Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt.""" - # If True, prefill requests can be chunked based - # on the remaining max_num_batched_tokens. - enable_chunked_prefill: bool = False + enable_chunked_prefill: bool = None # type: ignore + """If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens.""" is_multimodal_model: bool = False + """True if the model is multimodal.""" + + # TODO (ywang96): Make this configurable. + max_num_encoder_input_tokens: int = field(init=False) + """Multimodal encoder compute budget, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" - # NOTE: The following multimodal encoder budget will be initialized to - # max_num_batched_tokens and overridden in case max multimodal embedding - # size is larger. - # TODO (ywang96): Make these configurable. - # Multimodal encoder compute budget, only used in V1 - max_num_encoder_input_tokens: int = field(default=None) # type: ignore + # TODO (ywang96): Make this configurable. + encoder_cache_size: int = field(init=False) + """Multimodal encoder cache size, only used in V1. - # Multimodal encoder cache size, only used in V1 - encoder_cache_size: int = field(default=None) # type: ignore + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" - # Whether to perform preemption by swapping or - # recomputation. If not specified, we determine the mode as follows: - # We use recomputation by default since it incurs lower overhead than - # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not currently supported. In - # such a case, we use swapping instead. - preemption_mode: Optional[str] = None + preemption_mode: Optional[PreemptionMode] = None + """Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead.""" num_scheduler_steps: int = 1 + """Maximum number of forward steps per scheduler call.""" - multi_step_stream_outputs: bool = False + multi_step_stream_outputs: bool = True + """If False, then multi-step will stream outputs at the end of all steps""" - # Private API. If used, scheduler sends delta data to - # workers instead of an entire data. It should be enabled only - # when SPMD worker architecture is enabled. I.e., - # VLLM_USE_RAY_SPMD_WORKER=1 send_delta_data: bool = False - - # The scheduling policy to use. "fcfs" (default) or "priority". - policy: str = "fcfs" + """Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1""" + + policy: SchedulerPolicy = "fcfs" + """The scheduling policy to use:\n + - "fcfs" means first come first served, i.e. requests are handled in order + of arrival.\n + - "priority" means requests are handled based on given priority (lower + value means earlier handling) and time of arrival deciding any ties).""" chunked_prefill_enabled: bool = field(init=False) + """True if chunked prefill is enabled.""" - # If set to true and chunked prefill is enabled, we do not want to - # partially schedule a multimodal item. Only used in V1 - # This ensures that if a request has a mixed prompt - # (like text tokens TTTT followed by image tokens IIIIIIIIII) where only - # some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), - # it will be scheduled as TTTT in one step and IIIIIIIIII in the next. disable_chunked_mm_input: bool = False + """If set to true and chunked prefill is enabled, we do not want to + partially schedule a multimodal item. Only used in V1 + This ensures that if a request has a mixed prompt + (like text tokens TTTT followed by image tokens IIIIIIIIII) where only + some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), + it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" - # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) - # or "mod.custom_class". scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" + """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the + default scheduler. Can be a class directly or the path to a class of form + "mod.custom_class".""" def compute_hash(self) -> str: """ @@ -1862,6 +1952,18 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self) -> None: + if self.max_model_len is None: + self.max_model_len = 8192 + logger.warning( + "max_model_len was is not set. Defaulting to arbitrary value " + "of %d.", self.max_model_len) + + if self.max_num_seqs is None: + self.max_num_seqs = 128 + logger.warning( + "max_num_seqs was is not set. Defaulting to arbitrary value " + "of %d.", self.max_num_seqs) + if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: if self.num_scheduler_steps > 1: @@ -1974,9 +2076,19 @@ def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 +Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] + + +@config +@dataclass class DeviceConfig: - device: Optional[torch.device] - device_type: str + """Configuration for the device to use for vLLM execution.""" + + device: Union[Device, torch.device] = "auto" + """Device type for vLLM execution.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" def compute_hash(self) -> str: """ @@ -1998,8 +2110,8 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - def __init__(self, device: str = "auto") -> None: - if device == "auto": + def __post_init__(self): + if self.device == "auto": # Automated device type detection from vllm.platforms import current_platform self.device_type = current_platform.device_type @@ -2010,7 +2122,7 @@ def __init__(self, device: str = "auto") -> None: "to turn on verbose logging to help debug the issue.") else: # Device type is assigned explicitly - self.device_type = device + self.device_type = self.device # Some device types require processing inputs on CPU if self.device_type in ["neuron"]: @@ -2022,139 +2134,113 @@ def __init__(self, device: str = "auto") -> None: self.device = torch.device(self.device_type) +SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", + "draft_model"] +SpeculativeAcceptanceMethod = Literal["rejection_sampler", + "typical_acceptance_sampler"] + + +@config @dataclass class SpeculativeConfig: - """ - Configuration for speculative decoding. - Configurable parameters include: - - General Speculative Decoding Control: - - num_speculative_tokens (int): The number of speculative - tokens, if provided. It will default to the number in the draft - model config if present, otherwise, it is required. - - model (Optional[str]): The name of the draft model, eagle head, - or additional weights, if provided. - - method (Optional[str]): The name of the speculative method to use. - If users provide and set the `model` param, the speculative method - type will be detected automatically if possible, if `model` param - is not provided, the method name must be provided. - - Possible values: - - ngram - Related additional configuration: - - prompt_lookup_max (Optional[int]): - Maximum size of ngram token window when using Ngram - proposer, required when method is set to ngram. - - prompt_lookup_min (Optional[int]): - Minimum size of ngram token window when using Ngram - proposer, if provided. Defaults to 1. - - eagle - - medusa - - mlp_speculator - - draft_model - - acceptance_method (str): The method to use for accepting draft - tokens. This can take two possible values: 'rejection_sampler' and - 'typical_acceptance_sampler' for RejectionSampler and - TypicalAcceptanceSampler respectively. If not specified, it - defaults to 'rejection_sampler'. - - Possible values: - - rejection_sampler - - typical_acceptance_sampler - Related additional configuration: - - posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the - posterior probability of a token in the target model - for it to be accepted. This threshold is used only - when we use the TypicalAcceptanceSampler for token - acceptance. - - posterior_alpha (Optional[float]): - Scaling factor for entropy-based threshold, applied - when using TypicalAcceptanceSampler. - - draft_tensor_parallel_size (Optional[int]): The degree of the tensor - parallelism for the draft model. Can only be 1 or the same as the - target model's tensor parallel size. - - disable_logprobs (bool): If set to True, token log probabilities are - not returned during speculative decoding. If set to False, token - log probabilities are returned according to the log probability - settings in SamplingParams. If not specified, it defaults to True. - - - Draft Model Configuration: - - quantization (Optional[str]): Quantization method that was used to - quantize the draft model weights. If None, we assume the - model weights are not quantized. Note that it only takes effect - when using the draft model-based speculative method. - - max_model_len (Optional[int]): The maximum model length of the - draft model. Used when testing the ability to skip - speculation for some sequences. - - revision: The specific model version to use for the draft model. It - can be a branch name, a tag name, or a commit id. If unspecified, - will use the default version. - - code_revision: The specific revision to use for the draft model code - on Hugging Face Hub. It can be a branch name, a tag name, or a - commit id. If unspecified, will use the default version. + """Configuration for speculative decoding.""" - - Advanced Control: - - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to - batch expansion for scoring proposals. If not specified, it - defaults to False. - - disable_by_batch_size (Optional[int]): Disable speculative decoding - for new incoming requests when the number of enqueued requests is - larger than this value, if provided. - - Although the parameters above are structured hierarchically, there is no - need to nest them during configuration. - - Non-configurable internal parameters include: - - Model Configuration: - - target_model_config (ModelConfig): The configuration of the target - model. - - draft_model_config (ModelConfig): The configuration of the draft - model initialized internal. - - Parallelism Configuration: - - target_parallel_config (ParallelConfig): The parallel configuration - for the target model. - - draft_parallel_config (ParallelConfig): The parallel configuration - for the draft model initialized internal. - - Execution Control: - - enable_chunked_prefill (bool): Whether vLLM is configured to use - chunked prefill or not. Used for raising an error since it's not - yet compatible with speculative decode. - - disable_log_stats (bool): Whether to disable the periodic printing of - stage times in speculative decoding. - """ - # speculative configs from cli args + # General speculative decoding control num_speculative_tokens: int = field(default=None, init=True) # type: ignore - method: Optional[str] = None - acceptance_method: str = "rejection_sampler" + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler" + """The method to use for accepting draft tokens:\n + - "rejection_sampler" maps to `RejectionSampler`.\n + - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`. + + If using `typical_acceptance_sampler`, the related configuration + `posterior_threshold` and `posterior_alpha` should be considered.""" draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" - model: Optional[str] = None + # Draft model configuration quantization: Optional[str] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + # Advanced control disable_mqa_scorer: bool = False + """Disable the MQA scorer and fall back to batch expansion for scoring + proposals.""" disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + + # Ngram proposer configuration prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + # Typical acceptance sampler configuration posterior_threshold: Optional[float] = None + """A threshold value that sets a lower bound on the posterior probability + of a token in the target model for it to be accepted. This threshold is + used only when we use the `TypicalAcceptanceSampler` for token acceptance. + """ posterior_alpha: Optional[float] = None + """Scaling factor for entropy-based threshold, applied when using + `TypicalAcceptanceSampler`.""" # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, init=True) # type: ignore + """The configuration of the target model.""" target_parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + """The parallel configuration for the target model.""" enable_chunked_prefill: bool = field(default=None, init=True) # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" disable_log_stats: bool = field(default=None, init=True) # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" # params generated in the post-init stage draft_model_config: ModelConfig = field(default=None, init=True) # type: ignore + """The configuration of the draft model initialized internal.""" draft_parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: """ @@ -2168,9 +2254,10 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # spec decode does not use `torch.compile` yet. factors: list[Any] = [] + # Eagle3 affects the computation graph because it returns intermediate + # hidden states in addition to the final hidden state. + factors.append(self.method == "eagle3") hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @@ -2205,7 +2292,8 @@ def __post_init__(self): if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting # mtp acceleration for more models besides deepseek_v3 - if self.target_model_config.hf_text_config.model_type \ + if self.target_model_config and \ + self.target_model_config.hf_text_config.model_type \ == "deepseek_v3": # use the draft model from the same model: self.model = self.target_model_config.model @@ -2286,7 +2374,10 @@ def __post_init__(self): ) # Automatically detect the method - if "eagle-" in self.draft_model_config.model.lower(): + if self.method in ('eagle', 'eagle3'): + pass + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" @@ -2297,7 +2388,7 @@ def __post_init__(self): self.method = "draft_model" # Replace hf_config for EAGLE draft_model - if self.method == "eagle": + if self.method in ("eagle", "eagle3"): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " @@ -2442,7 +2533,6 @@ def create_draft_parallel_config( max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. disable_custom_all_reduce, - tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, ray_workers_use_nsight=target_parallel_config. ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, @@ -2495,6 +2585,12 @@ def _verify_args(self) -> None: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: + raise ValueError( + "Eagle3 is only supported for Llama models. " + f"Got {self.target_model_config.hf_text_config.model_type=}") + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per @@ -2505,6 +2601,9 @@ def num_lookahead_slots(self) -> int: """ return self.num_speculative_tokens + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3") + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model @@ -2512,18 +2611,41 @@ def __repr__(self) -> str: return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config @dataclass class LoRAConfig: - max_lora_rank: int - max_loras: int + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ max_cpu_loras: Optional[int] = None - lora_dtype: Optional[Union[torch.dtype, str]] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" lora_extra_vocab_size: int = 256 + """Maximum size of extra vocabulary that can be present in a LoRA adapter + (added to the base model vocabulary).""" # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 - long_lora_scaling_factors: Optional[tuple[float]] = None + long_lora_scaling_factors: Optional[tuple[float, ...]] = None + """Specify multiple scaling factors (which can be different from base model + scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters + trained with those scaling factors to be used at the same time. If not + specified, only adapters trained with the base model scaling factor are + allowed.""" bias_enabled: bool = False + """Enable bias for LoRA adapters.""" def compute_hash(self) -> str: """ @@ -2582,25 +2704,27 @@ def verify_with_model_config(self, model_config: ModelConfig): elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): - # Reminder: Please update docs/source/features/compatibility_matrix.md - # If the feature combo become valid - if scheduler_config.chunked_prefill_enabled: - logger.warning("LoRA with chunked prefill is still experimental " - "and may be unstable.") - def verify_lora_support(self): if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1: raise ValueError( "V1 LoRA does not support long LoRA, please use V0.") +@config @dataclass class PromptAdapterConfig: - max_prompt_adapters: int - max_prompt_adapter_token: int + """Configuration for PromptAdapters.""" + + max_prompt_adapters: int = 1 + """Max number of PromptAdapters in a batch.""" + max_prompt_adapter_token: int = 0 + """Max number of PromptAdapters tokens.""" max_cpu_prompt_adapters: Optional[int] = None - prompt_adapter_dtype: Optional[torch.dtype] = None + """Maximum number of PromptAdapters to store in CPU memory. Must be >= than + `max_prompt_adapters`.""" + prompt_adapter_dtype: Union[torch.dtype, str] = "auto" + """Data type for PromptAdapter. If auto, will default to base model dtype. + """ def compute_hash(self) -> str: """ @@ -2632,20 +2756,26 @@ def __post_init__(self): self.max_cpu_prompt_adapters = self.max_prompt_adapters def verify_with_model_config(self, model_config: ModelConfig): - if self.prompt_adapter_dtype in (None, "auto"): + if self.prompt_adapter_dtype == "auto": self.prompt_adapter_dtype = model_config.dtype elif isinstance(self.prompt_adapter_dtype, str): self.prompt_adapter_dtype = getattr(torch, self.prompt_adapter_dtype) +@config @dataclass class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: Mapping[str, int] = field(default_factory=dict) + limit_per_prompt: dict[str, int] = field(default_factory=dict) """ The maximum number of input items allowed per prompt for each modality. + This should be a JSON string that will be parsed into a dictionary. + Defaults to 1 (V0) or 999 (V1) for each modality. + + For example, to allow up to 16 images and 2 videos per prompt: + ``{"images": 16, "videos": 2}`` """ def compute_hash(self) -> str: @@ -2667,24 +2797,20 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - def get_default_limit_per_prompt(self) -> int: - """ - Return the default number of input items allowed per prompt - for any modality if not specified by the user. - """ - return 999 if envs.VLLM_USE_V1 else 1 - def get_limit_per_prompt(self, modality: str) -> int: """ Get the maximum number of input items allowed per prompt for the given modality. """ - default = self.get_default_limit_per_prompt() - return self.limit_per_prompt.get(modality, default) + return self.limit_per_prompt.get( + modality, + 999 if envs.VLLM_USE_V1 else 1, + ) # TODO: Add configs to init vision tower or not. +@config @dataclass class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" @@ -2762,12 +2888,10 @@ def _get_and_verify_dtype( ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. - config_dtype = getattr(config, "torch_dtype", None) + config_dtype = getattr(config.get_text_config(), "torch_dtype", None) - # Fallbacks for multi-modal models if the root config + # Fallback for multi-modal models if the root config # does not define torch_dtype - if config_dtype is None and hasattr(config, "text_config"): - config_dtype = getattr(config.text_config, "torch_dtype", None) if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None) @@ -2783,6 +2907,13 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype + if config.model_type == "plamo2": + logger.info( + "For PLaMo2, we cast models to bfloat16 instead of using " + "float16 by default. This is because float16 does not work." + ) + torch_dtype = torch.bfloat16 + from vllm.platforms import current_platform if (current_platform.is_cpu() and current_platform.get_cpu_architecture() @@ -2812,6 +2943,11 @@ def _get_and_verify_dtype( "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 + elif dtype == "float16" and config.model_type == "plamo2": + logger.warning( + "For PLaMo2, using float16 is unstable and might cause " + "unexpected behavior. Please use bfloat16 or float32 instead.") + torch_dtype = torch.float16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") @@ -2997,15 +3133,28 @@ def get_served_model_name(model: str, return served_model_name +GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", + "xgrammar", "guidance"] +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] + + +@config @dataclass class DecodingConfig: - """Dataclass which contains the decoding strategy of the engine""" + """Dataclass which contains the decoding strategy of the engine.""" - # Which guided decoding algo to use. - # 'outlines' / 'lm-format-enforcer' / 'xgrammar' - guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar" + guided_decoding_backend: Union[ + GuidedDecodingBackendV0, + GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar" + """Which engine will be used for guided decoding (JSON schema / regex etc) + by default. With "auto", we will make opinionated choices based on request + contents and what the backend libraries currently support, so the behavior + is subject to change in each release.""" reasoning_backend: Optional[str] = None + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format. + Required for `--enable-reasoning`.""" def compute_hash(self) -> str: """ @@ -3027,17 +3176,12 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - v0_valid_guided_backends = [ - 'outlines', 'lm-format-enforcer', 'xgrammar', 'auto' - ] - v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto'] - backend = GuidedDecodingParams( backend=self.guided_decoding_backend).backend_name if envs.VLLM_USE_V1: - valid_guided_backends = v1_valid_guided_backends + valid_guided_backends = get_args(GuidedDecodingBackendV1) else: - valid_guided_backends = v0_valid_guided_backends + valid_guided_backends = get_args(GuidedDecodingBackendV0) if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}'," f" must be one of {valid_guided_backends}") @@ -3297,11 +3441,13 @@ class PassConfig(BaseModel): - enable_fusion: whether to enable the custom fusion pass. - enable_noop: whether to enable the custom no-op elimination pass. TODO(luka) better pass enabling system. + - enable_sequence_parallelism: whether to enable sequence parallelism. """ dump_graph_stages: list[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True enable_noop: bool = True + enable_sequence_parallelism: bool = False def uuid(self): """ @@ -3310,7 +3456,8 @@ def uuid(self): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) + dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ + "enable_sequence_parallelism"}) return InductorPass.hash_dict(dict_) def model_post_init(self, __context: Any) -> None: @@ -3337,7 +3484,8 @@ def model_post_init(self, __context: Any) -> None: compilation_time: float = PrivateAttr # Per-model forward context - # Map from layer name to the attention cls + # Map from layer name to layer objects that need to be accessed outside + # model code, e.g., Attention, FusedMOE when dp_size>1. static_forward_context: dict[str, Any] = PrivateAttr def compute_hash(self) -> str: @@ -3668,6 +3816,17 @@ def _get_quantization_config( return quant_config return None + @staticmethod + def get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config(copy.deepcopy(model_config), + load_config) + def with_hf_config( self, hf_config: PretrainedConfig, @@ -3697,8 +3856,6 @@ def __post_init__(self): if self.lora_config: self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) self.lora_config.verify_lora_support() if self.prompt_adapter_config: self.prompt_adapter_config.verify_with_model_config( @@ -3722,6 +3879,8 @@ def __post_init__(self): if self.compilation_config is None: self.compilation_config = CompilationConfig() + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: # NOTE(woosuk): Currently, we use inductor because the piecewise @@ -3729,7 +3888,8 @@ def __post_init__(self): # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. # FIXME(rob): Add function to set all of these. - self.compilation_config.custom_ops = ["none"] + if not self.compilation_config.custom_ops: + self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 @@ -3738,6 +3898,18 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() + if self.parallel_config is not None and \ + self.parallel_config.tensor_parallel_size > 1 and \ + self.parallel_config.pipeline_parallel_size > 1 and \ + self.compilation_config is not None and \ + self.compilation_config.pass_config is not None and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + logger.warning_once( + "Sequence parallelism is not supported with pipeline " + "parallelism. Disabling sequence parallelism.") + self.compilation_config.pass_config.\ + enable_sequence_parallelism = False + self._set_cudagraph_sizes() if self.cache_config is not None and \ @@ -3777,6 +3949,26 @@ def __post_init__(self): if not self.instance_id: self.instance_id = random_uuid()[:5] + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) + + return [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + def _set_cudagraph_sizes(self): """ cudagraph batchsize padding logic: @@ -3814,6 +4006,11 @@ def _set_cudagraph_sizes(self): not self.model_config.enforce_eager: possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + possible_sizes = self.update_sizes_for_sequence_parallelism( + possible_sizes) + # find the minimum size that is larger than max_num_seqs, # which then becomes the max_batchsize_to_capture larger_sizes = [ @@ -3837,6 +4034,11 @@ def _set_cudagraph_sizes(self): not self.model_config.enforce_eager: batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ size for size in batch_size_capture_list @@ -3935,3 +4137,43 @@ def get_current_vllm_config() -> VllmConfig: from vllm.config import VllmConfig return VllmConfig() return _current_vllm_config + + +def contains_object_print(text): + """ + Check if the text looks like a printed Python object, e.g. + contains any substring matching the pattern: "at 0xFFFFFFF>" + We match against 0x followed by 2-16 hex chars (there's + a max of 16 on a 64 bit system). + + Args: + text (str): The text to check + + Returns: + bool: True if a match is found, False otherwise + """ + pattern = r'at 0x[a-fA-F0-9]{2,16}>' + match = re.search(pattern, text) + return match is not None + + +def assert_hashable(text): + if not contains_object_print(text): + return True + raise AssertionError( + f"vLLM tried to hash some configs that may have Python objects ids " + f"in them. This is a bug, please file an issue. " + f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index cf85a2135c81..97d03d5e3b40 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1596,7 +1596,6 @@ def schedule( multi_modal_placeholders=( seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None), - mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) else: diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0228264f91f9..894a0fafb640 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -19,6 +19,12 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor, return get_tp_group().all_gather(input_, dim) +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """Reduce-Scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_, dim) + + def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) -> Optional[torch.Tensor]: diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb12f8834b41..240313b98c88 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -61,6 +61,40 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() + def gather(self, input_: torch.Tensor, dst: int = 0, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 07c9ff506092..8bca278f3888 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -70,6 +70,31 @@ def all_reduce(self, input_): torch.distributed.all_reduce(out, group=self.device_group) return out + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 11ed7c084377..723719c79e9c 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -7,11 +7,13 @@ from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import List, Optional, Tuple, Union +from threading import Event +from typing import Any, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed as dist +import zmq from torch.distributed import ProcessGroup from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore @@ -239,7 +241,7 @@ def __init__( self.remote_socket.setsockopt(IPV6, 1) remote_addr_ipv6 = True connect_ip = f"[{connect_ip}]" - socket_addr = f"tcp://*:{remote_subscribe_port}" + socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" else: @@ -400,7 +402,9 @@ def acquire_write(self, timeout: Optional[float] = None): break @contextmanager - def acquire_read(self, timeout: Optional[float] = None): + def acquire_read(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -430,6 +434,9 @@ def acquire_read(self, timeout: Optional[float] = None): ) n_warning += 1 + if cancel is not None and cancel.is_set(): + raise RuntimeError("cancelled") + # if we time out, raise an exception if (timeout is not None and time.monotonic() - start_time > timeout): @@ -464,10 +471,12 @@ def enqueue(self, obj, timeout: Optional[float] = None): if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self, timeout: Optional[float] = None): + def dequeue(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): """ Read from message queue with optional timeout (in seconds) """ if self._is_local_reader: - with self.acquire_read(timeout) as buf: + with self.acquire_read(timeout, cancel) as buf: overflow = buf[0] == 1 if not overflow: # no need to know the size of serialized object @@ -475,15 +484,21 @@ def dequeue(self, timeout: Optional[float] = None): # see https://docs.python.org/3/library/pickle.html obj = pickle.loads(buf[1:]) if overflow: - recv = self.local_socket.recv() - obj = pickle.loads(recv) + obj = MessageQueue.recv(self.local_socket, timeout) elif self._is_remote_reader: - recv = self.remote_socket.recv() - obj = pickle.loads(recv) + obj = MessageQueue.recv(self.remote_socket, timeout) else: raise RuntimeError("Only readers can dequeue") return obj + @staticmethod + def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any: + timeout_ms = None if timeout is None else int(timeout * 1000) + if not socket.poll(timeout=timeout_ms): + raise TimeoutError + recv = socket.recv(copy=False) + return pickle.loads(recv.buffer) + def broadcast_object(self, obj=None): if self._is_writer: self.enqueue(obj) diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index e69de29bb2d1..ec07c6fe0d12 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.distributed.kv_transfer.kv_transfer_state import ( + ensure_kv_transfer_initialized, get_kv_transfer_group, + has_kv_transfer_group, is_v1_kv_transfer_group) + +__all__ = [ + "get_kv_transfer_group", "has_kv_transfer_group", + "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", + "KVConnectorBaseType" +] diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 57c764b481c2..0d1a3d40af41 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -12,6 +12,7 @@ import torch +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.sequence import IntermediateTensors if TYPE_CHECKING: @@ -121,3 +122,6 @@ def recv_kv_caches_and_hidden_states( """ raise NotImplementedError + + +KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index e37ce6dc75b0..6532c101a4f6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,14 +3,22 @@ import importlib from typing import TYPE_CHECKING, Callable, Dict, Type +import vllm.envs as envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.logger import init_logger + from .base import KVConnectorBase if TYPE_CHECKING: from vllm.config import VllmConfig +logger = init_logger(__name__) + class KVConnectorFactory: - _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + _registry: Dict[str, Callable[[], Type[KVConnectorBaseType]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -19,22 +27,51 @@ def register_connector(cls, name: str, module_path: str, if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> Type[KVConnectorBase]: + def loader() -> Type[KVConnectorBaseType]: module = importlib.import_module(module_path) return getattr(module, class_name) cls._registry[name] = loader @classmethod - def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + def create_connector_v0(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V0 Connector, " + f"but found {envs.VLLM_USE_V1=}") + connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase) return connector_cls(rank, local_rank, config) + @classmethod + def create_connector_v1( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase_V1: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + connector_name = config.kv_transfer_config.kv_connector + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase_V1) + logger.info("Creating v1 connector with name: %s", connector_name) + # NOTE(Kuntai): v1 connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + # - Should only be used inside the forward context & attention layer + # We build separately to enforce strict separation + return connector_cls(config, role) + # Register various connectors here. # The registration should not be done in each individual file, as we want to @@ -57,4 +94,14 @@ def create_connector(cls, rank: int, local_rank: int, KVConnectorFactory.register_connector( "MooncakeStoreConnector", "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") \ No newline at end of file + "MooncakeStoreConnector") + +KVConnectorFactory.register_connector( + "SharedStorageConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", + "SharedStorageConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnectorV1", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", + "LMCacheConnectorV1") diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py index c5135dab23eb..7b26aec23239 100644 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ MooncakeStore Connector for Distributed Machine Learning Inference - The MooncakeStoreConnector transfers KV caches between prefill vLLM workers (KV cache producer) and decode vLLM workers (KV cache consumer) using a database-style KVStore. @@ -11,9 +10,10 @@ import torch -from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) from vllm.logger import init_logger from vllm.sequence import IntermediateTensors @@ -32,8 +32,7 @@ def __init__( config: VllmConfig, ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size - + self.kv_helper = kv_helper(config) self.local_tp_rank = local_rank # Init kv_store @@ -80,12 +79,7 @@ def send_kv_caches_and_hidden_states( slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - head_size = int(hidden_size / num_attention_heads) + num_heads, head_size = self.kv_helper.get_model_args(model_executable) for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx]) @@ -97,10 +91,8 @@ def send_kv_caches_and_hidden_states( for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] keys.append(key_cache[current_slot_mapping].unsqueeze(0)) @@ -173,22 +165,15 @@ def recv_kv_caches_and_hidden_states( layer = model_executable.model.layers[layer_id] # get kvcache object kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = kv_cache[0], kv_cache[1] - # get remote kvcache + # get remote kvcache remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ layer_id] - # use ops.reshape_and_cache_flash to put kv into kvcache - ops.reshape_and_cache_flash( - remote_k.to(key_cache.device), - remote_v.to(value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 49b97d7b5889..0464a7585138 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -12,10 +12,10 @@ import torch -import vllm.envs as envs -from vllm import _custom_ops as ops from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -37,9 +37,7 @@ def __init__( ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.kv_helper = kv_helper(config) if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -165,31 +163,7 @@ def send_kv_caches_and_hidden_states( num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - - # Deepseek's MLA (Multi-head Latent Attention) uses two different - # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. - # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, - # resulting in a kv_cache shape of [num_blks, blk_size, 1, - # kv_lora_rank + qk_rope_head_dim]. - # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading - # to a kv_cache shape of [2, num_blks, blk_size, - # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. - if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + \ - model_config.qk_rope_head_dim - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + \ - model_config.qk_rope_head_dim - else: - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + num_heads, head_size = self.kv_helper.get_model_args(model_executable) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -212,13 +186,8 @@ def send_kv_caches_and_hidden_states( for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] - - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] @@ -248,12 +217,12 @@ def recv_kv_caches_and_hidden_states( # and hidden states. bypass_model_exec = True - model_config = model_executable.model.config - input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer hidden_or_intermediate_states_for_one_req = [] @@ -312,41 +281,19 @@ def recv_kv_caches_and_hidden_states( end_pos = start_pos + num_computed_tokens # put received KV caches into paged memory - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - kv_cache = kv_caches[i - model_executable.model.start_layer] - layer = model_executable.model.layers[i] - - if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - k_c_normed_k_pe = keys[ - i - model_executable.model.start_layer].to( - kv_cache.device).squeeze(1) - k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] - ops.concat_and_cache_mla( - k_c_normed, - k_pe, - kv_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - ) - else: - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + for cur_layer in range(start_layer, end_layer): + + layer_id = cur_layer - start_layer + kv_cache = kv_caches[layer_id] + layer = model_executable.model.layers[cur_layer] + + # get remote kvcache + remote_k, remote_v = keys[layer_id], values[layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) hidden_or_intermediate_states_for_one_req.append(hidden) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py new file mode 100644 index 000000000000..0b0ce9828a74 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KV cache helper for store. +""" +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class model_aware_kv_ops_helper: + + def __init__(self, config: VllmConfig): + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.tp_size = config.parallel_config.tensor_parallel_size + + def get_model_args(self, model_executable: torch.nn.Module): + + model_config = model_executable.model.config + self.model_executable = model_executable + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", + int(hidden_size // num_attention_heads)) + + return num_heads, head_size + + def get_kv_from_cache(self, kv_cache, num_heads, head_size): + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + return key_cache, value_cache + + def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, + layer, kv_cache, slot_mapping, start_pos, end_pos): + + model_config = model_executable.model.config + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys.squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed.to(kv_cache.device), + k_pe.to(kv_cache.device), + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys.to(key_cache.device), + values.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py new file mode 100644 index 000000000000..a017b140e090 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorRole) + +__all__ = [ + "KVConnectorRole", + "KVConnectorBase_V1", +] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py new file mode 100644 index 000000000000..95967d2ca919 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State +communication in vLLM v1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save KV cache. + get_num_new_matched_tokens() - get number of new tokens + that exist in the remote KV cache + update_state_after_alloc() - update KVConnector state after + temporary buffer alloc by the CacheManager. + + Worker-side: runs in each worker, loads/saves KV cache to/from + the Connector based on the metadata. + start_load_kv() - starts loading all KVs (maybe async) + wait_for_layer_load() - blocks until layer i load is done + + save_kv_layer() - starts saving KV for layer i (maybe async) + wait_for_save() - blocks until all saves are done +""" + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class KVConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +@dataclass +class KVConnectorMetadata: + pass + + +class KVConnectorBase_V1(ABC): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + logger.warning( + "Initializing KVConnectorBase_V1. This API is experimental and " + "subject to change in the future as we iterate the design.") + self._connector_metadata = KVConnectorMetadata() + self._vllm_config = vllm_config + self._role = role + + @property + def role(self) -> KVConnectorRole: + return self._role + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = KVConnectorMetadata() + + def _get_connector_metadata(self) -> KVConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + return self._connector_metadata + + # ============================== + # Worker-side methods + # ============================== + + @abstractmethod + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + @abstractmethod + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + @abstractmethod + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + @abstractmethod + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + # ============================== + # Scheduler-side methods + # ============================== + @abstractmethod + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 000000000000..e07f185f0dd8 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +import torch +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class LMCacheConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self._lmcache_engine.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self._lmcache_engine.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, + **kwargs) + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self._lmcache_engine.wait_for_save() + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._lmcache_engine.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self._lmcache_engine.update_state_after_alloc(request, + num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self._lmcache_engine.build_connector_meta(scheduler_output) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py new file mode 100644 index 000000000000..f91ffbc720e7 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -0,0 +1,383 @@ +# SPDX-License-Identifier: Apache-2.0 +import hashlib +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + + @staticmethod + def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, + is_store: bool) -> "ReqMeta": + valid_num_tokens = align_to_block_size(len(token_ids), block_size) + token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + ) + + +@dataclass +class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)) + + +class SharedStorageConnector(KVConnectorBase_V1): + # NOTE: This is Simple debug implementation of the KV connector. + # It save / load the KV cache to / from the disk. + # It does extra work which will overwrite the existing prefix-cache in GPU + # - to remove the overhead, need to add some "mask" in the ReqMeta class + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + transfer_config = vllm_config.kv_transfer_config + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.info(vllm_config.kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = \ + self._get_connector_metadata() + assert isinstance(metadata, SharedStorageConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + logger.info("Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping)) + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache_layer = attn_layer.kv_cache[\ + forward_context.virtual_engine] + + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = safetensors.torch.load_file( + filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + for request in connector_metadata.requests: + if request.is_store: + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = extract_kv_from_layer(kv_layer, + request.slot_mapping) + tensors = {"kv_cache": kv_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + + def wait_for_save(self): + return + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + + # NOTE: in this debug implementation, we assume that the prompt is + # cached_prompt + newly_generated_single_token + # Therefore, we use prompt_token_ids[:-1] to determine the folder name + + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned + # with the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + if not self._found_match_for_request(request): + return 0 + + logger.info("External Cache Hit!") + + # Now, first num_tokens_to_check tokens are hit, we need to prepare + # the metadata for the worker connector to correctly load the KV + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + + return num_tokens_to_check - num_computed_tokens + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = SharedStorageConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + if not self._found_match_for_request(new_req): + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size, + is_store=True) + + for cached_req in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not cached_req.resumed_from_preemption: + break + if cached_req.req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[cached_req.req_id] + total_tokens = (len(cached_req.new_token_ids) + + cached_req.num_computed_tokens) + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = cached_req.new_block_ids + + meta.add_request(token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_request( + self, + request: "Request", + ) -> bool: + """Check if the cache is hit for the request. + """ + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + foldername = self._generate_foldername_debug(torch.tensor( + request.prompt_token_ids)[:num_tokens_to_check], + create_folder=False) + return os.path.exists(foldername) + + def _generate_foldername_debug( + self, + input_ids: torch.Tensor, + create_folder=False, + ) -> str: + """Generate a folder name based on the hash of the bytes of the input + ids. + """ + input_ids_bytes = input_ids.numpy().tobytes() + input_ids_hash = hashlib.md5(input_ids_bytes, + usedforsecurity=False).hexdigest() + foldername = os.path.join(self._storage_path, input_ids_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug( + self, + layer_name: str, + input_ids: torch.Tensor, + ) -> str: + """Generate a file name based on the layer name and the hash + of the bytes of the input ids. + """ + foldername = self._generate_foldername_debug(input_ids, + create_folder=True) + return os.path.join(foldername, f"{layer_name}.safetensors") + + +def align_to_block_size(num_tokens: int, block_size) -> int: + """Align the number of tokens to the block size. + """ + return (num_tokens - 1) // block_size * block_size diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py similarity index 97% rename from vllm/distributed/kv_transfer/kv_transfer_agent.py rename to vllm/distributed/kv_transfer/kv_connector_agent.py index 1e80e0bd7de8..9d7145098105 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_connector_agent.py @@ -46,7 +46,7 @@ def __init__( assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ "TransferAgent should only be used when kv_connector is set." - self.connector = KVConnectorFactory.create_connector( + self.connector = KVConnectorFactory.create_connector_v0( rank, local_rank, config) def send_kv_caches_and_hidden_states( diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py index 7fd5967293f2..5bb711021676 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -70,7 +70,7 @@ def __init__( ): try: - from mooncake_vllm_adaptor import MooncakeDistributedStore + from mooncake.store import MooncakeDistributedStore except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index ec46d4045447..aa4b1ba71492 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -2,6 +2,7 @@ import json import os +import struct from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Optional, Union @@ -57,14 +58,14 @@ class MooncakeTransferEngine: def __init__(self, kv_rank: int, local_rank: int): try: - import mooncake_vllm_adaptor as mva + from mooncake.engine import TransferEngine except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 "to run vLLM with MooncakeConnector.") from e - self.engine = mva.mooncake_vllm_adaptor() + self.engine = TransferEngine() self.local_rank = local_rank try: @@ -115,14 +116,14 @@ def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str, p_rank_offset = int(p_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2 if kv_rank == 0: - self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") + self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") - self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") + self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}") else: self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") - self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") - self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") + self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}") + self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") def initialize(self, local_hostname: str, metadata_server: str, @@ -140,12 +141,12 @@ def initialize(self, local_hostname: str, metadata_server: str, "Mooncake Configuration error. `metadata_backend`" f" should be one of {supported_backend}.") - self.engine.initializeExt(local_hostname, metadata_server, - protocol, device_name, metadata_backend) + self.engine.initialize_ext(local_hostname, metadata_server, + protocol, device_name, metadata_backend) def allocate_managed_buffer(self, length: int) -> int: """Allocate a managed buffer of the specified length.""" - ret = self.engine.allocateManagedBuffer(length) + ret = self.engine.allocate_managed_buffer(length) if ret <= 0: logger.error("Allocation Return Error") raise Exception("Allocation Return Error") @@ -153,13 +154,13 @@ def allocate_managed_buffer(self, length: int) -> int: def free_managed_buffer(self, buffer: int, length: int) -> int: """Free a previously allocated managed buffer.""" - return self.engine.freeManagedBuffer(buffer, length) + return self.engine.free_managed_buffer(buffer, length) def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int: """Synchronously transfer data to the specified address.""" - ret = self.engine.transferSync(self.remote_url, buffer, - peer_buffer_address, length) + ret = self.engine.transfer_sync_read(self.remote_url, buffer, + peer_buffer_address, length) if ret < 0: logger.error("Transfer Return Error") raise Exception("Transfer Return Error") @@ -168,15 +169,15 @@ def transfer_sync(self, buffer: int, peer_buffer_address: int, def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int: """Write bytes to the allocated buffer.""" - return self.engine.writeBytesToBuffer(buffer, user_data, length) + return self.engine.write_bytes_to_buffer(buffer, user_data, length) def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: """Read bytes from the allocated buffer.""" - return self.engine.readBytesFromBuffer(buffer, length) + return self.engine.read_bytes_from_buffer(buffer, length) def wait_for_ack(self, src_ptr: int, length: int) -> None: """Asynchronously wait for ACK from the receiver.""" - ack = self.sender_ack.recv_pyobj() + ack = self.sender_ack.recv() if ack != b'ACK': logger.error("Failed to receive ACK from the receiver") @@ -187,18 +188,22 @@ def send_bytes(self, user_data: bytes) -> None: length = len(user_data) src_ptr = self.allocate_managed_buffer(length) self.write_bytes_to_buffer(src_ptr, user_data, length) - self.sender_socket.send_pyobj((src_ptr, length)) + self.sender_socket.send_multipart( + [struct.pack("!Q", src_ptr), + struct.pack("!Q", length)]) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) def recv_bytes(self) -> bytes: """Receive bytes from the remote process.""" - src_ptr, length = self.receiver_socket.recv_pyobj() + data = self.receiver_socket.recv_multipart() + src_ptr = struct.unpack("!Q", data[0])[0] + length = struct.unpack("!Q", data[1])[0] dst_ptr = self.allocate_managed_buffer(length) self.transfer_sync(dst_ptr, src_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length) # Buffer cleanup - self.receiver_ack.send_pyobj(b'ACK') + self.receiver_ack.send(b'ACK') self.free_managed_buffer(dst_ptr, length) return ret diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py new file mode 100644 index 000000000000..25d2f2cf5c6e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.parallel_state import get_world_group + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None + + +def get_kv_transfer_group() -> KVConnectorBaseType: + assert _KV_CONNECTOR_AGENT is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_CONNECTOR_AGENT + + +def has_kv_transfer_group() -> bool: + return _KV_CONNECTOR_AGENT is not None + + +def is_v1_kv_transfer_group( + connector: Optional[KVConnectorBaseType] = None) -> bool: + """Check if the KV connector is the v1 connector. + If the argument is None, it will check the global KV connector + + Args: + connector: The KV connector to check. If None, it will check the + global KV connector. + + Note: + This function will no-longer be needed after the v1 KV connector + becomes the default. + """ + if connector is None: + connector = _KV_CONNECTOR_AGENT + + if connector is None: + return False + + return isinstance(connector, KVConnectorBase_V1) + + +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_CONNECTOR_AGENT + + if vllm_config.kv_transfer_config is None: + return + + if (vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None): + if envs.VLLM_USE_V1: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + config=vllm_config, role=KVConnectorRole.WORKER) + else: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config, + ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e0eeeffb88a7..cb9658ce1004 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,15 +29,13 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Union) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) @@ -46,9 +44,6 @@ from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, supports_custom_op) -if TYPE_CHECKING: - from vllm.config import VllmConfig - @dataclass class GraphCaptureContext: @@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.reduce_scatter(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +def all_gather(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.all_gather(tensor, dim) + + +def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] * world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + if supports_custom_op(): from vllm.platforms import current_platform direct_register_custom_op( @@ -128,6 +155,20 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + ) + + direct_register_custom_op( + op_name="all_gather", + op_func=all_gather, + mutates_args=[], + fake_impl=all_gather_fake, + ) + class GroupCoordinator: """ @@ -327,6 +368,18 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return self.device_communicator.all_gather(input_, dim) + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + return self.device_communicator.reduce_scatter(input_, dim) + def gather(self, input_: torch.Tensor, dst: int = 0, @@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group -_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None - - -def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: - assert _KV_TRANSFER is not None, ( - "disaggregated KV cache transfer parallel group is not initialized") - return _KV_TRANSFER - @contextmanager def graph_capture(device: torch.device): @@ -962,26 +1007,6 @@ def initialize_model_parallel( _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) -def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: - """ - Initialize KV cache transfer parallel group. - """ - - global _KV_TRANSFER - - if vllm_config.kv_transfer_config is None: - return - - if all([ - vllm_config.kv_transfer_config.is_kv_transfer_instance, - _KV_TRANSFER is None - ]): - _KV_TRANSFER = kv_transfer.KVTransferAgent( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, - config=vllm_config) - - def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 2cb57afd4566..e4d4008cd0a6 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -7,6 +7,7 @@ import dataclasses import datetime import pickle +import socket import time from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple @@ -123,6 +124,10 @@ class StatelessProcessGroup: rank: int world_size: int store: torch._C._distributed_c10d.Store + + # stores a reference to the socket so that the file descriptor stays alive + socket: Optional[socket.socket] + data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter @@ -234,18 +239,33 @@ def create( can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa + launch_server = rank == 0 + if launch_server: + # listen on the specified interface (instead of 0.0.0.0) + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + else: + listen_socket = None + listen_fd = None + store = TCPStore( host_name=host, port=port, world_size=world_size, - is_master=(rank == 0), + is_master=launch_server, timeout=datetime.timedelta(seconds=store_timeout), + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 + master_listen_fd=listen_fd, ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, + socket=listen_socket, data_expiration_seconds=data_expiration_seconds) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 975afe5ada83..5d735103fc03 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,25 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 +# yapf: disable import argparse import dataclasses import json import re import threading from dataclasses import MISSING, dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, - Tuple, Type, Union, cast, get_args, get_origin) +from typing import (Any, Callable, Dict, List, Literal, Optional, Type, + TypeVar, Union, cast, get_args, get_origin) import torch +from typing_extensions import TypeIs, deprecated import vllm.envs as envs from vllm import version -from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, - DecodingConfig, DeviceConfig, HfOverrides, +from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, + ConfigFormat, ConfigType, DecodingConfig, Device, + DeviceConfig, DistributedExecutorBackend, + GuidedDecodingBackendV1, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ModelImpl, ObservabilityConfig, - ParallelConfig, PoolerConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig, get_attr_docs) + ModelConfig, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, PromptAdapterConfig, + SchedulerConfig, SchedulerPolicy, SpeculativeConfig, + TaskOption, TokenizerPoolConfig, VllmConfig, + get_attr_docs, get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -28,33 +34,42 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor +from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor -if TYPE_CHECKING: - from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +# yapf: enable logger = init_logger(__name__) ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] -DEVICE_OPTIONS = [ - "auto", - "cuda", - "neuron", - "cpu", - "tpu", - "xpu", - "hpu", -] +# object is used to allow for special typing forms +T = TypeVar("T") +TypeHint = Union[type[Any], object] +TypeHintT = Union[type[T], object] -def nullable_str(val: str): - if not val or val == "None": - return None - return val +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + try: + if return_type is json.loads and not re.match("^{.*}$", val): + return cast(T, nullable_kvs(val)) + return return_type(val) + except ValueError as e: + raise argparse.ArgumentTypeError( + f"Value {val} cannot be converted to {return_type}.") from e + + return _optional_type -def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: +@deprecated( + "Passing a JSON argument as a string containing comma separated key=value " + "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " + "string instead.") +def nullable_kvs(val: str) -> dict[str, int]: """Parses a string containing comma separate key [str] to value [int] pairs into a dictionary. @@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: Returns: Dictionary with parsed values. """ - if len(val) == 0: - return None - - out_dict: Dict[str, int] = {} + out_dict: dict[str, int] = {} for item in val.split(","): kv_parts = [part.lower().strip() for part in item.split("=")] if len(kv_parts) != 2: @@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: return out_dict +def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: + """Check if the type hint is a specific type.""" + return type_hint is type or get_origin(type_hint) is type + + +def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool: + """Check if the type hints contain a specific type.""" + return any(is_type(type_hint, type) for type_hint in type_hints) + + +def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT: + """Get the specific type from the type hints.""" + return next((th for th in type_hints if is_type(th, type)), None) + + +def is_not_builtin(type_hint: TypeHint) -> bool: + """Check if the class is not a built-in type.""" + return type_hint.__module__ != "builtins" + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + # Get the default value of the field + default = field.default + if field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name] + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Get the set of possible types for the field + type_hints: set[TypeHint] = set() + if get_origin(field.type) is Union: + type_hints.update(get_args(field.type)) + else: + type_hints.add(field.type) + + # Set other kwargs based on the type hints + if contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + # Creates choices from Literal arguments + type_hint = get_type(type_hints, Literal) + choices = sorted(get_args(type_hint)) + kwargs[name]["choices"] = choices + choice_type = type(choices[0]) + assert all(type(c) is choice_type for c in choices), ( + "All choices must be of the same type. " + f"Got {choices} with types {[type(c) for c in choices]}") + kwargs[name]["type"] = choice_type + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif contains_type(type_hints, dict): + # Dict arguments will always be optional + kwargs[name]["type"] = optional_type(json.loads) + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + return kwargs + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -105,14 +216,15 @@ class EngineArgs: load_format: str = LoadConfig.load_format config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = None max_model_len: Optional[int] = None # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. distributed_executor_backend: Optional[Union[ - str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + DistributedExecutorBackend, + Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size @@ -120,20 +232,23 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers - block_size: Optional[int] = None - enable_prefix_caching: Optional[bool] = None - prefix_caching_hash_algo: str = "builtin" + block_size: Optional[BlockSize] = CacheConfig.block_size + enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching + prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + CacheConfig.prefix_caching_hash_algo disable_sliding_window: bool = False disable_cascade_attn: bool = False use_v2_block_manager: bool = True - swap_space: float = 4 # GiB - cpu_offload_gb: float = 0 # GiB - gpu_memory_utilization: float = 0.90 - max_num_batched_tokens: Optional[int] = None - max_num_partial_prefills: Optional[int] = 1 - max_long_partial_prefills: Optional[int] = 1 - long_prefill_token_threshold: Optional[int] = 0 - max_num_seqs: Optional[int] = None + swap_space: float = CacheConfig.swap_space + cpu_offload_gb: float = CacheConfig.cpu_offload_gb + gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + max_num_batched_tokens: Optional[ + int] = SchedulerConfig.max_num_batched_tokens + max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills + max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills + long_prefill_token_threshold: int = \ + SchedulerConfig.long_prefill_token_threshold + max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False revision: Optional[str] = None @@ -147,42 +262,51 @@ class EngineArgs: enforce_eager: Optional[bool] = None max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - tokenizer_pool_size: int = 0 - # Note: Specifying a tokenizer pool by passing a class - # is intended for expert use only. The API may change without - # notice. - tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" - tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None - limit_mm_per_prompt: Optional[Mapping[str, int]] = None + # The following three fields are deprecated and will be removed in a future + # release. Setting them will have no effect. Please remove them from your + # configurations. + tokenizer_pool_size: int = TokenizerPoolConfig.pool_size + tokenizer_pool_type: str = TokenizerPoolConfig.pool_type + tokenizer_pool_extra_config: dict = \ + get_field(TokenizerPoolConfig, "extra_config") + limit_mm_per_prompt: dict[str, int] = \ + get_field(MultiModalConfig, "limit_per_prompt") mm_processor_kwargs: Optional[Dict[str, Any]] = None disable_mm_preprocessor_cache: bool = False + # LoRA fields enable_lora: bool = False - enable_lora_bias: bool = False - max_loras: int = 1 - max_lora_rank: int = 16 + enable_lora_bias: bool = LoRAConfig.bias_enabled + max_loras: int = LoRAConfig.max_loras + max_lora_rank: int = LoRAConfig.max_lora_rank + fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras + max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras + lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype + lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size + long_lora_scaling_factors: Optional[tuple[float, ...]] = \ + LoRAConfig.long_lora_scaling_factors + # PromptAdapter fields enable_prompt_adapter: bool = False - max_prompt_adapters: int = 1 - max_prompt_adapter_token: int = 0 - fully_sharded_loras: bool = False - lora_extra_vocab_size: int = 256 - long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' - max_cpu_loras: Optional[int] = None - device: str = 'auto' - num_scheduler_steps: int = 1 - multi_step_stream_outputs: bool = True + max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters + max_prompt_adapter_token: int = \ + PromptAdapterConfig.max_prompt_adapter_token + + device: Device = DeviceConfig.device + num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps + multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[int] = None - num_lookahead_slots: int = 0 - model_loader_extra_config: Optional[ - dict] = LoadConfig.model_loader_extra_config + num_gpu_blocks_override: Optional[ + int] = CacheConfig.num_gpu_blocks_override + num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots + model_loader_extra_config: dict = \ + get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns - preemption_mode: Optional[str] = None + preemption_mode: Optional[str] = SchedulerConfig.preemption_mode - scheduler_delay_factor: float = 0.0 - enable_chunked_prefill: Optional[bool] = None - disable_chunked_mm_input: bool = False + scheduler_delay_factor: float = SchedulerConfig.delay_factor + enable_chunked_prefill: Optional[ + bool] = SchedulerConfig.enable_chunked_prefill + disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input guided_decoding_backend: str = DecodingConfig.guided_decoding_backend logits_processor_pattern: Optional[str] = None @@ -194,8 +318,8 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False - scheduling_policy: Literal["fcfs", "priority"] = "fcfs" - scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" + scheduling_policy: SchedulerPolicy = SchedulerConfig.policy + scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None @@ -210,11 +334,11 @@ class EngineArgs: enable_sleep_mode: bool = False model_impl: str = "auto" - calculate_kv_scales: Optional[bool] = None + calculate_kv_scales: bool = CacheConfig.calculate_kv_scales additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None - reasoning_parser: Optional[str] = None + reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load def __post_init__(self): @@ -236,38 +360,6 @@ def __post_init__(self): def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """Shared CLI arguments for vLLM engine.""" - def is_type_in_union(cls: type[Any], type: type[Any]) -> bool: - """Check if the class is a type in a union type.""" - return get_origin(cls) is Union and type in get_args(cls) - - def is_optional(cls: type[Any]) -> bool: - """Check if the class is an optional type.""" - return is_type_in_union(cls, type(None)) - - def get_kwargs(cls: type[Any]) -> Dict[str, Any]: - cls_docs = get_attr_docs(cls) - kwargs = {} - for field in fields(cls): - name = field.name - # One of these will always be present - default = (field.default_factory - if field.default is MISSING else field.default) - kwargs[name] = {"default": default, "help": cls_docs[name]} - # When using action="store_true" - # add_argument doesn't accept type - if field.type is bool: - continue - # Handle optional fields - if is_optional(field.type): - kwargs[name]["type"] = nullable_str - continue - # Handle str in union fields - if is_type_in_union(field.type, str): - kwargs[name]["type"] = str - continue - kwargs[name]["type"] = field.type - return kwargs - # Model arguments parser.add_argument( '--model', @@ -285,13 +377,13 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'which task to use.') parser.add_argument( '--tokenizer', - type=nullable_str, + type=optional_type(str), default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use. ' 'If unspecified, model name or path will be used.') parser.add_argument( "--hf-config-path", - type=nullable_str, + type=optional_type(str), default=EngineArgs.hf_config_path, help='Name or path of the huggingface config to use. ' 'If unspecified, model name or path will be used.') @@ -303,21 +395,21 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'the input. The generated output will contain token ids.') parser.add_argument( '--revision', - type=nullable_str, + type=optional_type(str), default=None, help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=nullable_str, + type=optional_type(str), default=None, help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=nullable_str, + type=optional_type(str), default=None, help='Revision of the huggingface tokenizer to use. ' 'It can be a branch name, a tag name, or a commit id. ' @@ -357,7 +449,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: load_group.add_argument('--model-loader-extra-config', **load_kwargs["model_loader_extra_config"]) load_group.add_argument('--use-tqdm-on-load', - action=argparse.BooleanOptionalAction, **load_kwargs["use_tqdm_on_load"]) parser.add_argument( @@ -382,14 +473,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: '* "bfloat16" for a balance between precision and range.\n' '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.') - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default=EngineArgs.kv_cache_dtype, - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument('--max-model-len', type=human_readable_int, default=EngineArgs.max_model_len, @@ -399,21 +482,25 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'Examples:\n' '- 1k → 1000\n' '- 1K → 1024\n') - parser.add_argument( + + # Guided decoding arguments + guided_decoding_kwargs = get_kwargs(DecodingConfig) + guided_decoding_group = parser.add_argument_group( + title="DecodingConfig", + description=DecodingConfig.__doc__, + ) + guided_decoding_group.add_argument( '--guided-decoding-backend', - type=str, - default=DecodingConfig.guided_decoding_backend, - help='Which engine will be used for guided decoding' - ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/mlc-ai/xgrammar and ' - 'https://github.com/guidance-ai/llguidance.' - 'Valid backend values are "xgrammar", "guidance", and "auto". ' - 'With "auto", we will make opinionated choices based on request ' - 'contents and what the backend libraries currently support, so ' - 'the behavior is subject to change in each release.') + **guided_decoding_kwargs["guided_decoding_backend"]) + guided_decoding_group.add_argument( + "--reasoning-parser", + # This choices is a special case because it's not static + choices=list(ReasoningParserManager.reasoning_parsers), + **guided_decoding_kwargs["reasoning_backend"]) + parser.add_argument( '--logits-processor-pattern', - type=nullable_str, + type=optional_type(str), default=None, help='Optional regex pattern specifying valid logits processor ' 'qualified names that can be passed with the `logits_processors` ' @@ -439,7 +526,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: ) parallel_group.add_argument( '--distributed-executor-backend', - choices=['ray', 'mp', 'uni', 'external_launcher'], **parallel_kwargs["distributed_executor_backend"]) parallel_group.add_argument( '--pipeline-parallel-size', '-pp', @@ -450,46 +536,40 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: **parallel_kwargs["data_parallel_size"]) parallel_group.add_argument( '--enable-expert-parallel', - action='store_true', **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument( '--max-parallel-loading-workers', **parallel_kwargs["max_parallel_loading_workers"]) parallel_group.add_argument( '--ray-workers-use-nsight', - action='store_true', **parallel_kwargs["ray_workers_use_nsight"]) parallel_group.add_argument( '--disable-custom-all-reduce', - action='store_true', **parallel_kwargs["disable_custom_all_reduce"]) - # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32, 64, 128], - help='Token block size for contiguous chunks of ' - 'tokens. This is ignored on neuron devices and ' - 'set to ``--max-model-len``. On CUDA devices, ' - 'only block sizes up to 32 are supported. ' - 'On HPU devices, block size defaults to 128.') - parser.add_argument( - "--enable-prefix-caching", - action=argparse.BooleanOptionalAction, - default=EngineArgs.enable_prefix_caching, - help="Enables automatic prefix caching. " - "Use ``--no-enable-prefix-caching`` to disable explicitly.", - ) - parser.add_argument( - "--prefix-caching-hash-algo", - type=str, - choices=["builtin", "sha256"], - default=EngineArgs.prefix_caching_hash_algo, - help="Set the hash algorithm for prefix caching. " - "Options are 'builtin' (Python's built-in hash) or 'sha256' " - "(collision resistant but with certain overheads).", + # KV cache arguments + cache_kwargs = get_kwargs(CacheConfig) + cache_group = parser.add_argument_group( + title="CacheConfig", + description=CacheConfig.__doc__, ) + cache_group.add_argument('--block-size', **cache_kwargs["block_size"]) + cache_group.add_argument('--gpu-memory-utilization', + **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"]) + cache_group.add_argument('--kv-cache-dtype', + **cache_kwargs["cache_dtype"]) + cache_group.add_argument('--num-gpu-blocks-override', + **cache_kwargs["num_gpu_blocks_override"]) + cache_group.add_argument("--enable-prefix-caching", + **cache_kwargs["enable_prefix_caching"]) + cache_group.add_argument("--prefix-caching-hash-algo", + **cache_kwargs["prefix_caching_hash_algo"]) + cache_group.add_argument('--cpu-offload-gb', + **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument('--calculate-kv-scales', + **cache_kwargs["calculate_kv_scales"]) + parser.add_argument('--disable-sliding-window', action='store_true', help='Disables sliding window, ' @@ -502,86 +582,11 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'block manager v2) is now the default. ' 'Setting this flag to True or False' ' has no effect on vLLM behavior.') - parser.add_argument( - '--num-lookahead-slots', - type=int, - default=EngineArgs.num_lookahead_slots, - help='Experimental scheduling config necessary for ' - 'speculative decoding. This will be replaced by ' - 'speculative config in the future; it is present ' - 'to enable correctness tests until then.') parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='Random seed for operations.') - parser.add_argument('--swap-space', - type=float, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU.') - parser.add_argument( - '--cpu-offload-gb', - type=float, - default=0, - help='The space in GiB to offload to CPU, per GPU. ' - 'Default is 0, which means no offloading. Intuitively, ' - 'this argument can be seen as a virtual way to increase ' - 'the GPU memory size. For example, if you have one 24 GB ' - 'GPU and set this to 10, virtually you can think of it as ' - 'a 34 GB GPU. Then you can load a 13B model with BF16 weight, ' - 'which requires at least 26GB GPU memory. Note that this ' - 'requires fast CPU-GPU interconnect, as part of the model is ' - 'loaded from CPU memory to GPU memory on the fly in each ' - 'model forward pass.') - parser.add_argument( - '--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='The fraction of GPU memory to be used for the model ' - 'executor, which can range from 0 to 1. For example, a value of ' - '0.5 would imply 50%% GPU memory utilization. If unspecified, ' - 'will use the default value of 0.9. This is a per-instance ' - 'limit, and only applies to the current vLLM instance.' - 'It does not matter if you have another vLLM instance running ' - 'on the same GPU. For example, if you have two vLLM instances ' - 'running on the same GPU, you can set the GPU memory utilization ' - 'to 0.5 for each instance.') - parser.add_argument( - '--num-gpu-blocks-override', - type=int, - default=None, - help='If specified, ignore GPU profiling result and use this number' - ' of GPU blocks. Used for testing preemption.') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='Maximum number of batched tokens per ' - 'iteration.') - parser.add_argument( - "--max-num-partial-prefills", - type=int, - default=EngineArgs.max_num_partial_prefills, - help="For chunked prefill, the max number of concurrent \ - partial prefills.") - parser.add_argument( - "--max-long-partial-prefills", - type=int, - default=EngineArgs.max_long_partial_prefills, - help="For chunked prefill, the maximum number of prompts longer " - "than --long-prefill-token-threshold that will be prefilled " - "concurrently. Setting this less than --max-num-partial-prefills " - "will allow shorter prompts to jump the queue in front of longer " - "prompts in some cases, improving latency.") - parser.add_argument( - "--long-prefill-token-threshold", - type=float, - default=EngineArgs.long_prefill_token_threshold, - help="For chunked prefill, a request is considered long if the " - "prompt is longer than this number of tokens.") - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='Maximum number of sequences per iteration.') parser.add_argument( '--max-logprobs', type=int, @@ -594,7 +599,7 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: # Quantization settings. parser.add_argument('--quantization', '-q', - type=nullable_str, + type=optional_type(str), choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' @@ -645,154 +650,108 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'Additionally for encoder-decoder models, if the ' 'sequence length of the encoder input is larger ' 'than this, we fall back to the eager mode.') - parser.add_argument('--tokenizer-pool-size', - type=int, - default=EngineArgs.tokenizer_pool_size, - help='Size of tokenizer pool to use for ' - 'asynchronous tokenization. If 0, will ' - 'use synchronous tokenization.') - parser.add_argument('--tokenizer-pool-type', - type=str, - default=EngineArgs.tokenizer_pool_type, - help='Type of tokenizer pool to use for ' - 'asynchronous tokenization. Ignored ' - 'if tokenizer_pool_size is 0.') - parser.add_argument('--tokenizer-pool-extra-config', - type=nullable_str, - default=EngineArgs.tokenizer_pool_extra_config, - help='Extra config for tokenizer pool. ' - 'This should be a JSON string that will be ' - 'parsed into a dictionary. Ignored if ' - 'tokenizer_pool_size is 0.') + + # Tokenizer arguments + tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) + tokenizer_group = parser.add_argument_group( + title="TokenizerPoolConfig", + description=TokenizerPoolConfig.__doc__, + ) + tokenizer_group.add_argument('--tokenizer-pool-size', + **tokenizer_kwargs["pool_size"]) + tokenizer_group.add_argument('--tokenizer-pool-type', + **tokenizer_kwargs["pool_type"]) + tokenizer_group.add_argument('--tokenizer-pool-extra-config', + **tokenizer_kwargs["extra_config"]) # Multimodal related configs - parser.add_argument( - '--limit-mm-per-prompt', - type=nullable_kvs, - default=EngineArgs.limit_mm_per_prompt, - # The default value is given in - # MultiModalConfig.get_default_limit_per_prompt - help=('For each multimodal plugin, limit how many ' - 'input instances to allow for each prompt. ' - 'Expects a comma-separated list of items, ' - 'e.g.: `image=16,video=2` allows a maximum of 16 ' - 'images and 2 videos per prompt. Defaults to ' - '1 (V0) or 999 (V1) for each modality.')) + multimodal_kwargs = get_kwargs(MultiModalConfig) + multimodal_group = parser.add_argument_group( + title="MultiModalConfig", + description=MultiModalConfig.__doc__, + ) + multimodal_group.add_argument('--limit-mm-per-prompt', + **multimodal_kwargs["limit_per_prompt"]) + parser.add_argument( '--mm-processor-kwargs', default=None, type=json.loads, - help=('Overrides for the multimodal input mapping/processing, ' - 'e.g., image processor. For example: ``{"num_crops": 4}``.')) + help=('Overrides for the multi-modal processor obtained from ' + '``AutoProcessor.from_pretrained``. The available overrides ' + 'depend on the model that is being run.' + 'For example, for Phi-3-Vision: ``{"num_crops": 4}``.')) parser.add_argument( '--disable-mm-preprocessor-cache', action='store_true', - help='If true, then disables caching of the multi-modal ' - 'preprocessor/mapper. (not recommended)') + help='If True, disable caching of the processed multi-modal ' + 'inputs.') # LoRA related configs - parser.add_argument('--enable-lora', - action='store_true', - help='If True, enable handling of LoRA adapters.') - parser.add_argument('--enable-lora-bias', - action='store_true', - help='If True, enable bias for LoRA adapters.') - parser.add_argument('--max-loras', - type=int, - default=EngineArgs.max_loras, - help='Max number of LoRAs in a single batch.') - parser.add_argument('--max-lora-rank', - type=int, - default=EngineArgs.max_lora_rank, - help='Max LoRA rank.') - parser.add_argument( - '--lora-extra-vocab-size', - type=int, - default=EngineArgs.lora_extra_vocab_size, - help=('Maximum size of extra vocabulary that can be ' - 'present in a LoRA adapter (added to the base ' - 'model vocabulary).')) - parser.add_argument( + lora_kwargs = get_kwargs(LoRAConfig) + lora_group = parser.add_argument_group( + title="LoRAConfig", + description=LoRAConfig.__doc__, + ) + lora_group.add_argument( + '--enable-lora', + action=argparse.BooleanOptionalAction, + help='If True, enable handling of LoRA adapters.') + lora_group.add_argument('--enable-lora-bias', + **lora_kwargs["bias_enabled"]) + lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"]) + lora_group.add_argument('--max-lora-rank', + **lora_kwargs["max_lora_rank"]) + lora_group.add_argument('--lora-extra-vocab-size', + **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument( '--lora-dtype', - type=str, - default=EngineArgs.lora_dtype, - choices=['auto', 'float16', 'bfloat16'], - help=('Data type for LoRA. If auto, will default to ' - 'base model dtype.')) - parser.add_argument( - '--long-lora-scaling-factors', - type=nullable_str, - default=EngineArgs.long_lora_scaling_factors, - help=('Specify multiple scaling factors (which can ' - 'be different from base model scaling factor ' - '- see eg. Long LoRA) to allow for multiple ' - 'LoRA adapters trained with those scaling ' - 'factors to be used at the same time. If not ' - 'specified, only adapters trained with the ' - 'base model scaling factor are allowed.')) - parser.add_argument( - '--max-cpu-loras', - type=int, - default=EngineArgs.max_cpu_loras, - help=('Maximum number of LoRAs to store in CPU memory. ' - 'Must be >= than max_loras.')) - parser.add_argument( - '--fully-sharded-loras', - action='store_true', - help=('By default, only half of the LoRA computation is ' - 'sharded with tensor parallelism. ' - 'Enabling this will use the fully sharded layers. ' - 'At high sequence length, max rank or ' - 'tensor parallel size, this is likely faster.')) - parser.add_argument('--enable-prompt-adapter', - action='store_true', - help='If True, enable handling of PromptAdapters.') - parser.add_argument('--max-prompt-adapters', - type=int, - default=EngineArgs.max_prompt_adapters, - help='Max number of PromptAdapters in a batch.') - parser.add_argument('--max-prompt-adapter-token', - type=int, - default=EngineArgs.max_prompt_adapter_token, - help='Max number of PromptAdapters tokens') - parser.add_argument("--device", - type=str, - default=EngineArgs.device, - choices=DEVICE_OPTIONS, - help='Device type for vLLM execution.') - parser.add_argument('--num-scheduler-steps', - type=int, - default=1, - help=('Maximum number of forward steps per ' - 'scheduler call.')) + **lora_kwargs["lora_dtype"], + ) + lora_group.add_argument('--long-lora-scaling-factors', + **lora_kwargs["long_lora_scaling_factors"]) + lora_group.add_argument('--max-cpu-loras', + **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument('--fully-sharded-loras', + **lora_kwargs["fully_sharded_loras"]) + + # PromptAdapter related configs + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) + prompt_adapter_group = parser.add_argument_group( + title="PromptAdapterConfig", + description=PromptAdapterConfig.__doc__, + ) + prompt_adapter_group.add_argument( + '--enable-prompt-adapter', + action=argparse.BooleanOptionalAction, + help='If True, enable handling of PromptAdapters.') + prompt_adapter_group.add_argument( + '--max-prompt-adapters', + **prompt_adapter_kwargs["max_prompt_adapters"]) + prompt_adapter_group.add_argument( + '--max-prompt-adapter-token', + **prompt_adapter_kwargs["max_prompt_adapter_token"]) + + # Device arguments + device_kwargs = get_kwargs(DeviceConfig) + device_group = parser.add_argument_group( + title="DeviceConfig", + description=DeviceConfig.__doc__, + ) + device_group.add_argument("--device", **device_kwargs["device"]) + + # Speculative arguments + speculative_group = parser.add_argument_group( + title="SpeculativeConfig", + description=SpeculativeConfig.__doc__, + ) + speculative_group.add_argument( + '--speculative-config', + type=json.loads, + default=None, + help='The configurations for speculative decoding.' + ' Should be a JSON string.') - parser.add_argument( - '--multi-step-stream-outputs', - action=StoreBoolean, - default=EngineArgs.multi_step_stream_outputs, - nargs="?", - const="True", - help='If False, then multi-step will stream outputs at the end ' - 'of all steps') - parser.add_argument( - '--scheduler-delay-factor', - type=float, - default=EngineArgs.scheduler_delay_factor, - help='Apply a delay (of delay factor multiplied by previous ' - 'prompt latency) before scheduling next prompt.') - parser.add_argument( - '--enable-chunked-prefill', - action=StoreBoolean, - default=EngineArgs.enable_chunked_prefill, - nargs="?", - const="True", - help='If set, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens.') - parser.add_argument('--speculative-config', - type=json.loads, - default=None, - help='The configurations for speculative decoding.' - ' Should be a JSON string.') parser.add_argument( '--ignore-patterns', action="append", @@ -801,13 +760,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: help="The pattern(s) to ignore when loading the model." "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") - parser.add_argument( - '--preemption-mode', - type=str, - default=None, - help='If \'recompute\', the engine performs preemption by ' - 'recomputing; If \'swap\', the engine performs preemption by ' - 'block swapping.') parser.add_argument( "--served-model-name", @@ -863,22 +815,47 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: help="Disable async output processing. This may result in " "lower performance.") - parser.add_argument( - '--scheduling-policy', - choices=['fcfs', 'priority'], - default="fcfs", - help='The scheduling policy to use. "fcfs" (first come first served' - ', i.e. requests are handled in order of arrival; default) ' - 'or "priority" (requests are handled based on given ' - 'priority (lower value means earlier handling) and time of ' - 'arrival deciding any ties).') - - parser.add_argument( - '--scheduler-cls', - default=EngineArgs.scheduler_cls, - help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' - 'is the default scheduler. Can be a class directly or the path to ' - 'a class of form "mod.custom_class".') + # Scheduler arguments + scheduler_kwargs = get_kwargs(SchedulerConfig) + scheduler_group = parser.add_argument_group( + title="SchedulerConfig", + description=SchedulerConfig.__doc__, + ) + scheduler_group.add_argument( + '--max-num-batched-tokens', + **scheduler_kwargs["max_num_batched_tokens"]) + scheduler_group.add_argument('--max-num-seqs', + **scheduler_kwargs["max_num_seqs"]) + scheduler_group.add_argument( + "--max-num-partial-prefills", + **scheduler_kwargs["max_num_partial_prefills"]) + scheduler_group.add_argument( + "--max-long-partial-prefills", + **scheduler_kwargs["max_long_partial_prefills"]) + scheduler_group.add_argument( + "--long-prefill-token-threshold", + **scheduler_kwargs["long_prefill_token_threshold"]) + scheduler_group.add_argument('--num-lookahead-slots', + **scheduler_kwargs["num_lookahead_slots"]) + scheduler_group.add_argument('--scheduler-delay-factor', + **scheduler_kwargs["delay_factor"]) + scheduler_group.add_argument('--preemption-mode', + **scheduler_kwargs["preemption_mode"]) + scheduler_group.add_argument('--num-scheduler-steps', + **scheduler_kwargs["num_scheduler_steps"]) + scheduler_group.add_argument( + '--multi-step-stream-outputs', + **scheduler_kwargs["multi_step_stream_outputs"]) + scheduler_group.add_argument('--scheduling-policy', + **scheduler_kwargs["policy"]) + scheduler_group.add_argument( + '--enable-chunked-prefill', + **scheduler_kwargs["enable_chunked_prefill"]) + scheduler_group.add_argument( + "--disable-chunked-mm-input", + **scheduler_kwargs["disable_chunked_mm_input"]) + parser.add_argument('--scheduler-cls', + **scheduler_kwargs["scheduler_cls"]) parser.add_argument( '--override-neuron-config', @@ -905,10 +882,11 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'testing only. level 3 is the recommended level ' 'for production.\n' 'To specify the full compilation config, ' - 'use a JSON string.\n' + 'use a JSON string, e.g. ``{"level": 3, ' + '"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n' 'Following the convention of traditional ' - 'compilers, using -O without space is also ' - 'supported. -O3 is equivalent to -O 3.') + 'compilers, using ``-O`` without space is also ' + 'supported. ``-O3`` is equivalent to ``-O 3``.') parser.add_argument('--kv-transfer-config', type=KVTransferConfig.from_cli, @@ -930,7 +908,7 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: 'class without changing the existing functions.') parser.add_argument( "--generation-config", - type=nullable_str, + type=optional_type(str), default="auto", help="The folder path to the generation config. " "Defaults to 'auto', the generation config will be loaded from " @@ -957,15 +935,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: help="Enable sleep mode for the engine. " "(only cuda platform is supported)") - parser.add_argument( - '--calculate-kv-scales', - action='store_true', - help='This enables dynamic calculation of ' - 'k_scale and v_scale when kv-cache-dtype is fp8. ' - 'If calculate-kv-scales is false, the scales will ' - 'be loaded from the model checkpoint if available. ' - 'Otherwise, the scales will default to 1.0.') - parser.add_argument( "--additional-config", type=json.loads, @@ -983,16 +952,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: "If enabled, the model will be able to generate reasoning content." ) - parser.add_argument( - "--reasoning-parser", - type=str, - choices=list(ReasoningParserManager.reasoning_parsers), - default=None, - help= - "Select the reasoning parser depending on the model that you're " - "using. This is used to parse the reasoning content into OpenAI " - "API format. Required for ``--enable-reasoning``.") - parser.add_argument( "--disable-cascade-attn", action="store_true", @@ -1003,20 +962,6 @@ def get_kwargs(cls: type[Any]) -> Dict[str, Any]: "Note that even if this is set to False, cascade attention will be " "only used when the heuristic tells that it's beneficial.") - parser.add_argument( - "--disable-chunked-mm-input", - action=StoreBoolean, - default=EngineArgs.disable_chunked_mm_input, - nargs="?", - const="True", - help="Disable multimodal input chunking attention for V1. " - "If set to true and chunked prefill is enabled, we do not want to" - " partially schedule a multimodal item. This ensures that if a " - "request has a mixed prompt (like text tokens TTTT followed by " - "image tokens IIIIIIIIII) where only some image tokens can be " - "scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled " - "as TTTT in one step and IIIIIIIIII in the next.") - return parser @classmethod @@ -1210,11 +1155,6 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, @@ -1288,8 +1228,6 @@ def create_engine_config( if self.qlora_adapter_name_or_path is not None and \ self.qlora_adapter_name_or_path != "": - if self.model_loader_extra_config is None: - self.model_loader_extra_config = {} self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path @@ -1370,7 +1308,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - if self.preemption_mode != EngineArgs.preemption_mode: + if self.preemption_mode != SchedulerConfig.preemption_mode: _raise_or_fallback(feature_name="--preemption-mode", recommend_to_remove=True) return False @@ -1381,34 +1319,28 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=True) return False - if self.scheduling_policy != EngineArgs.scheduling_policy: + if self.scheduling_policy != SchedulerConfig.policy: _raise_or_fallback(feature_name="--scheduling-policy", recommend_to_remove=False) return False - if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: + if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps: _raise_or_fallback(feature_name="--num-scheduler-steps", recommend_to_remove=True) return False - if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor: + if self.scheduler_delay_factor != SchedulerConfig.delay_factor: _raise_or_fallback(feature_name="--scheduler-delay-factor", recommend_to_remove=True) return False - if self.additional_config != EngineArgs.additional_config: - _raise_or_fallback(feature_name="--additional-config", - recommend_to_remove=False) - return False - - # Xgrammar and Guidance are supported. - SUPPORTED_GUIDED_DECODING = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", - "guidance:disable-any-whitespace", "auto" - ] - if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: - _raise_or_fallback(feature_name="--guided-decoding-backend", - recommend_to_remove=False) + # remove backend options when doing this check + if self.guided_decoding_backend.split(':')[0] \ + not in get_args(GuidedDecodingBackendV1): + _raise_or_fallback( + feature_name= + f"--guided-decoding-backend={self.guided_decoding_backend}", + recommend_to_remove=False) return False # Need at least Ampere for now (FA support required). @@ -1432,7 +1364,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False if fp8_attention and will_use_fa: - from vllm.vllm_flash_attn.fa_utils import ( + from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() if not supported: @@ -1475,9 +1407,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # No Concurrent Partial Prefills so far. if (self.max_num_partial_prefills - != EngineArgs.max_num_partial_prefills + != SchedulerConfig.max_num_partial_prefills or self.max_long_partial_prefills - != EngineArgs.max_long_partial_prefills): + != SchedulerConfig.max_long_partial_prefills): _raise_or_fallback(feature_name="Concurrent Partial Prefill", recommend_to_remove=False) return False @@ -1497,7 +1429,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True - elif speculative_method == "eagle": + elif speculative_method in ("eagle", "eagle3"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") @@ -1509,16 +1441,17 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Disaggregated Prefill so far. - if self.kv_transfer_config != EngineArgs.kv_transfer_config: - _raise_or_fallback(feature_name="--kv-transfer-config", - recommend_to_remove=False) - return False - - # No FlashInfer or XFormers so far. + # No XFormers so far. V1_BACKENDS = [ - "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", - "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" + "FLASH_ATTN_VLLM_V1", + "FLASH_ATTN", + "PALLAS", + "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", + "TRITON_MLA", + "FLASHMLA", + "FLASHINFER", + "FLASHINFER_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): @@ -1620,9 +1553,7 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None: self.enable_prefix_caching = False # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo is None: - self.prefix_caching_hash_algo = "builtin" - elif self.prefix_caching_hash_algo == "sha256": + if self.prefix_caching_hash_algo == "sha256": raise ValueError( "sha256 is not supported for prefix caching in V0 engine. " "Please use 'builtin'.") @@ -1641,10 +1572,6 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: if self.enable_prefix_caching is None: self.enable_prefix_caching = True - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching and self.prefix_caching_hash_algo is None: - self.prefix_caching_hash_algo = "builtin" - # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: @@ -1661,13 +1588,13 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: # values for non-H100/H200 GPUs. try: from vllm.platforms import current_platform - device_name = current_platform.get_device_name().lower() + device_memory = current_platform.get_device_total_memory() except Exception: # This is only used to set default_max_num_batched_tokens - device_name = "no-device" + device_memory = 0 - if "h100" in device_name or "h200" in device_name: - # For H100 and H200, we use larger default values. + if device_memory >= 70 * GiB_bytes: + # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { UsageContext.LLM_CLASS: 16384, UsageContext.OPENAI_API_SERVER: 8192, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7f9f85e1f93f..6cc9b881464e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -493,12 +493,11 @@ async def add_request_async( tokenizer = await self.get_tokenizer_async(lora_request) self._validate_token_prompt(prompt, tokenizer=tokenizer) - preprocessed_inputs = await self.input_preprocessor.preprocess_async( + processed_inputs = await self.input_preprocessor.preprocess_async( prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - processed_inputs = self.input_processor(preprocessed_inputs) if isinstance(params, SamplingParams) and \ params.guided_decoding is not None: @@ -526,10 +525,15 @@ async def add_request_async( ) async def check_health_async(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() self.model_executor.check_health() + async def collective_rpc_async(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + raise NotImplementedError + async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, @@ -1167,6 +1171,10 @@ def _abort(self, request_id: str) -> None: exception=asyncio.CancelledError, verbose=self.log_requests) + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + return self.engine.get_vllm_config() + async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" return self.engine.get_model_config() @@ -1234,6 +1242,17 @@ async def is_sleeping(self) -> bool: async def add_lora(self, lora_request: LoRARequest) -> None: self.engine.add_lora(lora_request) + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """ + Perform a collective RPC call to the given path. + """ + return await self.engine.collective_rpc_async(method, timeout, args, + kwargs) + # TODO(v1): Remove this class proxy when V1 goes default. if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 54f7b8fb69b5..c23530990611 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,8 +29,7 @@ from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType, SingletonInputs) +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -55,7 +54,7 @@ from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) + TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import (Counter, Device, deprecate_kwargs, @@ -66,7 +65,6 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _R = TypeVar("_R", default=Any) @@ -205,7 +203,7 @@ def validate_outputs( return outputs_ - tokenizer: Optional[BaseTokenizerGroup] + tokenizer: Optional[TokenizerGroup] def __init__( self, @@ -214,7 +212,6 @@ def __init__( log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -275,11 +272,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.tokenizer, mm_registry) - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - self.model_config) - - self.model_executor = executor_class(vllm_config=vllm_config, ) + self.model_executor = executor_class(vllm_config=vllm_config) if self.model_config.runner_type != "pooling": self._initialize_kv_caches() @@ -321,11 +314,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.parallel_config.disable_custom_all_reduce, }) - if self.tokenizer: - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() - self.cached_scheduler_outputs = [ SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size) @@ -537,21 +525,12 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() - def get_tokenizer_group( - self, - group_type: Type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - return tokenizer_group + return self.tokenizer def get_tokenizer( self, @@ -559,11 +538,10 @@ def get_tokenizer( ) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - def _init_tokenizer(self) -> BaseTokenizerGroup: + def _init_tokenizer(self) -> TokenizerGroup: return init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=self.scheduler_config, - parallel_config=self.parallel_config, lora_config=self.lora_config) def _verify_args(self) -> None: @@ -778,12 +756,11 @@ def add_request( prompt, tokenizer=self.get_tokenizer(lora_request=lora_request)) - preprocessed_inputs = self.input_preprocessor.preprocess( + processed_inputs = self.input_preprocessor.preprocess( prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, @@ -914,6 +891,10 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: scheduler.abort_seq_group( request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) + def get_vllm_config(self) -> VllmConfig: + """Gets the vllm configuration.""" + return self.vllm_config + def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config @@ -1948,8 +1929,6 @@ def is_sleeping(self) -> bool: return self.model_executor.is_sleeping def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() self.model_executor.check_health() def is_tracing_enabled(self) -> bool: @@ -2058,7 +2037,7 @@ def _validate_model_input( raise ValueError(f"The {prompt_type} prompt cannot be empty") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) >= max_prompt_len: + if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 7c4265fac20b..033551d07c39 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -140,16 +140,13 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] - if not vllm_config.model_config.enforce_eager: - buckets = vllm_config.compilation_config.\ - cudagraph_capture_sizes.copy() - buckets.sort() self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", labelnames=labelnames, - buckets=buckets) + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + ]) self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f058b13297bb..eb3ae89394ec 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -93,6 +93,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self._errored_with: Optional[BaseException] = None # Get the configs. + self.vllm_config = engine_config self.model_config = engine_config.model_config self.decoding_config = engine_config.decoding_config @@ -100,7 +101,6 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=engine_config.scheduler_config, - parallel_config=engine_config.parallel_config, lora_config=engine_config.lora_config) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer) @@ -377,6 +377,9 @@ async def get_input_preprocessor(self) -> InputPreprocessor: async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + async def get_decoding_config(self) -> DecodingConfig: return self.decoding_config diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 5f126c7571dc..126e7da70216 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -178,7 +178,7 @@ def _process_seq_outputs(self, seq: Sequence, # generates a fixed number of tokens without evaluating stopping # conditions within the block. This can cause an eos token to be # unintentionally ignored. - if not sampling_params.ignore_eos: + if not sampling_params.ignore_eos and self.detokenizer: eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id # Avoiding .index calls as exception throwing in the happy path # is expensive. diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e2974b02c5ba..7e5ac3a28452 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -5,7 +5,7 @@ from typing import AsyncGenerator, List, Mapping, Optional from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt @@ -220,6 +220,11 @@ async def abort(self, request_id: str) -> None: """ ... + @abstractmethod + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + ... + @abstractmethod async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index c81ff958531b..1c027181156f 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -111,7 +111,7 @@ async def init_app( engine = (llm_engine if llm_engine is not None else AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.API_SERVER)) - + app.state.engine_client = engine return app diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6fb7dc2c9763..fcaa24eec8c8 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -27,10 +27,11 @@ ChatCompletionToolMessageParam) from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) +from pydantic import TypeAdapter # yapf: enable -# pydantic needs the TypedDict from typing_extensions from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin) +# pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -482,11 +483,8 @@ def _placeholder_str(self, modality: ModalityStr, if modality in ("image", "image_embeds"): if model_type == "chatglm": return "<|begin_of_image|><|endoftext|><|end_of_image|>" - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer + if model_type in ("phi3_v", "phi4mm"): return f"<|image_{current_count}|>" - if model_type == "phi4mm": - return "<|endoftext10|>" # 200010 (see vocab.json in hf model) if model_type in ("minicpmo", "minicpmv"): return "(./)" if model_type in ("blip-2", "florence2", "fuyu", "paligemma", @@ -506,20 +504,24 @@ def _placeholder_str(self, modality: ModalityStr, return "<|image|>" if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "qwen2_5_omni": + return "<|vision_start|><|IMAGE|><|vision_end|>" if model_type == "molmo": return "" if model_type == "aria": return "<|fim_prefix|><|img|><|fim_suffix|>" if model_type == "gemma3": return "" + if model_type == "kimi_vl": + return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501 raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": - if model_type == "ultravox": + if model_type in ("ultravox", "granite_speech"): return "<|audio|>" if model_type == "phi4mm": - return "<|endoftext11|>" # 200011 (see vocab.json in hf model) - if model_type == "qwen2_audio": + return f"<|audio_{current_count}|>" + if model_type in ("qwen2_audio", "qwen2_5_omni"): return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") if model_type == "minicpmo": @@ -528,6 +530,8 @@ def _placeholder_str(self, modality: ModalityStr, elif modality == "video": if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|video_pad|><|vision_end|>" + if model_type == "qwen2_5_omni": + return "<|vision_start|><|VIDEO|><|vision_end|>" if model_type in ("minicpmo", "minicpmv"): return "()" if model_type.startswith("llava"): @@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], # No need to validate using Pydantic again _TextParser = partial(cast, ChatCompletionContentPartTextParam) -_ImageParser = partial(cast, ChatCompletionContentPartImageParam) _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) -_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) -_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) +# Need to validate url objects +_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python +_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python +_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] @@ -1092,7 +1097,11 @@ def _parse_chat_message_content( if role == 'assistant': parsed_msg = _AssistantParser(message) - if "tool_calls" in parsed_msg: + # The 'tool_calls' is not None check ensures compatibility. + # It's needed only if downstream code doesn't strictly + # follow the OpenAI spec. + if ("tool_calls" in parsed_msg + and parsed_msg["tool_calls"] is not None): result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1189,14 +1198,25 @@ def apply_hf_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one.") - return tokenizer.apply_chat_template( - conversation=conversation, # type: ignore[arg-type] - tools=tools, # type: ignore[arg-type] - chat_template=hf_chat_template, - tokenize=tokenize, - **kwargs, - ) + try: + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=hf_chat_template, + tokenize=tokenize, + **kwargs, + ) + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `transformers` while applying chat template") + raise ValueError from e def apply_mistral_chat_template( tokenizer: MistralTokenizer, @@ -1205,6 +1225,8 @@ def apply_mistral_chat_template( tools: Optional[list[dict[str, Any]]], **kwargs: Any, ) -> list[int]: + from mistral_common.exceptions import MistralCommonException + # The return value of resolve_mistral_chat_template is always None, # and we won't use it. resolve_mistral_chat_template( @@ -1222,5 +1244,16 @@ def apply_mistral_chat_template( # if input does not comply with the expected format. # We convert those assertion errors to ValueErrors so they can be # are properly caught in the preprocessing_input step - except AssertionError as e: + except (AssertionError, MistralCommonException) as e: + raise ValueError from e + + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `mistral_common` while applying chat " + "template") raise ValueError from e diff --git a/vllm/entrypoints/cli/benchmark/latency.py b/vllm/entrypoints/cli/benchmark/latency.py new file mode 100644 index 000000000000..5aca16e0b640 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/latency.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.latency import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): + """ The `latency` subcommand for vllm bench. """ + + def __init__(self): + self.name = "latency" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark the latency of a single batch of requests." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkLatencySubcommand()] diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index 1bcb25be2fca..9e857af7d6db 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import vllm.entrypoints.cli.benchmark.latency import vllm.entrypoints.cli.benchmark.serve +import vllm.entrypoints.cli.benchmark.throughput from vllm.entrypoints.cli.types import CLISubcommand from vllm.utils import FlexibleArgumentParser -# TODO: Add the rest of the benchmark subcommands here, -# e.g., throughput, latency, etc. BENCHMARK_CMD_MODULES = [ + vllm.entrypoints.cli.benchmark.latency, vllm.entrypoints.cli.benchmark.serve, + vllm.entrypoints.cli.benchmark.throughput, ] diff --git a/vllm/entrypoints/cli/benchmark/throughput.py b/vllm/entrypoints/cli/benchmark/throughput.py new file mode 100644 index 000000000000..88ee6aa03857 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/throughput.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse + +from vllm.benchmarks.throughput import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): + """ The `throughput` subcommand for vllm bench. """ + + def __init__(self): + self.name = "throughput" + super().__init__() + + @property + def help(self) -> str: + return "Benchmark offline inference throughput." + + def add_cli_args(self, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkThroughputSubcommand()] diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py new file mode 100644 index 000000000000..d5f9f7e729f0 --- /dev/null +++ b/vllm/entrypoints/cli/collect_env.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from vllm.collect_env import main as collect_env_main +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser + + +class CollectEnvSubcommand(CLISubcommand): + """The `serve` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "collect-env" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Collect information about the environment.""" + collect_env_main() + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "collect-env", + help="Start collecting environment information.", + description="Start collecting environment information.", + usage="vllm collect-env") + return make_arg_parser(serve_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [CollectEnvSubcommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index aa54bd66bed6..b7c1afce7118 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -5,6 +5,7 @@ import sys import vllm.entrypoints.cli.benchmark.main +import vllm.entrypoints.cli.collect_env import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.serve import vllm.version @@ -15,6 +16,7 @@ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, vllm.entrypoints.cli.benchmark.main, + vllm.entrypoints.cli.collect_env, ] diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index b09ee526f14a..a4f70a51ebaf 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,9 +12,11 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.protocol import EngineClient from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) @@ -40,6 +42,8 @@ async def serve_http(app: FastAPI, loop = asyncio.get_running_loop() + watchdog_task = loop.create_task( + watchdog_loop(server, app.state.engine_client)) server_task = loop.create_task( server.serve(sockets=[sock] if sock else None)) @@ -52,6 +56,7 @@ async def serve_http(app: FastAPI, def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + watchdog_task.cancel() if ssl_cert_refresher: ssl_cert_refresher.stop() @@ -73,48 +78,69 @@ async def dummy_shutdown() -> None: port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() + finally: + watchdog_task.cancel() + + +async def watchdog_loop(server: uvicorn.Server, engine: EngineClient): + """ + # Watchdog task that runs in the background, checking + # for error state in the engine. Needed to trigger shutdown + # if an exception arises is StreamingResponse() generator. + """ + VLLM_WATCHDOG_TIME_S = 5.0 + while True: + await asyncio.sleep(VLLM_WATCHDOG_TIME_S) + terminate_if_errored(server, engine) + + +def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): + """ + See discussions here on shutting down a uvicorn server + https://github.com/encode/uvicorn/discussions/1103 + In this case we cannot await the server shutdown here + because handler must first return to close the connection + for this request. + """ + engine_errored = engine.errored and not engine.is_running + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored: + server.should_exit = True def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: - """Adds handlers for fatal errors that should crash the server""" + """ + VLLM V1 AsyncLLM catches exceptions and returns + only two types: EngineGenerateError and EngineDeadError. + + EngineGenerateError is raised by the per request generate() + method. This error could be request specific (and therefore + recoverable - e.g. if there is an error in input processing). + + EngineDeadError is raised by the background output_handler + method. This error is global and therefore not recoverable. + + We register these @app.exception_handlers to return nice + responses to the end user if they occur and shut down if needed. + See https://fastapi.tiangolo.com/tutorial/handling-errors/ + for more details on how exception handlers work. + + If an exception is encountered in a StreamingResponse + generator, the exception is not raised, since we already sent + a 200 status. Rather, we send an error message as the next chunk. + Since the exception is not raised, this means that the server + will not automatically shut down. Instead, we use the watchdog + background task for check for errored state. + """ @app.exception_handler(RuntimeError) - async def runtime_error_handler(request: Request, __): - """On generic runtime error, check to see if the engine has died. - It probably has, in which case the server will no longer be able to - handle requests. Trigger a graceful shutdown with a SIGTERM.""" - engine = request.app.state.engine_client - if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored - and not engine.is_running): - logger.fatal("AsyncLLMEngine has failed, terminating server " - "process") - # See discussions here on shutting down a uvicorn server - # https://github.com/encode/uvicorn/discussions/1103 - # In this case we cannot await the server shutdown here because - # this handler must first return to close the connection for - # this request. - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(AsyncEngineDeadError) - async def async_engine_dead_handler(_, __): - """Kill the server if the async engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("AsyncLLMEngine is already dead, terminating server " - "process") - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(MQEngineDeadError) - async def mq_engine_dead_handler(_, __): - """Kill the server if the mq engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("MQLLMEngine is already dead, terminating server " - "process") - server.should_exit = True + @app.exception_handler(EngineDeadError) + @app.exception_handler(EngineGenerateError) + async def runtime_exception_handler(request: Request, __): + terminate_if_errored( + server=server, + engine=request.app.state.engine_client, + ) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a707087a2e28..653e61a11ebd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -40,7 +40,6 @@ RequestOutputKind, SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, is_list_of) @@ -118,7 +117,7 @@ class LLM: disable_async_output_proc: Disable async output processing. This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files - . If `True`, will use the token generated when running + . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the @@ -252,11 +251,15 @@ def __init__( self.request_counter = Counter() self.default_sampling_params: Union[dict[str, Any], None] = None - def get_tokenizer(self) -> AnyTokenizer: - return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group().get_lora_tokenizer( + lora_request) def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: - tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) + tokenizer_group = self.llm_engine.get_tokenizer_group() # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from @@ -520,11 +523,9 @@ def beam_search( prompts: A list of prompts. Each prompt can be a string or a list of token IDs. params: The beam search parameters. - - TODO: how does beam search work together with length penalty, frequency - penalty, and stopping criteria, etc.? """ - + # TODO: how does beam search work together with length penalty, + # frequency, penalty, and stopping criteria, etc.? beam_width = params.beam_width max_tokens = params.max_tokens temperature = params.temperature @@ -536,15 +537,18 @@ def sort_beams_key(x: BeamSearchSequence) -> float: tokenizer.eos_token_id, length_penalty) - # TODO - fix handling of multimodal data for beam search; we pass it - # through in the async version on the abstract EngineClient, but not - # here. - if any("multi_modal_data" in prompt - and prompt["multi_modal_data"] is not None - for prompt in prompts): - logger.warning( - "Multimodal data appears to have been provided, but is not" - " currently being passed through in LLM.beam_search()!") + def create_tokens_prompt_from_beam( + beam: BeamSearchSequence) -> TokensPrompt: + token_prompt_kwargs: TokensPrompt = { + "prompt_token_ids": beam.tokens + } + if beam.multi_modal_data is not None: + token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data + + if beam.mm_processor_kwargs is not None: + token_prompt_kwargs[ + "mm_processor_kwargs"] = beam.mm_processor_kwargs + return TokensPrompt(**token_prompt_kwargs) tokenizer = self.get_tokenizer() # generate 2 * beam_width candidates at each step @@ -556,11 +560,20 @@ def sort_beams_key(x: BeamSearchSequence) -> float: instances: list[BeamSearchInstance] = [] for prompt in prompts: + # Add multimodal processor kwargs & data + mm_kwargs = {} + if "multi_modal_data" in prompt: + mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] + if "mm_processor_kwargs" in prompt: + mm_kwargs["mm_processor_kwargs"] = prompt[ + "mm_processor_kwargs"] + if is_token_prompt(prompt): prompt_tokens = prompt["prompt_token_ids"] else: prompt_tokens = tokenizer.encode(prompt["prompt"]) - instances.append(BeamSearchInstance(prompt_tokens)) + instances.append( + BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) for _ in range(max_tokens): all_beams: list[BeamSearchSequence] = list( @@ -575,8 +588,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: break prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams + create_tokens_prompt_from_beam(beam) for beam in all_beams ] # only runs for one step @@ -602,7 +614,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float: tokens=current_beam.tokens + [token_id], logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) if token_id == tokenizer.eos_token_id and \ not ignore_eos: @@ -701,7 +716,7 @@ def chat( cast(list[ChatCompletionMessageParam], messages) ] - tokenizer = self.get_tokenizer() + tokenizer = self.get_tokenizer(lora_request) model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( chat_template, @@ -724,9 +739,8 @@ def chat( content_format=resolved_content_format, ) - prompt_data: Union[str, list[int]] if isinstance(tokenizer, MistralTokenizer): - prompt_data = apply_mistral_chat_template( + prompt_token_ids = apply_mistral_chat_template( tokenizer, messages=msgs, chat_template=chat_template, @@ -735,7 +749,7 @@ def chat( continue_final_message=continue_final_message, ) else: - prompt_data = apply_hf_chat_template( + prompt_str = apply_hf_chat_template( tokenizer, trust_remote_code=model_config.trust_remote_code, conversation=conversation, @@ -744,12 +758,12 @@ def chat( add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) - prompt: Union[TokensPrompt, TextPrompt] - if is_list_of(prompt_data, int): - prompt = TokensPrompt(prompt_token_ids=prompt_data) - else: - prompt = TextPrompt(prompt=prompt_data) + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) if mm_data is not None: prompt["multi_modal_data"] = mm_data @@ -1048,8 +1062,6 @@ def _embedding_score( if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - scores: list[PoolingRequestOutput] = [] - scores = _cosine_similarity(tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2) @@ -1384,7 +1396,9 @@ def _add_guided_params( grammar=guided_options.guided_grammar, json_object=guided_options.guided_json_object, backend=guided_options.guided_decoding_backend, - whitespace_pattern=guided_options.guided_whitespace_pattern) + whitespace_pattern=guided_options.guided_whitespace_pattern, + structural_tag=guided_options.structural_tag, + ) return params def _run_engine( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6a8bdd060228..136819580897 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -30,7 +30,7 @@ from typing_extensions import assert_never import vllm.envs as envs -from vllm.config import ModelConfig +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore from vllm.engine.multiprocessing.client import MQLLMEngineClient @@ -310,32 +310,33 @@ def mount_metrics(app: FastAPI): # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (CollectorRegistry, make_asgi_app, + from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, multiprocess) from prometheus_fastapi_instrumentator import Instrumentator + registry = REGISTRY + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) if prometheus_multiproc_dir_path is not None: logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", prometheus_multiproc_dir_path) registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) - Instrumentator( - excluded_handlers=[ - "/metrics", - "/health", - "/load", - "/ping", - "/version", - ], - registry=registry, - ).add().instrument(app).expose(app) - - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - else: - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app()) + + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + ], + registry=registry, + ).add().instrument(app).expose(app) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) # Workaround for 307 Redirect for /metrics metrics_route.path_regex = re.compile("^/metrics(?P.*)$") @@ -687,6 +688,11 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request): if envs.VLLM_SERVER_DEV_MODE: + @router.get("/server_info") + async def show_server_info(raw_request: Request): + server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} + return JSONResponse(content=server_info) + @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): """ @@ -875,7 +881,8 @@ async def log_response(request: Request, call_next): section async for section in response.body_iterator ] response.body_iterator = iterate_in_threadpool(iter(response_body)) - logger.info("response_body={%s}", response_body[0].decode()) + logger.info("response_body={%s}", + response_body[0].decode() if response_body else None) return response for middleware in args.middleware: @@ -894,7 +901,7 @@ async def log_response(request: Request, call_next): async def init_app_state( engine_client: EngineClient, - model_config: ModelConfig, + vllm_config: VllmConfig, state: State, args: Namespace, ) -> None: @@ -915,6 +922,8 @@ async def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats + state.vllm_config = vllm_config + model_config = vllm_config.model_config resolved_chat_template = load_chat_template(args.chat_template) if resolved_chat_template is not None: @@ -1069,8 +1078,8 @@ def signal_handler(*_) -> None: async with build_async_engine_client(args) as engine_client: app = build_app(args) - model_config = await engine_client.get_model_config() - await init_app_state(engine_client, model_config, app.state, args) + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) def _listen_addr(a: str) -> str: if is_valid_ipv6_address(a): diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 218a8fbe10b7..b3824013f055 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from typing import Optional, Union, get_args -from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) from vllm.entrypoints.openai.serving_models import (LoRAModulePath, @@ -79,7 +79,7 @@ def __call__( def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--host", - type=nullable_str, + type=optional_type(str), default=None, help="Host name.") parser.add_argument("--port", type=int, default=8000, help="Port number.") @@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=["*"], help="Allowed headers.") parser.add_argument("--api-key", - type=nullable_str, + type=optional_type(str), default=None, help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument( "--lora-modules", - type=nullable_str, + type=optional_type(str), default=None, nargs='+', action=LoRAParserAction, @@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "\"base_model_name\": \"id\"}``") parser.add_argument( "--prompt-adapters", - type=nullable_str, + type=optional_type(str), default=None, nargs='+', action=PromptAdapterParserAction, help="Prompt adapter configurations in the format name=path. " "Multiple adapters can be specified.") parser.add_argument("--chat-template", - type=nullable_str, + type=optional_type(str), default=None, help="The file path to the chat template, " "or the template in single-line form " @@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'similar to OpenAI schema. ' 'Example: ``[{"type": "text", "text": "Hello world!"}]``') parser.add_argument("--response-role", - type=nullable_str, + type=optional_type(str), default="assistant", help="The role name to return if " "``request.add_generation_prompt=true``.") parser.add_argument("--ssl-keyfile", - type=nullable_str, + type=optional_type(str), default=None, help="The file path to the SSL key file.") parser.add_argument("--ssl-certfile", - type=nullable_str, + type=optional_type(str), default=None, help="The file path to the SSL cert file.") parser.add_argument("--ssl-ca-certs", - type=nullable_str, + type=optional_type(str), default=None, help="The CA certificates file.") parser.add_argument( @@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( "--root-path", - type=nullable_str, + type=optional_type(str), default=None, help="FastAPI root_path when app is behind a path based routing proxy." ) parser.add_argument( "--middleware", - type=nullable_str, + type=optional_type(str), action="append", default=[], help="Additional ASGI middleware to apply to the app. " diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4639b4cea06b..015943762ab1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2,6 +2,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import json import re import time from argparse import Namespace @@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): strict: Optional[bool] = None +class StructuralTag(OpenAIBaseModel): + begin: str + # schema is the field, but that causes conflicts with pydantic so + # instead use structural_tag_schema with an alias + structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, + alias="schema") + end: str + + +class StructuralTagResponseFormat(OpenAIBaseModel): + type: Literal["structural_tag"] + structures: list[StructuralTag] + triggers: list[str] + + class ResponseFormat(OpenAIBaseModel): - # type must be "json_schema", "json_object" or "text" + # type must be "json_schema", "json_object", or "text" type: Literal["text", "json_object", "json_schema"] json_schema: Optional[JsonSchemaResponseFormat] = None +AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat] + + class StreamOptions(OpenAIBaseModel): include_usage: Optional[bool] = True continuous_usage_stats: Optional[bool] = False @@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel): max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 - response_format: Optional[ResponseFormat] = None + response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stream: Optional[bool] = False @@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "If specified, the output will follow the context free grammar."), ) + structural_tag: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the structural tag schema."), + ) guided_decoding_backend: Optional[str] = Field( default=None, description=( @@ -476,6 +500,12 @@ def to_sampling_params( json_schema = self.response_format.json_schema assert json_schema is not None self.guided_json = json_schema.json_schema + elif self.response_format.type == "structural_tag": + structural_tag = self.response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structural_tag = json.dumps(s_tag_obj) guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json, @@ -485,6 +515,7 @@ def to_sampling_params( json_object=guided_json_object, backend=self.guided_decoding_backend, whitespace_pattern=self.guided_whitespace_pattern, + structural_tag=self.structural_tag, ) return SamplingParams.from_optional( @@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel): "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt."), ) - response_format: Optional[ResponseFormat] = Field( + response_format: Optional[AnyResponseFormat] = Field( default=None, - description= - ("Similar to chat completion, this parameter specifies the format of " - "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or " - "{'type': 'text' } is supported."), + description=( + "Similar to chat completion, this parameter specifies the format " + "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" + ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." + ), ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, @@ -1577,14 +1609,6 @@ class TranscriptionRequest(OpenAIBaseModel): """ ## TODO (varun) : Support if set to 0, certain thresholds are met !! - temperature: float = Field(default=0.0) - """The sampling temperature, between 0 and 1. - - Higher values like 0.8 will make the output more random, while lower values - like 0.2 will make it more focused / deterministic. If set to 0, the model - will use [log probability](https://en.wikipedia.org/wiki/Log_probability) - to automatically increase the temperature until certain thresholds are hit. - """ timestamp_granularities: list[Literal["word", "segment"]] = Field( alias="timestamp_granularities[]", default=[]) @@ -1596,6 +1620,7 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ + # doc: begin-transcription-extra-params stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -1604,10 +1629,51 @@ class TranscriptionRequest(OpenAIBaseModel): # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False stream_continuous_usage_stats: Optional[bool] = False + # doc: end-transcription-extra-params + + # doc: begin-transcription-sampling-params + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + top_p: Optional[float] = None + """Enables nucleus (top-p) sampling, where tokens are selected from the + smallest possible set whose cumulative probability exceeds `p`. + """ + + top_k: Optional[int] = None + """Limits sampling to the `k` most probable tokens at each step.""" + + min_p: Optional[float] = None + """Filters out tokens with a probability lower than `min_p`, ensuring a + minimum likelihood threshold during sampling. + """ + + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + + frequency_penalty: Optional[float] = 0.0 + """The frequency penalty to use for sampling.""" + + repetition_penalty: Optional[float] = None + """The repetition penalty to use for sampling.""" + + presence_penalty: Optional[float] = 0.0 + """The presence penalty to use for sampling.""" + # doc: end-transcription-sampling-params # Default sampling parameters for transcription requests. _DEFAULT_SAMPLING_PARAMS: dict = { - "temperature": 0, + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, } def to_sampling_params( @@ -1619,13 +1685,35 @@ def to_sampling_params( if default_sampling_params is None: default_sampling_params = {} + # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) return SamplingParams.from_optional(temperature=temperature, max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 0d06ba3df23f..fccf459f17dc 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -12,7 +12,7 @@ from prometheus_client import start_http_server from tqdm import tqdm -from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger, logger # yapf: disable @@ -61,7 +61,7 @@ def parse_args(): "to the output URL.", ) parser.add_argument("--response-role", - type=nullable_str, + type=optional_type(str), default="assistant", help="The role name to return if " "`request.add_generation_prompt=True`.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index bbc8eddd8b1b..49b346a23baf 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -10,6 +10,7 @@ from pydantic import Field from starlette.datastructures import Headers +import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient # yapf conflicts with isort for this block @@ -125,18 +126,29 @@ async def _check_model( self, request: AnyRequest, ) -> Optional[ErrorResponse]: + + error_response = None + if self._is_model_supported(request.model): return None if request.model in [ lora.lora_name for lora in self.models.lora_requests ]: return None + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( + load_result := await self.models.resolve_lora(request.model)): + if isinstance(load_result, LoRARequest): + return None + if isinstance(load_result, ErrorResponse) and \ + load_result.code == HTTPStatus.BAD_REQUEST.value: + error_response = load_result if request.model in [ prompt_adapter.prompt_adapter_name for prompt_adapter in self.models.prompt_adapter_requests ]: return None - return self.create_error_response( + + return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 7a68452efc65..74433a1a3c3f 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -2,6 +2,8 @@ import json import pathlib +from asyncio import Lock +from collections import defaultdict from dataclasses import dataclass from http import HTTPStatus from typing import Optional, Union @@ -15,6 +17,7 @@ UnloadLoRAAdapterRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.utils import AtomicCounter @@ -63,11 +66,19 @@ def __init__( self.base_model_paths = base_model_paths self.max_model_len = model_config.max_model_len self.engine_client = engine_client + self.model_config = model_config self.static_lora_modules = lora_modules self.lora_requests: list[LoRARequest] = [] self.lora_id_counter = AtomicCounter(0) + self.lora_resolvers: list[LoRAResolver] = [] + for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers( + ): + self.lora_resolvers.append( + LoRAResolverRegistry.get_resolver(lora_resolver_name)) + self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) + self.prompt_adapter_requests = [] if prompt_adapters is not None: for i, prompt_adapter in enumerate(prompt_adapters, start=1): @@ -234,6 +245,65 @@ async def _check_unload_lora_adapter_request( return None + async def resolve_lora( + self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: + """Attempt to resolve a LoRA adapter using available resolvers. + + Args: + lora_name: Name/identifier of the LoRA adapter + + Returns: + LoRARequest if found and loaded successfully. + ErrorResponse (404) if no resolver finds the adapter. + ErrorResponse (400) if adapter(s) are found but none load. + """ + async with self.lora_resolver_lock[lora_name]: + # First check if this LoRA is already loaded + for existing in self.lora_requests: + if existing.lora_name == lora_name: + return existing + + base_model_name = self.model_config.model + unique_id = self.lora_id_counter.inc(1) + found_adapter = False + + # Try to resolve using available resolvers + for resolver in self.lora_resolvers: + lora_request = await resolver.resolve_lora( + base_model_name, lora_name) + + if lora_request is not None: + found_adapter = True + lora_request.lora_int_id = unique_id + + try: + await self.engine_client.add_lora(lora_request) + self.lora_requests.append(lora_request) + logger.info( + "Resolved and loaded LoRA adapter '%s' using %s", + lora_name, resolver.__class__.__name__) + return lora_request + except BaseException as e: + logger.warning( + "Failed to load LoRA '%s' resolved by %s: %s. " + "Trying next resolver.", lora_name, + resolver.__class__.__name__, e) + continue + + if found_adapter: + # An adapter was found, but all attempts to load it failed. + return create_error_response( + message=(f"LoRA adapter '{lora_name}' was found " + "but could not be loaded."), + err_type="BadRequestError", + status_code=HTTPStatus.BAD_REQUEST) + else: + # No adapter was found + return create_error_response( + message=f"LoRA adapter {lora_name} does not exist", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + def create_error_response( message: str, diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 20c3238fb3df..5c181616aa01 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -27,6 +27,7 @@ @ToolParserManager.register_module("llama3_json") +@ToolParserManager.register_module("llama4_json") class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.1 models intended for use with the diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 0661445639d7..9dbfe85ecc68 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -38,6 +38,10 @@ def generate_random_id(): # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 return "".join(choices(ALPHANUMERIC, k=9)) + @staticmethod + def is_valid_id(id: str) -> bool: + return id.isalnum() and len(id) == 9 + @ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser): @@ -70,6 +74,19 @@ def __init__(self, tokenizer: AnyTokenizer): "Mistral Tool Parser could not locate the tool call token in " "the tokenizer!") + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if not isinstance( + self.model_tokenizer, MistralTokenizer + ) and request.tools and request.tool_choice != 'none': + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + def extract_tool_calls( self, model_output: str, diff --git a/vllm/env_override.py b/vllm/env_override.py index 0fa5b70c2ef9..71f031d1e231 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -8,8 +8,21 @@ # that interact with vllm workers. # they are executed whenever `import vllm` is called. -# see https://github.com/NVIDIA/nccl/issues/1234 -os.environ['NCCL_CUMEM_ENABLE'] = '0' +if not os.path.exists('/dev/nvidia-caps-imex-channels'): + # normally, we disable NCCL_CUMEM_ENABLE because it + # will cost 1~2 GiB GPU memory with cudagraph+allreduce, + # see https://github.com/NVIDIA/nccl/issues/1234 + # for more details. + # However, NCCL requires NCCL_CUMEM_ENABLE to work with + # multi-node NVLink, typically on GB200-NVL72 systems. + # The ultimate way to detect multi-node NVLink is to use + # NVML APIs, which are too expensive to call here. + # As an approximation, we check the existence of + # /dev/nvidia-caps-imex-channels, used by + # multi-node NVLink to communicate across nodes. + # This will still cost some GPU memory, but it is worthwhile + # because we can get very fast cross-node bandwidth with NVLink. + os.environ['NCCL_CUMEM_ENABLE'] = '0' # see https://github.com/vllm-project/vllm/pull/15951 # it avoids unintentional cuda initialization from torch.cuda.is_available() diff --git a/vllm/envs.py b/vllm/envs.py index f80bf878f79c..ea40bfff11b5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,10 +75,12 @@ VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True - VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_MLA: bool = True + VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True @@ -96,6 +98,7 @@ VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_HPU_USE_DELAYED_SAMPLING: bool = False VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 @@ -103,10 +106,10 @@ VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False - VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 + VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 def get_default_cache_root(): @@ -533,6 +536,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), + # Whether to use aiter paged attention. + # By default is disabled. + "VLLM_ROCM_USE_AITER_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in + ("true", "1")), + # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) @@ -546,18 +555,21 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1")), - # Whether to use aiter block scaled moe kernel. - # By default this is disabled. - "VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE": - lambda: - (os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in - ("true", "1")), - # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), + # Whether to use aiter mla ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MLA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in + ("true", "1")), + # use rocm skinny gemms + "VLLM_ROCM_USE_SKINNY_GEMM": + lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in + ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), @@ -639,6 +651,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), + # Use delayed sampling for HPU to reduce host cpu overhead + # between each step. + "VLLM_HPU_USE_DELAYED_SAMPLING": + lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in + ("1", "true"), + # Rank of the process in the data parallel setting "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), @@ -684,11 +702,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", - # If set, disables TPU-specific optimization for top-k & top-p sampling - "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": - lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) - if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, - # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": @@ -704,6 +717,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # It can be changed with this variable if needed for some reason. "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), + + # Control the threshold for msgspec to use 'zero copy' for + # serialization/deserialization of tensors. Tensors below + # this limit will be encoded into the msgpack buffer, and + # tensors above will instead be sent via a separate message. + # While the sending side still actually copies the tensor + # in all cases, on the receiving side, tensors above this + # limit will actually be zero-copy decoded. + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": + lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), } # end-env-vars-definition @@ -742,7 +765,7 @@ def compute_hash() -> str: variables, ensure that it is included in the factors list if it affects the computation graph. For example, different values of VLLM_PP_LAYER_PARTITION will generate different computation - graphs, so it is included in the factors list. The env vars that + graphs, so it is included in the factors list. The env vars that affect the choice of different kernels or attention backends should also be included in the factors list. """ @@ -771,6 +794,7 @@ def factorize(name: str): if key in environment_variables: factorize(key) - hash_str = hashlib.md5(str(factors).encode()).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 8c004c790fcb..2e4b47c1e24a 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -34,13 +34,13 @@ def _init_executor(self) -> None: if len(device_info) > 1: local_rank = int(device_info[1]) rank = 0 + is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - is_driver_worker=(not self.parallel_config) - or (rank % self.parallel_config.tensor_parallel_size == 0), + is_driver_worker=is_driver_worker, ) self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c5cac..06790d8ee2f8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,6 +11,10 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.logger import init_logger if TYPE_CHECKING: @@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata) + + # KVConnector: trigger (possibly async) load before forward. + # Each attn layer will block until the reading is complete. + trigger_kv_transfer = (attn_metadata is not None + and has_kv_transfer_group() + and is_v1_kv_transfer_group()) + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.start_load_kv(_forward_context) + try: yield finally: @@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any, logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) + + # KVConnector: each attn layer triggers (possibly async) save. + # Ensure all those operations complete before forward() is done. + if trigger_kv_transfer: + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + kv_connector.wait_for_save() + _forward_context = prev_context diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6f8f2cd758f7..ca706e202836 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -2,10 +2,9 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonInputsAdapter, SingletonPrompt, - TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -27,7 +26,6 @@ "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", - "SingletonInputsAdapter", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 138a8f61107b..970b36bca9be 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,17 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 - from collections.abc import Iterable -from dataclasses import dataclass -from functools import cached_property from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast -import torch -from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never +from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: - from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict) - from vllm.multimodal.inputs import MultiModalInputs + from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs class TextPrompt(TypedDict): @@ -147,46 +141,11 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired["MultiModalDataDict"] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - multi_modal_inputs: NotRequired["MultiModalKwargs"] - """ - Optional multi-modal inputs to pass to the model, - if the model supports it. - """ - - multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] - """ - Placeholder ranges for the multi-modal data. - """ - - multi_modal_hashes: NotRequired[list[str]] - """ - The hashes of the multi-modal data. - """ - - mm_processor_kwargs: NotRequired[dict[str, Any]] - """ - Optional multi-modal processor kwargs to be forwarded to the - multimodal input mapper & processor. Note that if multiple modalities - have registered mappers etc for the model being considered, we attempt - to pass the mm_processor_kwargs to each of them. - """ - def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, - multi_modal_data: Optional["MultiModalDataDict"] = None, - multi_modal_inputs: Optional["MultiModalKwargs"] = None, - multi_modal_hashes: Optional[list[str]] = None, - multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -195,16 +154,6 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids - if multi_modal_data is not None: - inputs["multi_modal_data"] = multi_modal_data - if multi_modal_inputs is not None: - inputs["multi_modal_inputs"] = multi_modal_inputs - if multi_modal_hashes is not None: - inputs["multi_modal_hashes"] = multi_modal_hashes - if multi_modal_placeholders is not None: - inputs["multi_modal_placeholders"] = multi_modal_placeholders - if mm_processor_kwargs is not None: - inputs["mm_processor_kwargs"] = mm_processor_kwargs return inputs @@ -237,112 +186,6 @@ class EncoderDecoderInputs(TypedDict): :class:`vllm.sequence.Sequence`. """ - -@dataclass -class SingletonInputsAdapter: - """ - Unified interface to access the components of :class:`SingletonInputs`. - """ - inputs: SingletonInputs - - @cached_property - def prompt(self) -> Optional[str]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt") - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def prompt_token_ids(self) -> list[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt_token_ids", []) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def token_type_ids(self) -> list[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("token_type_ids", []) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def prompt_embeds(self) -> Optional[torch.Tensor]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return None - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_data(self) -> "MultiModalDataDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_data", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_inputs", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_hashes(self) -> list[str]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_hashes", []) - - if inputs["type"] == "multimodal": - # only the case when we use MultiModalInputs - return inputs.get("mm_hashes", []) # type: ignore[return-value] - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_placeholders", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def mm_processor_kwargs(self) -> dict[str, Any]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) - - if inputs["type"] == "multimodal": - return {} - - assert_never(inputs) # type: ignore[arg-type] - - ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] """ The inputs to :data:`vllm.inputs.InputProcessor`. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 669fb96e6653..0edb6da06209 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -13,7 +13,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) @@ -27,7 +27,7 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, - tokenizer: Optional[BaseTokenizerGroup], + tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: super().__init__() @@ -36,7 +36,7 @@ def __init__( self.tokenizer = tokenizer self.mm_registry = mm_registry - def get_tokenizer_group(self) -> BaseTokenizerGroup: + def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: raise ValueError("You cannot pass text prompts when " "`skip_tokenizer_init` is True") @@ -223,28 +223,6 @@ async def _tokenize_prompt_async( lora_request=lora_request, add_special_tokens=add_special_tokens) - def _can_process_multimodal(self) -> bool: - model_config = self.model_config - - if not model_config.is_multimodal_model: - raise ValueError("Your model does not support multi-modal inputs") - - # Interim measure so we can handle models that have yet to be - # updated to use the new multi-modal processor - can_process_multimodal = self.mm_registry.has_processor(model_config) - if not can_process_multimodal: - from vllm.model_executor.models.registry import _VLLM_MODELS - if not any(arch in _VLLM_MODELS - for arch in model_config.architectures): - logger.warning_once( - "Your model uses the legacy input pipeline, which will be " - "removed in an upcoming release. " - "Please upgrade to the new multi-modal processing pipeline " - "(https://docs.vllm.ai/en/latest/design/mm_processing.html)" - ) - - return can_process_multimodal - def _process_multimodal( self, prompt: Union[str, list[int]], @@ -258,8 +236,7 @@ def _process_multimodal( returning the corresponding token IDs and metadata. """ # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal - # input. + # initialized without a tokenizer while using also multi-modal input if not self.tokenizer: tokenizer = object() # Dummy else: @@ -285,8 +262,7 @@ async def _process_multimodal_async( ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal - # input. + # initialized without a tokenizer while using also multi-modal input if not self.tokenizer: tokenizer = object() # Dummy else: @@ -343,7 +319,7 @@ def _prompt_to_llm_inputs( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return self._process_multimodal( prompt_token_ids, multi_modal_data, @@ -355,8 +331,6 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, token_type_ids=token_type_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) if parsed["type"] == "text": @@ -366,7 +340,7 @@ def _prompt_to_llm_inputs( multi_modal_data = text_content.get("multi_modal_data") mm_processor_kwargs = text_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return self._process_multimodal( prompt_text, multi_modal_data, @@ -383,8 +357,6 @@ def _prompt_to_llm_inputs( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) assert_never(parsed) @@ -417,7 +389,7 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return await self._process_multimodal_async( prompt_token_ids, multi_modal_data, @@ -426,11 +398,7 @@ async def _prompt_to_llm_inputs_async( return_mm_hashes=return_mm_hashes, ) - return token_inputs( - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - ) + return token_inputs(prompt_token_ids=prompt_token_ids) if parsed["type"] == "text": text_content = parsed["content"] @@ -439,7 +407,7 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = text_content.get("multi_modal_data") mm_processor_kwargs = text_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return await self._process_multimodal_async( prompt_text, multi_modal_data, @@ -456,8 +424,6 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) assert_never(parsed) @@ -594,15 +560,13 @@ def _process_encoder_decoder_prompt( decoder_inputs = self._prompt_to_llm_inputs(decoder_input) # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: inputs = self._prompt_to_llm_inputs(prompt) - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( @@ -637,15 +601,13 @@ async def _process_encoder_decoder_prompt_async( # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: inputs = await self._prompt_to_llm_inputs_async(prompt) - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 0579893e5d76..4c334ab62d3e 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,24 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -from collections import UserDict from collections.abc import Mapping from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional, - Protocol, Union) +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union -from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar -from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) - -from .data import ProcessorInputs, SingletonInputs -from .parse import split_enc_dec_inputs +from vllm.utils import resolve_mm_processor_kwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -26,8 +16,6 @@ MultiModalRegistry) from vllm.sequence import SequenceData -logger = init_logger(__name__) - _T = TypeVar("_T") _C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) _P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) @@ -172,142 +160,23 @@ def call_hf_processor( raise RuntimeError(msg) from exc -N = TypeVar("N", bound=type[nn.Module]) - - class DummyData(NamedTuple): - """Dummy data used for profiling.""" + """ + Dummy data used for profiling. + + Note: This is only used in V0. + """ seq_data: "SequenceData" multi_modal_data: Optional["MultiModalDataDict"] = None multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None -class DummyDataFactory(Protocol): - - def __call__( - self, - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - **mm_processor_kwargs: Any, - ) -> DummyData: - """ - Create dummy data to be inputted into the model. - - Note: - :data:`InputProcessor` is not applied to the dummy data. - - The :code:`mm_processor_kwargs` are overrides provided at - initialization time to values in the config whose values - may affect the number of tokens per instance. - """ - ... - - -class _MultiModalCounts(UserDict[str, int]): - """ - Wraps `mm_counts` for a more informative error message - when attempting to access a plugin that does not exist. - """ - - def __getitem__(self, key: str) -> int: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"There is no multi-modal plugin with the key: {key}. " - f"Available keys: {set(self.keys())}") - raise KeyError(msg) from exc - - -InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] -"""Preprocess the inputs to the model.""" - - class InputRegistry: """ - A registry to dispatch data processing - according to the target model. + Note: This is only used in V0. """ - def __init__(self) -> None: - self._dummy_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._dummy_encoder_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._input_processors_by_model_type = \ - ClassRegistry[nn.Module, InputProcessor]() - - def _default_dummy_data_factory( - self, - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> DummyData: - """ - The default dummy data factory represents the longest possible text - that can be inputted to the model. - - Note: - :data:`InputProcessor` is not applied to the dummy data. - """ - # Avoid circular import - from vllm.sequence import SequenceData - - return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) - - def register_dummy_data(self, factory: DummyDataFactory): - """ - Register a dummy data factory to a model class. - - During memory profiling, the provided function is invoked to create - dummy data to be inputted into the model. The resulting memory usage - should be an upper bound of what the model would use at inference time. - """ - - def wrapper(model_cls: N) -> N: - if self._dummy_factories_by_model_type.contains(model_cls, - strict=True): - logger.warning( - "Model class %s already has dummy data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def _get_dummy_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - - def register_dummy_encoder_data(self, factory: DummyDataFactory): - """ - Register a dummy encoder data factory to a model class - - This is similar to :meth:`~register_dummy_data`, but for encoder input. - """ - - def wrapper(model_cls: N) -> N: - if self._dummy_encoder_factories_by_model_type.contains( - model_cls, strict=True): - logger.warning( - "Model class %s already has dummy encoder data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_encoder_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_encoder_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - def dummy_data_for_profiling( self, model_config: "ModelConfig", @@ -319,169 +188,25 @@ def dummy_data_for_profiling( Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. - - Note: - This should be called after - :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. """ # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.profiling import MultiModalProfiler from vllm.sequence import SequenceData - if mm_registry.has_processor(model_config): - processor = mm_registry.create_processor(model_config, - disable_cache=True) - profiler = MultiModalProfiler(processor) - - dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len) - if is_encoder_data else - profiler.get_decoder_dummy_data(seq_len)) - _seq_data = SequenceData.from_seqs( - dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined] - - dummy_data = DummyData( - seq_data=_seq_data, - multi_modal_data=getattr(dummy_data_v1, "multi_modal_data", - None), - multi_modal_placeholders=getattr(dummy_data_v1, - "multi_modal_placeholders", - None), - ) - else: - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) - else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, - overrides=model_config.mm_processor_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) - - # Having more tokens is over-conservative but otherwise fine - num_tokens = dummy_data.seq_data.prompt_token_ids - if len(num_tokens) < seq_len: - if is_encoder_data: - logger.warning_once( - f"Expected at least {seq_len} dummy encoder tokens for " - f"profiling, but found {len(num_tokens)} tokens instead.") - else: - raise AssertionError( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - - if (dummy_data.multi_modal_data is not None and - not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): - for k, v in dummy_data.multi_modal_data.items(): - num_items = len(v) if isinstance(v, list) else 1 - num_expected = mm_counts[k] - assert num_items >= num_expected, ( - f"Expected at least {num_expected} dummy '{k}' instances " - f"for profiling, but found {num_items} instances instead.") - - return dummy_data - - def _default_input_processor( - self, - ctx: InputContext, - inputs: ProcessorInputs, - **kwargs: object, - ) -> ProcessorInputs: - """The default input processor is a no-op.""" - return inputs - - def register_input_processor(self, processor: InputProcessor): - """ - Register an input processor to a model class. - - The provided function is invoked on each input to the model. This - happens before - :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`. - """ - - def wrapper(model_cls: N) -> N: - if self._input_processors_by_model_type.contains(model_cls, - strict=True): - logger.warning( - "Model class %s already has input processor " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._input_processors_by_model_type[model_cls] = processor - - return model_cls + if not model_config.is_multimodal_model: + seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) + return DummyData(seq_data=seq_data) - return wrapper + # Encoder dummy data does not contain multi-modal data + if is_encoder_data: + enc_data = mm_registry.get_encoder_dummy_data( + model_config, seq_len) + seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) + return DummyData(seq_data=seq_data) - def _get_model_input_processor(self, model_cls: type[nn.Module]): - return self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) - - def _ensure_mm_kwargs( - self, - inputs: SingletonInputs, - mm_processor_kwargs: dict[str, Any], - ): - if inputs["type"] == "token": - # In case the input processor for that model fails to set it - if "mm_processor_kwargs" not in inputs: - inputs["mm_processor_kwargs"] = mm_processor_kwargs - elif inputs["type"] == "multimodal": - # Be more strict in V2 - assert "mm_kwargs" in inputs - else: - assert_never(inputs["type"]) # type: ignore[arg-type] - - def process_input(self, model_config: "ModelConfig", - inputs: ProcessorInputs) -> ProcessorInputs: - """ - Apply an input processor to an instance of model inputs. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - processor = self._get_model_input_processor(model_cls) - - # Handle multimodal processor kwargs with priority: - # Inference kwargs -> Init kwargs -> {} - # If it's empty, it'll fall back to the default kwarg values - mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - inputs.get("mm_processor_kwargs", {}), # type: ignore - processor, - requires_kw_only=False, - allow_var_kwargs=True, - ) + dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) - processed_inputs = processor( - InputContext(model_config), - inputs, - **mm_processor_kwargs, + return DummyData( + seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), + multi_modal_data=dec_data.multi_modal_data, + multi_modal_placeholders=dec_data.multi_modal_placeholders, ) - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - if encoder_inputs is not None: - self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs) - if decoder_inputs is not None: - self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs) - - return processed_inputs - - def create_input_processor(self, model_config: "ModelConfig"): - """ - Create an input processor (see :meth:`_process_input`) for a - specific model. - """ - return functools.partial(self.process_input, model_config) diff --git a/vllm/lora/ops/triton_ops/lora_expand.py b/vllm/lora/ops/triton_ops/lora_expand.py index eacc6fb46ebd..e41ae1d9594a 100644 --- a/vllm/lora/ops/triton_ops/lora_expand.py +++ b/vllm/lora/ops/triton_ops/lora_expand.py @@ -204,7 +204,6 @@ def _lora_expand( NUM_WARPS = 4 NUM_CTAS = 1 NUM_STAGES = 2 - MAX_NREG = None EVEN_K = K % BLOCK_K == 0 # type: ignore @@ -258,7 +257,6 @@ def _lora_expand( num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, - maxnreg=MAX_NREG, ) return diff --git a/vllm/lora/ops/triton_ops/lora_shrink.py b/vllm/lora/ops/triton_ops/lora_shrink.py index 82331939d859..fb0422cf0b0e 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink.py +++ b/vllm/lora/ops/triton_ops/lora_shrink.py @@ -168,7 +168,6 @@ def _lora_shrink( NUM_WARPS = 4 NUM_CTAS = 1 NUM_STAGES = 2 - MAX_NREG = None EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore @@ -213,7 +212,6 @@ def _lora_shrink( num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, - maxnreg=MAX_NREG, ) return diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py new file mode 100644 index 000000000000..6726ca9a903f --- /dev/null +++ b/vllm/lora/resolver.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import AbstractSet, Dict, Optional + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest + +logger = init_logger(__name__) + + +class LoRAResolver(ABC): + """Base class for LoRA adapter resolvers. + + This class defines the interface for resolving and fetching LoRA adapters. + Implementations of this class should handle the logic for locating and + downloading LoRA adapters from various sources (e.g. S3, cloud storage, + etc.). + """ + + @abstractmethod + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + """Abstract method to resolve and fetch a LoRA model adapter. + + Implements logic to locate and download LoRA adapter based on the name. + Implementations might fetch from a blob storage or other sources. + + Args: + base_model_name: The name/identifier of the base model to resolve. + lora_name: The name/identifier of the LoRA model to resolve. + + Returns: + Optional[LoRARequest]: The resolved LoRA model information, or None + if the LoRA model cannot be found. + """ + pass + + +@dataclass +class _LoRAResolverRegistry: + resolvers: Dict[str, LoRAResolver] = field(default_factory=dict) + + def get_supported_resolvers(self) -> AbstractSet[str]: + """Get all registered resolver names.""" + return self.resolvers.keys() + + def register_resolver( + self, + resolver_name: str, + resolver: LoRAResolver, + ) -> None: + """Register a LoRA resolver. + Args: + resolver_name: Name to register the resolver under. + resolver: The LoRA resolver instance to register. + """ + if resolver_name in self.resolvers: + logger.warning( + "LoRA resolver %s is already registered, and will be " + "overwritten by the new resolver instance %s.", resolver_name, + resolver) + + self.resolvers[resolver_name] = resolver + + def get_resolver(self, resolver_name: str) -> LoRAResolver: + """Get a registered resolver instance by name. + Args: + resolver_name: Name of the resolver to get. + Returns: + The resolver instance. + Raises: + KeyError: If the resolver is not found in the registry. + """ + if resolver_name not in self.resolvers: + raise KeyError( + f"LoRA resolver '{resolver_name}' not found. " + f"Available resolvers: {list(self.resolvers.keys())}") + return self.resolvers[resolver_name] + + +LoRAResolverRegistry = _LoRAResolverRegistry() diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 610cbf87f66a..883ca938ea1a 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -114,7 +114,7 @@ def parse_fine_tuned_lora_name( is_bias whether the tensor is lora bias. """ - # LoRA weight qualified name always starts with `base_model.model.`, + # LoRA weight qualified name usually starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. if "base_model.model." in name: @@ -123,18 +123,23 @@ def parse_fine_tuned_lora_name( # recover the prefix `base_model.model.` name = "base_model.model." + name + # In some situations, we may not start with `base_model.model.`. + # If we don't (e.g., ibm-granite/granite-speech-3.3-8b), + # we should keep the prefix intact. + start_index = 2 if "base_model.model." in name else 0 + parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): - new_name = ".".join(parts[2:-2]) + new_name = ".".join(parts[start_index:-2]) return new_name, parts[-2] == "lora_A", False if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - new_name = ".".join(parts[2:-1]) + new_name = ".".join(parts[start_index:-1]) return new_name, parts[-1] == "lora_embedding_A", False if parts[-1] == "bias": - new_name = ".".join(parts[2:-2]) + new_name = ".".join(parts[start_index:-2]) return new_name, False, True raise ValueError(f"{name} is unsupported LoRA weight") diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d4ee1be9a445..8fdcdcafa980 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -65,7 +65,7 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str, fallback_or_error( guided_params, "xgrammar does not support advanced JSON schema features like " - "enums, patterns or numeric ranges.", "outlines") + "string length, item limits, or property bounds.", "outlines") # xgrammar only supports GBNF grammars, so we must convert Lark. # We must check if the grammar is likely Lark and if that diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py index f19ebcbe420e..95b7c71107aa 100644 --- a/vllm/model_executor/guided_decoding/guidance_decoding.py +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import json from re import escape as regex_escape import llguidance @@ -7,6 +8,8 @@ from vllm.model_executor.guided_decoding.guidance_logits_processors import ( GuidanceLogitsProcessor) from vllm.sampling_params import GuidedDecodingParams +from vllm.v1.structured_output.backend_guidance import ( + process_for_additional_properties) def get_local_guidance_guided_decoding_logits_processor( @@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor( grm = "" any_whitespace = 'disable-any-whitespace' not in \ guided_params.backend_options() - if guided_params.json: + if (guide_json := guided_params.json) is not None: + # Optionally set additionalProperties to False at the top-level + # By default, other backends do not allow additional top-level + # properties, so this makes guidance more similar to other backends + if 'no-additional-properties' in guided_params.backend_options(): + if not isinstance(guide_json, str): + guide_json = json.dumps(guide_json) + guide_json = process_for_additional_properties(guide_json) + grm = llguidance.LLMatcher.grammar_from_json_schema( - guided_params.json, + guide_json, overrides={"whitespace_pattern": guided_params.whitespace_pattern}, defaults={ "whitespace_flexible": any_whitespace, diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index db4ce26806c1..1593868a164a 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -27,14 +27,15 @@ class GuidedDecodingRequest: guided_decoding_backend: Optional[str] = None guided_whitespace_pattern: Optional[str] = None guided_json_object: Optional[bool] = None + structural_tag: Optional[str] = None def __post_init__(self): """Validate that some fields are mutually exclusive.""" - guide_count = sum([ - self.guided_json is not None, self.guided_regex is not None, - self.guided_choice is not None, self.guided_grammar is not None, - self.guided_json_object is not None - ]) + guide_count = sum(x is not None + for x in (self.guided_json, self.guided_regex, + self.guided_choice, self.guided_grammar, + self.guided_json_object, + self.structural_tag)) if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding but multiple are " diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py index ba7c10252699..1ad1ef8fbf16 100644 --- a/vllm/model_executor/guided_decoding/utils.py +++ b/vllm/model_executor/guided_decoding/utils.py @@ -10,16 +10,8 @@ def check_object(obj: dict) -> bool: if not isinstance(obj, dict): return False - # Check for pattern restrictions - if "pattern" in obj: - return True - # Check for numeric ranges - if obj.get("type") in ("integer", "number") and any( - key in obj for key in [ - "minimum", "maximum", "exclusiveMinimum", - "exclusiveMaximum", "multipleOf" - ]): + if obj.get("type") in ("integer", "number") and ("multipleOf" in obj): return True # Check for array unsupported keywords diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 1de0f499c1a6..f082afb7e9c0 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -354,6 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ "gelu": lambda: GeluAndMul(), "silu": lambda: SiluAndMul(), + "gelu_and_mul": lambda: GeluAndMul(), }) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..555d17364452 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..5de5605d401c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..2221e99cd1ad --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..74374c573f3f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..b34b6e4e8a8e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..ab169a0183dd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..324ad7b22fed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..ab6e15552909 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..249359fb93d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..b4efc9b7e44c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..03dfc73b6c0a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..9c07695ba910 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..beaac7f641e4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..ebff99e26dc7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..857d11e48891 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/README b/vllm/model_executor/layers/fused_moe/configs/README index 787bd0611664..85970e2d1cea 100644 --- a/vllm/model_executor/layers/fused_moe/configs/README +++ b/vllm/model_executor/layers/fused_moe/configs/README @@ -9,5 +9,4 @@ The example configurations provided are for the Mixtral model for TP2 on H100 and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have N = 7168 and for TP4 we have N = 3584. -Please feel free to tune the configurations using scripts in `benchmarks/kernels/benchmark_moe.py` -Some of the configurations files are copied from the SGLang repository. Thank you! +See `benchmark/kernels/benchmark_moe.py` on how to generate these config files. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d6a27aa0ddc4..960c7f834857 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -15,7 +15,7 @@ def cutlass_moe_fp8( w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_ids_: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, @@ -23,6 +23,7 @@ def cutlass_moe_fp8( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, + expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ @@ -57,12 +58,19 @@ def cutlass_moe_fp8( quantize the intermediate result between the gemms. Shape: scalar or [M] - out_dtype (torch.Tensor): The output tensor type. + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" assert w1_q.dtype == torch.float8_e4m3fn assert w2_q.dtype == torch.float8_e4m3fn assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" @@ -96,7 +104,13 @@ def cutlass_moe_fp8( k = w1_q.size(1) n = w2_q.size(1) - topk = topk_ids.size(1) + local_topk_ids = topk_ids_ + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids_] != -1, + expert_map[topk_ids_], -1) + + topk = local_topk_ids.size(1) per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) @@ -120,10 +134,23 @@ def cutlass_moe_fp8( dtype=torch.int32, device=device) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + a_map_initializer = torch.empty + c2_initializer = torch.empty + if expert_map is not None: + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + a_map_initializer = torch.zeros + c2_initializer = torch.zeros + + a_map = a_map_initializer((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + c_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, num_experts, n, k) @@ -131,7 +158,7 @@ def cutlass_moe_fp8( rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index ee158d7ee474..62614a59cbe9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -5,17 +5,16 @@ import torch +import vllm._custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) -from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import direct_register_custom_op def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - assert num_bits == 4 - return scalar_types.uint4 + return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 @@ -27,9 +26,12 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -62,7 +64,7 @@ def single_marlin_moe( assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype == torch.float16 + assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] M, K = hidden_states.shape @@ -83,39 +85,54 @@ def single_marlin_moe( block_size_m = config['BLOCK_SIZE_M'] - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = (N // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=hidden_states.device, - requires_grad=False) - - has_zero_point = w_zeros is not None - if w_zeros is None: - w_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - - if g_idx is None: - g_idx = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - if sort_indices is None: - sort_indices = torch.empty((0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - scalar_type = get_scalar_type(num_bits, has_zero_point) + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, block_size_m, E, expert_map) + + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * \ + (sorted_token_ids.size(0) // block_size_m) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + scalar_type = get_scalar_type(num_bits, w_zeros is not None) + intermediate_cache = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K, - is_k_full, E, topk, block_size_m, True, False) + ops.moe_wna16_marlin_gemm(hidden_states, + intermediate_cache, + w, + scales, + w_zeros, + g_idx, + sort_indices, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type=scalar_type, + size_m=M, + size_n=N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=False, + use_fp32_reduce=True, + is_zp_float=False) + intermediate_cache = intermediate_cache.view(-1, topk, N) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -127,9 +144,12 @@ def single_marlin_moe_fake( gating_output: torch.Tensor, topk: int, renormalize: bool, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, num_bits: int = 8, is_k_full: bool = True, ) -> torch.Tensor: @@ -144,24 +164,26 @@ def single_marlin_moe_fake( ) -def fused_marlin_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: +def fused_marlin_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -196,27 +218,12 @@ def fused_marlin_moe( 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( num_bits // 2), "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype == torch.float16 + assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert num_bits in [4, 8] - has_no_act_order = (g_idx1 is None and g_idx2 is None - and sort_indices1 is None and sort_indices2 is None) - has_all_act_order = (g_idx1 is not None and g_idx2 is not None - and sort_indices1 is not None - and sort_indices2 is not None) - assert has_no_act_order or has_all_act_order, ( - "g_idx and sorted_indices " - "must be all not None or must be all None") - - has_no_zp = w1_zeros is None and w2_zeros is None - has_all_zp = w1_zeros is not None and w2_zeros is not None - assert has_no_zp or has_all_zp, ("zero points must be both not None or " - "must be both None") - M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -234,122 +241,128 @@ def fused_marlin_moe( block_size_m = config["BLOCK_SIZE_M"] - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device=current_platform.device_type, - requires_grad=False) - - if has_no_zp: - w1_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - w2_zeros = torch.empty((0, 0), - dtype=hidden_states.dtype, - device=hidden_states.device, - requires_grad=False) - - if has_no_act_order: - g_idx1 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - g_idx2 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - sort_indices1 = torch.empty((0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - sort_indices2 = torch.empty((0, 0), - dtype=torch.int32, - device=hidden_states.device, - requires_grad=False) - - scalar_type1 = get_scalar_type(num_bits, has_all_zp) - scalar_type2 = get_scalar_type(num_bits, has_all_zp) + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, block_size_m, global_num_experts, + expert_map) + + if workspace is None: + max_workspace_size = (max(2 * N, K) // 64) * \ + (sorted_token_ids.size(0) // block_size_m) + device = hidden_states.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_workspace_size = min(max_workspace_size, sms * 4) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype, ) + intermediate_cache13 = torch.empty( + (M * topk_ids.shape[1] * max(2 * N, K), ), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] + intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) + intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] + intermediate_cache3 = intermediate_cache3.view(-1, K) + + use_atomic_add = hidden_states.dtype == torch.half or \ + torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache1 = ops.moe_wna16_marlin_gemm( hidden_states, + intermediate_cache1, w1, - sorted_token_ids, - topk_weights, - topk_ids, w1_scale, w1_zeros, g_idx1, sort_indices1, workspace, - scalar_type1.id, - M, - 2 * N, - K, - is_k_full, - E, - topk, - block_size_m, - True, - False, - ) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=topk, + mul_topk_weights=False, + is_ep=expert_map is not None, + b_q_type=scalar_type1, + size_m=M, + size_n=2 * N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False) torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + if expert_map is not None: + intermediate_cache3.zero_() + + intermediate_cache3 = ops.moe_wna16_marlin_gemm( intermediate_cache2, + intermediate_cache3, w2, - sorted_token_ids, - topk_weights, - topk_ids, w2_scale, w2_zeros, g_idx2, sort_indices2, workspace, - scalar_type2.id, - M, - K, - N, - is_k_full, - E, - topk, - block_size_m, - False, - True, - ) - + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=1, + mul_topk_weights=True, + is_ep=expert_map is not None, + b_q_type=scalar_type2, + size_m=M * topk, + size_n=K, + size_k=N, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False).view(-1, topk, K) + + output = hidden_states if inplace else torch.empty_like(hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - -def fused_marlin_moe_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - num_bits: int = 8, - is_k_full: bool = True, -) -> torch.Tensor: + dim=1, + out=output) + + +def fused_marlin_moe_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 8, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a6d7b426717a..4936d6c527c0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -23,9 +23,7 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, - rocm_aiter_fused_experts, - rocm_aiter_topk_softmax) +from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) @@ -773,6 +771,18 @@ def get_default_config( config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + elif is_marlin: + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break + return {"BLOCK_SIZE_M": block_size_m} + elif M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } else: config = { "BLOCK_SIZE_M": 64, @@ -780,14 +790,6 @@ def get_default_config( "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } return config @@ -842,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: if is_rocm_aiter_moe_enabled(): + from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax return rocm_aiter_topk_softmax return vllm_topk_softmax @@ -1098,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: if is_rocm_aiter_moe_enabled(): + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts return rocm_aiter_fused_experts if inplace: return torch_vllm_inplace_fused_experts diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8ff33a158cf9..3cdf3c97a7d3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -113,12 +113,9 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w13_weight.data), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w2_weight.data), - requires_grad=False) + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, shuffle_weights) @@ -127,10 +124,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight.data = shuffled_w13 + layer.w2_weight.data = shuffled_w2 if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: @@ -473,6 +468,7 @@ def __init__( self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 + self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ac158a7eee53..acaa93f5a23e 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,126 +1,385 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from functools import cache +from typing import List, Optional, Tuple import torch -import vllm.envs as envs +from vllm import envs from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +@cache def is_rocm_aiter_moe_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER_MOE \ - and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER -def is_rocm_aiter_block_scaled_moe_enabled() -> bool: - return is_rocm_aiter_moe_enabled() and \ - envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE - - -def rocm_aiter_fused_experts( - *, +def rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_str: str = "silu") -> torch.Tensor: + + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = \ + ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu + + return asm_moe_tkw1(hidden_states, + w1, + w2, + topk_weight, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation) + + +def rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None, - **kwagrs # Ignore additional keyword arguments -) -> torch.Tensor: + activation_str: str = "silu") -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + from aiter import ck_moe + return ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) + - import aiter as rocm_aiter +def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_states_dtype: torch.dtype, + expert_mask: torch.Tensor, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: List[int], + smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + from aiter import fmoe_fp8_blockscale_g1u1 + from aiter.fused_moe_bf16_asm import moe_sorting_ck + + topk = topk_ids.shape[1] + model_dim = w1.shape[-1] + local_E = E = w1.shape[0] + if expert_mask is not None: + E = expert_mask.numel() + + ( + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + out_asm, + ) = moe_sorting_ck(topk_ids, + topk_weights, + E, + model_dim, + hidden_states_dtype, + expert_mask=expert_mask) + + fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, + sorted_weight_buf, sorted_expert_ids, + num_valid_ids, topk, w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), *block_shape, + smooth_scale) + + return out_asm + + +def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_states_dtype: torch.dtype, + expert_mask: torch.Tensor, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: List[int], + smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + + return torch.empty_like(a1, dtype=torch.bf16) + + +def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + activation: str = "silu") -> torch.Tensor: import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe + from aiter import ActivationType + + assert activation in ["silu", "gelu"], "The given activation:" \ + f" {activation}" \ + " is not supported in" \ + " AITER." + if activation == "silu": + aiter_activation = ActivationType.Silu + else: + aiter_activation = ActivationType.Gelu + + return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weight, + topk_ids=topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + activation=aiter_activation) + + +def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + activation: str = "silu") -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + from aiter import topk_softmax + topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output, renormalize) + + +def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + pass + + +if current_platform.is_rocm(): + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_ck_moe", + op_func=rocm_aiter_ck_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_ck_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", + op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, + mutates_args=[], + fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe", + op_func=rocm_aiter_asm_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def rocm_aiter_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) - if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: + # All AITER Fused MoE kernels are expecting the following datatypes + topk_weights = topk_weights.to(torch.float32) + topk_ids = topk_ids.to(torch.int32) + + # w8a8 block-scaled + if block_shape is not None and use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported for block scaled moe" + ) assert w1_scale is not None assert w2_scale is not None - local_E = E = w1.shape[0] - if expert_mask is not None: - E = expert_mask.numel() - - topk = topk_ids.shape[1] - model_dim = w1.shape[-1] - dtype = hidden_states.dtype # The default block sizes are 128 in AITER. - if block_shape is None: - block_shape = [128, 128] - - scale_blk_k = block_shape[1] - - ( - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - out_asm, - ) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids, - topk_weights, - E, - model_dim, - dtype, - expert_mask=expert_mask) - - a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) - rocm_aiter.fmoe_fp8_blockscale_g1u1( - out_asm, - a1, + block_shape = [128, 128] if block_shape is None else block_shape + + a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) + + return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( + topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, + w2, w1_scale, w2_scale, a1_scale, block_shape, None) + + # w8a8 per-channel quantization + elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` + # This applies topk_weights on the GEMM output of the first FC layer + # rather than the second FC. + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.shape[-1] == 1, ( + "Only support topk=1 when" + " `apply_router_weight_on_input` is True") + + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, w1, w2, - sorted_token_ids, - sorted_weight_buf, - sorted_expert_ids, - num_valid_ids, - topk, - w1_scale.view(local_E, -1), - w2_scale.view(local_E, -1), - a1_scale.t().contiguous(), - block_shape[0], - block_shape[1], - None, - ) - return out_asm - + topk_weights, + topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=expert_map, + activation_str=activation) + + # w8a8 per-tensor activation per-tensor weight elif use_fp8_w8a8: - return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weight=topk_weights, - topk_ids=topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, - fc1_smooth_scale=None, - fc2_smooth_scale=None, - a16=False) - - return rocm_aiter.ck_moe(hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids) + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is not supported for fp8_w8a8") + return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + activation=activation) + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + topk_ids = topk_ids.to(torch.int32) + topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + + # w16a16 fallback to rocm_aiter_ck_moe w16a16 + return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: - import aiter as rocm_aiter - rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, - gating_output, renormalize) - + renormalize: bool) -> Tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, + token_expert_indices, gating_output, + renormalize) return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: +def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. @@ -129,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: *tensors: Variable number of torch.Tensor objects. Returns: - A tuple of shuffled tensors. + A Tuple of shuffled tensors. """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor) for tensor in tensors) def expand_weights(*tensors: torch.Tensor, - expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: + expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]: """ Expands the dimensions of input tensors. @@ -147,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor, corresponding to each tensor. Returns: - A tuple of tensors with expanded dimensions. + A Tuple of tensors with expanded dimensions. """ assert len(tensors) == len(expansion_dims), \ diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 5e8eb6c54c89..75a5317b10ba 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -168,7 +168,8 @@ def forward_hpu( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from vllm_hpu_extension.ops import HPUFusedRMSNorm + from vllm_hpu_extension.kernels import rms_norm + HPUFusedRMSNorm = rms_norm() if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 21035a9e5dbe..794de4c383b0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -17,6 +16,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -31,6 +31,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", + "BitBLASLinearMethod", + "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", @@ -50,6 +52,15 @@ ] +def adjust_bitblas_shard(param, shard_size, shard_offset): + bitblas_tile_size = getattr(param, "bitblas_tile_size", None) + if bitblas_tile_size is not None: + return (shard_size // bitblas_tile_size, + shard_offset // bitblas_tile_size) + + return shard_size, shard_offset + + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: @@ -188,7 +199,7 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return F.linear(x, layer.weight, bias) + return dispatch_unquantized_gemm()(x, layer.weight, bias) class LinearBase(torch.nn.Module): @@ -615,6 +626,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + if use_bitsandbytes_4bit: index = list(itertools.accumulate([0] + self.output_sizes)) orig_offsets = { @@ -646,6 +660,8 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) @@ -913,6 +929,15 @@ def weight_loader_v2(self, shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) + # Note(simon): This is needed for Qwen3's fp8 quantization. + if isinstance(param, BlockQuantScaleParameter): + assert self.quant_method is not None + assert hasattr(self.quant_method, "quant_config") + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + param.load_qkv_weight(loaded_weight=loaded_weight, num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index b31b980fbe84..9fbad9d2f91e 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -10,8 +10,10 @@ from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import HAS_TRITON -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +TRITON3 = HAS_TRITON and (version.parse(triton.__version__) + >= version.parse("3.0.0")) if TRITON3: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 89533955fd76..15e08220b7b5 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Type +from typing import Literal, Type, get_args from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -QUANTIZATION_METHODS: List[str] = [ +QuantizationMethods = Literal[ "aqlm", "awq", "deepspeedfp", @@ -15,12 +15,12 @@ "fbgemm_fp8", "modelopt", "nvfp4", - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) "marlin", + "bitblas", "gguf", "gptq_marlin_24", "gptq_marlin", + "gptq_bitblas", "awq_marlin", "gptq", "compressed-tensors", @@ -34,6 +34,7 @@ "moe_wna16", "torchao", ] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) # The customized quantization methods which will be added to this dict. _CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} @@ -85,6 +86,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .aqlm import AQLMConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig + from .bitblas import BitBLASConfig from .bitsandbytes import BitsAndBytesConfig from .compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) @@ -94,6 +96,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .fp8 import Fp8Config from .gguf import GGUFConfig from .gptq import GPTQConfig + from .gptq_bitblas import GPTQBitBLASConfig from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin_24 import GPTQMarlin24Config from .hqq_marlin import HQQMarlinConfig @@ -107,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig - method_to_config: Dict[str, Type[QuantizationConfig]] = { + method_to_config: dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, @@ -116,12 +119,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "nvfp4": ModelOptNvFp4Config, - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) "marlin": MarlinConfig, + "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, + "gptq_bitblas": GPTQBitBLASConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, @@ -144,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: __all__ = [ "QuantizationConfig", + "QuantizationMethods", "get_quantization_config", "QUANTIZATION_METHODS", -] +] \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index cb1d5400f3a0..ef4a7765d61e 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -17,14 +17,13 @@ is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - check_marlin_supports_layer, marlin_make_empty_g_idx, - marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, - moe_awq_to_marlin_zero_points, verify_marlin_supported, - verify_marlin_supports_shape) + check_marlin_supports_layer, check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -136,12 +135,15 @@ def get_quant_method(self, layer: torch.nn.Module, self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - if layer.local_num_experts > 32: - # For MoEs with many experts the moe_wna16 kernel is faster + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_one( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - else: - return AWQMoEMethod(self) + return AWQMoEMethod(self) return None @classmethod @@ -391,6 +393,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) + device = layer.w13_qweight.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + layer.workspace = torch.zeros((sms * 4, ), + dtype=torch.int, + device=device, + requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device @@ -473,10 +482,7 @@ def apply( activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " - "fused Marlin MoE method.") + if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" @@ -503,7 +509,10 @@ def apply( router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, + workspace=layer.workspace, num_bits=self.quant_config.weight_bits, ) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py new file mode 100644 index 000000000000..3eaaa6c252ce --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, + BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class BitBLASConfig(QuantizationConfig): + """Config class for BitBLAS. + + Reference: https://github.com/Microsoft/BitBLAS + """ + TORCH_DTYPE = torch.float16 + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # gptq_with_bitblas prefer "quantized implementation" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: Optional[int], + desc_act: Optional[bool], + is_sym: Optional[bool], + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + + storage_dtype = self.STORAGE_DTYPE + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + + self.storage_dtype = storage_dtype + self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + def __repr__(self) -> str: + return (f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> str: + return "bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @staticmethod + def get_from_keys(config: Dict[str, Any], + keys: List[str], + default: Any = None) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + return default + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"], -1) + desc_act = cls.get_from_keys(config, ["desc_act"], False) + is_sym = cls.get_from_keys(config, ["sym"], False) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_bitblas_format: bool + is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" + or hf_quant_cfg.get("is_bitblas_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "bitblas") + + if is_bitblas_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return BitBLASLinearMethod(self) + return None + + +class BitBLASLinearMethod(LinearMethodBase): + """Linear method for BitBLAS. + + Args: + quant_config: The BitBLAS quantization config. + """ + # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS + # Instead of BITBLAS_OPTIMIZE_FEATURES + # If you want to high contiguous batching + # performance + OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: BitBLASConfig): + self.quant_config = quant_config + + def create_weights_gptq( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized + weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_size_per_partition: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size in + `quant_config`. + """ + del input_size, output_size # Unused arguments. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype not in self.quant_config.get_supported_act_dtypes(): + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + group_size = self.quant_config.group_size + if group_size is None: + group_size = -1 + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if (group_size != -1 and input_size_per_partition % group_size != 0): + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({group_size}).") + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Initialize quantized weights with dimensions + # Quantized 4Bit weights packed. + qweight = PackedvLLMParameter( + data=torch.empty( + self.bitblas_matmul.retrieve_weight_shape(), + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b else None), + weight_loader=weight_loader, + ) + + # Compute the number of input groups for channel-wise quantization. + input_groups = (1 if group_size == -1 else input_size_per_partition // + group_size) + + # Initialize scales and zeros for the quantized weights. + weight_scale_args = { + "data": + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + if self.quant_config.zeros_mode == "quantized": + zeros = PackedvLLMParameter( + data=torch.empty( + input_groups, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + else: + zeros = BasevLLMParameter( + torch.empty(output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype), + weight_loader=weight_loader, + ) + # Set attributes to indicate how scales and zeros are applied. + set_weight_attrs(zeros, { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0, + }) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("scales", scales) + layer.register_parameter("zeros", zeros) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.quant_method == "gptq": + return self.create_weights_gptq(layer, input_size_per_partition, + output_partition_sizes, input_size, + output_size, params_dtype, + **extra_weight_attrs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + out_dtype="float16", + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + + with_scaling = False + with_zeros = False + group_size = self.quant_config.group_size + zeros_mode = self.quant_config.zeros_mode + if self.quant_config.quant_method == "gptq": + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if self.quant_config.is_sym: + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + matmul_config = MatmulConfig( + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=out_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.STORAGE_DTYPE, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...") + logger.info(TUNING_MESSAGE) + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNED_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNED_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created." + logger.info(_message) + else: + _message = ( + f"BitBLAS Operator {config} found in global_operator_cache.") + logger.info(_message) + return bitblas_matmul + + def apply_gptq( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.zeros + + x_2d = x.view(-1, x.shape[-1]) + + if self.quant_config.is_sym: + output_2d = self.bitblas_matmul(x_2d, qweight, scales) + else: + output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output + + def apply( + self, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + if self.quant_config.quant_method == "gptq": + return self.apply_gptq(*args, **kwargs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b714d95b6025..7b0032572ecf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -72,7 +72,7 @@ def get_min_capability(cls) -> int: return 70 def get_name(self) -> str: - return "compressed_tensors" + return "compressed-tensors" def get_quant_method( self, @@ -302,14 +302,12 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None - is_symmetric = weight_quant.symmetric is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value or weight_quant.strategy == QuantizationStrategy.GROUP.value) is_static = not weight_quant.dynamic - return (is_channel_group and input_quant_none and is_symmetric - and is_static) + return (is_channel_group and input_quant_none and is_static) def _get_scheme_from_parts( self, weight_quant: BaseModel, @@ -319,6 +317,7 @@ def _get_scheme_from_parts( if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, @@ -328,6 +327,7 @@ def _get_scheme_from_parts( return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, group_size=weight_quant.group_size, actorder=weight_quant.actorder) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 628724c5b7d6..721e36af2b28 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -67,7 +67,7 @@ def get_moe_method( else: return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - and layer.activation == "silu" and layer.expert_map is None): + and layer.activation == "silu"): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -250,6 +250,28 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + + # Property to determine if AITER is used + if is_rocm_aiter_moe_enabled(): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + self.fused_experts_func = rocm_aiter_fused_experts + else: + from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts + def apply( self, layer: torch.nn.Module, @@ -268,7 +290,6 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -282,10 +303,10 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + return self.fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, @@ -489,8 +510,6 @@ def apply( ) -> torch.Tensor: assert activation == "silu" - assert global_num_experts == layer.w13_weight.shape[0] - assert expert_map is None topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -521,6 +540,7 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, out_dtype=x.dtype, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 38df09ff3937..3535dd3f3f14 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -12,11 +12,15 @@ MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, + PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter) +# yapf: enable from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -26,6 +30,7 @@ 4: scalar_types.uint4b8, 8: scalar_types.uint8b128 } +WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) @@ -36,10 +41,12 @@ def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None, + symmetric: Optional[bool] = True, actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy + self.symmetric = symmetric self.group_size = -1 if group_size is None else group_size self.has_g_idx = actorder == ActivationOrdering.GROUP @@ -53,7 +60,9 @@ def __init__(self, f"Unsupported num_bits = {num_bits}. " f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") - self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric else + WNA16_SUPPORTED_TYPES_MAP[num_bits]) @classmethod def get_min_capability(cls) -> int: @@ -75,7 +84,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, weight_type=self.quant_type, act_type=params_dtype, group_size=self.group_size, - zero_points=False, + zero_points=not self.symmetric, has_g_idx=self.has_g_idx ) @@ -120,13 +129,37 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, dtype=params_dtype, ) } + + zeros_args = { + "weight_loader": + weight_loader, + "data": + torch.zeros( + output_size_per_partition // self.pack_factor, + scales_and_zp_size, + dtype=torch.int32, + ) + } + if not partition_scales: weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) + + if not self.symmetric: + qzeros = PackedColumnParameter(output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) else: weight_scale = GroupQuantScaleParameter(output_dim=0, input_dim=1, **weight_scale_args) + if not self.symmetric: + qzeros = PackedvLLMParameter(input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) # A 2D array defining the original shape of the weights # before packing @@ -138,6 +171,9 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + if not self.symmetric: + layer.register_parameter("weight_zero_point", qzeros) + # group index (for activation reordering) if self.has_g_idx: weight_g_idx = RowvLLMParameter(data=torch.empty( @@ -151,7 +187,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", w_s_param_name="weight_scale", - w_zp_param_name=None, + w_zp_param_name="weight_zero_point", w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b7327f47733b..01056c37b86c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -140,6 +140,11 @@ def get_cache_scale(self, name: str) -> Optional[str]: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + # If no matches, return None return None @@ -575,8 +580,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_block_scaled_moe_enabled, - is_rocm_aiter_moe_enabled, shuffle_weights) + expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights) # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -603,7 +607,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False) - if is_rocm_aiter_block_scaled_moe_enabled(): + if is_rocm_aiter_moe_enabled(): # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py new file mode 100644 index 000000000000..88cada4c61b8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional, Set + +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + BitBLASLinearKernel, MPLinearLayerConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks, + check_bitblas_supported, verify_bitblas_supported) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class GPTQBitBLASConfig(QuantizationConfig): + """Config class for GPTQ BitBLAS""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + TORCH_DTYPE = torch.float16 + GPTQ_CKPT_STORAGE_DTYPE = ( + "int32" # GPTQ Default Checkpoints use int32 as storage dtype + ) + GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # the gptq_bitblas prefer "quantized" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") + + self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE + + storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE + if c.isdigit())) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> str: + return "gptq_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "bitblas" + or user_quant == "gptq_bitblas") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return GPTQBitBLASLinearMethod(self) + return None + + @property + def torch_storage_dtype(self) -> torch.dtype: + return self.TORCH_BITBLAS_STORAGE_DTYPE + + @classmethod + def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits, + sym)], + group_size=group_size) + + +class GPTQBitBLASLinearMethod(LinearMethodBase): + """Linear method for GPTQ BitBLAS. + + Args: + quant_config: The GPTQ BitBLAS quantization config. + """ + + kernel_type = BitBLASLinearKernel + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, quant_config: GPTQBitBLASConfig) -> None: + self.quant_config = quant_config + # Verify supported on platform. + verify_bitblas_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or + if the input size per partition is not divisible by the + group size in `quant_config`. + """ + if params_dtype != torch.float16: + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({self.quant_config.group_size})." + ) + + kernel_type = self.kernel_type + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQBitBLASLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Init buffers + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + # Ignore warning from fused linear layers such as QKVParallelLinear. + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }, + ) + + # Quantized zero-points + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx", + bitblas_quant_config=self.quant_config, + ) + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self.kernel.configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + bias=False, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out = self.kernel.apply_gptq_bitblas_linear(layer, x) + if bias is not None: + out.add_(bias) + return out diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0615bb4ab4df..52cd0a5b6975 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,13 +15,13 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) -from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, marlin_moe_permute_scales, - marlin_repeat_scales_on_all_ranks, verify_marlin_supported) + check_marlin_supported, check_moe_marlin_supports_layer, + marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, + verify_marlin_supported) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -153,12 +153,15 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - if layer.local_num_experts > 32: - # For MoEs with many experts the moe_wna16 kernel is faster + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_one( + f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") return MoeWNA16Config.from_config( self.full_config).get_quant_method(layer, prefix) - else: - return GPTQMarlinMoEMethod(self) + return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @@ -408,7 +411,7 @@ def create_weights( torch.empty(num_experts, scales_size13, 2 * intermediate_size_per_partition, - dtype=torch.half), + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) @@ -418,7 +421,7 @@ def create_weights( torch.empty(num_experts, scales_size2, hidden_size, - dtype=torch.half), + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) @@ -493,6 +496,13 @@ def create_weights( w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + device = layer.w13_qweight.device + sms = torch.cuda.get_device_properties(device).multi_processor_count + layer.workspace = torch.zeros((sms * 4, ), + dtype=torch.int, + device=device, + requires_grad=False) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Process act_order @@ -601,10 +611,6 @@ def apply( "Apply router weight on input is not supported for" "fused Marlin MoE method.") - # The input must currently be float16 - orig_dtype = x.dtype - x = x.half() - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -626,9 +632,12 @@ def apply( router_logits, topk_weights, topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, - is_k_full=self.is_k_full).to(orig_dtype) + workspace=layer.workspace, + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 520e1bc96721..d144bb436104 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -5,6 +5,8 @@ import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 AllSparkLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 + BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -20,6 +22,7 @@ MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, + BitBLASLinearKernel, ExllamaLinearKernel, ] @@ -76,4 +79,4 @@ def choose_mp_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + + '\n'.join(failure_reasons)) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py new file mode 100644 index 000000000000..21452d08b8a1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, + MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx, + check_bitblas_supports_shape, query_bitblas_supported_quant_types, + unpack_gptq_qweight, unpack_gptq_qzeros) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +logger = init_logger(__name__) + + +class BitBLASLinearKernel(MPLinearKernel): + + OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING: bool = True + MATMUL_LAYOUT: str = "nt" + BITBLAS_DTYPES: Dict[torch.dtype, str] = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + bitblas_matmul: object = None + + def __init__( + self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None, + bitblas_quant_config: Optional[QuantizationConfig] = None, + ): + self.quant_config = bitblas_quant_config + super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name, + w_gidx_param_name) + + def repack_bitblas_from_gptq( + self, + b_q_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: Optional[torch.Tensor] = None, + ): + from bitblas.quantization.utils import general_compress + assert self.bitblas_matmul is not None, "bitblas_matmul is None" + + quant_config = self.quant_config + # qweight in gptq old quant linear stored with + # (outfeatures, infeatures), should be transposed. + qweight = b_q_weight.T.contiguous().view( + quant_config.torch_storage_dtype) # type: ignore[union-attr] + intweight = unpack_gptq_qweight( + qweight, + quant_config.weight_bits).contiguous() # type: ignore[union-attr] + if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined] + qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined] + intweight.cpu()).cuda() + # scales in gptq old quant linear stored with + # (infeatures // group_size, outfeatures), should be transposed. + scales = scales.T.contiguous() + + if qzeros is None: + return qweight, scales, None + + # qzeros should be de-quantized to int zeros. + weight_bits = quant_config.weight_bits # type: ignore[union-attr] + intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous() + zeros: Optional[torch.Tensor] = None + zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined] + if zeros_mode == "original": + zeros = intzeros.to(torch.float16).contiguous() + elif zeros_mode == "rescale": + assert zeros is not None, "zeros should not be None" + zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :] + elif zeros_mode == "quantized": + zeros = ( + torch.Tensor( + general_compress( + intzeros.T.contiguous().cpu().numpy(), + weight_bits, + )).to(qweight.device). + to(quant_config.torch_storage_dtype # type: ignore[union-attr] + ).contiguous()) + else: + raise ValueError("Unsupported zeros type: {}".format(zeros_mode)) + + return qweight, scales, zeros + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + is_bitblas_installed = True + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError: + is_bitblas_installed = False + + if not is_bitblas_installed: + return False, "bitblas is not installed. Please install bitblas "\ + "by running `pip install bitblas>="\ + f"{MINIMUM_BITBLAS_VERSION}`" + + quant_types = query_bitblas_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, (f"Quant type ({c.weight_type}) not supported by" + f" BitBLAS, supported types are: {quant_types}") + + if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: + return False, (f"Group size ({c.group_size}) not supported by " + "BitBLAS, supported group sizes are: " + f"{BITBLAS_SUPPORTED_GROUP_SIZES}") + + return check_bitblas_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + quant_config = self.quant_config + + # Default names since bitblas requires empty parameters for these, + # TODO: remove this requirement from bitblas (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "qzeros" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = bitblas_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device)) + layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device) + + if c.zero_points: + raise NotImplementedError("Zero points not supported by BitBLAS") + else: + setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device)) + + # Repack weights + bitblas_qweight, bitblas_scales, bitblas_qzeros = ( + self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + None if quant_config.is_sym else # type: ignore[union-attr] + layer.qzeros, # type: ignore[union-attr] + )) + replace_parameter(layer, self.w_q_name, bitblas_qweight) + replace_parameter(layer, self.w_s_name, bitblas_scales) + if bitblas_qzeros is not None: + replace_parameter(layer, self.w_zp_name, bitblas_qzeros) + + def configure_bitblas_matmul( + self, + infeatures: int, + outfeatures: int, + params_dtype: torch.dtype, + bias: bool, + ) -> None: + enable_tuning = self.ENABLE_TUNING + layout = self.MATMUL_LAYOUT + bits = self.quant_config.weight_bits # type: ignore[union-attr] + self._configure_bitblas_matmul( + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ) + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + quant_config = self.quant_config + with_scaling = False + with_zeros = False + group_size = quant_config.group_size # type: ignore[union-attr] + zeros_mode = quant_config.zeros_mode # type: ignore[union-attr] + if quant_config.quant_method == "gptq": # type: ignore[union-attr] + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if quant_config.is_sym: # type: ignore[union-attr] + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr] + ) # type: ignore[union-attr] + + matmul_config = MatmulConfig( + M=self.OPT_FEATURES, + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=quant_config. # type: ignore[union-attr] + storage_dtype, # type: ignore[union-attr] + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNING_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNING_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created without tuning. " + logger.info(_message) + else: + _message = f"BitBLAS Operator {config} retrieved from cache." + logger.info(_message) + return bitblas_matmul + + def apply_gptq_bitblas_linear( + self, + layer: torch.nn.Module, + x: torch.Tensor, + ) -> torch.Tensor: + output_size_per_partition = self.config.partition_weight_shape[1] + out_shape = x.shape[:-1] + (output_size_per_partition, ) + args = [x, layer.qweight, layer.scales] + if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined] + args.append(layer.qzeros) + output = self.bitblas_matmul(*args) # type: ignore[operator] + return output.view(out_shape) + + def apply_weights(self, layer, x, bias=None): + NOT_IMPLEMENT_MESSAGE = ( + f"{self.__class__.__name__}.apply_weights is not implemented. " + "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead") + raise NotImplementedError(NOT_IMPLEMENT_MESSAGE) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index 3f0586f6e30d..b3ffeca4f100 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -26,17 +26,14 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" - if c.zero_points: - return False, "Zero points currently not supported by "\ - " Compressed Tensors + Machete. (Kernel supports it"\ - " but CompressedTensorsWNA16 does not so support has"\ - " not been added to MacheteWNA16Kernel yet" + return False, "Zero points currently not supported by Machete" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index e21801cf6a78..7bd824ff9e55 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -9,7 +9,7 @@ MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, - query_marlin_supported_quant_types) + marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) @@ -25,10 +25,6 @@ def get_min_capability(cls) -> int: @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: - if c.zero_points: - return False, "Zero points currently not supported by "\ - " MarlinLinearKernel. Will be added when AWQMarlin "\ - "is migrated over to using MPLinearKernel backend" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: @@ -67,28 +63,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.w_zp_name is None: self.w_zp_name = "w_zp" - if c.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx( - getattr(layer, self.w_gidx_name)) - self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - else: - setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - if c.zero_points: - pass - # TODO (lucas): add the following when AWQMarlin is migrated over to - # using MPLinearKernel backend - # self._transform_param(layer, self.w_zp_name, lambda x: \ - # marlin_zero_points( - # x, - # size_k=c.partition_weight_shape[0], - # size_n=c.partition_weight_shape[1], - # num_bits=c.weight_type.size_bits)) - else: - setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) - def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) @@ -108,6 +82,28 @@ def transform_w_s(x): group_size=c.group_size) return x + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + grouped_k = (c.partition_weight_shape[0] // + c.group_size if c.group_size != -1 else 1) + self._transform_param(layer, self.w_zp_name, lambda x: \ + marlin_zero_points( + unpack_cols(x.t(), c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1]), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) @@ -131,5 +127,6 @@ def apply_weights(self, wtype=c.weight_type, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], + has_zp=self.config.zero_points, is_k_full=self.is_k_full, bias=bias) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 5d766c2c27ac..5dff8b09693c 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -38,6 +38,9 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize P = softmax(QK^T) scales + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( @@ -97,5 +100,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = lambda x: isinstance(x, float) or isinstance( + x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + if not is_singleton_float(q_scale) or not is_singleton_float( + prob_scale): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if q_scale == 1.0 or prob_scale == 1.0: + logger.warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") + del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index ca71da8b736a..cf9108ea72c3 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import fnmatch -import re from typing import Any, Dict, List, Optional, cast import torch @@ -125,6 +124,13 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": for q_config in q_configs: q_config["output_tensors"] = None + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(Dict[str, Any], + layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + return cls(quant_config=config, kv_cache_group=kv_cache_group, kv_cache_config=kv_cache_config, @@ -289,25 +295,14 @@ def get_cache_scale(self, name: str) -> Optional[str]: :param name: param name :return: matching param name for KV cache scale in vLLM """ - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: - return None - - kv_proj_names = [ - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group - ] - if name.endswith(".output_scale"): - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" - return name.replace(kv_output_scale_name, ".attn.k_scale") - - elif len(kv_proj_names) == 2: - for kv_proj_name in kv_proj_names: - if kv_proj_name in name and kv_proj_name == "k_proj": - return name.replace(".k_proj.output_scale", - ".attn.k_scale") - elif kv_proj_name in name and kv_proj_name == "v_proj": - return name.replace(".v_proj.output_scale", - ".attn.v_scale") + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py new file mode 100644 index 000000000000..5d28d327e8a2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +MINIMUM_BITBLAS_VERSION = "0.1.0" + +BITBLAS_MIN_WEIGHT_SIZE_N = 16 +BITBLAS_MIN_WEIGHT_SIZE_K = 16 +GPTQ_BITBLAS_MAX_PARALLEL = 16 + +BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# For dynamic shape code generation +BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024] +# If want to enable high performance for contiguous batching +# Please use the following values +BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024] + +BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] +BITBLAS_SUPPORTED_SYM = [False, True] + + +# Determines the supported quantization types for BitBLAS based on the +# device's capability and whether zero-point (zp) is used. +def query_bitblas_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 70: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_bitblas_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_bitblas_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"BitBLAS does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES): + return (False, f"BitBLAS does not support group_size = {group_size}. " + f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_bitblas_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> + (bits * i)) & 0xF + if not is_gptq_v2: + return unpacked_zeros + 1 + return unpacked_zeros + + +def unpack_gptq_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> + (bits * i)) + + return torch.bitwise_and(unpacked_weight, 2**bits - 1) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 5b2e3ca2c799..4a190480d35b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ group_size=group_size)[0] +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + hidden_size = layer.hidden_size + intermediate_size_per_partition = layer.intermediate_size_per_partition + + # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) + # down: (n, k) = (hidden_size, intermediate_size_per_partition) + # moe marlin requires n % 128 == 0 and k % 64 == 0 + return hidden_size % 128 == 0 and \ + intermediate_size_per_partition % max(64, group_size) == 0 and \ + group_size in [-1, 32, 64, 128] + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -319,6 +332,7 @@ def apply_gptq_marlin_linear( wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, + has_zp: bool, is_k_full: bool, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: @@ -343,8 +357,8 @@ def apply_gptq_marlin_linear( size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=is_k_full, - has_zp=False, use_atomic_add=use_atomic_add, + has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b8e6384d7359..8ab45d610053 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config from vllm.platforms import current_platform @@ -17,6 +18,7 @@ # The condition is determined once as the operations # are time consuming. USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and torch.__version__[0:3] >= "2.7" and current_platform.has_device_capability(94)) @@ -131,6 +133,160 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: List, **kwargs) -> torch.Tensor: + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + return output.view(*output_shape) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + from vllm.platforms.rocm import on_mi250_mi300 + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300( + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, + current_platform.get_cu_count()) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM + # when using it. + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List, + **kwargs) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * scale_b.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +def dispatch_w8a8_scaled_mm( + cutlass_fp8_supported: bool, per_tensor_weights: bool, + per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] +) -> Callable[..., torch.Tensor]: + + if cutlass_fp8_supported: + return cutlass_w8a8_scaled_mm + if per_tensor_weights and per_tensor_activations: + if current_platform.is_rocm(): + return rocm_per_tensor_w8a8_scaled_mm + return torch_per_tensor_w8a8_scaled_mm + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + if (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + return torch_per_token_w8a8_scaled_mm + return torch_channelwise_w8a8_scaled_mm + + # TODO(luka): follow similar pattern for marlin and block-fp8-linear # https://github.com/vllm-project/vllm/issues/14397 class Fp8LinearOp: @@ -156,7 +312,8 @@ def __init__(self, if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE - self.output_padding = 17 if pad_output else None + self.output_padding = 17 if ( + pad_output and not current_platform.is_rocm()) else None def apply( self, @@ -195,18 +352,6 @@ def apply( input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=use_per_token_if_dynamic) - - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - return output.view(*output_shape) - - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token else: if input.dtype != current_platform.fp8_dtype(): # Maybe apply padding to output, see comment in __init__ @@ -218,84 +363,21 @@ def apply( else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) - - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, - input_2d.shape[0]).view(*output_shape) - - elif (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM): - # For now validated on ROCm platform - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt - # and ROCm 6.3, which only exists in torch 2.7 and above. - # For CUDA platform please validate if the - # torch._scaled_mm support rowwise scaled GEMM - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale.t(), - bias=bias) - - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - output = output.view(*output_shape) - return output - - else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( + self.cutlass_fp8_supported, per_tensor_weights, + per_tensor_activations, use_per_token_if_dynamic) + + return w8a8_scaled_mm_func(qinput=qinput, + weight=weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + input_2d=input_2d, + output_shape=output_shape) def normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 624ed63ab8b4..c5970c71c539 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -def _apply_rotary_emb( +def _apply_rotary_emb_torch( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool, ) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) if is_neox_style: @@ -75,6 +67,24 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) +def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + if current_platform.is_cuda_alike(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + return apply_rotary_emb(x.unsqueeze(0), cos, sin, + not is_neox_style).squeeze(0) + else: + return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) + + @CustomOp.register("rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -141,14 +151,16 @@ def forward_native( query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, + self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -988,8 +1000,9 @@ def forward( key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - @staticmethod + @classmethod def get_input_positions( + cls, input_tokens: List[int], hf_config: PretrainedConfig, image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], @@ -997,6 +1010,8 @@ def get_input_positions( second_per_grid_ts: Optional[List[float]], context_len: int = 0, seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" @@ -1006,7 +1021,7 @@ def get_input_positions( second_per_grid_ts llm_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( + cls.get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, image_grid_thw=image_grid_thw, @@ -1014,12 +1029,52 @@ def get_input_positions( second_per_grid_ts=second_per_grid_ts, context_len=context_len, seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) return llm_positions.tolist(), mrope_position_delta - @staticmethod + @classmethod def get_input_positions_tensor( + cls, + input_tokens: List[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + second_per_grid_ts: List[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> Tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): + return cls._omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + return cls._vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + @classmethod + def _vl_get_input_positions_tensor( + cls, input_tokens: List[int], hf_config: PretrainedConfig, image_grid_thw: Union[List[List[int]], torch.Tensor], @@ -1037,11 +1092,6 @@ def get_input_positions_tensor( tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - if isinstance(video_grid_thw, torch.Tensor): - video_grid_thw = video_grid_thw.tolist() - input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( input_tokens_tensor == vision_start_token_id).squeeze(1) @@ -1121,6 +1171,224 @@ def get_input_positions_tensor( return llm_positions, mrope_position_delta + @classmethod + def _omni_get_input_positions_tensor( + cls, + input_tokens: List[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + second_per_grid_ts: Optional[List[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> Tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [ + audio_token_id, video_token_id, image_token_id + ]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], + dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: List[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len( + t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * + vision_ntoken_per_chunk) + vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_chunk, + grid_hs, grid_ws).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len) * [audio_token_id]) + audio_start_idx = start_idx if len( + audio_llm_pos_ids_list + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + if min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = (torch.arange( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len)).expand(3, -1) + + audio_start_idx).split(1, + dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand( + 3, -1) + llm_pos_ids_list[-1].max() + 1).split( + 1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, + dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, + ) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( + len(t_index), -1, llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + len(t_index), llm_grid_h, -1).flatten()) + t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( + -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _split_list_into_ranges(lst: torch.Tensor, + interval: int) -> List[List[int]]: + ranges: List[List[int]] = [[] + for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + @staticmethod def get_next_input_positions( mrope_position_delta: int, @@ -1144,6 +1412,58 @@ def get_next_input_positions_tensor( mrope_position_delta + seq_len, ).expand(3, -1) + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[List[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> List[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, + audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a9ef973917e1..adb966c4b1c0 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 """Utility methods for model layers.""" -from typing import Tuple +from typing import Callable, Optional, Tuple import torch +from vllm import _custom_ops as ops +from vllm import envs +from vllm.platforms import current_platform + def get_token_bin_counts_and_mask( tokens: torch.Tensor, @@ -47,12 +51,49 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, vocab_size) - logits[logits > 0] /= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits > 0] - logits[logits <= 0] *= torch.where(prompt_mask | output_mask, - repetition_penalties, 1.0)[logits <= 0] + + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits + + +def rocm_unquantized_gemm(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + from vllm.platforms.rocm import on_mi250_mi300 + k = weight.shape[1] + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \ + x.dtype in [torch.float16, torch.bfloat16] \ + and k % 8 == 0 and bias is None) + + if use_skinny is not True: + return torch.nn.functional.linear(x, weight, bias) + + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] + cu_count = current_platform.get_cu_count() + + if m > 8 and 0 < n < 4: + out = ops.wvSplitK(weight, x_view, cu_count) + return out.view(*x.shape[:-1], weight.shape[0]) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = ops.LLMM1(weight, x_view, 4) + return out.view(*x.shape[:-1], weight.shape[0]) + return torch.nn.functional.linear(x, weight, bias) + + +def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: + if current_platform.is_rocm(): + return rocm_unquantized_gemm + return torch.nn.functional.linear diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 1eb0c8c2ef4e..d5eaeec1ae24 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,6 +12,7 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -40,7 +41,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return F.linear(x, layer.weight, bias) + return dispatch_unquantized_gemm()(x, layer.weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b0a0a20aa76f..cb9100e35594 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader): DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" - def __init__(self, load_config: LoadConfig): + def __init__(self, + load_config: LoadConfig, + runai_model_streamer: bool = False): super().__init__(load_config) + + self.runai_model_streamer = runai_model_streamer extra_config = ({} if load_config.model_loader_extra_config is None else load_config.model_loader_extra_config.copy()) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) @@ -659,7 +663,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int: def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): - if os.path.isdir(model_name_or_path): + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): return model_name_or_path else: allow_patterns = ["*.safetensors"] @@ -678,12 +682,13 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config target_device = torch.device(device_config.device) - from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank - local_model_path = self._prepare_weights(model_config.model, - model_config.revision) + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + local_model_path = model_weights with set_default_torch_dtype(model_config.dtype): with target_device: @@ -695,40 +700,56 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: local_model_path, self.pattern.format(rank=rank, part="*"), ) - filepaths = glob.glob(pattern) + + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) if not filepaths: # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " f"pre-sharded checkpoints are currently supported!") state_dict = self._filter_subtensors(model.state_dict()) - for path in filepaths: - with safe_open(path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - tensor = f.get_tensor(key) - # If loading with LoRA enabled, additional padding may - # be added to certain parameters. We only load into a - # narrowed view of the parameter data. - param_data = state_dict[key].data - param_shape = state_dict[key].shape - for dim, size in enumerate(tensor.shape): - if size < param_shape[dim]: - param_data = param_data.narrow(dim, 0, size) - if tensor.shape != param_shape: - logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", - tensor.shape, - key, - param_shape, - ) - param_data.copy_(tensor) - state_dict.pop(key) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) if state_dict: raise ValueError( f"Missing keys {tuple(state_dict)} in loaded state!") return model.eval() + def iterate_over_files( + self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]: + if self.runai_model_streamer: + yield from runai_safetensors_weights_iterator(paths, True) + else: + from safetensors.torch import safe_open + for path in paths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + yield key, tensor + @staticmethod def save_model( model: torch.nn.Module, @@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.RUNAI_STREAMER: return RunaiModelStreamerLoader(load_config) + if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: + return ShardedStateLoader(load_config, runai_model_streamer=True) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 15f37aad6d8c..0ca6b6fd88b6 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -30,15 +30,6 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) -def is_transformers_impl_compatible( - arch: str, - module: Optional["transformers.PreTrainedModel"] = None) -> bool: - mod = module or getattr(transformers, arch, None) - if mod is None: - return False - return mod.is_backend_compatible() - - def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): for i, arch in enumerate(architectures): @@ -55,20 +46,32 @@ def resolve_transformers_arch(model_config: ModelConfig, # "AutoModelFor": "--", # }, auto_modules = { - name: get_class_from_dynamic_module(module, model_config.model) + name: + get_class_from_dynamic_module(module, + model_config.model, + revision=model_config.revision) for name, module in sorted(auto_map.items(), key=lambda x: x[0]) } - custom_model_module = auto_modules.get("AutoModel") + model_module = getattr(transformers, arch, None) + if model_module is None: + if "AutoModel" not in auto_map: + raise ValueError( + f"Cannot find model module. '{arch}' is not a registered " + "model in the Transformers library (only relevant if the " + "model is meant to be in Transformers) and 'AutoModel' is " + "not present in the model config's 'auto_map' (relevant " + "if the model is custom).") + model_module = auto_modules["AutoModel"] # TODO(Isotr0py): Further clean up these raises. # perhaps handled them in _ModelRegistry._raise_for_unsupported? if model_config.model_impl == ModelImpl.TRANSFORMERS: - if not is_transformers_impl_compatible(arch, custom_model_module): + if not model_module.is_backend_compatible(): raise ValueError( f"The Transformers implementation of {arch} is not " "compatible with vLLM.") architectures[i] = "TransformersForCausalLM" if model_config.model_impl == ModelImpl.AUTO: - if not is_transformers_impl_compatible(arch, custom_model_module): + if not model_module.is_backend_compatible(): raise ValueError( f"{arch} has no vLLM implementation and the Transformers " "implementation is not compatible with vLLM. Try setting " @@ -97,10 +100,10 @@ def get_model_architecture( architectures = ["QuantMixtralForCausalLM"] vllm_supported_archs = ModelRegistry.get_supported_archs() - is_vllm_supported = any(arch in vllm_supported_archs - for arch in architectures) - if (not is_vllm_supported - or model_config.model_impl == ModelImpl.TRANSFORMERS): + vllm_not_supported = not any(arch in vllm_supported_archs + for arch in architectures) + if (model_config.model_impl == ModelImpl.TRANSFORMERS or + model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 065715cbde4e..dfe8f20c70d6 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -435,7 +434,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.unpadded_vocab_size = config.vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -462,14 +460,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index edf67c860e97..7c716efab8ef 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -15,11 +15,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import (SamplerOutput, - SamplingMetadata, get_sampler) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs) @@ -527,7 +526,6 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.vocab_size, logit_scale) - self.sampler = get_sampler() def _validate_image_sizes( self, images: List[torch.Tensor]) -> List[torch.Tensor]: @@ -653,14 +651,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 8700f24d2bd2..d152287e8fa3 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 Adapted from # https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision -from functools import cached_property from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple, TypedDict, Union, cast) @@ -17,7 +16,6 @@ from vllm.config import VllmConfig from vllm.jsontree import json_map_leaves -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs @@ -461,17 +459,3 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) - - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 6a3112b5f769..444ed38d05c0 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -396,7 +395,6 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -423,14 +421,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index dfb8f49cc014..16dac6123d66 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -24,7 +24,6 @@ MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -462,7 +461,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -538,14 +536,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 04d6cde555e2..bcfbe92c3a11 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -791,7 +790,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() def forward( self, @@ -828,14 +826,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - stacked_params_mapping = { "q_proj": { "param_name": "qkv_proj", diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index e1d77646f47e..76a529c93343 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -11,8 +11,10 @@ from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.activation import (get_act_and_mul_fn, + get_act_fn) from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, @@ -108,6 +110,7 @@ class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = ""): super().__init__() @@ -118,6 +121,7 @@ def __init__(self, BertLayer(config=config, cache_config=cache_config, quant_config=quant_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.layer.{layer_idx}") for layer_idx in range(config.num_hidden_layers) @@ -139,6 +143,7 @@ def __init__(self, config: BertConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = ""): super().__init__() @@ -149,19 +154,31 @@ def __init__(self, layer_norm_eps=config.layer_norm_eps, cache_config=cache_config, quant_config=quant_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.attention") - self.intermediate = BertIntermediate( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.intermediate") + if config.hidden_act in ["silu", "gelu_and_mul"]: + self.intermediate = BertGatedIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + else: + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") self.output = BertOutput(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, layer_norm_eps=config.layer_norm_eps, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.output") @@ -181,6 +198,7 @@ def __init__( layer_norm_eps: float, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = "", ): @@ -190,11 +208,13 @@ def __init__( num_attention_heads=num_attention_heads, cache_config=cache_config, quant_config=quant_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.output") self.output = BertSelfOutput(hidden_size=hidden_size, layer_norm_eps=layer_norm_eps, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.output") @@ -215,6 +235,7 @@ def __init__( num_attention_heads: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, rotary_kwargs: Optional[dict] = None, prefix: str = "", ): @@ -240,7 +261,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj") @@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module): def __init__(self, hidden_size: int, layer_norm_eps: float, + bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = RowParallelLinear(input_size=hidden_size, output_size=hidden_size, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.dense") self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) @@ -301,12 +323,13 @@ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str, + bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = ColumnParallelLinear(input_size=hidden_size, output_size=intermediate_size, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.dense") self.intermediate_act_fn = get_act_fn(hidden_act) @@ -317,19 +340,46 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +class BertGatedIntermediate(nn.Module): + # for NomciBert and GteModel + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.act_fn = get_act_and_mul_fn(hidden_act) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + hidden_states = self.act_fn(gate_up) + return hidden_states + + class BertOutput(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, layer_norm_eps: float, + bias: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.dense = RowParallelLinear(input_size=intermediate_size, output_size=hidden_size, - bias=True, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.dense") @@ -343,19 +393,32 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module, SupportsQuant): - packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} + packed_modules_mapping = { + "qkv_proj": ["query", "key", "value"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", embedding_class: type = BertEmbedding, + bias: bool = True, rotary_kwargs: Optional[dict] = None, add_pooling_layer: bool = False): super().__init__() + """ + For BertModel, all linear layers have bias. + For NomicBertModel, all linear layers do not have bias. + """ + config = vllm_config.model_config.hf_config self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, + bias=bias, rotary_kwargs=rotary_kwargs, prefix=f"{prefix}.encoder") self.pooler = BertPooler(config) if add_pooling_layer else None @@ -387,6 +450,8 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "query", "q"), ("qkv_proj", "key", "k"), ("qkv_proj", "value", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) @@ -546,3 +611,115 @@ def forward( inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, token_type_ids=token_type_ids) + + +class NomicBertEmbeddingModel(BertEmbeddingModel): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "emb_ln": "embeddings.LayerNorm", + "layers": "layer", + "attn.Wqkv": "attention.self.qkv_proj", + "attn.out_proj": "attention.output.dense", + 'norm1': "attention.output.LayerNorm", + 'mlp.fc11': "intermediate.up_proj", + 'mlp.fc12': "intermediate.gate_proj", + 'mlp.fc2': "output.dense", + 'norm2': "output.LayerNorm", + }) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NomicBertConfig" + assert config.activation_function == "swiglu" + + # Assume NomicBertModel all linear layers do not have bias + assert not config.mlp_fc1_bias + assert not config.mlp_fc2_bias + assert not config.qkv_proj_bias + + config.layer_norm_eps = config.layer_norm_epsilon + config.position_embedding_type = "rotary" + config.intermediate_size = config.n_inner + config.hidden_act = "silu" + config.hidden_size = config.n_embd + config.num_hidden_layers = config.n_layer + + head_dim = config.hidden_size // config.num_attention_heads + rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_trained_positions, + "base": config.rotary_emb_base, + "rope_scaling": { + "rope_type": "dynamic", + "factor": config.rotary_scaling_factor + } + } + + return BertModel(vllm_config=vllm_config, + prefix=prefix, + bias=False, + rotary_kwargs=rotary_kwargs, + embedding_class=BertEmbedding) + + +class GteEmbeddingModel(BertEmbeddingModel): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "attention.qkv_proj": "attention.self.qkv_proj", + "attention.o_proj": "attention.output.dense", + 'attn_ln': "attention.output.LayerNorm", + 'mlp.down_proj': "output.dense", + 'mlp_ln': "output.LayerNorm", + }) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "GteConfig" + assert config.position_embedding_type == "rope" + assert config.hidden_act == "gelu" + + config.position_embedding_type = "rotary" + config.hidden_act = "gelu_and_mul" + + head_dim = config.hidden_size // config.num_attention_heads + rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + } + + model = BertModel(vllm_config=vllm_config, + prefix=prefix, + rotary_kwargs=rotary_kwargs, + embedding_class=BertEmbedding) + + # GteModel only gate_up_proj does not have bias. + # Hack method learned from vllm/model_executor/models/glm.py + for layer in model.encoder.layer: + layer.intermediate.gate_up_proj.bias = None + layer.intermediate.skip_bias_add = True + return model + + def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]): + n = "mlp.up_gate_proj" + for name, weight in weights: + if n in name: + up, gate = weight.chunk(2, dim=0) + yield name.replace(n, "intermediate.up_proj"), up + yield name.replace(n, "intermediate.gate_proj"), gate + else: + yield name, weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + weights = self.split_up_gate_proj(weights) + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a6f00f999773..eed49e74ac9f 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -12,7 +11,6 @@ from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -62,6 +60,7 @@ def __init__( quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], is_cross_attention: bool = False, + prefix: str = "", ) -> None: super().__init__() @@ -141,7 +140,7 @@ def forward( class Blip2QFormerSelfOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig) -> None: + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -169,6 +168,7 @@ def __init__( quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], is_cross_attention: bool = False, + prefix: str = "", ) -> None: super().__init__() @@ -177,9 +177,10 @@ def __init__( quant_config=quant_config, cache_config=cache_config, is_cross_attention=is_cross_attention, + prefix=f"{prefix}.attention", ) - self.output = Blip2QFormerSelfOutput(config) + self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output") def forward( self, @@ -197,7 +198,7 @@ def forward( class Blip2QFormerIntermediate(nn.Module): - def __init__(self, config: Blip2QFormerConfig) -> None: + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -211,7 +212,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Blip2QFormerOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig) -> None: + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -239,6 +240,7 @@ def __init__( quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], layer_idx: int, + prefix: str = "", ) -> None: super().__init__() @@ -246,7 +248,8 @@ def __init__( self.seq_len_dim = 1 self.attention = Blip2QFormerAttention(config, quant_config=quant_config, - cache_config=cache_config) + cache_config=cache_config, + prefix=f"{prefix}.attention") self.layer_idx = layer_idx @@ -255,13 +258,16 @@ def __init__( config, quant_config=quant_config, cache_config=cache_config, - is_cross_attention=True) + is_cross_attention=True, + prefix=f"{prefix}.crossattention") self.has_cross_attention = True else: self.has_cross_attention = False - self.intermediate_query = Blip2QFormerIntermediate(config) - self.output_query = Blip2QFormerOutput(config) + self.intermediate_query = Blip2QFormerIntermediate( + config, prefix=f"{prefix}.intermediate_query") + self.output_query = Blip2QFormerOutput(config, + prefix=f"{prefix}.output_query") def forward( self, @@ -327,6 +333,7 @@ def __init__( *, quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], + prefix: str = "", ) -> None: super().__init__() @@ -336,7 +343,8 @@ def __init__( Blip2QFormerLayer(config, quant_config=quant_config, cache_config=cache_config, - layer_idx=layer_idx) + layer_idx=layer_idx, + prefix=f"{prefix}.layer.{layer_idx}") for layer_idx in range(config.num_hidden_layers) ]) @@ -367,6 +375,7 @@ def __init__( *, quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], + prefix: str = "", ) -> None: super().__init__() @@ -378,7 +387,8 @@ def __init__( self.encoder = Blip2QFormerEncoder(config, quant_config=quant_config, - cache_config=cache_config) + cache_config=cache_config, + prefix=f"{prefix}.encoder") def forward( self, @@ -513,7 +523,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.qformer = Blip2QFormerModel(config.qformer_config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.qformer") self.language_projection = nn.Linear( config.qformer_config.hidden_size, @@ -530,13 +541,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -649,7 +653,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> IntermediateTensors: """Run forward pass for BLIP-2. One key thing to understand is the `input_ids` already accounts for the @@ -707,13 +711,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index f960075b98bc..74d401b295ce 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -35,7 +35,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -297,7 +296,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -324,14 +322,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 0ad5e89df2e2..e2c275300f8c 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -950,7 +949,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -1054,14 +1052,6 @@ def compute_logits( return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1b1738f882b7..233e9ee0a258 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -429,7 +428,6 @@ def __init__( self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -442,14 +440,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index bb8d9bf8a03c..25b1d5a1955f 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -89,6 +88,7 @@ def __init__( self, config: CohereConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -99,12 +99,14 @@ def __init__( [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) self.act_fn = SiluAndMul() @@ -158,12 +160,14 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -244,7 +248,9 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.self_attn") - self.mlp = CohereMLP(config, quant_config=quant_config) + self.mlp = CohereMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) @@ -365,7 +371,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): scale=config.logit_scale) self.model = CohereModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -399,14 +404,6 @@ def compute_logits( return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index b66529860bc2..40c0a73f52d5 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -16,7 +16,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -390,7 +389,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -417,14 +415,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: expert_params_mapping = [( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 5e036d049a8a..c6421143dd68 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -453,7 +452,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -480,14 +478,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index e7fde76cd0ba..b50175cf764f 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -10,7 +10,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -154,8 +153,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix( prefix, "model")) - self.sampler = get_sampler() - def forward( self, input_ids: torch.Tensor, @@ -179,14 +176,6 @@ def compute_logits( return self.model.compute_logits(hidden_states, sampling_metadata, spec_step_idx) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 23b450aeddac..ffa5840b4604 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -686,7 +685,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -713,14 +711,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c3dbadb29276..ac136698ee17 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -4,7 +4,6 @@ """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -16,7 +15,6 @@ from vllm.config import VllmConfig from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -393,13 +391,6 @@ def _init_vision_module( model = model.to(dtype=torch.get_default_dtype()) return model - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -647,13 +638,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 3e4a5040b7c8..4ff1e785494f 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -9,7 +9,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -131,10 +130,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # checkpoint file has token_map tensor. self.token_map = None - @property - def sampler(self): - return self.model.sampler - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.model.get_input_embeddings(input_ids) @@ -188,14 +183,6 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B # due to missing lm_head weights and its config being that of a diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 553c524ebc37..4a6490cd127a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -510,8 +509,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -538,14 +535,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 0e67b1ec94f6..e7e03fc09972 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -473,7 +472,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -500,14 +498,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 359cc7f37731..d1a36c3f481a 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -3,7 +3,6 @@ import math from collections import OrderedDict from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -14,7 +13,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartParallelLMHead, @@ -673,7 +671,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) - self.sampler = get_sampler() def forward( self, @@ -716,11 +713,6 @@ def compute_logits( sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -929,12 +921,6 @@ def _build_image_projection_layers(self, config: PretrainedConfig): raise NotImplementedError( 'Florence2 only supports COSINE as temporal embedding.') - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - return get_sampler() - def _validate_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -1110,13 +1096,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 27cd8d0986a5..d6bd6155a447 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,7 +27,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -270,10 +269,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @property - def sampler(self): - return self.language_model.sampler - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.patch_size @@ -387,14 +382,6 @@ def compute_logits( self.language_model.lm_head, hidden_states, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.language_model.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 92d99883c774..c1cc0df11178 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -388,7 +387,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = GemmaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -415,14 +413,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d125c666f3cd..7fb2e9948c06 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -146,8 +145,8 @@ def __init__(self, # reference: # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa layer_idx = extract_layer_index(prefix) - use_sliding_window = (layer_idx % 2 == 0 and - config.interleaved_sliding_window is not None) + use_sliding_window = (layer_idx % 2 == 0 and getattr( + config, "interleaved_sliding_window", None) is not None) sliding_window = config.interleaved_sliding_window if \ use_sliding_window else None self.attn = Attention(self.num_heads, @@ -388,7 +387,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -415,14 +413,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index fb8eccc55078..4e0d4f84ca6b 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -147,7 +146,9 @@ def __init__(self, # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) - self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) + self.is_sliding = (getattr( + config, "interleaved_sliding_window", None) is not None and bool( + (layer_idx + 1) % config.sliding_window_pattern)) # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. @@ -493,7 +494,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -521,14 +521,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e5a3d6762fff..65c177f8c5ad 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, Literal, Optional, Set, Tuple, TypedDict import torch from torch import nn @@ -12,7 +12,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -479,7 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.sliding_window = config.text_config.interleaved_sliding_window + self.sliding_window = getattr(config.text_config, + "interleaved_sliding_window", None) self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, @@ -503,10 +503,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def dtype(self): return next(self.parameters()).dtype - @property - def sampler(self): - return self.language_model.sampler - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -607,7 +603,7 @@ def forward(self, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + **kwargs: object) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -685,13 +681,14 @@ def prepare_attn_masks( global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - # Create a local causal mask with sliding window (1024). - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, - diagonal=-self.sliding_window) - local_attn_mask = torch.where(local_attn_mask == 0, - global_attn_mask, float("-inf")) - local_attn_masks.append(local_attn_mask) + if self.sliding_window is not None: + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks return kwargs @@ -704,13 +701,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py index 8d52da8b7482..6269ebcee5c0 100644 --- a/vllm/model_executor/models/glm.py +++ b/vllm/model_executor/models/glm.py @@ -3,13 +3,13 @@ from vllm.config import VllmConfig from vllm.model_executor.models.llama import LlamaForCausalLM -from .interfaces import SupportsV0Only from .utils import PPMissingLayer -class GlmForCausalLM(LlamaForCausalLM, SupportsV0Only): +class GlmForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + vllm_config.model_config.hf_config.partial_rotary_factor = 0.5 super().__init__(vllm_config=vllm_config, prefix=prefix) # Hack Llama model to fit HF format GLM implementation # Attention difference between GLM and Llama: @@ -17,7 +17,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # 2. There is no bias for o_proj in attention for layer in self.model.layers: if not isinstance(layer, PPMissingLayer): - layer.self_attn.rotary_emb.rotary_dim //= 2 layer.self_attn.rotary_emb.is_neox_style = False layer.self_attn.o_proj.bias = None layer.self_attn.o_proj.skip_bias_add = True diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index cba093cbfef7..290be968cb54 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -82,7 +81,7 @@ def __init__(self, partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads - self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.rotary_dim = self.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -110,6 +109,7 @@ def __init__(self, base=self.rope_theta, rope_scaling=rope_scaling, partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -197,13 +197,12 @@ def forward( ) hidden_states = self.post_self_attn_layernorm(hidden_states) - hidden_states = residual + hidden_states # Fully Connected - hidden_states = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) - hidden_states = residual + hidden_states return hidden_states, residual @@ -267,7 +266,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -295,14 +293,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 776c03f652bd..e3219333915e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -35,7 +35,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -255,7 +254,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = self.lm_head.tie_weights(self.transformer.wte) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -282,14 +280,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 43f3d4f6dc9c..def6b1544d8c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -35,7 +35,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -43,7 +42,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -244,6 +243,30 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} @@ -278,7 +301,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -305,36 +327,10 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "lm_head.weight" in name: - continue - if ".attn.bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method - if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: - weight_loader(param, loaded_weight, 'q') - weight_loader(param, loaded_weight, 'k') - weight_loader(param, loaded_weight, 'v') - else: - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."]), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 752aec0b223d..3db96fb8e187 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -43,7 +42,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -188,6 +187,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config + self.quant_config = quant_config self.embed_dim = config.n_embd self.wte = VocabParallelEmbedding( config.vocab_size, @@ -228,61 +228,6 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - -class GPTJForCausalLM(nn.Module, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - assert not config.tie_word_embeddings - self.transformer = GPTJModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.n_embd, - bias=True, - quant_config=quant_config, - ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -339,3 +284,54 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class GPTJForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + assert not config.tie_word_embeddings + self.transformer = GPTJModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.n_embd, + bias=True, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 582b2ff7e755..620ee66f57e7 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -299,7 +298,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) @@ -326,14 +324,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 3bd6332c11ca..0696a7245c22 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -441,8 +440,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -464,14 +461,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py new file mode 100644 index 000000000000..b43b59da6d11 --- /dev/null +++ b/vllm/model_executor/models/granite_speech.py @@ -0,0 +1,777 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite speeech model.""" +import math +from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import BatchFeature, PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .blip2 import Blip2QFormerModel +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, embed_multimodal, + init_vllm_registered_model, maybe_prefix) + + +### Audio Input +class GraniteSpeechAudioInputs(TypedDict): + + input_features: torch.Tensor + """Shape: `(bsz, num_features, 160)`""" + + input_features_mask: torch.Tensor + """Shape: `(bsz, num_features)`""" + + audio_embed_sizes: list[int] + """List of length `bsz`""" + + +class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 1} + + # There is no limit to the maximum number of audio tokens that can be + # encoded as features; we pick ~5000 as a number that is probably higher + # than we would expect to encounter. The sequence of length + # get_max_audio_len() produces get_max_audio_tokens(). + def get_max_audio_tokens(self): + return 5001 + + def get_max_audio_len(self): + return 8000000 + + +### Input Processing & Multimodal utils +class GraniteSpeechMultiModalProcessor( + BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_hf_processor().audio_processor + sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] + return MultiModalDataParser(target_sr=sampling_rate) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + audio_embed_sizes=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + feature_extractor = processor.audio_processor + vocab = tokenizer.get_vocab() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|audio|>") + audio_token_id = vocab[audio_token] + + def get_replacement(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + audio_length = audio.shape[-1] + num_projector_features = feature_extractor._get_num_audio_features( + [audio_length])[0] + return [audio_token_id] * num_projector_features + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_replacement, + ) + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + if audios: + # GraniteSpeechFeatureExtractor accepts "audio" + mm_data["audio"] = audios + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + if "audio" in mm_data: + # Calculate the number of audio tokens per entry in the batch; + # This is used to split the batch back out after padding. + audio_token_index = self.info.get_hf_config().audio_token_index + processed_outputs["audio_embed_sizes"] = [ + torch.sum(indices == audio_token_index).item() + for indices in processed_outputs["input_ids"] + ] + + return processed_outputs + + +class GraniteSpeechDummyInputsBuilder( + BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + return { + "audio": + self._get_dummy_audios( + length=self.info.get_max_audio_len(), + num_audios=num_audios, + ) + } + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + audio_token = getattr(hf_processor, "audio_token", "<|audio|>") + return audio_token * num_audios + + +### QFormer Projector +class GraniteSpeechEncoderProjector(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: CacheConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.projector_config.hidden_size + self.downsample_rate = config.downsample_rate + self.window_size = config.window_size + self.num_queries = config.window_size // config.downsample_rate + + self.query = nn.Parameter( + torch.zeros(1, self.num_queries, + config.projector_config.hidden_size)) + + # NOTE - this is implemented generically in transformers, + # but for now we create the QFormer model directly since + # all existing models use this for the projector. + self.qformer = Blip2QFormerModel( + config.projector_config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.qformer", + ) + self.linear = nn.Linear(config.projector_config.hidden_size, + config.text_config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() + nblocks = math.ceil(seq_len / self.window_size) + pad = nblocks * self.window_size - seq_len + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), + "constant", 0) + hidden_states = hidden_states.view(batch_size * nblocks, + self.window_size, dim) + + last_hidden_state = self.qformer( + query_embeds=self.query.data, + encoder_hidden_states=hidden_states, + ) + + query_proj = self.linear( + last_hidden_state.view( + batch_size, + nblocks * self.window_size // self.downsample_rate, + -1, + )) + return query_proj + + +# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git +# NOTE - it would be nice to see if we can align this with other models using +# conformer in vLLM, e.g., phi4mm audio. +class GraniteSpeechConformerFeedForward(nn.Module): + """Feedforward module for conformer encoder blocks.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + + self.up_proj = ColumnParallelLinear( + input_size=config.hidden_dim, + output_size=config.hidden_dim * config.feedforward_mult, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.silu = nn.SiLU() + + self.down_proj = RowParallelLinear( + input_size=config.hidden_dim * config.feedforward_mult, + output_size=config.hidden_dim, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + hidden_states, _ = self.up_proj(hidden_states) + hidden_states = self.silu(hidden_states) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class GraniteSpeechConformerAttention(nn.Module): + """Attention for conformer blocks using Shaw's relative positional + embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155) + for more details. + """ + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = self.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, + self.dim_head) + + if self.context_size <= 0 or self.context_size > self.max_pos_emb: + raise ValueError( + "Context size is either less than 0 or exceeds the max_pos_emb" + ) + + def forward(self, hidden_states: torch.Tensor, + attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + bsz, num_features, _ = hidden_states.shape + + num_blocks = math.ceil(num_features / self.context_size) + remainder = num_features % self.context_size + if remainder > 0: + # right padding to reach block size + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, self.context_size - remainder)) + + # NOTE: would be nice to try to use qkvparallellinear + # here for this block attention implementation if possible + query_states = self.to_q(hidden_states) + key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) + + query_states = query_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, + -1).transpose(2, 3) + key_states = key_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, -1).transpose(2, 3) + value_states = value_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, + -1).transpose(2, 3) + + # shaw's relative positional embedding + dist = attention_dists.to(hidden_states.device) + rel_pos_emb = self.rel_pos_emb(dist) + rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + + list(rel_pos_emb.shape)) + pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, + dim=-1) * self.scale + + if remainder > 0: + # masked attention in the extended block + mask = torch.ones(self.context_size, + self.context_size, + dtype=bool, + device=hidden_states.device) + mask[:remainder, :remainder] = 0 + mask_value = -torch.finfo(pos_attn.dtype).max + pos_attn[:, -1, :].masked_fill_(mask, mask_value) + + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=pos_attn, + scale=self.scale) + out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) + return self.to_out(out[:, :num_features, :]) + + +class GraniteSpeechConformerDepthWiseConv1d(nn.Module): + """Wrapper for padded 1D pointwise convolution.""" + + def __init__(self, + chan_in: int, + chan_out: int, + kernel_size: int, + prefix: str = ""): + super().__init__() + # Padding for the 1D conv is symmetric or close (i.e., offset by one). + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + + self.conv = nn.Conv1d(chan_in, + chan_out, + kernel_size, + groups=chan_in, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.padding) + return self.conv(hidden_states) + + +class GraniteSpeechConformerConvModule(nn.Module): + """Conformer conv module consisting of several 1D/depthwise 1D + convolutional layers. + """ + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.glu = nn.GLU(dim=1) + self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=config.conv_kernel_size, + prefix=f"{prefix}.depth_conv", + ) + self.silu = nn.SiLU() + self.batch_norm = nn.BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) + hidden_states = self.glu(hidden_states) + hidden_states = self.depth_conv(hidden_states) + hidden_states = self.silu(self.batch_norm(hidden_states)) + hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) + return hidden_states + + +class GraniteSpeechConformerBlock(nn.Module): + """Conformer block, consisting largely of linear layers, + attention, and convolutional layers.""" + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + self.ff1 = GraniteSpeechConformerFeedForward(config, + prefix=f"{prefix}.ff1") + self.attn = GraniteSpeechConformerAttention(config, + prefix=f"{prefix}.attn") + self.conv = GraniteSpeechConformerConvModule(config, + prefix=f"{prefix}.conv") + self.ff2 = GraniteSpeechConformerFeedForward(config, + prefix=f"{prefix}.ff2") + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor, + attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states + hidden_states = self.attn( + hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = self.conv(hidden_states) + hidden_states + hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states + hidden_states = self.post_norm(hidden_states) + return hidden_states + + +class GraniteSpeechCTCEncoder(nn.Module): + """CTC Encoder comprising conformer blocks and additional linear layers.""" + + def __init__(self, + config: PretrainedConfig, + prefix: str, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + # Precompute clamped relative positional encoding distances + seq = torch.arange(config.context_size) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + self.attention_dists = torch.clamp( + relpos_dist, -config.context_size, + config.context_size) + config.max_pos_emb + + self.input_linear = nn.Linear(config.input_dim, + config.hidden_dim, + bias=True) + self.layers = nn.ModuleList([ + GraniteSpeechConformerBlock( + config, + prefix=f"{prefix}.layers.{idx}", + ) for idx in range(config.num_layers) + ]) + + self.out = ColumnParallelLinear( + input_size=config.hidden_dim, + output_size=config.output_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out", + ) + + self.out_mid = RowParallelLinear( + input_size=config.output_dim, + output_size=config.hidden_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_mid", + ) + self.softmax = nn.Softmax(dim=-1) + self.num_layers = config.num_layers + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.input_linear(hidden_states) + for idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, + attention_dists=self.attention_dists) + + if idx == self.num_layers // 2: + hidden_states_mid = hidden_states.clone() + hidden_states_mid, _ = self.out(hidden_states_mid) + hidden_states_mid = self.softmax(hidden_states_mid) + hidden_states_mid, _ = self.out_mid(hidden_states_mid) + hidden_states += hidden_states_mid + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + GraniteSpeechMultiModalProcessor, + info=GraniteSpeechMultiModalProcessingInfo, + dummy_inputs=GraniteSpeechDummyInputsBuilder) +class GraniteSpeechForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, +): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + + self.config = config + self.quant_config = quant_config + self.cache_config = cache_config + self.sampler = get_sampler() + + # The language model is typically a Granite LLM + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + # Conformer encoder + self.encoder = GraniteSpeechCTCEncoder( + config=config.encoder_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + + # Blip2 QFormer + self.projector = GraniteSpeechEncoderProjector( + config=config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.projector", + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_audio_input( + self, + **kwargs: object, + ) -> Optional[GraniteSpeechAudioInputs]: + input_features = kwargs.pop("input_features", None) + input_features_mask = kwargs.pop("input_features_mask", None) + audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) + if input_features is None: + return None + + # If we have a batch of variable feature length audio clips, we need + # to mask the features; usually we would get an input_features_mask + # from the processor, but we handle rebuilding it here since + # vLLM generally processes everything independently + batches. + if input_features_mask is None: + input_features_mask = self._build_input_features_mask( + audio_embed_sizes) + + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_features)}") + + if input_features_mask is not None and not isinstance( + input_features_mask, torch.Tensor): + raise ValueError("Incorrect type of audio input features mask. " + f"Got type: {type(input_features_mask)}") + + if isinstance(input_features, torch.Tensor): + # Granite speech currently only allows one audio token per instance + # and features are already unsqueezed in the processor, so one + # instance will have shape [1, {num_features}, 160]. As such, + # input features will usually be of shape + # [bsz, 1, num_features, 160], which we squeeze to be 3D here. + if len(input_features.shape) == 4: + input_features = input_features.squeeze(1) + if len(input_features.shape) != 3: + raise ValueError( + "Squeezed input features should be 3D but are of shape " + f"{input_features.shape}") + input_features = input_features.to( + self.encoder.input_linear.weight.dtype) + + else: + # Otherwise we have a list of tensors, which are almost certainly + # differing in their respective numbers of audio features; + # stack them into a 3D tensor of size [bsz, most_num_features, 160]. + input_features = self._pad_and_stack_input_features( + input_features, ).to(self.encoder.input_linear.weight.dtype) + + return GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes.flatten().tolist(), + ) + + def _build_input_features_mask( + self, + audio_embed_sizes: torch.Tensor, + ) -> torch.Tensor: + """Calculate the input features mask, which will generally be used + to mask the the padded features for all entries in the batch except + for those with the most audio features. + + Args: + audio_embed_sizes: torch.Tensor + Tensor of num features in each seq in the batch. + Returns: + torch.Tensor: Mask of shape (bsz, num_features) to be applied to + the audio features prior to splitting the audio embeddings. + """ + most_audio_features = torch.max(audio_embed_sizes).item() + mask_indices = torch.arange( + most_audio_features, + device=audio_embed_sizes.device, + ).view(1, -1) + input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1) + return input_features_mask + + def _pad_and_stack_input_features( + self, + input_features: list[torch.Tensor], + ) -> torch.Tensor: + """Given a list of input features of varying length, pad them to the + same length and stack them into a torch.Tensor. + + NOTE: Usually, padding is done in the input processor/feature extractor + and zero padded prior to the computation of the Mel features; the + resulting values are only constant within a batch and generally nonzero + (i.e., slightly negative nums); we should validate that this is okay + since we don't use a feature attention mask, but the more important + thing is that we apply the input_features_mask with variable len + batches. + + Args: + input_features: list[torch.Tensor] + Input features to be coerced into a tensor. + Returns: + torch.Tensor: Tensor of shape [bsz, num_features, 160], where + num_features is the max number of features of any entry in the + batch. + """ + # Input features are of shape [bsz, num_features, 160] + feat_lens = [feats.shape[1] for feats in input_features] + padding = [max(feat_lens) - length for length in feat_lens] + # TODO (Alex) - Validate that it's okay to zero pad like this; + # in transformers we zero pad prior to calculating the speech features, + # so the value is not zero and is dependent on the batched features. + padded = [ + torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0)) + for feats, pad in zip(input_features, padding) + ] + stacked_features = torch.cat(padded, dim=0).to(input_features[0]) + return stacked_features + + def _process_audio_input( + self, + audio_input: GraniteSpeechAudioInputs, + ) -> tuple[torch.Tensor]: + """Compute the audio features to be merged into the LLM embeddings. + + Args: + audio_input: GraniteSpeechAudioInputs + Audio inputs object containing Mel features, an input features + mask, and the (flattened) number of audio tokens per instance. + Returns: + tuple[torch.Tensor]: List of length bsz. + """ + # TODO (Alex) - support embedding inputs + encoder_embeds = self.encoder(audio_input["input_features"]) + # [bsz, , 4096] + projected_embeds = self.projector(encoder_embeds) + # Apply mask on variable length audio features + masked_embeds = projected_embeds[audio_input["input_features_mask"]] + # Split variable length features into a tuple + return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + + def get_multimodal_embeddings( + self, + **kwargs: object, + ) -> Optional[MultiModalEmbeddings]: + """Compute the audio embeddings if audio inputs are present.""" + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + audio_features = self._process_audio_input(audio_input) + return audio_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + """Compute the merged LLM / audio embeddings.""" + if multimodal_embeddings is None: + return self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = embed_multimodal( + input_ids, + self.config.audio_token_index, + self.language_model.model.get_input_embeddings, + multimodal_embeddings, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + audio_embeds = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) + input_ids = None + + model_output = self.language_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits( + hidden_states, + sampling_metadata, + ) + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix in multimodal models.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="projector", + tower_model="encoder", + ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 367722126e56..7fff14cb9f12 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -391,8 +390,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): scale=1 / self.config.logits_scaling) - self.sampler = get_sampler() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -428,14 +425,6 @@ def make_empty_intermediate_tensors( device=device), }) - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index cf8c969e118f..4e660cbf667b 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -295,8 +294,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): scale=1 / self.config.logits_scaling) - self.sampler = get_sampler() - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -332,14 +329,6 @@ def make_empty_intermediate_tensors( device=device), }) - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 2984f2241286..e4692c458088 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -170,7 +170,8 @@ def forward( mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( 1) - pooled_data = self.head(mean_embeddings) + pooled_data = self.head(mean_embeddings, + pooling_metadata=pooling_metadata) pooled_outputs = [ PoolingSequenceGroupOutput(data) for data in pooled_data diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index ef96257ba4bb..c48cb157084d 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -521,7 +520,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, self.output_multiplier_scale) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -551,14 +549,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: skip_prefixes = ["rotary_emb.inv_freq"] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index c31870461b4c..961954c2b584 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -603,7 +602,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.text_config.tie_word_embeddings: self.lm_head.weight = self.model.text_model.wte.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) - self.sampler = get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -754,14 +752,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 22c9287509ed..f141dcf3cd4f 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -13,7 +13,6 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import PoolerOutput - from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -103,14 +102,6 @@ def compute_logits( """Return `None` if TP rank > 0.""" ... - def sample( - self, - logits: T, - sampling_metadata: "SamplingMetadata", - ) -> "SamplerOutput": - """Only called on TP rank 0.""" - ... - @overload def is_text_generation_model( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 520b85c0cdfb..c3d7cbfcddbb 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -336,7 +335,6 @@ def __init__(self, if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -363,14 +361,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -423,7 +413,7 @@ def __init__( prefix=prefix, model_type=model_type) - for attr in ("output", "logits_processor", "sampler"): + for attr in ("output", "logits_processor"): delattr(self, attr) config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8f5f454cbf60..23b92ad2bbf6 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -8,7 +8,6 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union import torch @@ -20,7 +19,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -698,13 +696,6 @@ def _patch_quant_config(self, config: PretrainedConfig, (llm_quant_config is not None): quant_config.modules_to_not_convert.append("vision_model") - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _init_vision_model( self, config: PretrainedConfig, @@ -903,7 +894,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> IntermediateTensors: if intermediate_tensors is not None: input_ids = None @@ -941,13 +932,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 78fe6588eddc..e1e3f0f199c5 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -36,7 +36,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -308,7 +307,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.mup_width_scale) self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, scale=self.output_logits_scale) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -335,14 +333,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6fabc8228e18..46335c2b3930 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -409,7 +408,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -466,14 +464,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py new file mode 100644 index 000000000000..0629266860fd --- /dev/null +++ b/vllm/model_executor/models/kimi_vl.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py +# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import copy +import math +from collections.abc import Mapping +from dataclasses import dataclass +from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple, + TypedDict, Union) + +import torch +from torch import nn +from transformers import BatchFeature +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.moonvit import MoonVitPretrainedModel +from vllm.model_executor.models.utils import merge_multimodal_embeddings +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config + +from .utils import is_pp_missing_parameter, maybe_prefix + + +# For dummy input only +@dataclass +class MaxImageTokenMeta: + width: int = 1024 + height: int = 1024 + + +class KimiVLMultiModalProjector(nn.Module): + + def __init__(self, config: KimiVLConfig): + super().__init__() + + self.hidden_size = (config.vision_config.hidden_size * + config.vision_config.merge_kernel_size[0] * + config.vision_config.merge_kernel_size[1]) + + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-5) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear(self.hidden_size, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(image_features).view( + -1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class KimiVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape:`(num_patches, num_channels, patch_size, patch_size)` + """ + + image_grid_hws: torch.Tensor + """Shape:`(num_images, 2)`""" + + +# TODO: support embeds too +# We only support pixel input for kimi-vl now +KimiVLImageInputs = KimiVLImagePixelInputs + + +class KimiVLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(KimiVLConfig) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_processor = self.get_hf_processor() + patch_size = hf_processor.image_processor.patch_size + kernel_size = hf_processor.image_processor.merge_kernel_size + in_token_limit = hf_processor.image_processor.in_token_limit + height = image_height + width = image_width + assert isinstance(height, + int), f"height must be int, current height {height}" + assert isinstance(width, + int), f"width must be int, current width {width}" + assert kernel_size is not None, "kernel_size must be specified" + + if (width // patch_size) * (height // patch_size) > in_token_limit: + scale = math.sqrt(in_token_limit / ((width // patch_size) * + (height // patch_size))) + new_w, new_h = int(width * scale), int(height * scale) + width, height = new_w, new_h + + kernel_height, kernel_width = kernel_size + + pad_height = (kernel_height * patch_size - height % + (kernel_height * patch_size)) % (kernel_height * + patch_size) + pad_width = (kernel_width * patch_size - width % + (kernel_width * patch_size)) % (kernel_width * patch_size) + + # Calculate new dimensions after padding and patching + token_height = (height + pad_height) // (kernel_size[0] * patch_size) + token_width = (width + pad_width) // (kernel_size[1] * patch_size) + return int(token_height * token_width) + + @property + def image_token_id(self) -> int: + return self.get_hf_config().media_placeholder_token_id + + +class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=MaxImageTokenMeta.width, + height=MaxImageTokenMeta.height, + num_images=num_images) + } + + +class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2))) + image_grid_sizes = image_grid_hws.prod(-1) + + # pixel_values is merged as a single large tensor + # image_grid_hws is shapes for each subtensor in pixel_values + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_hws=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_token_id = self.info.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, + info=KimiVLProcessingInfo, + dummy_inputs=KimiVLDummyInputsBuilder) +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + model_config = vllm_config.model_config + config: KimiVLConfig = model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + + assert isinstance(config.vision_config, MoonViTConfig) + + self.vision_tower = MoonVitPretrainedModel(config.vision_config) + + self.multi_modal_projector = KimiVLMultiModalProjector(config=config) + + self.quant_config = quant_config + sub_vllm_config = copy.deepcopy(vllm_config) + sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config + self.language_model = DeepseekV2Model( + vllm_config=sub_vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.media_placeholder: int = self.config.media_placeholder_token_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_world_size = get_tensor_model_parallel_world_size() + + # ref: qwen2_vl.py + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mm_input.reshape(-1, mm_input.shape[-1]) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[KimiVLImageInputs]: + # image input type must be pixel values now + pixel_values = kwargs.pop("pixel_values", None) + image_grid_hws = kwargs.pop("image_grid_hws", None) + + if pixel_values is None: + return None + + image_grid_hws = self._validate_and_reshape_mm_tensor( + image_grid_hws, "image grid hws") + # pixel_values may have complex shapes + num_channels = 3 + patch_size = self.config.vision_config.patch_size + if isinstance(pixel_values, list): + pixel_values = torch.cat([ + x.reshape(-1, num_channels, patch_size, patch_size) + for x in pixel_values + ]) + else: + pixel_values = pixel_values.reshape(-1, num_channels, patch_size, + patch_size) + pixel_values = pixel_values.to(self.vision_tower.dtype) + # image_grid_hws.shape = (N, 2) + assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}" + + return KimiVLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_hws=image_grid_hws, + ) + + # perform vt on processored pixel_values + @torch.inference_mode() + def _process_image_pixels(self, + inputs: KimiVLImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["pixel_values"] + image_grid_hws = inputs["image_grid_hws"] + return self.vision_tower(pixel_values, image_grid_hws) + + def _process_image_input(self, + image_input: KimiVLImageInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + image_features = self._process_image_pixels(image_input) + assert isinstance(image_features, list) + lengths = [x.shape[0] for x in image_features] + return self.multi_modal_projector( + torch.cat(image_features)).split(lengths) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> Optional[NestedTensors]: + # Validate the multimodal input keyword arguments + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + + # Run multimodal inputs through encoder and projector + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + # `get_input_embeddings` should already be implemented for the language + # model as one of the requirements of basic vLLM model implementation. + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.media_placeholder_token_id) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings(input_ids) + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config. + media_placeholder_token_id, + ) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + **kwargs) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, **kwargs) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + config = self.config.text_config + _KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", + } + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + if not config.use_mla: + stacked_params_mapping += [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + if getattr(config, "n_routed_experts", None): + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=config.n_routed_experts) + else: + expert_params_mapping = [] + + params_dict = dict(self.named_parameters()) + for args in weights: + name, loaded_weight = args[:2] + kwargs = args[2] if len(args) > 2 else {} + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id, **kwargs) + break + else: + for idx, (param_name, weight_name, expert_id, + shard_id) in enumerate(expert_params_mapping): + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + **kwargs) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, **kwargs) + + +def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index caa4a5108a92..38a18180e234 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -131,8 +130,8 @@ def __init__(self, self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) # Phi models introduced a partial_rotary_factor parameter in the config - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) - self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -164,11 +163,12 @@ def __init__(self, self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_dim, + rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, ) if hasattr(config, "interleaved_sliding_window"): @@ -331,6 +331,8 @@ def __init__(self, else: self.norm = PPMissingLayer() + self.aux_hidden_state_layers: tuple[int] = tuple() + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -344,7 +346,8 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -356,7 +359,11 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -366,6 +373,9 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def load_weights(self, weights: Iterable[Tuple[str, @@ -515,11 +525,16 @@ def __init__(self, else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def _init_model(self, vllm_config: VllmConfig, prefix: str = "", @@ -551,11 +566,6 @@ def compute_logits( sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index e5d1a671f5d6..0fdc30f36f9b 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -273,8 +273,8 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.self_attn", ) - is_moe_layer = (self.layer_idx + - 1) % config.interleave_moe_layer_step == 0 + is_moe_layer = config.interleave_moe_layer_step > 0 and ( + self.layer_idx + 1) % config.interleave_moe_layer_step == 0 if is_moe_layer: self.feed_forward = Llama4MoE( config=config, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 28ad6128c4f1..56e53ac2b815 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -70,7 +70,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) @@ -82,7 +82,8 @@ def forward( hidden_states, residual, ) - return hidden_states + residual + hidden_states = hidden_states + residual + return hidden_states, hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -132,7 +133,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py new file mode 100644 index 000000000000..0b18e4a8fe2f --- /dev/null +++ b/vllm/model_executor/models/llama_eagle3.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import LlamaConfig + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM) +from vllm.v1.sample.metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class LlamaDecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config=quant_config, prefix=prefix) + + # override qkv + self.self_attn.qkv_proj = QKVParallelLinear( + 2 * self.hidden_size, + self.self_attn.head_dim, + self.self_attn.total_num_heads, + self.self_attn.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "qkv_proj"), + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + *, + model_config: ModelConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = model_config.hf_config + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + ) + ]) + if hasattr(self.config, "target_hidden_size"): + self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, + self.config.hidden_size, + bias=False) + else: + self.fc = torch.nn.Linear(self.config.hidden_size * 3, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + if (hidden_states.shape[-1] != input_embeds.shape[-1]): + hidden_states = self.fc(hidden_states) + + residual = None + hidden_states, residual = self.layers[0]( + positions, + input_embeds, + hidden_states, + residual, + ) + + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if 'midlayer.' in name: + name = name.replace('midlayer.', 'layers.0.') + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Eagle3LlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + nn.Module.__init__(self) + self.config = model_config.hf_config + self.model = LlamaModel(model_config=model_config, + start_layer_id=start_layer_id, + prefix="model") + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.draft_vocab_size, + padding_size=(DEFAULT_VOCAB_PADDING_SIZE), + prefix="") + self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, + scale=logit_scale) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros((self.config.draft_vocab_size), + dtype=torch.long).type(torch.LongTensor), + requires_grad=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full(( + logits.shape[0], + self.config.vocab_size, + ), float('-inf')) + logits_new[:, targets] = logits + return logits_new + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + elif "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + return loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index fbd212d17004..8862b2679f93 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -2,7 +2,6 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union, cast) @@ -23,7 +22,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -546,13 +544,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -763,13 +754,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 9c4d0e1fc275..c646c0f03d1e 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod -from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union) @@ -13,7 +12,6 @@ from typing_extensions import NotRequired from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig @@ -250,13 +248,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -585,13 +576,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 0221c6b237cb..a5ff189cfdb5 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -2,7 +2,6 @@ import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -12,7 +11,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -301,13 +299,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_video_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -469,13 +460,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 60d32c924694..5c2b388e403d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -2,7 +2,6 @@ import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -16,7 +15,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -455,13 +453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -583,21 +574,21 @@ def _parse_and_validate_video_input( ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} + mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) - return modalities + return mm_input_by_modality def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: @@ -848,8 +839,9 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each @@ -858,14 +850,13 @@ def get_multimodal_embeddings( # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - for modality in modalities: - if modality == "images": - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) multimodal_embeddings += tuple(vision_embeddings) - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_pixels(video_input) + if modality == "video": + video_embeddings = self._process_video_pixels(multimodal_input) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -957,13 +948,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 7a525ad8e494..af78ece66bbe 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -27,7 +26,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -154,6 +153,26 @@ def forward( return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsV0Only): @@ -193,7 +212,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) @@ -247,30 +265,7 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 526dec46ff29..78303733f6bb 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -19,7 +19,6 @@ MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -208,7 +207,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) @@ -282,14 +280,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index cf03396a9ca9..866dc3f466e7 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -553,7 +552,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -584,14 +582,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1a91cf9bab47..65a26eadd5c8 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -25,7 +25,7 @@ import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial +from functools import partial from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, Union) @@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM @@ -758,13 +757,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.llm, "sampler"): - return self.llm.sampler - - return get_sampler() - def _parse_and_validate_vision_input( self, modality: str, @@ -946,14 +938,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.llm.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 7562aa678d5a..74be08159cd8 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -994,7 +993,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) - self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -1030,16 +1028,6 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ): - - next_tokens = self.sampler(logits, sampling_metadata) - - return next_tokens - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 8b1a1d68fc3f..f8e9e3181367 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -2,7 +2,6 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, TypeVar, Union) @@ -19,7 +18,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -274,6 +272,9 @@ def _get_prompt_updates( vision_config = hf_config.vision_config assert isinstance(vision_config, PixtralVisionConfig) + # Need to sneak in spatial_merge_size for Mistral3 + vision_config.spatial_merge_size = getattr(hf_config, + "spatial_merge_size", 1) encoder_info = PixtralHFEncoderInfo(vision_config) def get_replacement(item_idx: int): @@ -435,13 +436,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -598,13 +592,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b0ac99f21ead..1513c8dad097 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -454,7 +453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -481,14 +479,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 96eb925cf894..7c022a5b8f68 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -372,7 +371,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -399,14 +397,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7bfb3ada6bb4..0c1d61c01f91 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -47,7 +47,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -1211,7 +1210,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) - self.sampler = get_sampler() def compute_logits( self, @@ -1222,14 +1220,6 @@ def compute_logits( hidden_states, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def unpack_data(self, image_data: Union[List[torch.Tensor], torch.Tensor], padding_value=0) -> torch.Tensor: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 0966f546ddf9..56a7f02c4159 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -17,7 +17,6 @@ # limitations under the License. import math from collections.abc import Iterable, Mapping -from functools import cached_property from itertools import tee from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union @@ -38,7 +37,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -672,9 +670,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")) - self.language_model = _initialize_model( - vllm_config=vllm_config.with_hf_config(config.text_config), + vllm_config=vllm_config.with_hf_config(config.text_config, + ["LlamaForCausalLM"]), prefix=maybe_prefix(prefix, "language_model"), model_class=Llama4ForCausalLM, ) @@ -682,13 +680,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: # num_images, 1, num_chunks, channel, image_size, image_size @@ -785,10 +776,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def separate_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], @@ -824,7 +811,7 @@ def load_weights(self, weights: Iterable[Tuple[str, # language_model is an Llama4ForCausalLM instance. We load it's # using llama4's load_weights routine. language_model_weights, other_weights = self.separate_weights( - weights, prefix="language_model.model.") + weights, prefix="language_model.") loader = AutoWeightsLoader(self) loaded_language_model_params = loader.load_weights( language_model_weights) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py new file mode 100644 index 000000000000..2190241f0ba3 --- /dev/null +++ b/vllm/model_executor/models/modernbert.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import ModernBertConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import CrossEncodingPooler +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsCrossEncoding +from .utils import WeightsMapper, maybe_prefix + + +class ModernBertEmbeddings(nn.Module): + + def __init__(self, config: ModernBertConfig): + + super().__init__() + self.config = config + self.tok_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + bias=config.norm_bias) + + def forward( + self, + input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds: + return self.norm(inputs_embeds) + else: + inputs_embeds = self.tok_embeddings(input_ids) + embeddings = self.norm(inputs_embeds) + return embeddings + + +class ModernBertRotaryEmbedding(RotaryEmbedding): + + def __init__(self, config: ModernBertConfig, head_size: int, dim: int, + base: float): + super().__init__( + head_size=head_size, + rotary_dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + is_neox_style=True, + dtype=torch.float16) + self.config = config + + +class ModernBertAttention(nn.Module): + + def __init__(self, + config: ModernBertConfig, + layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.layer_id = layer_id + self.deterministic_flash_attn = config.deterministic_flash_attn + self.num_heads = config.num_attention_heads + assert self.num_heads % tp_size == 0 + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.scaling = self.head_dim**-0.5 + self.Wqkv = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.num_heads, + bias=config.attention_bias, + ) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, + config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + if self.local_attention != ( + -1, -1) and config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config, + head_size=self.head_dim, + dim=self.head_dim, + base=rope_theta) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + prefix=f"{layer_id}.attn", + attn_type=AttentionType.ENCODER_ONLY) + self.Wo = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + qkv, _ = self.Wqkv(hidden_states) + q, k, v = qkv.split([self.all_head_size] * 3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_outputs = self.attn(q, k, v) + hidden_states = attn_outputs + hidden_states, _ = self.Wo(hidden_states) + return hidden_states + + +class ModernBertMLP(nn.Module): + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.config = config + self.Wi = nn.Linear(config.hidden_size, + int(config.intermediate_size) * 2, + bias=config.mlp_bias) + self.act = nn.GELU() + self.Wo = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=config.mlp_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input, gate = self.Wi(hidden_states).chunk(2, dim=-1) + return self.Wo(self.act(input) * gate)[0] + + +class ModernBertLayer(nn.Module): + + def __init__(self, + config: ModernBertConfig, + prefix: str = "", + layer_id: Optional[int] = None): + super().__init__() + self.config = config + if layer_id == 0: + self.attn_norm = nn.Identity() + else: + self.attn_norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + self.attn = ModernBertAttention(config=config, layer_id=layer_id) + self.mlp_norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + self.mlp = ModernBertMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ): + attn_outputs = self.attn(self.attn_norm(hidden_states), + position_ids=position_ids) + hidden_states = hidden_states + attn_outputs + mlp_output = self.mlp(self.mlp_norm(hidden_states)) + hidden_states = hidden_states + mlp_output + return hidden_states + + +class ModernBertEncoderLayer(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.layers = nn.ModuleList([ + ModernBertLayer(config=config, layer_id=layer_id) + for layer_id in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, position_ids) + return hidden_states + + +@support_torch_compile +class ModernBertModel(nn.Module): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"layers.": "encoder_layer.layers."}) + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.embeddings = ModernBertEmbeddings(config) + self.encoder_layer = ModernBertEncoderLayer(vllm_config) + self.final_norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + weights = self.hf_to_vllm_mapper.apply(weights) + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids=input_ids, + inputs_embeds=inputs_embeds) + + outputs = self.encoder_layer( + hidden_states=hidden_states, + position_ids=position_ids, + ) + norm_outputs = self.final_norm(outputs) + return norm_outputs + + +class ModernBertPooler(nn.Module): + + def __init__(self, config: ModernBertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size, + config.classifier_bias) + self.act = nn.GELU() + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.norm_eps, + bias=config.norm_bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pooled_output = hidden_states + pooled_output = pooled_output.mean(dim=0, keepdim=False) + pooled_output = self.norm(self.act(self.dense(pooled_output))) + return pooled_output + + +class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.model = ModernBertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "modernbert")) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = CrossEncodingPooler(config, self.classifier, + ModernBertPooler(config)) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("model."): + yield name[len("model."):], weight + else: + self_weights.append((name, weight)) + + self.model.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if name.startswith("head"): + param = params_dict["_pooler.pooler." + name[len("head") + 1:]] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.LongTensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=positions, + ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index d75845b45e73..46147a333b06 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -1394,7 +1393,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -1506,7 +1504,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> torch.Tensor: if intermediate_tensors is not None: inputs_embeds = None @@ -1532,14 +1530,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py new file mode 100644 index 000000000000..c367d90f847b --- /dev/null +++ b/vllm/model_executor/models/moonvit.py @@ -0,0 +1,628 @@ +# SPDX-License-Identifier: Apache-2.0 +# ruff: noqa: E501 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py +# This file is meant to be used in kimi_vl.py only +# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved. +# +# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL. +# +# Licensing Information: +# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0. +# - Other parts of the code are licensed under the MIT License. +# +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math +from copy import deepcopy +from functools import cached_property +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import ACT2FN, PytorchGELUTanh +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import is_flash_attn_2_available + +from vllm.transformers_utils.configs.moonvit import MoonViTConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func +else: + flash_attn_varlen_func = None + + +def multihead_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, +): + """Multi-head attention using flash attention 2. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. + The first element should be 0 and the last element should be q.shape[0]. + k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. + The first element should be 0 and the last element should be k.shape[0]. + + Returns: + output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, + where dim = num_heads * head_dim + """ + # Unified format legal check + assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" + assert q_cu_seqlens[-1] == q.shape[ + 0], "q_cu_seqlens must sum to q.shape[0]" + assert (k_cu_seqlens[-1] == k.shape[0] == + v.shape[0]), "k_cu_seqlens must sum to k.shape[0]" + assert q.dtype in [ + torch.bfloat16, + torch.float16, + ], f"unsupported dtype {q.dtype} for multihead attn" + + max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() + max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() + attn_out = flash_attn_varlen_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + max_seqlen_q, + max_seqlen_k, + causal=False, + ) + attn_out = attn_out.flatten(start_dim=-2) + + return attn_out + + +def sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_cu_seqlens: Optional[torch.Tensor] = None, + k_cu_seqlens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """SDPA attention. + + Args: + q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + """ + seq_length = q.shape[0] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(q_cu_seqlens)): + attention_mask[ + ..., + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + q_cu_seqlens[i - 1]:q_cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + return attn_output + + +VL_VISION_ATTENTION_FUNCTIONS = { + "flash_attention_2": multihead_attention, + "sdpa": sdpa_attention, +} + + +def _apply_rope_input_validation(x, freqs_cis): + assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) + assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) + assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) + assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, + freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: (The leading dimensions of all inputs should be the same) + xq: query, tensor of shape (..., num_heads, head_dim) + xk: key, tensor of shape (..., num_heads, head_dim) + freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid. + Returns: + xq_out, xk_out: tensors of shape (..., num_heads, head_dim) + """ + _apply_rope_input_validation(xq, freqs_cis) + _apply_rope_input_validation(xk, freqs_cis) + + freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2 + # ..., num_heads, head_dim/2 + xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten( + -2) # ..., num_heads, head_dim + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class Learnable2DInterpPosEmb(nn.Module): + + def __init__(self, + height: int, + width: int, + dim: int, + interpolation_mode: str = "bicubic") -> None: + super().__init__() + self.height = height + self.width = width + self.interpolation_mode = interpolation_mode + self.weight = nn.Parameter(torch.empty(height, width, dim)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.weight) + + def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor: + pos_embs = [] + for shape in grid_hws.tolist(): + if shape == self.weight.shape[:-1]: + pos_embs.append(self.weight.flatten(end_dim=1)) + else: + pos_embs.append( + F.interpolate( + self.weight.permute((2, 0, 1)).unsqueeze(0), + size=shape, + mode=self.interpolation_mode, + ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1)) + out = x + torch.cat(pos_embs) + return out + + +class MoonVisionPatchEmbed(nn.Module): + + def __init__( + self, + out_dim: int, + in_dim: int = 3, + patch_size: Union[int, Tuple[int, int]] = (14, 14), + pos_emb_height: int = 14, + pos_emb_width: int = 14, + ): + super().__init__() + assert isinstance( + patch_size, + (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}" + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + assert (len(patch_size) == 2 + ), f"Expected patch_size to be a tuple of 2, got {patch_size}" + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_dim, + out_dim, + kernel_size=patch_size, + stride=patch_size) + + self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height, + width=pos_emb_width, + dim=out_dim) + + def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + """ + Args: + x (L, Channels): input tensor + grid_hw (N, 2): grid height and width + + Returns: + (L, Cout) tensor + """ + x = self.proj(x).view(x.size(0), -1) + # apply positional embedding + x = self.pos_emb(x, grid_hw) + return x + + +class Rope2DPosEmb(nn.Module): + """2D rotary position embedding with multi-resolution support. + + This class is intended to be used in the following way: + 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. + 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration. + 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation. + The rope is shared across all attention layers and all heads. + + Refs: + - RoFormer: https://arxiv.org/abs/2104.09864 + - VisionLLaMA: https://arxiv.org/abs/2403.00522 + - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py + + Args: + dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed) + max_height (int): the maximum height of the 2D grid + max_width (int): the maximum width of the 2D grid + theta_base (float): the base of the theta + device (str): the device to store the precomputed cis + """ + + def __init__(self, + dim: int, + max_height: int, + max_width: int, + theta_base=10000, + device="cuda"): + super().__init__() + self.dim = dim + assert self.dim % 4 == 0, "dim must be divisible by 4" + self.max_height = max_height + self.max_width = max_width + self.theta_base = theta_base + self.device = device + + def extra_repr(self): + return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}" + + @cached_property + def precomputed_freqs_cis(self) -> torch.Tensor: + """Calculate the cis(freqs) for each position in the 2D grid. + + Return: complex tensor of shape (max_height, max_width, dim//2) and value: + height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim)) + weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4)) + note: `cis` is a mathematical notation defined by cis x = cos x + i sin x, + """ + N = self.max_height * self.max_width + flat_pos = torch.arange(0, N).float().to(self.device) + x_pos = flat_pos % self.max_width + y_pos = flat_pos // self.max_width + dim_range = (torch.arange(0, self.dim, + 4)[:(self.dim // 4)].float().to(self.device) + ) # C/4 + freqs = 1.0 / (self.theta_base**(dim_range / self.dim)) + x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 + y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 + # N, C/4, 2 + freqs_cis = torch.cat( + [x_cis.unsqueeze(dim=-1), + y_cis.unsqueeze(dim=-1)], dim=-1) + # max_height, max_width, C/2 + freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) + return freqs_cis + + def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor: + """ + Args: + grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples. + Returns: + freqs_cis: tensor of shape (sum(t * height * width), dim//2) + """ + shapes = grid_hws.tolist() + assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width + for h, w in shapes), ( + shapes, + self.max_height, + self.max_width, + ) + freqs_cis = torch.cat( + [ + self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2) + for h, w in shapes + ], + dim=0, + ) + return freqs_cis + + def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor, + pos_idx_mask: torch.Tensor) -> torch.Tensor: + """ + Args: + pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token. + pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx. + Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones. + Return: + freqs_cis: tensor of shape (..., dim//2) + """ + assert (pos_idx.shape[:-1] == pos_idx_mask.shape + and pos_idx.shape[-1] == 2 and pos_idx.ndim + == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape) + assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype + + shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2 + freqs_cis = torch.ones(shp, dtype=torch.complex64, + device=self.device) # ..., head_dim/2 + freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[ + ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]] + return freqs_cis + + +class MLP2(nn.Module): + """ + Args: + dims: [in_dim, hidden_dim, out_dim] + bias: whether to use bias in linear layer. + """ + + def __init__(self, dims: list[int], activation, bias=True): + super().__init__() + assert len(dims) == 3 + self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) + self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) + self.activation = activation + for m in [self.fc0, self.fc1]: + nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc0(x) + x = self.activation(x) + return self.fc1(x) + + +class MoonVitEncoderLayer(nn.Module): + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + *, + attn_implementation: str = "sdpa", + activation=F.gelu, + attn_bias: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads + self.attn_implementation = attn_implementation + # use fa2 in vllm by default + if is_flash_attn_2_available(): + self.attn_implementation = "flash_attention_2" + + self.norm0 = nn.LayerNorm(hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) + self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) + self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) + + def attention_qkvpacked( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: Optional[torch.Tensor] = None, + ): + """ + Args: + x (torch.Tensor): (batch_size, seqlen, hidden_dim) + cu_seqlens (torch.Tensor): + """ + xqkv = self.wqkv(x) + + qkv_shape = xqkv.size()[:-1] + ( + 3, + self.num_heads, + self.hidden_size_per_attention_head, + ) + # xqkv: (batch_size, seqlen, 3, nheads, headdim) + xqkv = xqkv.view(*qkv_shape) + xq, xk, xv = torch.unbind(xqkv, dim=-3) + + xq, xk = apply_rope(xq, xk, rope_freqs_cis) + + attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] + attn_out = attn_func(xq, + xk, + xv, + q_cu_seqlens=cu_seqlens, + k_cu_seqlens=cu_seqlens) + + attn_out = self.wo(attn_out) + return attn_out + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: Union[torch.Tensor, None] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set + + Returns: + output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input + """ + residual = hidden_states + hidden_states = self.norm0(hidden_states) + attn_out = self.attention_qkvpacked(hidden_states, + cu_seqlens, + rope_freqs_cis=rope_freqs_cis) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.mlp(self.norm1(hidden_states)) + hidden_states = residual + hidden_states + return hidden_states + + +class MoonVitEncoder(nn.Module): + + def __init__( + self, + hidden_dim: int, + num_layers: int, + block_cfg: dict, + ) -> None: + super().__init__() + + self.rope_2d = Rope2DPosEmb( + block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) + self.blocks = nn.ModuleList( + [MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]) + self.final_layernorm = nn.LayerNorm(hidden_dim) + + def forward(self, hidden_states: torch.Tensor, + grid_hw: torch.Tensor) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( + grid_hws=grid_hw) + + lengths = torch.cat(( + torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + grid_hw[:, 0] * grid_hw[:, 1], + )) + cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) + + for _, block in enumerate(self.blocks): + hidden_states = block(hidden_states, + cu_seqlens, + rope_freqs_cis=rope_freqs_cis) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +def patch_merger( + x: torch.Tensor, + grid_hw: torch.Tensor, + merge_kernel_size: list[int, int] = (2, 2), +) -> List[torch.Tensor]: + d_model = x.size(-1) + + outputs = [] + pre_sum = 0 + for x_shape in grid_hw.tolist(): + height, width = x_shape[0], x_shape[1] + # Get the current sequence + seq = x[pre_sum:pre_sum + height * width] + # Reshape along self.merge_kernel_size and concat to the last dimension + kernel_height, kernel_width = merge_kernel_size + new_height, new_width = height // kernel_height, width // kernel_width + reshaped_seq = seq.view(new_height, kernel_height, new_width, + kernel_width, d_model) + reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous() + padded_seq = reshaped_seq.view(new_height * new_width, + kernel_height * kernel_width, -1) + outputs.append(padded_seq) + pre_sum += height * width + + return outputs + + +class MoonVitVLProjector(nn.Module): + + def __init__( + self, + in_channels: int, + merge_kernel_size: list[int, int], + hidden_act: str = "gelu", + ln_eps: float = 1e-5, + out_dim: int = 4096, + ): + super().__init__() + self.hidden_size = in_channels * merge_kernel_size[ + 0] * merge_kernel_size[1] + + self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = ACT2FN[hidden_act] + self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class MoonVitPretrainedModel(PreTrainedModel): + config_class = MoonViTConfig + model_type = "moonvit" + _no_split_modules = ["PackingTransformer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, config: MoonViTConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + config = deepcopy(config) + self.merge_kernel_size = config.merge_kernel_size + self.patch_size = config.patch_size + self.patch_embed = MoonVisionPatchEmbed( + out_dim=config.hidden_size, + patch_size=config.patch_size, + pos_emb_height=config.init_pos_emb_height, + pos_emb_width=config.init_pos_emb_width, + ) + + self.encoder = MoonVitEncoder( + hidden_dim=config.hidden_size, + num_layers=config.num_hidden_layers, + block_cfg={ + "num_heads": config.num_attention_heads, + "hidden_dim": config.hidden_size, + "mlp_dim": config.intermediate_size, + "activation": PytorchGELUTanh(), + "attn_bias": True, + "attn_implementation": config._attn_implementation, + }, + ) + + def forward(self, pixel_values: torch.Tensor, + grid_hw: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input pixel values. + grid_hw (torch.Tensor): The grid height and width. + + Returns: + torch.Tensor: The output tokens. + """ + hidden_states = self.patch_embed(pixel_values, grid_hw) + hidden_states = self.encoder(hidden_states, grid_hw) + hidden_states = patch_merger(hidden_states, + grid_hw, + merge_kernel_size=self.merge_kernel_size) + return hidden_states diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index b30f3ee37997..77bd794058cd 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -18,7 +18,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -298,7 +297,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "transformer")) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -325,14 +323,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 0ea296b2f93d..5208c0796c8d 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -416,8 +415,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -444,14 +441,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 5c9b04cab180..264999496876 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -408,8 +407,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -439,11 +436,6 @@ def compute_logits( sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 4a341c97d6cd..0781ca168f84 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -309,7 +308,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -340,14 +338,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index f9427cdadf7a..44beae5726dc 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -28,6 +28,7 @@ import torch from torch import nn +from transformers import Olmo2Config from vllm.attention import Attention from vllm.config import VllmConfig @@ -42,7 +43,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -52,7 +52,6 @@ make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.olmo2 import Olmo2Config class Olmo2Attention(nn.Module): @@ -339,7 +338,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -367,14 +365,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 6cf3f1f82645..9bed29d0132f 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -39,7 +38,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -255,7 +254,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size - + self.config = config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -308,56 +307,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class OlmoeForCausalLM(nn.Module, SupportsPP): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = OlmoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -380,8 +329,6 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: @@ -453,3 +400,50 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class OlmoeForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["rotary_emb.inv_freq"], + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 4a12f36d90e8..d258eddae25d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -35,7 +35,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -43,7 +42,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -313,6 +312,43 @@ def forward( intermediate_tensors, inputs_embeds=inputs_embeds) + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class OPTForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { @@ -320,6 +356,10 @@ class OPTForCausalLM(nn.Module, SupportsPP): "gate_up_proj": ["gate_proj", "up_proj"] } + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "decoder.": "model.decoder.", + }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -334,7 +374,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead(config.vocab_size, config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -361,52 +400,11 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "lm_head.weight" in name and self.config.tie_word_embeddings: - continue - if name.startswith("decoder."): - name = "model." + name - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 0b42666e02d6..8d9c000750d7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -30,7 +29,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -260,6 +259,45 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class OrionForCausalLM(nn.Module, SupportsPP): @@ -277,7 +315,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -304,56 +341,16 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + loader = AutoWeightsLoader( + self, + skip_prefixes=([ + "rotary_emb.inv_freq", # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + "rotary_emb.cos_cached", + "rotary_emb.sin_cached" + ]), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 6c1bd499f639..8699ae52622d 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -8,7 +8,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -260,10 +259,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @property - def sampler(self): - return self.language_model.sampler - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -369,7 +364,7 @@ def forward(self, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + **kwargs: object) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -396,13 +391,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index db8d170a8c91..eacf02433b57 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -46,7 +45,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -116,9 +115,10 @@ def __init__(self, self.rotary_emb = get_rope( self.head_dim, - rotary_dim=int(self.partial_rotary_factor * self.head_dim), + rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, @@ -221,7 +221,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size - + self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( @@ -260,6 +260,38 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # copy from vllm/model_executor/models/bloom.py + # NOTE: Persimmon's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class PersimmonForCausalLM(nn.Module, SupportsPP): @@ -274,7 +306,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, bias=False) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -305,49 +336,7 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - - if "query_key_value" in name: - # copy from vllm/model_executor/models/bloom.py - # NOTE: Persimmon's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index fdf7734595a5..fc2b108bad97 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -53,7 +53,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -322,7 +321,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): bias=True, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -350,14 +348,6 @@ def compute_logits( sampling_metadata, self.lm_head.bias) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 33984f54ae27..338e87b4285f 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -17,7 +17,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -26,7 +25,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -353,10 +352,29 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class Phi3SmallForCausalLM(nn.Module, SupportsPP): _tied_weights_keys = ["lm_head.weight"] + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_suffix={"rotary_emb.inv_freq": None}) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -377,7 +395,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -418,6 +435,7 @@ def compute_logits( sampling_metadata) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + logits = logits / self.mup_width_multiplier return logits def forward( @@ -436,33 +454,10 @@ def forward( output_hidden_states = output_hidden_states return output_hidden_states - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - - next_tokens = self.sampler(logits / self.mup_width_multiplier, - sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if "lm_head.weight" in name and self.config.tie_word_embeddings: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head.weight"] + if self.config.tie_word_embeddings else None)) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 7f41ad2359df..a1442251b992 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -16,7 +16,6 @@ # limitations under the License. import re from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -27,7 +26,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -327,7 +325,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[ProcessorMixin], + processor: Optional[ProcessorMixin] = None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -555,13 +553,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -716,13 +707,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index ec19797f8875..6035994f4336 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1,41 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 import math -import re -from functools import lru_cache -from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union import numpy as np -import scipy.signal import torch import torch.nn as nn -import torchvision.transforms as T -from PIL import Image -from transformers import PretrainedConfig, SiglipVisionConfig -from transformers.utils import logging +from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, + SequenceFeatureExtractor, SiglipVisionConfig) from vllm.config import VllmConfig from vllm.distributed import get_pp_group -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) -from vllm.inputs.data import TokenInputs, token_inputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, + ImageProcessorItems, ImageSize, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only +from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 @@ -43,115 +43,19 @@ _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 -DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz - -DYNAMIC_HD = 16 -AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" -IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" SIGLIP_NAME = "siglip-so400m-patch14-448" VISION_ENCODER_TO_PROCESSING_CONFIG = { 'siglip-so400m-patch14-448': { - 'dynamic_hd': 16, 'vit_image_size': 448, 'vit_patch_size': 14, 'token_compression_factor': 2, }, } -logger = logging.get_logger(__name__) -# This is a workaround to prevent text (user input) + audio + image -# from being used in the same prompt. -# It includes token ids for "/n" and tokens in added_tokens_decoder -# from the tokenizer_confg.json file. -NON_USER_INPUT_TOKENS = { - 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, - 200023, 200024, 200025, 200026, 200027, 200028 -} -def get_max_dummy_image(ctx: InputContext): - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - - max_side = vit_image_size * dynamic_hd_size - dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) - return dummy_image - - -# image token length -def get_max_phi4mm_image_tokens(ctx: InputContext): - dummy_image = get_max_dummy_image(ctx) - - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] - - image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, - vit_image_size, - vit_patch_size, - token_compression_factor) - return image_num_tokens - - -def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, - image_size): - best_ratio_diff = float('inf') - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - -def _find_target_aspect_ratio(image, image_size, max_num, min_num): - orig_width, orig_height = image.size - - w_crop_num = math.ceil(orig_width / float(image_size)) - h_crop_num = math.ceil(orig_height / float(image_size)) - if w_crop_num * h_crop_num > max_num: - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, target_ratios, orig_width, orig_height, image_size) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - logger.debug("target_aspect_ratio: %s", target_aspect_ratio) - else: - target_width = image_size * w_crop_num - target_height = image_size * h_crop_num - target_aspect_ratio = (w_crop_num, h_crop_num) - return target_aspect_ratio, target_height, target_width - - -def _get_padding_size(image, target_height, target_width): - orig_width, orig_height = image.size +def _get_padding_size(orig_width: int, orig_height: int, target_height: int, + target_width: int): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -164,181 +68,6 @@ def _get_padding_size(image, target_height, target_width): return padding_height, padding_width -def dynamic_preprocess(image, - min_num=1, - max_num=12, - image_size=384, - mask_size=27): - target_aspect_ratio, target_height, target_width =\ - _find_target_aspect_ratio( - image, image_size, max_num, min_num) - padding_height, padding_width = _get_padding_size(image, target_height, - target_width) - - # Calculate the ratio - orig_width, orig_height = image.size - ratio_width = target_width / orig_width - ratio_height = target_height / orig_height - if ratio_width < ratio_height: - new_size = (target_width, int(orig_height * ratio_width)) - else: - new_size = (int(orig_width * ratio_height), target_height) - - attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), - int(mask_size * target_aspect_ratio[0]))) - if padding_width >= 14: - attention_mask[:, -math.floor(padding_width / 14):] = 0 - if padding_height >= 14: - attention_mask[-math.floor(padding_height / 14):, :] = 0 - assert attention_mask.sum( - ) > 0, f'attention mask is empty {attention_mask}' - - if min(new_size[1], target_height) < 10 or min(new_size[0], - target_width) < 10: - raise ValueError(f'the aspect ratio is very extreme {new_size}') - - image = T.functional.resize( - image, - [new_size[1], new_size[0]], - ) - - resized_img = T.functional.pad(image, - [0, 0, padding_width, padding_height], - fill=[255, 255, 255]) - - return resized_img, attention_mask - - -def pad_to_max_num_crops(images, max_crops=5): - """ - images: B x 3 x H x W, B<=max_crops - """ - B, _, H, W = images.shape - if max_crops > B: - pad = torch.zeros(max_crops - B, - 3, - H, - W, - dtype=images.dtype, - device=images.device) - images = torch.cat([images, pad], dim=0) - return images - - -def pad_mask_to_max_num_crops(masks, max_crops=5): - B, H, W = masks.shape - if max_crops > B: - pad = torch.ones(max_crops - B, - H, - W, - dtype=masks.dtype, - device=masks.device) - masks = torch.cat([masks, pad], dim=0) - return masks - - -def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): - - # Basic settings. - img_processor = T.Compose([ - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ]) - # Dynamic HD - base_resolution = vit_resolution - images = [image.convert('RGB') for image in images] - # cover 384 and 448 resolution - mask_resolution = base_resolution // vit_patch_size - elems, image_attention_masks = [], [] - for im in images: - elem, attention_mask = dynamic_preprocess(im, - max_num=dynamic_hd_size, - image_size=base_resolution, - mask_size=mask_resolution) - elems.append(elem) - image_attention_masks.append(attention_mask) - hd_images = [img_processor(im) for im in elems] - global_image = [ - torch.nn.functional.interpolate( - im.unsqueeze(0).float(), - size=(base_resolution, base_resolution), - mode='bicubic', - ).to(im.dtype) for im in hd_images - ] - shapes = [[im.size(1), im.size(2)] for im in hd_images] - mask_shapes = [[mask.size(0), mask.size(1)] - for mask in image_attention_masks] - global_attention_mask = [ - torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images - ] - hd_images_reshape = [ - im.reshape(1, 3, h // base_resolution, base_resolution, - w // base_resolution, base_resolution).permute( - 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, - base_resolution).contiguous() - for im, (h, w) in zip(hd_images, shapes) - ] - attention_masks_reshape = [ - mask.reshape(1, h // mask_resolution, mask_resolution, - w // mask_resolution, mask_resolution).permute( - 0, 1, 3, 2, 4).reshape(-1, mask_resolution, - mask_resolution).contiguous() - for mask, (h, w) in zip(image_attention_masks, mask_shapes) - ] - # NOTE token compression is hard coded here, and odd numbers seems to fail - downsample_attention_masks = [ - mask[:, 0::2, - 0::2].reshape(1, h // mask_resolution, w // mask_resolution, - mask_resolution // 2 + mask_resolution % 2, - mask_resolution // 2 + mask_resolution % 2).permute( - 0, 1, 3, 2, 4) - for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) - ] - downsample_attention_masks = [ - mask.reshape(mask.size(1) * mask.size(2), - mask.size(3) * mask.size(4)) - for mask in downsample_attention_masks - ] - # NOTE hard coded number of tokens - num_img_tokens = [ - 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 - for mask in downsample_attention_masks - ] - - hd_images_reshape = [ - torch.cat([_global_image] + [_im], dim=0) - for _global_image, _im in zip(global_image, hd_images_reshape) - ] - hd_masks_reshape = [ - torch.cat([_global_mask] + [_mask], - dim=0) for _global_mask, _mask in zip( - global_attention_mask, attention_masks_reshape) - ] - max_crops = max([img.size(0) for img in hd_images_reshape]) - image_transformed = [ - pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape - ] - image_transformed = torch.stack(image_transformed, dim=0) - mask_transformed = [ - pad_mask_to_max_num_crops(mask, max_crops) \ - for mask in hd_masks_reshape - ] - mask_transformed = torch.stack(mask_transformed, dim=0) - - returned_input_image_embeds = image_transformed - returned_image_sizes = torch.tensor(shapes, dtype=torch.long) - returned_image_attention_mask = mask_transformed - returned_num_img_tokens = num_img_tokens - - data = { - "pixel_values": returned_input_image_embeds, - "image_sizes": returned_image_sizes, - "image_attention_mask": returned_image_attention_mask, - "num_img_tokens": returned_num_img_tokens, - } - return data - - def get_navit_vision_model(layer_idx: int = -1, **kwargs): vision_config = { "hidden_size": 1152, @@ -492,7 +221,7 @@ def get_img_features(self, def forward(self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - image_attention_mask: torch.Tensor) -> torch.FloatTensor: + image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]: """ process image and return vision embeddings. @@ -656,785 +385,505 @@ def forward(self, pixel_values: torch.FloatTensor, for _output_img in output_imgs: img_feature_proj = self.img_projection( _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) + img_set_tensor.append(img_feature_proj.squeeze(0)) return img_set_tensor -class Phi4MMAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - data: Tuple[NestedTensors] - """Shape: `((batch_size, num_audios, 80, M), )""" - - -class Phi4MMAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" - - -Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] - - -def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): - """Create a Mel filter-bank the same as SpeechLib FbankFC. - - Args: - sample_rate (int): Sample rate in Hz. number > 0 [scalar] - n_fft (int): FFT size. int > 0 [scalar] - n_mel (int): Mel filter size. int > 0 [scalar] - fmin (float): lowest frequency (in Hz). If None use 0.0. - float >= 0 [scalar] - fmax: highest frequency (in Hz). If None use sample_rate / 2. - float >= 0 [scalar] - - Returns - out (numpy.ndarray): Mel transform matrix - [shape=(n_mels, 1 + n_fft/2)] +class Phi4MMImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - bank_width = int(n_fft // 2 + 1) - if fmax is None: - fmax = sample_rate / 2 - if fmin is None: - fmin = 0 - assert fmin >= 0, "fmin cannot be negative" - assert (fmin < fmax <= - sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" - - def mel(f): - return 1127.0 * np.log(1.0 + f / 700.0) - - def bin2mel(fft_bin): - return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) - - def f2bin(f): - return int((f * n_fft / sample_rate) + 0.5) - - # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] - klo = f2bin(fmin) + 1 - khi = f2bin(fmax) - - khi = max(khi, klo) - - # Spec 2: SpeechLib uses triangles in Mel space - mlo = mel(fmin) - mhi = mel(fmax) - m_centers = np.linspace(mlo, mhi, n_mels + 2) - ms = (mhi - mlo) / (n_mels + 1) - - matrix = np.zeros((n_mels, bank_width), dtype=np.float32) - for m in range(0, n_mels): - left = m_centers[m] - center = m_centers[m + 1] - right = m_centers[m + 2] - for fft_bin in range(klo, khi): - mbin = bin2mel(fft_bin) - if left < mbin < right: - matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms - - return matrix - - -class LogFbankProcessor: - - def __init__(self): - - self._eightk_method = "fillzero" - self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T - - self._hamming400 = np.hamming(400) # for 16k audio - self._hamming200 = np.hamming(200) # for 8k audio + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ - def extract_spectrogram(self, wav, fs): - """Extract spectrogram features from waveform. - Args: - wav (1D array): waveform of the input - fs (int): sampling rate of the waveform, 16000 or 8000. - If fs=8000, the waveform will be resampled to 16000Hz. - Output: - log_fbank (2D array): a TxD matrix of log Mel filterbank features. - D=80, and T is the number of frames. - """ - if wav.ndim > 1: - wav = np.squeeze(wav) + image_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` - # by default, we extract the mean if stereo - if len(wav.shape) == 2: - wav = wav.mean(1) + This should be in `(height, width)` format. + """ - # Resample to 16000 or 8000 if needed - if fs > 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 16000) - fs = 16000 - elif 8000 < fs < 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 8000) - fs = 8000 - elif fs < 8000: - raise RuntimeError(f"Unsupported sample rate {fs}") - - if fs == 8000: - if self._eightk_method == "resample": - # Input audio is 8 kHz. Convert to 16 kHz before feature - # extraction - wav = scipy.signal.resample_poly(wav, 2, 1) - fs = 16000 - # Do nothing here for fillzero method - elif fs != 16000: - # Input audio is not a supported sample rate. - raise RuntimeError( - f"Input data using an unsupported sample rate: {fs}") - - preemphasis = 0.97 - - if fs == 8000: - n_fft = 256 - win_length = 200 - hop_length = 80 - fft_window = self._hamming200 - elif fs == 16000: - n_fft = 512 - win_length = 400 - hop_length = 160 - fft_window = self._hamming400 - - # Spec 1: SpeechLib cut remaining sample insufficient for a hop - n_batch = (wav.shape[0] - win_length) // hop_length + 1 - # Here we don't use stride_tricks since the input array may not satisfy - # memory layout requirement and we need writeable output - # Here we only use list of views before copy to destination - # so it is more efficient than broadcasting - y_frames = np.array( - [ - wav[_stride:_stride + win_length] - for _stride in range(0, hop_length * n_batch, hop_length) - ], - dtype=np.float32, - ) + num_img_tokens: list[int] + """Shape: `(batch_size * num_images)`""" - # Spec 2: SpeechLib applies preemphasis within each batch - y_frames_prev = np.roll(y_frames, 1, axis=1) - y_frames_prev[:, 0] = y_frames_prev[:, 1] - y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + image_attention_mask: torch.Tensor + """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - S = np.fft.rfft(fft_window * y_frames, n=n_fft, - axis=1).astype(np.complex64) - if fs == 8000: - # Need to pad the output to look like 16 kHz data but with zeros in - # the 4 to 8 kHz bins. - frames, bins = S.shape - padarray = np.zeros((frames, bins)) - S = np.concatenate((S[:, 0:-1], padarray), - axis=1) # Nyquist bin gets set to zero +class Phi4MMImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - spec = np.abs(S).astype(np.float32) - return spec + `hidden_size` must match the hidden size of language model backbone. + """ - def extract_features(self, wav, fs): - """Extract log filterbank features from waveform. - Args: - wav (1D array): waveform of the input - fs (int): sampling rate of the waveform, 16000 or 8000. - If fs=8000, the waveform will be resampled to 16000Hz. - Output: - log_fbank (2D array): a TxD matrix of log Mel filterbank features. - D=80, and T is the number of frames. - """ - spec = self.extract_spectrogram(wav, fs) - spec_power = spec**2 - fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) - log_fbank = np.log(fbank_power).astype(np.float32) +class Phi4MMAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_audios, 80, M)""" - return log_fbank +class Phi4MMAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" -@lru_cache -def audio_feature_extractor() -> LogFbankProcessor: - # Creates an instance of the audio processor, needed to extract the - # the audio features from the sound file - # LRU cache ensures that we only make one copy - return LogFbankProcessor() +Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] -def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, - vit_patch_size, token_compression_factor): - """ - compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing - only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be - 32x32 feature map - NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 - """ - assert vit_image_size % vit_patch_size == 0, \ - "vit_image_size must be divisible by vit_patch_size" - assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ - "vit_image_size // vit_patch_size must be divisible by "\ - "token_compression_factor" - - target_aspect_ratio, target_height, target_width = ( - _find_target_aspect_ratio(image, - vit_image_size, - dynamic_hd_size, - min_num=1)) - assert target_aspect_ratio[ - 0] * vit_image_size == target_width, \ - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" - assert target_aspect_ratio[ - 1] * vit_image_size == target_height, \ - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) - - padding_height, padding_width = _get_padding_size(image, target_height, - target_width) - assert padding_width == 0 or padding_height == 0, \ - "padding_width or padding_height must be 0" - - target_feat_width = target_width // vit_patch_size - target_feat_height = target_height // vit_patch_size - if padding_width >= vit_patch_size: - assert padding_height == 0, "padding_height not 0" - non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) - non_pad_feat_height = target_feat_height - elif padding_height >= vit_patch_size: - assert padding_width == 0, "padding_width not 0" - non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) - non_pad_feat_width = target_feat_width - else: - # small padding shorter than a vit patch - non_pad_feat_width = target_feat_width - non_pad_feat_height = target_feat_height - - feat_width = non_pad_feat_width // token_compression_factor - feat_height = non_pad_feat_height // token_compression_factor - # NOTE it's possible that the non-padding feature is not divisible - if non_pad_feat_width % token_compression_factor != 0: - feat_width += 1 - if non_pad_feat_height % token_compression_factor != 0: - feat_height += 1 - num_hd_patch_tokens = feat_width * feat_height - num_hd_newline_tokens = feat_height - vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 - num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor - - return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + - num_hd_newline_tokens + num_global_image_newline_tokens) - - -def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: +def cat_with_pad(tensors, dim, padding_value=0): """ - Compute the output size of the `extract_features` method. - - Args: - wav_length (int): Length of the input waveform in samples. - fs (int): Sampling rate of the waveform, either 16000 or 8000. - - Returns: - tuple (int, int): Output size as (T, D), where: - T: Number of time frames. - D: Number of Mel filterbank bins (80). + cat along dim, while pad to max for all other dims """ + ndim = tensors[0].dim() + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" - # Resample to 16000 or 8000 if needed - if fs > 16000: - wav_length //= fs // 16000 - fs = 16000 - elif 8000 <= fs < 16000: - # We'll resample to 16K from 8K - wav_length *= 2 - fs = 16000 - elif fs < 8000: - raise RuntimeError(f"Unsupported sample rate {fs}") - - # Spectrogram parameters for 16 kHz - win_length = 400 # Frame length in samples - hop_length = 160 # Frame shift in samples - mel_bins = 80 # Number of mel filterbank bins - - # Calculate number of frames (T) - T = (wav_length - win_length) // hop_length + 1 - if T < 1: - raise ValueError("Waveform too short for given parameters.") - - # Return time frames (T) and mel bins (D) - return T, mel_bins - + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) -def _get_audio_embed_sizes(audios, ctx: InputContext): - """ - Get the audio embedding sizes for each audio file. + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) - Args: - audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of - waveform and sample rate. - ctx (InputContext): Input context. + output[slices] = t + index += t.shape[dim] - Returns: - List[int]: List of audio embedding sizes. - """ - audio_embed_sizes = [] - for audio in audios: - audio_data, sf = audio - audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) - audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), - audio_frames) - audio_embed_sizes.append(audio_embed_size) - return audio_embed_sizes + return output -def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): - """ - The following will search for `<|audio_{idx}|>` tokens and - return a mapping of audio placeholder tokens to audio placeholder token ids - based on the size of the audio embeddings. +class Phi4MMProcessingInfo(BaseProcessingInfo): - Args: - audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of - waveform and sample rate. - ctx (InputContext): Input context. - prompt_str (str): The prompt string. + def get_hf_processor( + self, + *, + dynamic_hd: Optional[int] = None, + **kwargs: object, + ) -> ProcessorMixin: + if dynamic_hd is not None: + kwargs["dynamic_hd"] = dynamic_hd - Returns: - Dict[str, List[int]]: Mapping of audio placeholder tokens to audio - placeholder token ids. + return self.ctx.get_hf_processor(**kwargs) - """ - if len(audios) == 0: - return {} - - audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) - audio_ids = [int(audio_id) for audio_id in audio_ids] - assert len(audio_ids) == len( - audio_embed_sizes - ), "Number of audio tokens and audio features do not match" - assert tuple(audio_ids) == tuple(range(1, - len(audio_ids) + - 1)), "Audio ids are not in order!" - audio_id_to_input_ids = { - f"<|audio_{audio_id}|>": - [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) - } + @property + def image_tokens(self) -> list[str]: + return [f"<|image_{i+1}|>" for i in range(100)] - return audio_id_to_input_ids - - -def _count_image_tokens(images, ctx: InputContext): - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] - - image_token_counts = [ - _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, - vit_patch_size, token_compression_factor) - for image in images - ] - return image_token_counts - - -def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): - if len(images) == 0: - return {} - - image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) - image_ids = [int(image_id) for image_id in image_ids] - assert len(image_ids) == len( - set(image_ids)), "Duplicate image tokens in prompt" - assert len(images) == len( - image_ids), "Number of images and image tokens in prompt do not match" - - # NOTE the following assertion is not strictly necessary - assert tuple(image_ids) == tuple(range(1, - len(image_ids) + - 1)), "Image ids are not in order" - - image_token_counts = _count_image_tokens(images, ctx) - image_id_to_input_ids = { - f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens - for image_id, num_tokens in zip(image_ids, image_token_counts) - } - return image_id_to_input_ids + @property + def audio_tokens(self) -> list[str]: + return [f"<|audio_{i+1}|>" for i in range(100)] + def get_dynamic_hd( + self, + processor: Optional[ProcessorMixin] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + image_processor = processor.image_processor + return image_processor.dynamic_hd -def input_processor_for_phi4mm(ctx: InputContext, - inputs: DecoderOnlyInputs) -> TokenInputs: - """ - Implements the input processor, which transforms the input prompt ids - to include the audio placeholder token. This will become the `input_ids` - in `forward` for the model. + def get_feature_extractor(self) -> SequenceFeatureExtractor: + return self.get_hf_processor().audio_processor - Args: - ctx (InputContext): Input context. - inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) - to process. + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None} - Returns: - TokenInputs: Processed inputs - """ - multi_modal_data = inputs.get("multi_modal_data") - if (multi_modal_data is None or - ("audio" not in multi_modal_data and "image" not in multi_modal_data)): - # pure text input, so no need to do pre-processing - return inputs - - prompt_str = inputs.get("prompt") - prompt_token_ids = inputs.get("prompt_token_ids") - # for offline_inference, we will get str input and we parse MM special - # tokens from it - # (ignore prompt_token_ids) - # for OAI server, we will get prompt_token_ids, where MM special tokens - # are already parsed - - if 'audio' in multi_modal_data: - audios = multi_modal_data["audio"] - - if not isinstance(audios, list): - audios = [audios] - if prompt_str is not None: - audio_id_to_input_ids = _get_audio_id_to_input_ids( - audios, ctx, prompt_str=prompt_str) - audio_embed_sizes = [] - elif prompt_token_ids is not None: - audio_id_to_input_ids = {} - audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - else: - audio_id_to_input_ids = {} - audio_embed_sizes = [] - - if 'image' in multi_modal_data: - # PIL Image or list of PIL Images - images = multi_modal_data["image"] - if not isinstance(images, list): - images = [images] - if prompt_str is not None: - image_id_to_input_ids = _get_image_id_to_input_ids( - images, prompt_str, ctx) - image_token_counts = [] - elif prompt_token_ids is not None: - image_id_to_input_ids = {} - image_token_counts = _count_image_tokens(images, ctx) - else: - image_id_to_input_ids = {} - image_token_counts = [] - - # Handle the case where the prompt is a string and we need to manually - # tokenize it. - # In this case, the `audio_id_to_input_ids` dict will be mapping from - # an audio placeholder - # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the - # given audio length. - if prompt_str: - pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" - prompt_chunk_strings = re.split(pattern, prompt_str) - prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] - - # Create the new input_ids with the placeholder image and audio - # tokens inserted - tokenizer = cached_tokenizer_from_config(ctx.model_config) - input_ids = [] - has_imag, has_audio, has_user_text_input = False, False, False - for prompt_chunk_string in prompt_chunk_strings: - if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): - input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) - has_imag = True - elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): - input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) - has_audio = True - else: - curr_token_ids = tokenizer(prompt_chunk_string).input_ids - if not has_user_text_input: - for token_id in curr_token_ids: - if token_id not in NON_USER_INPUT_TOKENS: - has_user_text_input = True - break - input_ids.extend(curr_token_ids) - if has_audio and has_imag and has_user_text_input: - raise ValueError( - "Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") - # Handle the case where the prompt is already tokenized - else: - assert prompt_token_ids is not None, \ - "If string prompt isn't provided, prompt_token_ids must be" - - i = 0 - input_ids = prompt_token_ids - # only needed for later assertion - img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 - image_token_count_iter = iter(image_token_counts) - audio_embed_size_iter = iter(audio_embed_sizes) - while i < len(input_ids): - token_id = input_ids[i] - if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: - token_count = next(audio_embed_size_iter) - audio_cnt += 1 - elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: - token_count = next(image_token_count_iter) - img_cnt += 1 - else: - user_text_input_cnt += 1 if token_id not in \ - NON_USER_INPUT_TOKENS else 0 - i += 1 - continue - tokens = [token_id] * token_count - input_ids = input_ids[:i] + tokens + input_ids[i + 1:] - i += token_count - - if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: - raise ValueError( - "Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") - # If the below assertion fails, it might be that input pure-text - # messages contain image/audio special tokens literally - # (<|endoftext10|>, <|endoftext11|>). - assert (img_cnt == len(image_token_counts)), ( - f"Number of image tokens in prompt_token_ids ({img_cnt}) " - f"does not match number of images ({len(image_token_counts)})") - assert (audio_cnt == len(audio_embed_sizes)), ( - f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " - f"does not match number of audios ({len(audio_embed_sizes)})") - - # NOTE: Create a defensive copy of the original inputs - return token_inputs( - prompt_token_ids=input_ids, - prompt=prompt_str, - multi_modal_data=multi_modal_data, - ) + def _find_target_aspect_ratio( + self, + orig_width: int, + orig_height: int, + image_size: int, + max_num: int, + min_num: int, + ): + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + image_processor = self.get_hf_processor().image_processor + target_aspect_ratio = image_processor.find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + orig_width, + orig_height, + image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width + def _compute_num_image_tokens( + self, + orig_width: int, + orig_height: int, + dynamic_hd_size: int, + vit_image_size: int, + vit_patch_size: int, + token_compression_factor: int = 2, + ): + """ + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels -def _compute_audio_embed_size(hf_config, audio_frames): - """ - Compute the audio embedding size based on the audio frames and - compression rate. - """ - compression_rate = hf_config.embd_layer['audio_embd_layer'][ - 'compression_rate'] - # NOTE: this is a hard-coded value but might be configurable in the future - qformer_compression_rate = 1 - integer = audio_frames // compression_rate - remainder = audio_frames % compression_rate + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, ( + "vit_image_size must be divisible by vit_patch_size") + assert (vit_image_size // vit_patch_size % + token_compression_factor == 0), ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor") + + target_aspect_ratio, target_height, target_width = ( + self._find_target_aspect_ratio(orig_width, + orig_height, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[0] * vit_image_size == target_width, ( + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + assert target_aspect_ratio[1] * vit_image_size == target_height, ( + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size( + orig_width, orig_height, target_height, target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // + token_compression_factor)**2 + num_sep_tokens = 1 + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + + num_hd_patch_tokens + num_hd_newline_tokens + + num_global_image_newline_tokens) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin] = None, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ + vision_encoder_name] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + dynamic_hd_size = self.get_dynamic_hd(processor=processor) + + image_num_tokens = self._compute_num_image_tokens( + image_width, + image_height, + dynamic_hd_size=dynamic_hd_size, + vit_image_size=vit_image_size, + vit_patch_size=vit_patch_size, + token_compression_factor=token_compression_factor, + ) - result = integer if remainder == 0 else integer + 1 + return image_num_tokens - integer = result // qformer_compression_rate - remainder = result % qformer_compression_rate - result = integer if remainder == 0 else integer + 1 # qformer compression + def get_image_size_with_most_features( + self, + processor: Optional[ProcessorMixin] = None, + ) -> ImageSize: + hf_config = self.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ + vision_encoder_name] + vit_image_size = prepro_config['vit_image_size'] + + max_side = vit_image_size * self.get_dynamic_hd(processor=processor) + return ImageSize(height=max_side, width=vit_image_size) + + def get_audio_num_frames(self, audio_len: int, sr: float) -> int: + """ + Compute the output size of the `extract_features` method. - return result + Args: + audio_len (int): Length of the input waveform in samples. + sr (float): Sampling rate of the waveform, either 16000 or 8000. + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ -def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: - return 10000 + # Resample to 16000 or 8000 if needed + if sr > 16000: + audio_len //= sr // 16000 + elif 8000 <= sr < 16000: + # We'll resample to 16K from 8K + audio_len *= 2 + elif sr < 8000: + raise RuntimeError(f"Unsupported sample rate {sr}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + + # Calculate number of frames (T) + num_frames = (audio_len - win_length) // hop_length + 1 + if num_frames < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) + return num_frames + + def _compute_audio_embed_size(self, audio_frames: int) -> int: + """ + Compute the audio embedding size based on the audio frames and + compression rate. + """ + hf_config = self.get_hf_config() + compression_rate = hf_config.embd_layer['audio_embd_layer'][ + 'compression_rate'] + # NOTE: this is a hard-coded value but might be configurable + # in the future + qformer_compression_rate = 1 + integer = audio_frames // compression_rate + remainder = audio_frames % compression_rate + result = integer if remainder == 0 else integer + 1 -def dummy_audio_for_phi4mm(audio_count: int) -> dict: - """ - Create dummy audio data for the Phi4MM model, which is used for profiling. + integer = result // qformer_compression_rate + remainder = result % qformer_compression_rate + # qformer compression + result = integer if remainder == 0 else integer + 1 - Args: - audio_count (int): Number of audio samples. + return result - Returns: - dict: Dummy audio data. - """ - dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) - return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count +class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): -def dummy_image_for_phi4mm(width: int, height: int): - image = Image.new('RGB', (width, height), color='black') - return image + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + image_tokens: list[str] = self.info.image_tokens[:num_images] + audio_tokens: list[str] = self.info.audio_tokens[:num_audios] -def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]) -> DummyData: - """ - Create dummy sequence (input_ids) and audio data for the Phi4MM model, - which is used for profiling. + return "".join(image_tokens + audio_tokens) - In this case, the sequence data is a bunch of 0s with a number of audio - tokens that correspond to the audio embed size of the - _AUDIO_MAX_SOUNDFILE_SIZE. + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) - Args: - ctx (InputContext): Input context. - seq_len (int): Length of the sequence. - mm_counts (Mapping[str, int]): Multi-modal counts. + target_width, target_height = \ + self.info.get_image_size_with_most_features() - Returns: - Tuple: Dummy sequence data and dummy audio data. - """ - audio_count = mm_counts["audio"] - audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, - DUMMY_SAMPLING_FREQUENCY) - audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), - audio_frames) - - image_count = mm_counts["image"] - dummy_image = get_max_dummy_image(ctx) - max_image_tokens = get_max_phi4mm_image_tokens(ctx) - total_image_tokens = image_count * max_image_tokens - - if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: - raise RuntimeError( - f"Phi4MM cannot process {audio_count} audios and {image_count}" - f"images in a prompt, please increase max_model_len to be at" - f" larger than " - f"{audio_feature_size * audio_count + total_image_tokens}" - " or reduce audio/image limit by --limit-mm-per-prompt.") - - if audio_feature_size * audio_count > total_image_tokens: - seq_data = SequenceData.from_prompt_token_counts( - (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), - (0, seq_len - audio_feature_size * audio_count), - ) mm_data = { - "audio": dummy_audio_for_phi4mm(audio_count), + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "audio": + self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios), } - else: - seq_data = SequenceData.from_prompt_token_counts( - (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), - (0, seq_len - total_image_tokens), - ) - mm_data = { - "image": [dummy_image] * image_count, - } - return DummyData(seq_data, mm_data) + return mm_data -def input_mapper_for_phi4mm_audio(ctx: InputContext, - data: object) -> MultiModalKwargs: - """ - This function is used to create the MultiModalKwargs for the Phi4MM - (audio) model. - Specifically, for audio, we extract the audio features from the sound - file and create pairs of audio features and audio embed lengths (the - latter of which is used to repeat the audio placeholder token in the - input prompt IDs). - These pairs are used, downstream, in `_audio_features_to_embeddings` - (via `_process_audio_input`). - - Note that the incoming audio data (each entry in `data`) is a tuple of - the audio data and the sampling frequency (e.g. from soundfile.read). - - Args: - ctx (InputContext): Input context. - data (object): Audio data. - - Returns: - MultiModalKwargs: Multi-modal inputs. - """ - if not isinstance(data, list): - data = [data] - - if len(data) == 0: - return MultiModalKwargs() - - audio_features = [] - for audio_input in data: - if not isinstance(audio_input, tuple): - raise NotImplementedError( - f"Unsupported data type: {type(audio_input)}") - - audio, sf = audio_input - feature_extractor = audio_feature_extractor() - single_audio_features = feature_extractor.extract_features(audio, sf) - feat_stride = (1 if not hasattr(feature_extractor, "stride") else - feature_extractor.stride) - audio_frames = len(single_audio_features) * feat_stride - single_audio_embed_size = _compute_audio_embed_size( - ctx.get_hf_config(), audio_frames) - single_audio_feature_audio_len_pair = ( - single_audio_features, - [single_audio_embed_size], - ) - audio_features.append(single_audio_feature_audio_len_pair) - return MultiModalKwargs({"audio_features": audio_features}) - - -def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): - if not isinstance(data, list): - data = [data] - # data: list of PIL images - if len(data) == 0: - return MultiModalKwargs() - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - - image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, - vit_patch_size) - return MultiModalKwargs({ - "pixel_values": - image_input_dict["pixel_values"], - "image_sizes": - image_input_dict["image_sizes"], - "image_attention_mask": - image_input_dict["image_attention_mask"], - "num_img_tokens": - image_input_dict["num_img_tokens"], - }) +class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): -def cat_with_pad(tensors, dim, padding_value=0): - """ - cat along dim, while pad to max for all other dims - """ - ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate, + audio_resample_method="scipy") - out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] - out_size[dim] = sum(t.shape[dim] for t in tensors) - output = tensors[0].new_full(out_size, padding_value) + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + sr = self.info.get_feature_extractor().sampling_rate + if (audio_data := mm_data.get("audios", [])): + mm_data['audios'] = [(data, sr) for data in audio_data] + + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs) + + num_img_tokens = [ + self.info.get_num_image_tokens(image_width=img_size[0], + image_height=img_size[1]) + for img_size in processed_outputs["image_sizes"] + ] + processed_outputs["num_img_tokens"] = num_img_tokens - index = 0 - for t in tensors: - # Create a slice list where every dimension except dim is full slice - slices = [slice(0, t.shape[d]) for d in range(ndim)] - # Update only the concat dimension slice - slices[dim] = slice(index, index + t.shape[dim]) + audio_features = processed_outputs['input_audio_embeds'] + feature_sizes = [ + self.info.get_audio_num_frames(len(audio), sr) + for audio in audio_data + ] + processed_outputs['input_audio_embeds'] = [ + audio_features[idx, :size] + for idx, size in enumerate(feature_sizes) + ] - output[slices] = t - index += t.shape[dim] + return processed_outputs - return output + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_image_embeds=MultiModalFieldConfig.batched("image"), + image_attention_mask=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + num_img_tokens=MultiModalFieldConfig.batched("image"), + input_audio_embeds=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_tokens: list[str] = self.info.image_tokens # type: ignore + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + feature_extractor = self.info.get_feature_extractor() + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + def get_image_replacement_phi4mm(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens + + return image_tokens + + def get_audio_replacement_phi4mm(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + # TODO(Isotr0py): support embedding inputs + audio_len = audios.get_audio_length(item_idx) + audio_frames = self.info.get_audio_num_frames( + audio_len, feature_extractor.sampling_rate) + audio_embed_size = self.info._compute_audio_embed_size( + audio_frames) + + audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + + return audio_tokens + + num_images = mm_items.get_count("image", strict=False) + num_audios = mm_items.get_count("audio", strict=False) + + image_repl = [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_image_replacement_phi4mm, + ) for image_token in image_tokens[:num_images] + ] + audio_repl = [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_audio_replacement_phi4mm, + ) for audio_token in audio_tokens[:num_audios] + ] + return image_repl + audio_repl -@MULTIMODAL_REGISTRY.register_input_mapper("audio", - input_mapper_for_phi4mm_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("image", - input_mapper_for_phi4mm_image) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_phi4mm_audio_tokens) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_phi4mm_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) -class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, - SupportsV0Only): +@MULTIMODAL_REGISTRY.register_processor( + Phi4MMMultiModalProcessor, + info=Phi4MMProcessingInfo, + dummy_inputs=Phi4MMDummyInputsBuilder, +) +class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ @@ -1518,48 +967,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() - - def _audio_features_to_embeddings( - self, - input_ids: torch.Tensor, - input_features: List[torch.Tensor], - audio_input_sizes: torch.Tensor, - audio_projection_mode: str, - ) -> torch.Tensor: - """ - Convert audio features to embeddings, which are used as input to the - model (via `inputs_embeds`). - - Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case). - input_features (list[torch.Tensor]): Input features (the audio - embeddings). - audio_input_sizes (list[torch.Tensor]): Audio input sizes (the - audio embed lengths to use for padding the audio placeholder token - in the input prompt IDs). - """ - # The audio projection can either be a single linear or Sequential, - # so handle both cases - if isinstance(self.embed_tokens_extend.audio_projection, - nn.Sequential): - target_dtype = self.embed_tokens_extend.audio_projection[ - 0].bias.dtype - else: - target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype - - audio_input = [ - input.unsqueeze(0).to(target_dtype) for input in input_features - ] - kwargs = { - "wte": self.model.embed_tokens, - 'audio_projection_mode': audio_projection_mode - } - audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, - audio_input_sizes, - **kwargs) - audio_embeddings = audio_embeddings.to(target_dtype) - return audio_embeddings def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: @@ -1574,7 +981,7 @@ def _parse_and_validate_audio_input( Returns: Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. """ - audio_features = kwargs.pop("audio_features", None) + audio_features = kwargs.pop("input_audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None) if audio_features is None and audio_embeds is None: @@ -1586,7 +993,7 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_features)}") return Phi4MMAudioFeatureInputs(type="audio_features", - data=audio_features) + data=flatten_bn(audio_features)) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): @@ -1598,8 +1005,7 @@ def _parse_and_validate_audio_input( raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, input_ids: torch.Tensor, - audio_input: Phi4MMAudioInputs, + def _process_audio_input(self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input @@ -1607,8 +1013,6 @@ def _process_audio_input(self, input_ids: torch.Tensor, created by `input_mapper_for_phi4mm_audio`. Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case, - before the audio token replication). audio_input (Phi4MMAudioInputs): Audio input. Returns: @@ -1620,21 +1024,20 @@ def _process_audio_input(self, input_ids: torch.Tensor, audio_features = audio_input["data"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) - audio_feature = [i[0] for j in audio_features for i in j] - audio_feature_len = [i[1].item() for j in audio_features for i in j] - # Add the batch dim via `squeeze` - return self._audio_features_to_embeddings( - input_ids.unsqueeze(0), - audio_feature, - audio_feature_len, - audio_projection_mode, - ).squeeze(0) + dtype = next(self.embed_tokens_extend.parameters()).dtype + audio_embeds = [ + self.embed_tokens_extend( + features.to(dtype), + audio_projection_mode=audio_projection_mode, + ) for features in audio_features + ] + return audio_embeds def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[Dict]: - pixel_values: Optional[Dict] = kwargs.get("pixel_values") - if pixel_values is None: + input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") + if input_image_embeds is None: return None image_sizes = kwargs.get("image_sizes") @@ -1643,23 +1046,24 @@ def _parse_and_validate_image_input(self, assert image_sizes is not None and image_attention_mask is not None\ and num_img_tokens is not None, "Missing image inputs" - if isinstance(pixel_values, list): - assert pixel_values[0].dim() == 5, "Incorrect image inputs" + if is_list_of(input_image_embeds, torch.Tensor): + assert all(p.dim() == 5 + for p in input_image_embeds), "Incorrect image inputs" # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. # need to pad along num_hd_patches. # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - pixel_values = cat_with_pad(pixel_values, dim=0) - elif isinstance(pixel_values, torch.Tensor): + input_image_embeds = cat_with_pad(input_image_embeds, dim=0) + elif isinstance(input_image_embeds, torch.Tensor): # dimension: batch_size, num_img_per_example, num_hd_patches, # channels, height, width. # we flatten first 2 dims to make it a single large batch for # SigLIP Encoder. - assert pixel_values.dim() == 6, "Incorrect image inputs" - pixel_values = pixel_values.flatten(0, 1) + assert input_image_embeds.dim() == 6, "Incorrect image inputs" + input_image_embeds = input_image_embeds.flatten(0, 1) else: - raise ValueError("Incorrect pixel_values inputs") + raise ValueError("Incorrect input_image_embeds inputs") if isinstance(image_attention_mask, list): image_attention_mask = cat_with_pad(image_attention_mask, dim=0) @@ -1685,80 +1089,140 @@ def _parse_and_validate_image_input(self, else: raise ValueError("Incorrect image_attention_mask inputs") - return { - 'pixel_values': pixel_values, - 'image_sizes': image_sizes, - 'image_attention_mask': image_attention_mask, - 'num_img_tokens': num_img_tokens, - } + return Phi4MMImagePixelInputs( + type="pixel_values", + data=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + num_img_tokens=num_img_tokens, + ) - def merge_image_features_to_inputs_embeds( + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("input_image_embeds", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("input_audio_embeds", + "audio_embeds") and "audios" not in modalities: + modalities["audios"] = self._parse_and_validate_audio_input( + **kwargs) + + return modalities + + def _process_image_input( + self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + dtype = next(self.vision_encoder.parameters()).dtype + pixel_values = image_input['data'].to(dtype) + image_sizes = image_input['image_sizes'] + image_attention_mask = image_input['image_attention_mask'] + image_embeds = self.vision_encoder(pixel_values, image_sizes, + image_attention_mask) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + audio_projection_mode = 'speech' + for modality in modalities: + # make sure process images first + if modality == "images": + audio_projection_mode = "vision" + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(vision_embeddings) + if modality == "audios": + audio_input = modalities["audios"] + audio_embeddings = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + multimodal_embeddings += tuple(audio_embeddings) + + return multimodal_embeddings + + def get_input_embeddings( self, input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_set_tensors: List[torch.Tensor], - ): - position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( - as_tuple=True) - - assert all([t.shape[0] == 1 for t in image_set_tensors - ]), 'img_set_tensor should have shape (1, N_tokens, C)' - # Shape: (merged_N_tokens, C) - image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) - image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( - inputs_embeds.device) - merged_embeds = inputs_embeds.index_put( - indices=position_tuple, - values=image_set_tensor, - accumulate=False, - ) - return merged_embeds + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.embed_tokens(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Phi4MMImagePixelInputs] = None, + audio_input: Optional[Phi4MMAudioFeatureInputs] = None, + ) -> torch.Tensor: + audio_projection_mode = 'speech' + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID, + ) + audio_projection_mode = 'vision' + + if audio_input is not None: + audio_embeds = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + audio_embeds, + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID, + ) + return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - # Each entry in this is a pair of audio_features and audio_embed - # lengths + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs) - image_inputs = self._parse_and_validate_image_input(**kwargs) - - has_audio = audio_input is not None - has_image = image_inputs is not None - - if has_audio: - audio_projection_mode = 'vision' if has_image else 'speech' - inputs_embeds = self._process_audio_input( - input_ids, audio_input, audio_projection_mode) - - if has_image: - dtype = self.vision_encoder.img_processor.embeddings.\ - patch_embedding.weight.dtype - pixel_values = image_inputs['pixel_values'].to(dtype) - image_sizes = image_inputs['image_sizes'] - image_attention_mask = image_inputs['image_attention_mask'] - image_set_tensors = self.vision_encoder( - pixel_values, image_sizes, image_attention_mask) - if not has_audio: - inputs_embeds = self.model.embed_tokens(input_ids) - - inputs_embeds = self.merge_image_features_to_inputs_embeds( - input_ids, inputs_embeds, image_set_tensors) - - if has_image or has_audio: - # multi-modal input, we have set inputs_embeds properly in - # previous steps - input_ids = None - else: - # text-only, we keep using original input_ids + + if image_input is None and audio_input is None: inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + audio_input=audio_input) + input_ids = None hidden_states = self.model( input_ids, @@ -1778,14 +1242,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: weights = ((name, data) for name, data in weights diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index db90848f9809..34a7a73d057a 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -1159,8 +1159,11 @@ def get_audio_features( input_embeds: torch.FloatTensor, audio_attention_mask: torch.Tensor = None, audio_projection_mode: str = "speech", - ): - + ) -> torch.FloatTensor: + """ + arguments: + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ if self.freeze_audio_processor: with torch.no_grad(): audio_features, masks = self.encoder(input_embeds, @@ -1210,62 +1213,20 @@ def get_audio_features( def forward( self, - input_ids: torch.LongTensor, - input_embeds: torch.FloatTensor, - audio_embed_sizes, - **kwargs, + audio_features: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", ) -> torch.FloatTensor: """ arguments: - input_ids: input text ids (B, U) - input_embeds: audio features (B, T, D) B: num audios in a sequence + audio_features: audio features (T, D) + + returns: + audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) """ - assert input_embeds is not None and len(input_embeds) == len( - audio_embed_sizes) - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - with torch.no_grad(): - positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero( - as_tuple=False) - - if not isinstance(input_embeds, list): - input_embeds = [input_embeds] - - audio_projection_mode = kwargs.get("audio_projection_mode", "speech") - audio_set_tensor = [ - self.get_audio_features( - input_embed, audio_projection_mode=audio_projection_mode) - for input_embed in input_embeds - ] - - with torch.no_grad(): - input_ids.clamp_min_(0).clamp_max_(self.vocab_size) - - if "wte" in kwargs: - # we use the token embedding layer from the huggingface model, this - # is REQUIRED to make sure we are using the loaded weights. - hidden_states = kwargs["wte"](input_ids) - else: - # otherwise, we use token embedding in pretrained mixformer from - # phi team - hidden_states = self.wte(input_ids) - - if len(positions.tolist()) > 0: - assert sum(audio_embed_sizes) == len( - positions - ), "please ensure the encoder outputs have the same length as"\ - " defined in input_ids!" - idx = 0 - for i in range(len(audio_embed_sizes)): - cnt = audio_embed_sizes[i] - assert audio_set_tensor[i].shape[0] == 1 - hidden_states[ - positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + cnt, - ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to( - hidden_states.dtype).to(hidden_states.device)) - idx += cnt - - return hidden_states + audio_embeds = self.get_audio_features( + audio_features.unsqueeze(0), + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + return audio_embeds.squeeze(0) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 381a33d98b9c..2dc55e4c352e 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -634,7 +633,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -659,14 +657,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 38e140a91ecf..73fd80146955 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -28,7 +28,6 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs @@ -331,13 +330,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: images = kwargs.pop("images", None) @@ -441,13 +433,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): @@ -926,9 +911,8 @@ def get_image_size(self) -> int: return self.vision_config.image_size def get_patch_size(self) -> int: - spatial_merge_size = getattr(self.vision_config, "spatial_merge_size", - 1) - return (self.vision_config.patch_size * spatial_merge_size) + return (self.vision_config.patch_size * + self.vision_config.spatial_merge_size) def get_patch_grid_length(self) -> int: image_size, patch_size = self.get_image_size(), self.get_patch_size() diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py new file mode 100644 index 000000000000..790c48ccd216 --- /dev/null +++ b/vllm/model_executor/models/plamo2.py @@ -0,0 +1,736 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only PLaMo2 model.""" +import math +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + + +# Only used for type hinting. +class Plamo2Config(PretrainedConfig): # type: ignore + model_type: str = "plamo2" + + hidden_size: int + num_hidden_layers: int + rms_norm_eps: float + # Attention + num_attention_heads: int + hidden_size_per_head: int + num_key_value_heads: int + # Mamba + mamba_d_state: int + mamba_d_conv: int + mamba_num_heads: int + mamba_step: int + # MLP + intermediate_size: int + # Tokenizer + vocab_size: int + + +class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore + + def _init_weights(self, module: torch.nn.Module) -> None: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +def get_initial_dt_bias(num_heads: int) -> torch.Tensor: + dt_min = 0.001 + dt_max = 0.1 + dt = torch.exp( + torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min)) + dt = torch.clamp(dt, 1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + return inv_dt + + +def is_mamba(config: Plamo2Config, i: int) -> bool: + assert config.mamba_step > 1 + + if config.num_hidden_layers <= (config.mamba_step // 2): + # use attention in last layer + return i != config.num_hidden_layers - 1 + return (i % config.mamba_step) != (config.mamba_step // 2) + + +# TODO(Shinichi): Replace this with RMSNorm. +def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, + eps: float) -> torch.Tensor: + input_shape = hidden_states.shape + hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + hidden_states = hidden_states.to(input_dtype) + hidden_states = weight * hidden_states + return hidden_states.reshape(input_shape) + + +def _swiglu(h: torch.Tensor) -> torch.Tensor: + h0, h1 = h.chunk(2, dim=-1) + return torch.nn.functional.silu(h0) * h1 + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class Plamo2MambaMixer(nn.Module): + # TODO(Shinichi): Rebase on Mamba2 implementation. + + def __init__(self, + config: Plamo2Config, + cache_config: CacheConfig, + quant_config: QuantizationConfig, + max_model_len: int, + prefix: str = "", + **kwargs) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = (config.mamba_num_heads * + config.hidden_size_per_head) + self.hidden_size_per_head = config.hidden_size_per_head + self.num_heads = config.mamba_num_heads + self.time_step_rank = max(64, self.hidden_size // 16) + self.use_conv_bias = False + self.use_bias = False + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias, + prefix=f"{prefix}.in_proj", + ) + # selective projection used to make dt, B and C input dependent + self.bcdt_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + prefix=f"{prefix}.bcdt_proj", + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear( + self.time_step_rank, + self.num_heads, + bias=False, + prefix=f"{prefix}.dt_proj", + ) + self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + prefix=f"{prefix}.out_proj", + ) + # The activation function is fixed to SiLU. + self.activation = "silu" + + self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + **kwargs, + ) -> torch.Tensor: + + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0] + # Reshaping the projected states as in modeling_plamo.py. + length = len(hidden_states) + projected_states = projected_states.reshape(length, self.num_heads, -1) + gate, hidden_states = torch.split( + projected_states, + [self.hidden_size_per_head, self.hidden_size_per_head], + dim=-1) + hidden_states = hidden_states.reshape(length, -1).transpose(0, 1) + gate = gate.reshape(length, -1).transpose(0, 1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0] + + # Splitting the ssm_parameters as in modeling_plamo.py. + B, C, time_step = torch.split( + ssm_parameters, + [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], + dim=-1, + ) + time_step = self.dt_norm(time_step.contiguous()) + B = self.B_norm(B.contiguous()) + C = self.C_norm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_bias.float() if hasattr( + self.dt_proj, "bias") else None) + + # Broadcasting as in modeling_plamo.py. + discrete_time_step = discrete_time_step.transpose( + 0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head) + discrete_time_step = discrete_time_step.reshape( + -1, self.intermediate_size).transpose(0, 1) + time_proj_bias = time_proj_bias[..., + None].expand(-1, + self.hidden_size_per_head) + time_proj_bias = time_proj_bias.reshape(self.intermediate_size) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class DenseMLP(nn.Module): + + def __init__( + self, + config: Plamo2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, [self.intermediate_size] * 2, + bias=False, + prefix=f"{prefix}.gate_up_proj", + quant_config=quant_config) + self.down_proj = RowParallelLinear(self.intermediate_size, + self.hidden_size, + bias=False, + prefix=f"{prefix}.down_proj", + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + h = self.gate_up_proj(hidden_states)[0] + h = _swiglu(h) + output, _ = self.down_proj(h) + return output # type: ignore + + +class Plamo2AttentionMixer(nn.Module): + + def __init__(self, + config: Plamo2Config, + cache_config: CacheConfig, + quant_config: QuantizationConfig, + max_model_len: int | None = None, + prefix: str = "", + **kwargs) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size_per_head + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.rope_theta = config.rope_theta if hasattr(config, + "rope_theta") else 10000 + self.rope_scaling = config.rope_scaling if hasattr( + config, "rope_scaling") else None + + assert max_model_len is not None, "max_model_len must be provided" + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_model_len, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + ) + self.q_weight = torch.nn.Parameter( + torch.ones((self.num_heads, config.hidden_size_per_head))) + self.k_weight = torch.nn.Parameter( + torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = _rms_norm(q, self.q_weight, 1e-6) + k = _rms_norm(k, self.k_weight, 1e-6) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Plamo2DecoderLayer(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + layer_idx: int, + max_model_len: int | None = None, + prefix: str = "", + **kwargs) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + max_model_len = vllm_config.scheduler_config.max_model_len + + self.is_mamba = is_mamba(config, layer_idx) + if self.is_mamba: + self.mixer = Plamo2MambaMixer(config=config, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=max_model_len, + prefix=f"{prefix}.mixer") + else: + self.mixer = Plamo2AttentionMixer(config=config, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=max_model_len, + prefix=f"{prefix}.mixer") + + self.mlp = DenseMLP(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.pre_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.pre_mixer_norm(hidden_states) + else: + hidden_states, residual = self.pre_mixer_norm( + hidden_states, residual) + + hidden_states = self.mixer(positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params) + hidden_states = self.post_mixer_norm(hidden_states) + # Fully Connected + hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_norm(hidden_states) + return hidden_states, residual + + +class Plamo2Decoder(torch.nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + + self.layers = nn.ModuleList([ + Plamo2DecoderLayer(vllm_config=vllm_config, + layer_idx=i, + prefix=f"{prefix}.layers.{i}") + for i in range(num_hidden_layers) + ]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + ) -> torch.Tensor: + mamba_cache_index = 0 + for layer in self.layers: + layer_mamba_cache_params = None + if layer.is_mamba: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + mamba_cache_index) + mamba_cache_index += 1 + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params) + return hidden_states, residual + + +class Plamo2Model(Plamo2PreTrainedModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config.model_config.hf_config) + + config = vllm_config.model_config.hf_config + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=f"{prefix}.embed_tokens", + ) + self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO(Shinichi): Implement pipeline parallelism. + hidden_states = self.embed_tokens(input_ids) + residual = None + + hidden_states, residual = self.layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, + SupportsV0Only): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + scheduler_config = vllm_config.scheduler_config + assert not vllm_config.cache_config.enable_prefix_caching, \ + "PLaMo2 currently does not support prefix caching" + + super().__init__(config) + self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.scheduler_config = scheduler_config + + # ModelConfig.get_head_size assumes head_dim is set or calculated as + # hidden_size // num_attention_heads. However, this is not always + # the case for PLaMo2, as indicated by the FIXME comment. + self.config.head_dim = self.config.hidden_size_per_head + + self.model = Plamo2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = self.config.vocab_size + self.unpadded_vocab_size = self.config.vocab_size + num_embeddings = ((self.vocab_size + 15) // 16) * 16 + self.lm_head = ParallelLMHead( + num_embeddings, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=f"{prefix}.lm_head", + ) + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = (self.config.mamba_num_heads * + self.config.hidden_size_per_head) + conv_state_shape = ( + hidden_size // world_size, + self.config.mamba_d_conv - 1, + ) + temporal_state_shape = ( + hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + + # Both tie_word_embeddings=True and lm_head.weight in the safetensor + # at the same time causes dict key access error. + if name == "lm_head.weight" and self.config.tie_word_embeddings: + assert "lm_head.weight" not in params_dict + continue + + # Update the weight names to be compatible with the vllm version + # of the model. + # Do not change the order of the replacements. + replacements = { + # Rename incompatible weight names. + ".A_log": ".A", + ".B_norm_weight": ".B_norm.weight", + ".C_norm_weight": ".C_norm.weight", + ".dt_norm_weight": ".dt_norm.weight", + } + # Apply replacements based on the defined mappings + for old, new in replacements.items(): + if old in name: + name = name.replace(old, new) + + # Broadcast the loaded weight to match the model's parameter shape. + if ".A" in name: + loaded_weight = loaded_weight[:, None, None].expand( + -1, self.config.hidden_size_per_head, + self.config.mamba_d_state) + loaded_weight = loaded_weight.reshape( + -1, self.config.mamba_d_state) + elif ".D" in name: + loaded_weight = loaded_weight[:, None].expand( + -1, self.config.hidden_size_per_head) + loaded_weight = loaded_weight.reshape(-1) + # Offset parameter with vllm's RMSNorm haven't been supported yet. + if ".pre_mixer_norm" in name: + loaded_weight += 1.0 + elif ".post_mixer_norm" in name: + loaded_weight += 1.0 / 5 + elif ".pre_mlp_norm" in name: + loaded_weight += 1.0 + elif ".post_mlp_norm" in name: + loaded_weight += 1.0 / (5**1.5) + elif "model.norm.weight" in name: + loaded_weight += 1.0 + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a33739a8eef9..e75294bc6cba 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -273,7 +272,6 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) @@ -286,14 +284,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 2831a5a12330..f76f31c9fc8d 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -450,7 +449,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -478,14 +476,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py new file mode 100644 index 000000000000..039f528db13b --- /dev/null +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -0,0 +1,901 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2.5-Omni model (thinker part).""" + +from copy import copy +from functools import partial +from typing import (Any, Dict, Iterable, List, Mapping, Optional, Sequence, + Set, Tuple, Union) + +import torch +import torch.nn as nn +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder) +from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( + Qwen2_5OmniProcessor) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) +from vllm.model_executor.models.qwen2_audio import ( + Qwen2AudioInputs, Qwen2AudioProcessingInfo, + _get_feat_extract_output_lengths) +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, + ModalityDataItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PlaceholderFeaturesInfo, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) + + +def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_feature_lengths = hf_inputs.get("audio_feature_lengths", + torch.empty((0, ))) + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + input_audio_features=MultiModalFieldConfig.flat_from_sizes( + "audio", audio_feature_lengths, dim=1), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + audio_feature_lengths=MultiModalFieldConfig.batched("audio"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): + + def _parse_audio_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={ + "input_audio_features", "audio_feature_lengths" + }, + fields_factory=_qwen2_5_omni_thinker_field_config, + ) + + return super()._parse_audio_data(data) + + +class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, + Qwen2_5_VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config + + def get_hf_processor( + self, + *, + sampling_rate: Optional[int] = None, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + fps: Optional[Union[float, List[float]]] = None, + **kwargs: object, + ) -> Qwen2_5OmniProcessor: + if fps is not None: + kwargs["fps"] = fps + processor = self.ctx.get_hf_processor( + Qwen2_5OmniProcessor, + image_processor=self.get_image_processor(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size), + **kwargs, + ) + if not hasattr(processor, "audio_token"): + processor.audio_token = "<|AUDIO|>" + if not hasattr(processor, "image_token"): + processor.image_token = "<|IMAGE|>" + if not hasattr(processor, "video_token"): + processor.video_token = "<|VIDEO|>" + return processor + + def get_feature_extractor( + self, + *, + sampling_rate: Optional[int] = None, + **kwargs: object, + ): + hf_processor = self.get_hf_processor(sampling_rate=sampling_rate) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None, "video": None} + + +class Qwen2_5OmniThinkerDummyInputsBuilder( + BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_processor = self.info.get_hf_processor() + + audio_token: str = hf_processor.audio_token + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + + return (audio_token * num_audios + image_token * num_images + + video_token * num_videos) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + feature_extractor = self.info.get_feature_extractor() + + target_audio_length = min( + feature_extractor.chunk_length, + 30, + ) * feature_extractor.sampling_rate + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + mm_data = { + "audio": + self._get_dummy_audios(length=target_audio_length, + num_audios=num_audios), + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos), + } + + return mm_data + + +class Qwen2_5OmniThinkerMultiModalProcessor( + BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return Qwen2_5OmniThinkerMultiModalDataParser( + target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + if audios: + # NOTE: Qwen2.5-Omni processor accept "audio" + mm_data["audio"] = audios + mm_kwargs = dict(**mm_kwargs, ) + + hf_inputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + input_features = hf_inputs.pop('input_features', None) + feature_attention_mask = hf_inputs.get('feature_attention_mask', None) + if ('input_audio_features' not in hf_inputs + and input_features is not None): + if feature_attention_mask is not None: + input_features = input_features.permute( + 0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + hf_inputs['input_audio_features'] = input_features + if ('audio_feature_lengths' not in hf_inputs + and feature_attention_mask is not None): + hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + return hf_inputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _qwen2_5_omni_thinker_field_config(hf_inputs) + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + prompt_ids: list[int], + mm_kwargs: MultiModalKwargs, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. + """ + unbound_prompt_updates = self._get_prompt_updates( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) + + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + + if is_update_applied: + mm_placeholders = self._find_mm_placeholders( + mm_prompt_updates, + prompt_ids, + mm_item_counts, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + use_audio_in_video=use_audio_in_video) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + else: + ( + prompt_ids, + prompt, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + mm_item_counts, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + use_audio_in_video=use_audio_in_video) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + + if use_audio_in_video: + mm_kwargs["use_audio_in_video"] = True + + return prompt_ids, prompt, mm_placeholders + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + _, audio_output_lens = _get_feat_extract_output_lengths( + audio_feature_lengths) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1)) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + item_idx += audio_in_video_item_idx + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model") + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + + audio_num_features = audio_output_lengths[audio_in_video_item_idx + + item_idx] + video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get( + "second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 1.0 + + return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video if use_audio_in_video else + partial(get_replacement_qwen2_vision, modality="video")) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, + modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _apply_hf_processor_main( + self, + prompt: Union[str, list[int]], + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + *, + enable_hf_prompt_update: bool, + ) -> tuple[list[int], MultiModalKwargs, bool]: + """ + Qwen2.5-Omni reimplements this function to handle text only. + """ + if isinstance(prompt, str): + if enable_hf_prompt_update: + return self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + tokenizer = self.info.get_tokenizer() + prompt_ids = encode_tokens(tokenizer, prompt) + else: + prompt_ids = self._apply_hf_processor_tokens_only(prompt) + + mm_kwargs = self._apply_hf_processor_mm_only( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_kwargs, False + + def _apply_hf_processor_mm_only( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalKwargs: + """ + Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. + """ + mm_counts = mm_items.get_all_counts() + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + if use_audio_in_video and "video" in mm_counts: + assert "audio" in mm_counts + mm_counts["audio"] -= mm_counts["video"] + + _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return mm_kwargs + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + use_audio_in_video: bool = False, + ) -> None: + if use_audio_in_video: + mm_item_counts = copy(mm_item_counts) + if "video" in mm_item_counts: + assert "audio" in mm_item_counts + mm_item_counts["audio"] -= mm_item_counts["video"] + super()._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + +class Qwen2_5OmniConditionalGenerationMixin: + + def _validate_and_reshape_mm_tensor(self, + mm_input: object, + name: str, + dim: int = 0) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input), dim=dim) + else: + return torch.concat(mm_input, dim=dim) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[Qwen2AudioInputs]: + input_audio_features = kwargs.pop('input_audio_features', None) + audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) + feature_attention_mask = kwargs.pop('feature_attention_mask', None) + if input_audio_features is None: + return None + input_audio_features = self._validate_and_reshape_mm_tensor( + input_audio_features, 'input_audio_features', dim=1) + if feature_attention_mask is not None: + feature_attention_mask = self._validate_and_reshape_mm_tensor( + feature_attention_mask, 'feature_attention_mask') + if not isinstance(input_audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_audio_features)}") + return Qwen2AudioInputs(input_features=input_audio_features, + audio_feature_lengths=audio_feature_lengths, + feature_attention_mask=feature_attention_mask) + + def _parse_and_validate_image_input( + self, + **kwargs: Dict[str, Any], + ) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, + **kwargs: Dict[str, Any], + ) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_audio_input( + self, + audio_input: Qwen2AudioInputs, + audio_hashes: List[str] = None, + cached_audio_features: torch.Tensor = None, + ) -> torch.Tensor: + + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + if input_features.ndim == 3: + assert input_features.shape[0] == 1 + input_features = input_features.squeeze(0) + if audio_feature_lengths.ndim == 2: + assert audio_feature_lengths.shape[ + 0] == 1 or audio_feature_lengths.shape[1] == 1 + if audio_feature_lengths.shape[0] == 1: + audio_feature_lengths = audio_feature_lengths.squeeze(0) + else: + audio_feature_lengths = audio_feature_lengths.squeeze(1) + + audio_feat_lengths, audio_output_lengths = ( + self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths)) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + return audio_features.split(audio_output_lengths.tolist()) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: List[str] = None, + cached_video_embeds: torch.Tensor = None) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5OmniThinkerMultiModalProcessor, + info=Qwen2_5OmniThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen2_5OmniThinkerForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, + Qwen2_5OmniConditionalGenerationMixin): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + thinker_config: Qwen2_5OmniThinkerConfig = ( + vllm_config.model_config.hf_config.thinker_config) + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part.") + + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.quant_config = quant_config + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + hf_config=thinker_config.text_config, + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + if input_key in ("input_audio_features" + ) and "audio" not in mm_input_by_modality: + mm_input_by_modality[ + "audio"] = self._parse_and_validate_audio_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + self.config.image_token_index, + self.config.video_token_index, + self.config.audio_token_index + ]) + return inputs_embeds + + def get_multimodal_embeddings_v0( + self, **kwargs: object) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if audio_input is None and image_input is None and video_input is None: + return None + + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + + if audio_input is not None: + audio_embeds = self._process_audio_input(audio_input) + multimodal_embeddings.append((audio_embeds, "audio")) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + multimodal_embeddings.append((image_embeds, "image")) + if video_input is not None: + video_embeds = self._process_video_input(video_input) + multimodal_embeddings.append((video_embeds, "video")) + return multimodal_embeddings + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is None: + return inputs_embeds + + for embeddings, modality in multimodal_embeddings: + if modality == "audio": + placeholder_token_id = self.config.audio_token_index + if modality == "image": + placeholder_token_id = self.config.image_token_index + if modality == "video": + placeholder_token_id = self.config.video_token_index + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, placeholder_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings_v0(**kwargs) + inputs_embeds = self.get_input_embeddings_v0( + input_ids, multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "token2wav."], + ) + loaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + + return loaded_weights diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 84b7e59c8a0a..84108200e914 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -24,7 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from functools import cached_property, partial +from functools import partial from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -38,19 +38,19 @@ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) from vllm.config import VllmConfig -from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -195,6 +195,25 @@ def forward(self, x: torch.Tensor): return x_down +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + class Qwen2_5_VisionAttention(nn.Module): def __init__( @@ -214,10 +233,14 @@ def __init__( self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, @@ -236,7 +259,8 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = tensor_model_parallel_all_gather(qkv) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) @@ -694,9 +718,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: Set[str] = set() @@ -808,13 +832,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. @@ -952,20 +969,20 @@ def _process_video_input( return video_embeds.split(sizes.tolist()) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} + mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) - return modalities + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -973,8 +990,9 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: return None # The result multimodal_embeddings is tuple of tensors, with each @@ -983,14 +1001,13 @@ def get_multimodal_embeddings( # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - for modality in modalities: - if modality == "images": - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) multimodal_embeddings += vision_embeddings - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_input(video_input) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) multimodal_embeddings += video_embeddings return multimodal_embeddings @@ -1102,13 +1119,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -1121,5 +1131,6 @@ def get_mm_mapping(self) -> MultiModelKeys: """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="visual.", - tower_model="visual.merger.") + connector="visual.merger.", + tower_model="visual.", + ) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 280cda0f68f1..0cb541c6cbb2 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,7 +22,6 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Any, Optional, Set, Tuple, TypedDict, Union import torch @@ -34,7 +33,6 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -267,13 +265,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): @@ -405,13 +396,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 2700c706b972..62696678b7af 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -497,7 +496,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -524,14 +522,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8c24b8f7df52..ef84becd269c 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial +from functools import partial from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, Union) @@ -51,7 +51,6 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -1112,13 +1111,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. @@ -1400,13 +1392,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -1419,5 +1404,6 @@ def get_mm_mapping(self) -> MultiModelKeys: """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="visual.", - tower_model="visual.merger.") + connector="visual.merger.", + tower_model="visual.", + ) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 9c14038e6113..73d2838f461e 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -283,7 +282,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -311,14 +309,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index f0ef79dfdfe2..70f9956e3efc 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -494,7 +493,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -521,14 +519,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 02ee3a857443..79be5b0e6529 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -100,6 +100,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), @@ -122,13 +123,11 @@ _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), - "RobertaModel": ("roberta", "RobertaEmbeddingModel"), - "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), - "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GritLM": ("gritlm", "GritLM"), + "GteModel": ("bert", "GteEmbeddingModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaModel": ("llama", "LlamaForCausalLM"), @@ -138,12 +137,16 @@ if arch == "LlamaForCausalLM" }, "MistralModel": ("llama", "LlamaForCausalLM"), + "NomicBertModel": ("bert", "NomicBertEmbeddingModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"), + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), + "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), @@ -162,6 +165,8 @@ "RobertaForSequenceClassification"), "XLMRobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), + "ModernBertForSequenceClassification": ("modernbert", + "ModernBertForSequenceClassification"), } _MULTIMODAL_MODELS = { @@ -174,10 +179,12 @@ "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), + "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 + "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 @@ -195,6 +202,7 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 + "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), # [Encoder-decoder] @@ -208,6 +216,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 19a23162aa84..e78c37b65f87 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,6 @@ # -------------------------------------------------------- from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union import torch @@ -21,7 +20,6 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -699,13 +697,6 @@ def _patch_quant_config(self, config: PretrainedConfig, (llm_quant_config is not None): quant_config.modules_to_not_convert.append("vision_model") - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def _init_vision_model( self, config: PretrainedConfig, @@ -908,7 +899,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ) -> Union[SamplerOutput, IntermediateTensors]: + ) -> IntermediateTensors: if intermediate_tensors is not None: input_ids = None @@ -946,13 +937,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: skip_prefixes = [ diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 1cae0a7fe0dc..f86aff7ba7ef 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -418,8 +417,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -440,14 +437,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 53f520304abc..1cbda7267e4c 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -105,9 +104,8 @@ def __init__(self, 1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - rope_pct = getattr(config, "rope_pct", - getattr(config, "partial_rotary_factor", 1)) - self.rotary_ndims = int(self.head_dim * rope_pct) + self.partial_rotary_factor = getattr( + config, "rope_pct", getattr(config, "partial_rotary_factor", 1)) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim @@ -131,9 +129,10 @@ def __init__(self, prefix=f"{prefix}.o_proj") self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_ndims, + rotary_dim=self.head_dim, max_position=self.config.max_position_embeddings, base=self.config.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -310,7 +309,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -337,14 +335,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 8b9fb7cb7bc6..6eebe4c4d614 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -317,7 +316,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -344,14 +342,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a1f233e04892..a37e88a387fd 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -35,7 +35,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -84,7 +83,7 @@ def replace_linear_class( ) -> Union[ColumnParallelLinear, RowParallelLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. - + Args: linear (nn.Linear): `nn.Linear` to be replaced. style (str): Tensor parallel style of the new linear, e.g. "colwise". @@ -396,8 +395,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -435,12 +432,6 @@ def compute_logits( sampling_metadata) return logits - def sample(self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: - - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index cb5ff4ed6365..bfa48099b741 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch @@ -18,7 +17,6 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -438,13 +436,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models @@ -628,13 +619,6 @@ def compute_logits(self, hidden_states: torch.Tensor, return self.language_model.compute_logits(hidden_states, sampling_metadata) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 63e71f268805..908cd7885aa8 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -669,7 +668,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() def forward( self, @@ -724,14 +722,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ea21fffaede5..d34033e3ac90 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -31,7 +31,6 @@ MambaMixer2, extra_groups_for_head_shards) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -870,7 +869,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Initialize logits processing and sampling self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. @@ -1004,23 +1002,6 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """Sample next tokens from computed logits. - - Args: - logits: Computed logits for next token prediction - sampling_metadata: Metadata for sampling process - - Returns: - Sampled tokens and related sampling information - """ - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 2b1294bf7baa..34a0b527b585 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -282,10 +282,12 @@ def __init__(self, packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size + self._bitblas_tile_size = bitblas_tile_size super().__init__(**kwargs) @property @@ -300,12 +302,17 @@ def packed_factor(self): def marlin_tile_size(self): return self._marlin_tile_size + @property + def bitblas_tile_size(self): + return self._bitblas_tile_size + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): return _adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset, packed_factor=self.packed_factor, - marlin_tile_size=self.marlin_tile_size) + marlin_tile_size=self.marlin_tile_size, + bitblas_tile_size=self.bitblas_tile_size) class PackedvLLMParameter(ModelWeightParameter): @@ -323,10 +330,12 @@ def __init__(self, packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size + self._bitblas_tile_size = bitblas_tile_size super().__init__(**kwargs) @property @@ -341,12 +350,17 @@ def packed_factor(self): def marlin_tile_size(self): return self._marlin_tile_size + @property + def bitblas_tile_size(self): + return self._bitblas_tile_size + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): return _adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset, packed_factor=self.packed_factor, - marlin_tile_size=self.marlin_tile_size) + marlin_tile_size=self.marlin_tile_size, + bitblas_tile_size=self.bitblas_tile_size) class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): @@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, + bitblas_tile_size): + return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size + + def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, - marlin_tile_size): + marlin_tile_size, bitblas_tile_size): shard_size = shard_size // packed_factor shard_offset = shard_offset // packed_factor if marlin_tile_size is not None: @@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_size=shard_size, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) - return shard_size, shard_offset + elif bitblas_tile_size is not None: + return _adjust_shard_indexes_for_bitblas( + shard_size=shard_size, + shard_offset=shard_offset, + bitblas_tile_size=bitblas_tile_size) + + return shard_size, shard_offset \ No newline at end of file diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 741bd1a6a1c1..c65d9407dcd1 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - -from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .base import MultiModalPlaceholderMap from .hasher import MultiModalHashDict, MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, @@ -26,7 +25,6 @@ "MultiModalKwargs", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", - "MultiModalPlugin", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f379ec1682a3..1fd2ab7f87d1 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 - import base64 from io import BytesIO from pathlib import Path +from typing import Literal, Optional import numpy as np import numpy.typing as npt -from vllm.inputs.registry import InputContext from vllm.utils import PlaceholderModule -from .base import MediaIO, MultiModalPlugin -from .inputs import AudioItem, ModalityData, MultiModalKwargs +from .base import MediaIO try: import librosa @@ -24,26 +22,7 @@ soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] -class AudioPlugin(MultiModalPlugin): - """Plugin for audio data.""" - - def get_data_key(self) -> str: - return "audio" - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[AudioItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - raise NotImplementedError("There is no default audio input mapper") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - raise NotImplementedError( - "There is no default maximum multimodal tokens") - - -def resample_audio( +def resample_audio_librosa( audio: npt.NDArray[np.floating], *, orig_sr: float, @@ -52,6 +31,55 @@ def resample_audio( return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) +def resample_audio_scipy( + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + target_sr: float, +): + # lazy import scipy.signal, otherwise it will crash doc build. + import scipy.signal + + if orig_sr > target_sr: + return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr) + elif orig_sr < target_sr: + return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1) + return audio + + +class AudioResampler: + """Resample audio data to a target sample rate.""" + + def __init__( + self, + target_sr: Optional[float] = None, + method: Literal["librosa", "scipy"] = "librosa", + ): + self.target_sr = target_sr + self.method = method + + def resample( + self, + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + ) -> npt.NDArray[np.floating]: + if self.target_sr is None: + raise RuntimeError("Audio resampling is not supported when " + "`target_sr` is not provided") + if self.method == "librosa": + return resample_audio_librosa(audio, + orig_sr=orig_sr, + target_sr=self.target_sr) + elif self.method == "scipy": + return resample_audio_scipy(audio, + orig_sr=orig_sr, + target_sr=self.target_sr) + else: + raise ValueError(f"Invalid resampling method: {self.method}. " + "Supported methods are 'librosa' and 'scipy'.") + + class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ad95b982499c..2f93922fcedb 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,247 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import Sequence from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, - Optional, TypeVar, Union) - -from torch import nn - -from vllm.inputs import InputContext -from vllm.logger import init_logger -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) +from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar if TYPE_CHECKING: - from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, - PlaceholderRange) - -logger = init_logger(__name__) - -MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], - MultiModalKwargs] -""" -Return a dictionary to be passed as keyword arguments to -:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers -and processors in HuggingFace Transformers. - -If the data is not supported, throw :exc:`TypeError`. -""" - -MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] -""" -Calculate the maximum number of multimodal tokens input to the language -model. This does not include tokens that correspond to the input text. -""" +from .inputs import MultiModalKwargs, PlaceholderRange _T = TypeVar("_T") -N = TypeVar("N", bound=type[nn.Module]) - - -class MultiModalPlugin(ABC): - """ - Base class that defines data processing logic for a specific modality. - - In particular, we adopt a registry pattern to dispatch data processing - according to the model being used (considering that different models may - process the same data differently). This registry is in turn used by - :class:`~MultiModalRegistry` which acts at a higher level - (i.e., the modality of the data). - """ - - def __init__(self) -> None: - self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() - self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() - - @abstractmethod - def get_data_key(self) -> str: - """ - Get the data key corresponding to the modality. - """ - raise NotImplementedError - - @abstractmethod - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[Any], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - """ - Return a dictionary to be passed as keyword arguments to - :meth:`~torch.nn.Module.forward`. This is similar in concept to - tokenizers and processors in HuggingFace Transformers. - - If the data is not supported, throw :exc:`TypeError`. - """ - raise NotImplementedError - - def register_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper to a model class. - - When the model receives input data that matches the modality served by - this plugin (see :meth:`get_data_key`), the provided function is - invoked to transform the data into a dictionary of model inputs. - - If `None` is provided, then the default input mapper is used instead. - """ - - def wrapper(model_cls: N) -> N: - if self._input_mappers.contains(model_cls, strict=True): - logger.warning( - "Model class %s already has an input mapper " - "registered to %s. It is overwritten by the new one.", - model_cls, - self, - ) - - self._input_mappers[model_cls] = (mapper - or self._default_input_mapper) - - return model_cls - - return wrapper - - def map_input( - self, - model_config: "ModelConfig", - data: ModalityData[Any], - mm_processor_kwargs: Optional[dict[str, Any]], - ) -> MultiModalKwargs: - """ - Transform the data into a dictionary of model inputs using the - input mapper registered for that model. - - The model is identified by ``model_config``. - - Raises: - TypeError: If the data type is not supported. - """ - - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - - mapper = self._input_mappers.get(model_cls) - - if mapper is None: - raise KeyError(f"No input mapper in {self} is registered for " - f"model class {model_cls.__name__}.") - - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - - # In the case of the default mapper, we have to get resource - # processor through its HuggingFace autoclass; since this goes - # through **kwargs, we can't inspect it the same way, so we allow - # drop mm_processor_kwargs based on signature inspection - # if we're using the default mapper. - # - # This should be safe in general due to the sanitation, since the - # transformers resource should filter unused kwargs anyway. - uses_default_mapper = mapper == self._default_input_mapper - mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - mm_processor_kwargs, - callable=mapper, - allow_var_kwargs=uses_default_mapper, - ) - return mapper(InputContext(model_config), data, **mm_processor_kwargs) - - @abstractmethod - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - """ - Calculate the maximum number of tokens, corresponding to a single - instance of multimodal data, that are passed to the language model. - """ - raise NotImplementedError - - def _validate_max_multimodal_tokens(self, max_mm_tokens: int): - if max_mm_tokens < 1: - raise ValueError("You should set the number of tokens to a " - f"positive integer. Found: {max_mm_tokens}") - - def register_max_multimodal_tokens( - self, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of tokens, corresponding to a single - instance of multimodal data, that are passed to the language model - for a model class. - - If `None` is provided, then the default calculation is used instead. - """ - - def wrapper(model_cls: N) -> N: - if self._max_mm_tokens.contains(model_cls, strict=True): - logger.warning( - "Model class %s already calculates maximum number of " - "tokens in %s. It is overwritten by the new one.", - model_cls, - self, - ) - - if isinstance(max_mm_tokens, int): - self._validate_max_multimodal_tokens(max_mm_tokens) - - self._max_mm_tokens[model_cls] = ( - max_mm_tokens or self._default_max_multimodal_tokens) - - return model_cls - - return wrapper - - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: - """ - Get the maximum number of multi-modal tokens - for profiling the memory usage of a model. - - If this registry is not applicable to the model, `0` is returned. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - from vllm.model_executor.models import supports_multimodal - - model_cls, _ = get_model_architecture(model_config) - - if not supports_multimodal(model_cls): - return 0 - - max_mm_tokens = self._max_mm_tokens.get(model_cls) - if max_mm_tokens is None: - return 0 - - if callable(max_mm_tokens): - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - max_mm_tokens, - overrides=model_config.mm_processor_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - max_mm_tokens = max_mm_tokens(InputContext(model_config), - **mm_processor_kwargs) - - self._validate_max_multimodal_tokens(max_mm_tokens) - - return max_mm_tokens class MultiModalPlaceholderMap: """ Relates multi-modal embeddings to their corresponding placeholders. + + Note: This is only used in V0. """ class IndexMap(NamedTuple): @@ -279,8 +55,7 @@ def __init__(self): @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[Optional[MultiModalDataDict], dict[str, - "MultiModalPlaceholderMap"]]: + ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a prompt (``seq_group``) represented by ``positions``, as well as a @@ -323,48 +98,24 @@ def from_seq_group( seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return seq_mm_data, {} - - # For merged processor, we directly use mm_kwargs as mm_data - if isinstance(seq_mm_data, MultiModalKwargs): - placeholder_maps = dict[str, MultiModalPlaceholderMap]() - - for modality, placeholders in seq_mm_placeholders.items(): - placeholder_map = MultiModalPlaceholderMap() + return MultiModalKwargs({}), {} - if positions: - placeholder_map.append_items_from_seq_group( - positions, - # Dummy, since we don't care about intersecting items - [None] * len(placeholders), - placeholders, - ) - - placeholder_maps[modality] = placeholder_map - - return seq_mm_data, placeholder_maps - - mm_data = {**seq_mm_data} - placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( - MultiModalPlaceholderMap) + placeholder_maps = dict[str, MultiModalPlaceholderMap]() for modality, placeholders in seq_mm_placeholders.items(): - mm_items = mm_data.pop(modality) - if not isinstance(mm_items, list): - mm_items = [mm_items] + placeholder_map = MultiModalPlaceholderMap() if positions: - intersecting_items = placeholder_maps[modality] \ - .append_items_from_seq_group( - positions, - mm_items, - placeholders, - ) + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) - if intersecting_items: - mm_data[modality] = intersecting_items + placeholder_maps[modality] = placeholder_map - return mm_data, placeholder_maps + return seq_mm_data, placeholder_maps def append_items_from_seq_group( self, @@ -445,8 +196,7 @@ def index_map(self) -> "IndexMap": f"The number of source ({len(src_indices)}) and destination " f"indices ({len(dest_indices)}) must be the same.") - return MultiModalPlaceholderMap.IndexMap(src=src_indices, - dest=dest_indices) + return self.IndexMap(src=src_indices, dest=dest_indices) class MediaIO(ABC, Generic[_T]): diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 0c5a84c6508a..939928bbf108 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,89 +3,11 @@ import base64 from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import torch from PIL import Image -from vllm.inputs.registry import InputContext -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_image_processor -from vllm.utils import is_list_of - -from .base import MediaIO, MultiModalPlugin -from .inputs import ImageItem, ModalityData, MultiModalKwargs - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -logger = init_logger(__name__) - - -class ImagePlugin(MultiModalPlugin): - """Plugin for image data.""" - - def get_data_key(self) -> str: - return "image" - - def _get_hf_image_processor( - self, - model_config: "ModelConfig", - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ): - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[ImageItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - model_config = ctx.model_config - - # PIL image - if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor( - model_config, - mm_processor_kwargs, - ) - - if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - try: - # NOTE: It may make sense to forward the mm_processor_kwargs - # here too. For now, to keep it simple, we only allow it be - # used for the initialization call though, just in case the - # signatures of the preprocessor initializer don't match - # preprocess() - batch_data = image_processor \ - .preprocess(data, return_tensors="pt") \ - .data - except Exception: - logger.error( - "Failed to process image (%s) with the default mapper. " - "This is most likely an edge-case with this model's image " - "processor in transformers (type: %s), and not vLLM.", - data, - type(image_processor).__name__) - raise - - return MultiModalKwargs(batch_data) - - # Image embedding - elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): - return MultiModalKwargs({"image_embeds": data}) - - raise TypeError(f"Invalid image type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 3000 +from .base import MediaIO def rescale_image_size(image: Image.Image, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 53729799b629..6855808e8e44 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -320,7 +320,8 @@ class MultiModalFlatField(BaseMultiModalField): :func:`MultiModalFieldConfig.flat` :func:`MultiModalFieldConfig.flat_from_sizes` """ - slices: Sequence[slice] + slices: Union[Sequence[slice], Sequence[Sequence[slice]]] + dim: int = 0 def build_elems( self, @@ -329,7 +330,10 @@ def build_elems( data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: field_factory = self._field_factory(modality=modality, key=key) - return [field_factory(data[s]) for s in self.slices] + if not is_list_of(self.slices, slice, check="all"): + assert isinstance(data, torch.Tensor), \ + "torch.Tensor is required for multiple slices" + return [field_factory(data[cast(slice, s)]) for s in self.slices] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): @@ -338,10 +342,16 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: # - produce exactly same result as `torch.concat(batch)` # - will achieve zero-copy if the tensor is contiguous return batch[0].contiguous() - first_shape = batch[0].shape - if all(elem.shape[1:] == first_shape[1:] for elem in batch): - return torch.concat(batch) + def _expect_same_shape(tensor: torch.Tensor): + return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:] + + first_shape = _expect_same_shape(batch[0]) + + if all(_expect_same_shape(elem) == first_shape for elem in batch): + return torch.concat(batch, dim=self.dim) + + assert self.dim == 0, "dim == 0 is required for nested list" return [e for elem in batch for e in elem] @@ -398,7 +408,9 @@ def batched(modality: str): ) @staticmethod - def flat(modality: str, slices: Sequence[slice]): + def flat(modality: str, + slices: Union[Sequence[slice], Sequence[Sequence[slice]]], + dim: int = 0): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -406,8 +418,10 @@ def flat(modality: str, slices: Sequence[slice]): Args: modality: The modality of the multi-modal item that uses this keyword argument. - slices: For each multi-modal item, a slice that is used to extract - the data corresponding to it. + slices: For each multi-modal item, a slice (dim=0) or a tuple of + slices (dim>0) that is used to extract the data corresponding + to it. + dim: The dimension to extract data, default to 0. Example: @@ -423,14 +437,33 @@ def flat(modality: str, slices: Sequence[slice]): Element 1: [AAA] Element 2: [BBBB] Element 3: [CC] + + .. code-block:: + + Given: + slices: [ + (slice(None), slice(0, 3)), + (slice(None), slice(3, 7)), + (slice(None), slice(7, 9))] + dim: 1 + + Input: + Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]] + + Output: + Element 1: [[A],[A],[A]] + Element 2: [[B],[B],[B],[B]] + Element 3: [[C],[C]] """ return MultiModalFieldConfig( - field=MultiModalFlatField(slices=slices), + field=MultiModalFlatField(slices=slices, dim=dim), modality=modality, ) @staticmethod - def flat_from_sizes(modality: str, size_per_item: torch.Tensor): + def flat_from_sizes(modality: str, + size_per_item: torch.Tensor, + dim: int = 0): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -440,6 +473,7 @@ def flat_from_sizes(modality: str, size_per_item: torch.Tensor): keyword argument. slices: For each multi-modal item, the size of the slice that is used to extract the data corresponding to it. + dim: The dimension to slice, default to 0. Example: @@ -455,6 +489,21 @@ def flat_from_sizes(modality: str, size_per_item: torch.Tensor): Element 1: [AAA] Element 2: [BBBB] Element 3: [CC] + + + .. code-block:: + + Given: + slices: [3, 4, 2] + dim: 1 + + Input: + Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]] + + Output: + Element 1: [[A],[A],[A]] + Element 2: [[B],[B],[B],[B]] + Element 3: [[C],[C]] See also: :func:`MultiModalFieldConfig.flat` @@ -465,12 +514,11 @@ def flat_from_sizes(modality: str, size_per_item: torch.Tensor): f"but found shape: {size_per_item.shape}") slice_idxs = [0, *accumulate(size_per_item)] - slices = [ - slice(slice_idxs[i], slice_idxs[i + 1]) - for i in range(len(size_per_item)) - ] + slices = [(slice(None, None, None), ) * dim + + (slice(slice_idxs[i], slice_idxs[i + 1]), ) + for i in range(len(size_per_item))] - return MultiModalFieldConfig.flat(modality, slices) + return MultiModalFieldConfig.flat(modality, slices, dim=dim) @staticmethod def shared(modality: str, batch_size: int): diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index fc5a294564e3..9707b9cfcf8b 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, - Union) +from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional, + TypeVar, Union) import numpy as np import torch @@ -14,7 +14,7 @@ from vllm.utils import is_list_of -from .audio import resample_audio +from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, VideoItem) @@ -308,10 +308,18 @@ class MultiModalDataParser: items to the model's expected sampling rate. """ - def __init__(self, *, target_sr: Optional[float] = None) -> None: + def __init__( + self, + *, + target_sr: Optional[float] = None, + audio_resample_method: Literal["librosa", "scipy"] = "librosa", + ) -> None: super().__init__() - self.target_sr = target_sr + self.audio_resampler = AudioResampler( + target_sr=target_sr, + method=audio_resample_method, + ) def _is_embeddings( self, data: object @@ -374,15 +382,8 @@ def _parse_audio_data( if orig_sr is None: new_audio = audio else: - target_sr = self.target_sr - if target_sr is None: - raise RuntimeError( - "Audio resampling is not supported when " - "`target_sr` is not provided") - - new_audio = resample_audio(audio, - orig_sr=orig_sr, - target_sr=target_sr) + new_audio = self.audio_resampler.resample(audio, + orig_sr=orig_sr) new_audios.append(new_audio) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7f289426d349..87131122e6f2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import json import re import sys from abc import ABC, abstractmethod @@ -1117,8 +1118,9 @@ def _to_mm_items( if num_items > allowed_limit: raise ValueError( - f"You set or defaulted to {modality}={allowed_limit} " - f"in --limit-mm-per-prompt`, but passed {num_items} " + "You set or defaulted to " + f"'{json.dumps({modality: allowed_limit})}' in " + f"`--limit-mm-per-prompt`, but passed {num_items} " f"{modality} items in the same prompt.") return mm_items @@ -1567,56 +1569,35 @@ def _validate_mm_placeholders( "model (usually arising from an inconsistency between " "`_call_hf_processor` and `_get_prompt_updates`).") - def apply( + def _hash_mm_items( self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - return_mm_hashes: bool = False, - ) -> MultiModalInputs: - """ - Process multi-modal inputs to be used in vLLM. + ) -> dict[str, list[str]]: + """Create MM hashes to be returned (only used in V1).""" - The main steps are: - - 1. Apply HF Processor on prompt text and multi-modal data together, - outputting token IDs and processed tensors. - 2. Find and update sequences in the token IDs with placeholder tokens. - The number of placeholder tokens equals the feature size of the - multi-modal data outputted by the multi-modal encoder. - 3. Extract information about the placeholder tokens from the - processed token IDs. - """ - mm_items = self._to_mm_items(mm_data) - - # Create MM hashes to be returned (only used in V1) # TODO: Use these hash keys for caching operations in apply_hf_processor # instead of rehashing. + model_id = self.info.model_id - if return_mm_hashes: - model_id = self.info.model_id - mm_hashes = { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } - else: - mm_hashes = None - - ( - prompt_ids, - mm_kwargs, - is_update_applied, - ) = self._cached_apply_hf_processor( - prompt, - mm_items, - hf_processor_mm_kwargs, - ) + return { + modality: [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_items.items() + } + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + prompt_ids: list[int], + mm_kwargs: MultiModalKwargs, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: unbound_prompt_updates = self._get_prompt_updates( mm_items, hf_processor_mm_kwargs, @@ -1650,6 +1631,51 @@ def apply( ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + return prompt_ids, prompt, mm_placeholders + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and update sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + mm_items = self._to_mm_items(mm_data) + + mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs) + if return_mm_hashes else None) + + ( + prompt_ids, + mm_kwargs, + is_update_applied, + ) = self._cached_apply_hf_processor( + prompt, + mm_items, + hf_processor_mm_kwargs, + ) + + prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + prompt_ids=prompt_ids, + mm_kwargs=mm_kwargs, + is_update_applied=is_update_applied, + ) + mm_placeholder_ranges = { modality: [item.to_range() for item in placeholders] for modality, placeholders in mm_placeholders.items() diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index def0595013b8..ec4f15681019 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -from collections import UserDict -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn +from typing_extensions import deprecated from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext @@ -15,15 +13,10 @@ cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .audio import AudioPlugin -from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc -from .image import ImagePlugin -from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache) from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) -from .video import VideoPlugin if TYPE_CHECKING: from vllm.config import ModelConfig @@ -84,169 +77,23 @@ def build_processor( return self.processor(info, dummy_inputs_builder, cache=cache) -class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]): - """ - Wraps `_limits_by_model` for a more informative error message - when attempting to access a model that does not exist. - """ - - def __getitem__(self, key: "ModelConfig") -> dict[str, int]: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " - "forget to call `init_mm_limits_per_prompt`?") - raise KeyError(msg) from exc - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. """ - DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) - - def __init__( - self, - *, - plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: - self._plugins = {p.get_data_key(): p for p in plugins} - + def __init__(self) -> None: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - # This is used for non-multimodal models - self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} - - self._limits_by_model = _MultiModalLimits() - self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) - def register_plugin(self, plugin: MultiModalPlugin) -> None: - """ - Register a multi-modal plugin so it can be recognized by vLLM. - """ - data_type_key = plugin.get_data_key() - - if data_type_key in self._plugins: - logger.warning( - "A plugin is already registered for data type %s, " - "and will be overwritten by the new plugin %s.", data_type_key, - plugin) - - self._plugins[data_type_key] = plugin - - def _get_plugin(self, data_type_key: str): - plugin = self._plugins.get(data_type_key) - if plugin is not None: - return plugin - - msg = f"Unknown multi-modal data type: {data_type_key}" - raise NotImplementedError(msg) - - def register_input_mapper( - self, - data_type_key: str, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for a specific modality to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self._get_plugin(data_type_key).register_input_mapper(mapper) - - def register_image_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for image data to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self.register_input_mapper("image", mapper) - - def map_input( - self, - model_config: "ModelConfig", - data: MultiModalDataDict, - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ) -> MultiModalKwargs: - """ - Apply an input mapper to the data passed to the model. - - The data belonging to each modality is passed to the corresponding - plugin which in turn converts the data into into keyword arguments - via the input mapper registered for that model. - - See :meth:`MultiModalPlugin.map_input` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. - """ - merged_dict = dict[str, NestedTensors]() - - for data_key, data_value in data.items(): - plugin = self._get_plugin(data_key) - - num_items = len(data_value) if isinstance(data_value, list) else 1 - max_items = self._limits_by_model[model_config][data_key] - if num_items > max_items: - raise ValueError( - f"You set {data_key}={max_items} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but found {num_items} items " - "in the same prompt.") - - input_dict = plugin.map_input(model_config, data_value, - mm_processor_kwargs) - for input_key, input_tensor in input_dict.items(): - if input_key in merged_dict: - raise ValueError(f"The input mappers (keys={set(data)}) " - f"resulted in a conflicting keyword " - f"argument to `forward()`: {input_key}") - - merged_dict[input_key] = input_tensor - - return MultiModalKwargs(merged_dict) - + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def create_input_mapper(self, model_config: "ModelConfig"): - """ - Create an input mapper (see :meth:`map_input`) for a specific model. - """ - # NOTE - we currently make the assumption that if a model has multiple - # supported modalities, they take the same kwargs. For the default, - # this could be an issue in the future if it falls back to two HF - # resources and we can't inspect the signature easily since it's - # getting initialized through the autoclass. - # - # If this is a problem in the future, we should revisit it, but since - # it potentially introduces a lot of complexity for a currently - # uncommon case, we do not for simplicity of both use & implementation - return functools.partial(self.map_input, model_config) - - def register_max_multimodal_tokens( - self, - data_type_key: str, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of tokens, corresponding to a single - instance of multimodal data belonging to a specific modality, that are - passed to the language model for a model class. - """ - return self._get_plugin(data_type_key) \ - .register_max_multimodal_tokens(max_mm_tokens) - - def register_max_image_tokens( - self, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of image tokens, corresponding to a single - image, that are passed to the language model for a model class. - """ - return self.register_max_multimodal_tokens("image", max_mm_tokens) + return lambda data, mm_processor_kwargs: data def get_max_tokens_per_item_by_modality( self, @@ -256,25 +103,22 @@ def get_max_tokens_per_item_by_modality( Get the maximum number of tokens per data item from each modality based on underlying model configuration. """ - if self.has_processor(model_config): - processor = self.create_processor(model_config, disable_cache=True) - profiler = MultiModalProfiler(processor) - - seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) - - return profiler.get_mm_max_tokens( - seq_len, - { - modality: 1 - for modality, limit in mm_limits.items() if limit > 0 - }, - ) + if not model_config.is_multimodal_model: + return {} - return { - key: plugin.get_max_multimodal_tokens(model_config) - for key, plugin in self._plugins.items() - } + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + + seq_len = model_config.max_model_len + mm_limits = self.get_mm_limits_per_prompt(model_config) + + return profiler.get_mm_max_tokens( + seq_len, + { + modality: 1 + for modality, limit in mm_limits.items() if limit > 0 + }, + ) def get_max_tokens_per_item_by_nonzero_modality( self, @@ -307,9 +151,6 @@ def get_max_tokens_by_modality( for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ mm_limits = self.get_mm_limits_per_prompt(model_config) @@ -325,47 +166,18 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ return sum(self.get_max_tokens_by_modality(model_config).values()) + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def init_mm_limits_per_prompt( self, model_config: "ModelConfig", ) -> None: - """ - Initialize the maximum number of multi-modal input instances for each - modality that are allowed per prompt for a model class. - """ - if model_config in self._limits_by_model: - logger.warning( - "`mm_limits` has already been set for model=%s, and will " - "be overwritten by the new values.", model_config.model) - - multimodal_config = model_config.multimodal_config - if multimodal_config is None: - limits_per_plugin = self._disabled_limits_per_plugin - else: - config_limits_per_plugin = multimodal_config.limit_per_prompt - - extra_keys = config_limits_per_plugin.keys() - self._plugins.keys() - if extra_keys: - logger.warning( - "Detected extra keys in `--limit-mm-per-prompt` which " - "are not registered as multi-modal plugins: %s. " - "They will be ignored.", extra_keys) - - # NOTE: Currently the default is set to 1 for each plugin - # TODO: Automatically determine the limits based on budget - # once more models support multi-image inputs - limits_per_plugin = { - key: multimodal_config.get_limit_per_prompt(key) - for key in self._plugins - } - - self._limits_by_model[model_config] = limits_per_plugin + pass def get_mm_limits_per_prompt( self, @@ -374,16 +186,13 @@ def get_mm_limits_per_prompt( """ Get the maximum number of multi-modal input instances for each modality that are allowed per prompt for a model class. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ - if self.has_processor(model_config): - processor = self.create_processor(model_config, disable_cache=True) - profiler = MultiModalProfiler(processor) - return profiler.get_mm_limits() + if not model_config.is_multimodal_model: + return {} - return self._limits_by_model[model_config] + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + return profiler.get_mm_limits() def register_processor( self, @@ -427,14 +236,12 @@ def _get_model_cls(self, model_config: "ModelConfig"): model_cls, _ = get_model_architecture(model_config) return model_cls + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def has_processor(self, model_config: "ModelConfig") -> bool: - """ - Test whether a multi-modal processor is defined for a specific model. - - See also: - :ref:`mm-processing` - """ - return self._get_model_cls(model_config) in self._processor_factories + return True def create_processor( self, @@ -449,6 +256,9 @@ def create_processor( See also: :ref:`mm-processing` """ + if not model_config.is_multimodal_model: + raise ValueError(f"{model_config.model} is not a multimodal model") + if tokenizer is None: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index f7c3f1052954..6d875a1c651e 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -4,80 +4,13 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import numpy as np import numpy.typing as npt from PIL import Image -from vllm.inputs.registry import InputContext -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_video_processor -from vllm.utils import is_list_of - -from .base import MediaIO, ModalityData -from .image import ImageMediaIO, ImagePlugin -from .inputs import MultiModalKwargs, VideoItem - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -logger = init_logger(__name__) - - -class VideoPlugin(ImagePlugin): - """Plugin for video data.""" - - def get_data_key(self) -> str: - return "video" - - def _get_hf_video_processor( - self, - model_config: "ModelConfig", - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ): - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return cached_get_video_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[VideoItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - model_config = ctx.model_config - - if isinstance(data, list) and len(data) == 1: - data = data[0] # type: ignore - - if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): - video_processor = self._get_hf_video_processor( - model_config, - mm_processor_kwargs, - ) - if video_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the video object") - try: - # NOTE: Similar to image; it may be a good idea to filter and - # pass mm_processor_kwargs here too, but for now we don't to - # avoid extra complexity if the initializer and preprocess - # signatures of the processor don't align - batch_data = video_processor(data, return_tensors="pt").data - except Exception: - logger.error("Failed to process video (%s)", data) - raise - - return MultiModalKwargs(batch_data) - - raise TypeError(f"Invalid video type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 4096 +from .base import MediaIO +from .image import ImageMediaIO def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..65a6ed01451d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -134,26 +134,32 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - def add(self, next_output: "RequestOutput") -> None: + def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" self.finished |= next_output.finished for next_completion in next_output.outputs: - for completion in self.outputs: + for i, completion in enumerate(self.outputs): if completion.index == next_completion.index: - # Merge outputs with same index - completion.text += next_completion.text - if not isinstance(completion.token_ids, MutableSequence): - completion.token_ids = list(completion.token_ids) - completion.token_ids.extend(next_completion.token_ids) - if next_completion.logprobs: - assert completion.logprobs is not None - completion.logprobs.extend(next_completion.logprobs) - completion.cumulative_logprob = ( - next_completion.cumulative_logprob) - completion.finish_reason = next_completion.finish_reason - completion.stop_reason = next_completion.stop_reason + if aggregate: + # Merge outputs with same index + completion.text += next_completion.text + if not isinstance(completion.token_ids, + MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend( + next_completion.logprobs) + completion.cumulative_logprob = ( + next_completion.cumulative_logprob) + completion.finish_reason = next_completion.finish_reason + completion.stop_reason = next_completion.stop_reason + else: + # Replace the output with the new one + self.outputs[i] = next_completion break else: self.outputs.append(next_completion) @@ -173,6 +179,13 @@ def from_seq_group( group.finish_seq(seq_group) if assembled_seq_group is None: return None + + # clear finished seq in seq_id_to_seq_group + if len(group.to_be_finished) == 0: + for sub_request_id in list(group.seq_id_to_index.keys()): + if sub_request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[sub_request_id] + return cls.from_seq_group(assembled_seq_group, use_cache, seq_id_to_seq_group) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0576022be448..f82af426b5a8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -21,9 +21,6 @@ if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig -else: - ModelConfig = None - VllmConfig = None logger = init_logger(__name__) @@ -109,7 +106,7 @@ def log_warnings(cls): pass @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config compilation_config = vllm_config.compilation_config @@ -213,6 +210,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return ("vllm.attention.backends." "flashmla.FlashMLABackend") if use_v1: + if selected_backend == _Backend.FLASHINFER: + logger.info_once("Using FlashInfer backend on V1 engine.") + return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") return ("vllm.v1.attention.backends." @@ -305,7 +305,7 @@ def supports_fp8(cls) -> bool: return cls.has_device_capability(89) @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: + def supports_v1(cls, model_config: "ModelConfig") -> bool: return True @classmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 31a7ffbd910d..c5555aba1a3e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,7 +8,7 @@ import numpy as np import torch -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: @@ -39,6 +39,7 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() + ROCM_AITER_MLA = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 @@ -148,6 +149,9 @@ def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + def is_sleep_mode_available(self) -> bool: + return self._enum == PlatformEnum.CUDA + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], @@ -397,9 +401,26 @@ def validate_request( cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" + def __getattr__(self, key: str): + device = getattr(torch, self.device_name, None) + if device is not None and hasattr(device, key): + return getattr(device, key) + else: + logger.warning("Current platform %s does not have '%s'" \ + " attribute.", self.device_name, key) + return None + + @classmethod + def get_cu_count(cls, device_id: int = 0) -> int: + """ + Returns the total number of compute units (CU) on single GPU. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index c1f426e5b880..e37a3a578cf2 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -50,7 +50,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ - vllm_config.model_config.max_model_len + vllm_config.model_config.max_model_len # type: ignore @classmethod def is_pin_memory_available(cls) -> bool: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d18b7c26f7ec..f6be3b0e814a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -13,9 +13,6 @@ if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig -else: - ModelConfig = None - VllmConfig = None logger = init_logger(__name__) @@ -98,24 +95,29 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +def on_mi250_mi300() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + + @cache def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, gqa_ratio: int, max_seq_len: int, sliding_window: int) -> bool: - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName - ON_NAVI = "gfx1" in GPU_ARCH - ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) - - # rocm custom page attention not support on navi (gfx1*) - return (ON_MI250_MI300 and not ON_NAVI - and (sliding_window == 0 or sliding_window == (-1, -1)) + # rocm custom page attention not support on gfx1* + # custom paged attn always supported on V0. On V1, requires sliding window + # disabled due to observed numerical discrepancy. + return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 + or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + and envs.VLLM_ROCM_USE_AITER)) class RocmPlatform(Platform): @@ -128,8 +130,8 @@ class RocmPlatform(Platform): device_control_env_var: str = "CUDA_VISIBLE_DEVICES" supported_quantization: list[str] = [ - "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" + "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", + "quark", "ptpc_fp8" ] @classmethod @@ -137,8 +139,36 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_mla: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + from vllm.attention.backends.rocm_aiter_mla import ( + is_aiter_mla_enabled) + + if selected_backend is None: + selected_backend = (_Backend.ROCM_AITER_MLA if + is_aiter_mla_enabled() or block_size == 1 + else _Backend.TRITON_MLA) + + if selected_backend == _Backend.TRITON_MLA: + if block_size != 1: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}.") + elif selected_backend == _Backend.ROCM_AITER_MLA: + if block_size == 1: + logger.info("Using AITER MLA backend.") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}." + "(currently only supports block size 1)") + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"is not MLA type while requested for MLA backend.") + selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: @@ -210,7 +240,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return True @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 @@ -299,7 +329,7 @@ def fp8_dtype(cls) -> torch.dtype: return torch.float8_e4m3fn @classmethod - def supports_v1(cls, model_config: ModelConfig) -> bool: + def supports_v1(cls, model_config: "ModelConfig") -> bool: # V1 support on AMD gpus is experimental return True @@ -309,3 +339,8 @@ def use_custom_allreduce(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName supported_archs = ['gfx94'] return any(gfx in gcn_arch for gfx in supported_archs) + + @classmethod + def get_cu_count(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties( + device_id).multi_processor_count diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d8807a72ba2f..d5923557a211 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -5,7 +5,7 @@ import torch import vllm.envs as envs -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -30,9 +30,7 @@ class TpuPlatform(Platform): ray_device_key: str = "TPU" device_control_env_var: str = "TPU_VISIBLE_CHIPS" - supported_quantization: list[str] = [ - "tpu_int8", "compressed-tensors", "compressed_tensors" - ] + supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] additional_env_vars: list[str] = [ "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" @@ -97,6 +95,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Using bfloat16 instead.", vllm_config.model_config.dtype) vllm_config.model_config.dtype = torch.bfloat16 + if envs.VLLM_USE_V1: + from vllm.v1.attention.backends.pallas import ( + PallasAttentionBackend) + min_page_size = PallasAttentionBackend.get_min_page_size( + vllm_config) + if min_page_size > vllm_config.cache_config.block_size: + logger.warning( + "Increase the page size from %s to %s to make sure there's" + "no SMEM OOM", + vllm_config.cache_config.block_size, + min_page_size, + ) + vllm_config.cache_config.block_size = min_page_size + parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": @@ -150,12 +162,13 @@ def validate_request( cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): - if params.guided_decoding is not None: + if params.guided_decoding is not None and not envs.VLLM_USE_V1: raise ValueError("Structured output is not supported on " - f"{cls.device_name}.") + f"{cls.device_name} V0.") if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError( "Torch XLA does not support per-request seed.") diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index f71daf0c1955..9a3b254f9b68 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -35,7 +35,16 @@ def verify(self, model_config: "ModelConfig") -> None: f'Model "{model_config.served_model_name}" does not ' f'support matryoshka representation, ' f'changing output dimensions will lead to poor results.') - if self.dimensions < 1: + + mds = model_config.matryoshka_dimensions + if mds is not None: + if self.dimensions not in mds: + raise ValueError( + f'Model "{model_config.served_model_name}" ' + f'only supports {str(mds)} matryoshka dimensions, ' + f'use other output dimensions will ' + f'lead to poor results.') + elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") def __repr__(self) -> str: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ed99664947..c430b74a9db9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -38,6 +38,7 @@ class GuidedDecodingParams: """These are other options that can be set""" backend: Optional[str] = None whitespace_pattern: Optional[str] = None + structural_tag: Optional[str] = None @staticmethod def from_optional( @@ -48,9 +49,10 @@ def from_optional( json_object: Optional[bool] = None, backend: Optional[str] = None, whitespace_pattern: Optional[str] = None, + structural_tag: Optional[str] = None, ) -> Optional["GuidedDecodingParams"]: - if all(arg is None - for arg in (json, regex, choice, grammar, json_object)): + if all(arg is None for arg in (json, regex, choice, grammar, + json_object, structural_tag)): return None # Extract json schemas from pydantic models if isinstance(json, (BaseModel, type(BaseModel))): @@ -63,6 +65,7 @@ def from_optional( json_object=json_object, backend=backend, whitespace_pattern=whitespace_pattern, + structural_tag=structural_tag, ) @property @@ -79,6 +82,17 @@ def backend_options(self) -> list[str]: return [] return self.backend.split(":")[1].split(",") + def add_option(self, opt_name: str) -> None: + """Adds an option to the backend options.""" + if not self.backend: + self.backend = f":{opt_name}" + elif ":" not in self.backend: + self.backend += f":{opt_name}" + else: + options = set(self.backend_options()) + options.add(opt_name) + self.backend = f"{self.backend_name}:{','.join(sorted(options))}" + def no_fallback(self) -> bool: """Returns True if the "no-fallback" option is supplied for the guided decoding backend""" @@ -423,6 +437,10 @@ def _verify_args(self) -> None: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + assert isinstance(self.stop_token_ids, list) + if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): + raise ValueError(f"stop_token_ids must contain only integers, " + f"got {self.stop_token_ids}.") assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") diff --git a/vllm/sequence.py b/vllm/sequence.py index 61867b025315..a97409523c94 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,9 +14,9 @@ import msgspec import torch -from vllm.inputs import SingletonInputs, SingletonInputsAdapter +from vllm.inputs import SingletonInputs from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -419,7 +419,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id - self.inputs = SingletonInputsAdapter(inputs) + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request @@ -448,31 +448,29 @@ def n_blocks(self) -> int: @property def prompt(self) -> Optional[str]: - return self.inputs.prompt + return self.inputs.get("prompt") @property def prompt_token_ids(self) -> list[int]: - return self.inputs.prompt_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self.inputs.prompt_embeds + return self.inputs["prompt_token_ids"] @property def token_type_ids(self) -> list[int]: - return self.inputs.token_type_ids + return self.inputs.get("token_type_ids", []) @property - def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.multi_modal_data + def multi_modal_data(self) -> MultiModalKwargs: + if self.inputs["type"] == "multimodal": + return self.inputs["mm_kwargs"] + + return MultiModalKwargs({}) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - return self.inputs.multi_modal_placeholders + if self.inputs["type"] == "multimodal": + return self.inputs["mm_placeholders"] - @property - def mm_processor_kwargs(self) -> dict[str, Any]: - return self.inputs.mm_processor_kwargs + return {} @property def lora_int_id(self) -> int: @@ -723,12 +721,12 @@ def token_type_ids(self) -> Optional[list[int]]: return self.first_seq.token_type_ids @property - def multi_modal_data(self) -> MultiModalDataDict: + def multi_modal_data(self) -> MultiModalKwargs: if self.first_seq.multi_modal_data: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return {} + return MultiModalKwargs({}) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -738,14 +736,6 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return self.encoder_seq.multi_modal_placeholders return {} - @property - def mm_processor_kwargs(self) -> dict[str, Any]: - if self.first_seq.multi_modal_data: - return self.first_seq.mm_processor_kwargs - elif self.encoder_seq is not None: - return self.encoder_seq.mm_processor_kwargs - return {} - @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -969,12 +959,9 @@ class SequenceGroupMetadata( computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) - # "MultiModalDataDict" types. We have to use Any due to msgspec - # doesn't allow to have union of 2 different dicts. token_type_ids: Optional[list[int]] = None - multi_modal_data: Optional[Any] = None + multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - mm_processor_kwargs: Optional[dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[list[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3ad9b4993327..24095ef2a567 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -295,7 +295,7 @@ def execute_model( if not self.is_driver_worker: return [] # Sample the next token. - output = self.model.sample( + output = self.model_runner.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index bc0e0a121cd5..0bb8d602ec8f 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -89,14 +90,14 @@ def init_tensors(self, self._rank = rank if isinstance(device_type, torch.device): device_type = device_type.type - if device_type == 'cuda': - self._copy_stream = torch.cuda.Stream() + stream = current_platform.Stream + if stream is not None: + self._copy_stream = stream() def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: - # currently using cuda.Event, skip for any non_cuda_alike platform - from vllm.platforms import current_platform - if not current_platform.is_cuda_alike(): + # Skip for any platform that doesn't have device Event + if current_platform.Event is None: return None # If a copy was initiated in the previous call, collect and return. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d8d54918fa98..6473740ae512 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -50,11 +50,10 @@ def init_device(self) -> None: def set_include_gpu_probs_tensor(self) -> None: # Need include_gpu_probs_tensor for MultiStepWorker - self.model_runner.model.sampler.include_gpu_probs_tensor = True + self.model_runner.sampler.include_gpu_probs_tensor = True def set_should_modify_greedy_probs_inplace(self) -> None: - self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( - True) + self.model_runner.sampler.should_modify_greedy_probs_inplace = True @torch.inference_mode() def sampler_output( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a724beade129..4e79003de391 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -410,9 +410,9 @@ def _configure_model_sampler_for_spec_decode(self): NOTE(cade): This will require a special check if the proposer worker does not have a sampler (e.g. ngram speculation). """ - (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + (self.scorer_worker.model_runner.sampler.include_gpu_probs_tensor ) = True - (self.scorer_worker.model_runner.model.sampler. + (self.scorer_worker.model_runner.sampler. should_modify_greedy_probs_inplace) = True self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_should_modify_greedy_probs_inplace() diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fe0319c9b033..e062afd68208 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -33,10 +33,10 @@ EAGLEConfig, ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, - MedusaConfig, MllamaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, NVLM_D_Config, - Olmo2Config, RWConfig, + KimiVLConfig, MedusaConfig, + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + NVLM_D_Config, RWConfig, SkyworkR1VChatConfig, SolarConfig, Telechat2Config, UltravoxConfig) # yapf: enable @@ -62,6 +62,7 @@ "cohere2": Cohere2Config, "dbrx": DbrxConfig, "deepseek_vl_v2": DeepseekVLV2Config, + "kimi_vl": KimiVLConfig, "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) @@ -74,7 +75,6 @@ "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, - "olmo2": Olmo2Config, "solar": SolarConfig, "skywork_chat": SkyworkR1VChatConfig, "telechat": Telechat2Config, @@ -220,8 +220,7 @@ def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None: logger.warning("Replacing legacy rope_type 'mrope' with 'default'") -def uses_mrope(config: PretrainedConfig) -> bool: - """Detect if the model with this config uses M-ROPE.""" +def _uses_mrope(config: PretrainedConfig) -> bool: rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is None: return False @@ -229,6 +228,24 @@ def uses_mrope(config: PretrainedConfig) -> bool: return "mrope_section" in rope_scaling +def uses_mrope(config: PretrainedConfig) -> bool: + """Detect if the model with this config uses M-ROPE.""" + return _uses_mrope(config) or thinker_uses_mrope(config) + + +def thinker_uses_mrope(config: PretrainedConfig) -> bool: + """Detect if the model contains a thinker config and it uses M-ROPE.""" + thinker_config = getattr(config, "thinker_config", None) + if thinker_config is None: + return False + + thinker_text_config = getattr(thinker_config, "text_config", None) + if thinker_text_config is None: + return False + + return uses_mrope(thinker_text_config) + + def is_encoder_decoder(config: PretrainedConfig) -> bool: """Detect if the model with this config is used as an encoder/decoder.""" text_config = getattr(config, "text_config", None) @@ -633,6 +650,11 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], config_file_name = "params.json" config_dict = get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists.") assert isinstance(config_dict, dict) config_mapping = { @@ -671,6 +693,9 @@ def recurse_elems(elem: Any): "quant_method": "fp8", "activation_scheme": "static" } + elif quantization.get("quant_method") == "compressed-tensors": + # Pass through the quantization config to compressed-tensors + quantization_config = quantization else: raise ValueError( f"Found unknown quantization='{quantization}' in config") @@ -688,6 +713,7 @@ def recurse_elems(elem: Any): if config_type == "multimodal": multimodal_config = config_dict.pop("vision_encoder") + quantization_config = config_dict.get("quantization_config", {}) config_dict = { "text_config": config_dict, @@ -695,6 +721,8 @@ def recurse_elems(elem: Any): } config_dict["architectures"] = ["PixtralForConditionalGeneration"] config_dict["model_type"] = "pixtral" + if quantization_config: + config_dict["quantization_config"] = quantization_config config_dict.update(kwargs) @@ -732,14 +760,22 @@ def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ - if hasattr(config, "text_config"): + # This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517 + if hasattr(config, "thinker_config"): + # TODO(suyang.fy): Refactor code. + # For Qwen2.5-Omni, change hf_text_config to + # thinker_config.text_config. + return config.thinker_config.text_config + + text_config = config.get_text_config() + + if text_config is not config: # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early # if transformers config doesn't align with this assumption. - assert hasattr(config.text_config, "num_attention_heads") - return config.text_config - else: - return config + assert hasattr(text_config, "num_attention_heads") + + return text_config def try_get_generation_config( diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 53699341bfba..8812d4c484b1 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -13,13 +13,14 @@ from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig +from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config -from vllm.transformers_utils.configs.olmo2 import Olmo2Config from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config @@ -40,9 +41,10 @@ "ExaoneConfig", "MllamaConfig", "MLPSpeculatorConfig", + "MoonViTConfig", + "KimiVLConfig", "NemotronConfig", "NVLM_D_Config", - "Olmo2Config", "SkyworkR1VChatConfig", "SolarConfig", "Telechat2Config", diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py new file mode 100644 index 000000000000..97ff44bb9c1c --- /dev/null +++ b/vllm/transformers_utils/configs/kimi_vl.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig + +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config +from vllm.transformers_utils.configs.moonvit import MoonViTConfig + + +class KimiVLConfig(PretrainedConfig): + model_type = "kimi_vl" + + def __init__(self, + vision_config: Optional[Union[dict, MoonViTConfig]] = None, + text_config: Optional[Union[dict, DeepseekV2Config]] = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + **kwargs): + if vision_config is None: + vision_config = MoonViTConfig() + elif isinstance(vision_config, dict): + vision_config = MoonViTConfig(**vision_config) + self.vision_config = vision_config + + if text_config is None: + text_config = DeepseekV2Config() + elif isinstance(text_config, dict): + text_config = DeepseekV2Config(**text_config) + self.text_config = text_config + + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + + super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/vllm/transformers_utils/configs/moonvit.py b/vllm/transformers_utils/configs/moonvit.py new file mode 100644 index 000000000000..a2b4059a63ef --- /dev/null +++ b/vllm/transformers_utils/configs/moonvit.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py +from transformers.configuration_utils import PretrainedConfig + + +class MoonViTConfig(PretrainedConfig): + model_type = "moonvit" + + def __init__( + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = patch_size + # Positional embedding config + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + # Transformer config + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + # Patch merger config + self.merge_kernel_size = merge_kernel_size diff --git a/vllm/transformers_utils/configs/olmo2.py b/vllm/transformers_utils/configs/olmo2.py deleted file mode 100644 index c6e446333b43..000000000000 --- a/vllm/transformers_utils/configs/olmo2.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# yapf: disable -# ruff: noqa: E501 -# coding=utf-8 -# Copied from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/configuration_olmo2.py -"""OLMo 2 configuration.""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class Olmo2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2 - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 50304): - Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Olmo2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 1): - Padding token id. - bos_token_id (`int`, *optional*): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 50279): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is - `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - - ```python - >>> from transformers import Olmo2Model, Olmo2Config - - >>> # Initializing a Olmo2 7B style configuration - >>> configuration = Olmo2Config() - - >>> # Initializing a model from the Olmo2 7B style configuration - >>> model = Olmo2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - """ - - model_type = "olmo2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=50304, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - use_cache=True, - pad_token_id=1, - bos_token_id=None, - eos_token_id=50279, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - rms_norm_eps=1e-5, - **kwargs, - ): - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - - self.rms_norm_eps = rms_norm_eps - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 9d1d4bb92e4a..991d5631e64e 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -8,13 +8,13 @@ from .detokenizer_utils import (convert_prompt_ids_to_tokens, detokenize_incrementally) from .tokenizer import AnyTokenizer -from .tokenizer_group import BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup class Detokenizer: """Provides methods to decode the output of a model into text.""" - def __init__(self, tokenizer_group: BaseTokenizerGroup): + def __init__(self, tokenizer_group: TokenizerGroup): self.tokenizer_group = tokenizer_group def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 1d09b99d50c0..4f06950c42e2 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -111,20 +111,20 @@ def cached_processor_from_config( ) -def get_image_processor( +def get_feature_extractor( processor_name: str, *args: Any, trust_remote_code: bool = False, **kwargs: Any, ): - """Load an image processor for the given model name via HuggingFace.""" + """Load an audio feature extractor for the given model name + via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() - from transformers import AutoImageProcessor - from transformers.image_processing_utils import BaseImageProcessor - + from transformers import AutoFeatureExtractor + from transformers.feature_extraction_utils import FeatureExtractionMixin try: - processor = AutoImageProcessor.from_pretrained( + feature_extractor = AutoFeatureExtractor.from_pretrained( processor_name, *args, trust_remote_code=trust_remote_code, @@ -135,61 +135,75 @@ def get_image_processor( # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors if not trust_remote_code: err_msg = ( - "Failed to load the image processor. If the image processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " + "Failed to load the feature extractor. If the feature " + "extractor is a custom extractor not yet available in the " + "HuggingFace transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " "`--trust-remote-code` flag in the CLI.") raise RuntimeError(err_msg) from e else: raise e + return cast(FeatureExtractionMixin, feature_extractor) - return cast(BaseImageProcessor, processor) +cached_get_feature_extractor = lru_cache(get_feature_extractor) -cached_get_image_processor = lru_cache(get_image_processor) - -def cached_image_processor_from_config( +def cached_feature_extractor_from_config( model_config: "ModelConfig", **kwargs: Any, ): - return cached_get_image_processor( + return cached_get_feature_extractor( model_config.model, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), ) -def get_video_processor( +def get_image_processor( processor_name: str, *args: Any, trust_remote_code: bool = False, **kwargs: Any, ): - """Load a video processor for the given model name via HuggingFace.""" + """Load an image processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor from transformers.image_processing_utils import BaseImageProcessor - processor = get_processor( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs, - ) + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e - return cast(BaseImageProcessor, processor.video_processor) + return cast(BaseImageProcessor, processor) -cached_get_video_processor = lru_cache(get_video_processor) +cached_get_image_processor = lru_cache(get_image_processor) -def cached_video_processor_from_config( +def cached_image_processor_from_config( model_config: "ModelConfig", **kwargs: Any, ): - return cached_get_video_processor( + return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 1bfb50328338..da5bec856662 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import os import warnings from functools import lru_cache @@ -70,18 +71,17 @@ def encode_tokens( def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: - """Get tokenizer with cached properties. - - This will patch the tokenizer object in place. - + """ By default, transformers will recompute multiple tokenizer properties - each time they are called, leading to a significant slowdown. This - function caches these properties for faster access.""" + each time they are called, leading to a significant slowdown. + This proxy caches these properties for faster access. + """ + cached_tokenizer = copy.copy(tokenizer) - tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_ids = tokenizer.all_special_ids + tokenizer_all_special_tokens = tokenizer.all_special_tokens tokenizer_all_special_tokens_extended = ( tokenizer.all_special_tokens_extended) - tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_vocab = tokenizer.get_vocab() tokenizer_len = len(tokenizer) @@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: class CachedTokenizer(tokenizer.__class__): # type: ignore @property - def all_special_ids(self): + def all_special_ids(self) -> list[int]: return tokenizer_all_special_ids @property - def all_special_tokens(self): + def all_special_tokens(self) -> list[str]: return tokenizer_all_special_tokens @property - def all_special_tokens_extended(self): + def all_special_tokens_extended(self) -> list[str]: return tokenizer_all_special_tokens_extended @property - def max_token_id(self): + def max_token_id(self) -> int: return max_token_id - def get_vocab(self): + def get_vocab(self) -> dict[str, int]: return tokenizer_vocab - def __len__(self): + def __len__(self) -> int: return tokenizer_len + def __reduce__(self): + return get_cached_tokenizer, (tokenizer, ) + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" - tokenizer.__class__ = CachedTokenizer - return tokenizer + cached_tokenizer.__class__ = CachedTokenizer + return cached_tokenizer def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index bb5ddaf88b21..b4eb081c9b99 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -2,7 +2,7 @@ import importlib from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -12,17 +12,17 @@ class TokenizerBase(ABC): @property @abstractmethod - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: raise NotImplementedError() @property @abstractmethod - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: raise NotImplementedError() @property @abstractmethod - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: raise NotImplementedError() @property @@ -66,7 +66,7 @@ def __len__(self) -> int: @abstractmethod def __call__( self, - text: Union[str, List[str], List[int]], + text: Union[str, list[str], list[int]], text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, @@ -75,11 +75,11 @@ def __call__( raise NotImplementedError() @abstractmethod - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: raise NotImplementedError() @abstractmethod - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: raise NotImplementedError() @abstractmethod @@ -88,44 +88,44 @@ def encode_one( text: str, truncation: bool = False, max_length: Optional[int] = None, - ) -> List[int]: + ) -> list[int]: raise NotImplementedError() @abstractmethod def encode(self, text: str, - add_special_tokens: Optional[bool] = None) -> List[int]: + add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() @abstractmethod def apply_chat_template(self, - messages: List["ChatCompletionMessageParam"], - tools: Optional[List[Dict[str, Any]]] = None, - **kwargs) -> List[int]: + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs) -> list[int]: raise NotImplementedError() @abstractmethod - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() @abstractmethod def decode(self, - ids: Union[List[int], int], + ids: Union[list[int], int], skip_special_tokens: bool = True) -> str: raise NotImplementedError() @abstractmethod def convert_ids_to_tokens( self, - ids: List[int], + ids: list[int], skip_special_tokens: bool = True, - ) -> List[str]: + ) -> list[str]: raise NotImplementedError() class TokenizerRegistry: # Tokenizer name -> (tokenizer module, tokenizer class) - REGISTRY: Dict[str, Tuple[str, str]] = {} + REGISTRY: dict[str, tuple[str, str]] = {} @staticmethod def register(name: str, module: str, class_name: str) -> None: diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py similarity index 84% rename from vllm/transformers_utils/tokenizer_group/tokenizer_group.py rename to vllm/transformers_utils/tokenizer_group.py index b6e9005bcd24..a829985cb459 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -2,7 +2,7 @@ from typing import List, Optional -from vllm.config import TokenizerPoolConfig +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, get_lora_tokenizer, @@ -10,10 +10,8 @@ get_tokenizer) from vllm.utils import LRUCache -from .base_tokenizer_group import BaseTokenizerGroup - -class TokenizerGroup(BaseTokenizerGroup): +class TokenizerGroup: """A group of tokenizers that can be used for LoRA adapters.""" def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, @@ -27,15 +25,6 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache[int, AnyTokenizer]( capacity=max(max_loras, max_num_seqs) if enable_lora else 0) - @classmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "TokenizerGroup": - return cls(**init_kwargs) - - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - return True - def get_max_input_len(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: @@ -104,3 +93,18 @@ async def get_lora_tokenizer_async( return tokenizer else: return self.lora_tokenizers[lora_request.lora_int_id] + + +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig]): + return TokenizerGroup( + tokenizer_id=model_config.tokenizer, + enable_lora=bool(lora_config), + max_num_seqs=scheduler_config.max_num_seqs, + max_loras=lora_config.max_loras if lora_config else 0, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision, + truncation_side=model_config.truncation_side) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py deleted file mode 100644 index 9d2209575bd3..000000000000 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional, Type - -from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, TokenizerPoolConfig) -from vllm.executor.ray_utils import ray - -from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup -from .tokenizer_group import TokenizerGroup - -if ray: - from .ray_tokenizer_group import RayTokenizerGroupPool -else: - RayTokenizerGroupPool = None # type: ignore - - -def init_tokenizer_from_configs(model_config: ModelConfig, - scheduler_config: SchedulerConfig, - parallel_config: ParallelConfig, - lora_config: Optional[LoRAConfig]): - init_kwargs = dict(tokenizer_id=model_config.tokenizer, - enable_lora=bool(lora_config), - max_num_seqs=scheduler_config.max_num_seqs, - max_loras=lora_config.max_loras if lora_config else 0, - max_input_length=None, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=model_config.truncation_side) - - return get_tokenizer_group(parallel_config.tokenizer_pool_config, - **init_kwargs) - - -def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> BaseTokenizerGroup: - tokenizer_cls: Type[BaseTokenizerGroup] - if tokenizer_pool_config is None: - tokenizer_cls = TokenizerGroup - elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass( - tokenizer_pool_config.pool_type, BaseTokenizerGroup): - tokenizer_cls = tokenizer_pool_config.pool_type - elif tokenizer_pool_config.pool_type == "ray": - if RayTokenizerGroupPool is None: - raise ImportError( - "RayTokenizerGroupPool is not available. Please install " - "the ray package to use the Ray tokenizer group pool.") - tokenizer_cls = RayTokenizerGroupPool - else: - raise ValueError( - f"Unknown pool type: {tokenizer_pool_config.pool_type}") - return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) - - -__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py deleted file mode 100644 index c5108a7fc6eb..000000000000 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod -from typing import List, Optional - -from vllm.config import TokenizerPoolConfig -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import AnyTokenizer - - -class BaseTokenizerGroup(ABC): - """A group of tokenizers that can be used for LoRA adapters.""" - - @classmethod - @abstractmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "BaseTokenizerGroup": - pass - - @abstractmethod - def ping(self) -> bool: - """Check if the tokenizer group is alive.""" - pass - - @abstractmethod - def get_max_input_len( - self, - lora_request: Optional[LoRARequest] = None, - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - pass - - @abstractmethod - def encode(self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group.""" - pass - - @abstractmethod - async def encode_async( - self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group.""" - pass - - @abstractmethod - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get a tokenizer for a LoRA request.""" - pass - - @abstractmethod - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get a tokenizer for a LoRA request.""" - pass - - def check_health(self): - """Raise exception if the tokenizer group is unhealthy.""" - return diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py deleted file mode 100644 index b048b8094174..000000000000 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import os -from typing import List, Optional - -try: - from ray.exceptions import ActorDiedError # type: ignore -except ImportError: - # For older versions of Ray - from ray.exceptions import RayActorError as ActorDiedError # type: ignore -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - -from vllm.config import TokenizerPoolConfig -from vllm.executor.ray_utils import ray -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import AnyTokenizer - -from .base_tokenizer_group import BaseTokenizerGroup -from .tokenizer_group import TokenizerGroup - -logger = init_logger(__name__) - - -class RayTokenizerGroupPool(BaseTokenizerGroup): - """A Ray-based pool of TokenizerGroups for async tokenization.""" - - # Class to use for workers making up the pool. - _worker_cls = TokenizerGroup - - @classmethod - def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], - **init_kwargs) -> "RayTokenizerGroupPool": - if not tokenizer_pool_config: - raise ValueError("tokenizer_pool_config must not be None.") - ray_actor_options = (tokenizer_pool_config.extra_config or { - "num_cpus": 0 - }) - ray_actor_options.setdefault( - "scheduling_strategy", - NodeAffinitySchedulingStrategy( - node_id=ray.get_runtime_context().get_node_id(), soft=True)) - - # Carry over the env vars to the actors. - # This is necessary for API keys and such. - ray_actor_options.setdefault("runtime_env", {}) - _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) - - init_kwargs["num_actors"] = tokenizer_pool_config.pool_size - init_kwargs["ray_actor_options"] = ray_actor_options - - return cls(**init_kwargs) - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], num_actors: int, - ray_actor_options: dict, **tokenizer_config): - # Store a local copy of the TokenizerGroup for quick access - # to underlying HF tokenizers. - self._tokenizer_config = { - "tokenizer_id": tokenizer_id, - "enable_lora": enable_lora, - "max_num_seqs": max_num_seqs, - "max_input_length": max_input_length, - **tokenizer_config - } - self._local_tokenizer_group = self._worker_cls( - **self._tokenizer_config, ) - - self._ray_tokenizer_group_cls = ray.remote( - self._worker_cls).options(**ray_actor_options) # type: ignore - self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] - self._idle_actors: Optional[asyncio.Queue] = None - - # If set, actor is unhealthy. Will reraise on the next - # check_health call. - self._exception: Optional[ActorDiedError] = None - - def _init_actor(self) -> ray.ObjectRef: - return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config) - - @property - def pool_size(self) -> int: - return len(self.tokenizer_actors) - - def ping(self): - return ray.get([ - actor.ping.remote() # type: ignore - for actor in self.tokenizer_actors - ]) - - def _ensure_queue_initialized(self): - if self._idle_actors is None: - self._idle_actors = asyncio.Queue() - for actor in self.tokenizer_actors: - self._idle_actors.put_nowait(actor) - - def _finalize_encode(self, actor: ray.ObjectRef, - original_actor: ray.ObjectRef, actor_is_alive: bool): - assert self._idle_actors is not None - # Cleanup the dead actor. - if not actor_is_alive or original_actor is not actor: - self.tokenizer_actors.remove(original_actor) - if actor_is_alive: - # Put the actor back in the queue. - # This is done in a finally block to ensure that the actor is - # always put back in the queue, even if an exception/cancellation - # is raised. - self._idle_actors.put_nowait(actor) - # Add back the new actor. - if original_actor is not actor: - self.tokenizer_actors.append(actor) - - def encode(self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group. - - We pick an idle actor and use it to encode the prompt. - The actor is then put back in the queue for future use. - This is blocking. - """ - self.check_health() - self._ensure_queue_initialized() - assert self._idle_actors is not None - - if self._idle_actors.empty(): - raise RuntimeError("No idle actors available.") - actor = self._idle_actors.get_nowait() - actor_is_alive = True - original_actor = actor - try: - ret = ray.get( - actor.encode.remote(prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens)) - except ActorDiedError as e: - # If the actor is dead, we first try to reinitialize it. - logger.warning("%s died with ActorDiedError, reinitializing.", - actor, - exc_info=e) - actor = self._init_actor() - try: - ret = ray.get( - actor.encode.remote(prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens)) - except ActorDiedError as e: - logger.error( - "%s died for second time in a row, marking " - "RayTokenizerGroupPool as unhealthy.", actor) - actor_is_alive = False - if not self._exception: - self._exception = e - self.check_health() - finally: - self._finalize_encode(actor, original_actor, actor_is_alive) - return ret - - async def encode_async( - self, - prompt: str, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> List[int]: - """Encode a prompt using the tokenizer group. - - We pick an idle actor and use it to encode the prompt. - If there are no idle actors, we wait until one becomes - available. - The actor is then put back in the queue for future use. - This is non-blocking. - """ - self.check_health() - self._ensure_queue_initialized() - assert self._idle_actors is not None - - actor = await self._idle_actors.get() - actor_is_alive = True - original_actor = actor - try: - ret = await actor.encode.remote( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) - except ActorDiedError as e: - # If the actor is dead, we first try to reinitialize it. - logger.warning("%s died with ActorDiedError, reinitializing.", - actor, - exc_info=e) - actor = self._init_actor() - try: - ret = await actor.encode.remote( - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) - except ActorDiedError as e: - logger.error( - "%s died for second time in a row, marking " - "RayTokenizerGroupPool as unhealthy.", actor) - actor_is_alive = False - if not self._exception: - self._exception = e - self.check_health() - finally: - self._finalize_encode(actor, original_actor, actor_is_alive) - return ret - - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - return self._local_tokenizer_group.get_max_input_len(lora_request) - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self._local_tokenizer_group.get_lora_tokenizer(lora_request) - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return await self._local_tokenizer_group.get_lora_tokenizer_async( - lora_request) - - def check_health(self): - if self._exception: - raise RuntimeError( - "TokenizerGroupPool is unhealthy.") from self._exception - - -def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: - """Copy over all current process environment variables to the runtime_env. - - The variables in runtime_env will take precedence over the current process - environment variables. - - runtime_env will be modified in place.""" - env_vars = os.environ.copy() - runtime_env.setdefault("env_vars", {}) - env_vars.update(runtime_env["env_vars"]) - runtime_env["env_vars"] = env_vars diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 58a114fa3a32..296149a45695 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -257,7 +257,7 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str, # the following attributes are set to fit vLLM's design and are used # by the guided structured output backends. @property - def all_special_tokens_extended(self) -> List[str]: + def all_special_tokens_extended(self) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens # tekken defines its own extended special tokens list @@ -271,11 +271,11 @@ def all_special_tokens_extended(self) -> List[str]: ] @property - def all_special_tokens(self) -> List[str]: + def all_special_tokens(self) -> list[str]: return self.all_special_tokens_extended @property - def all_special_ids(self) -> List[int]: + def all_special_ids(self) -> list[int]: return [ self.all_special_tokens.index(t) for t in self.all_special_tokens ] @@ -335,12 +335,12 @@ def __call__( input_ids = self.encode_one(text, truncation, max_length) return Encoding(input_ids=input_ids) - def get_vocab(self) -> Dict[str, int]: + def get_vocab(self) -> dict[str, int]: # NB: the dictionary form of the vocabulary collapses token ids that map # to the same string but have different bytes return self._vocab_dict - def get_added_vocab(self) -> Dict[str, int]: + def get_added_vocab(self) -> dict[str, int]: # Mistral tokenizers have no added vocabulary return {} diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 43918bcd7c55..bffc56a2e75c 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -2,4 +2,4 @@ from vllm.triton_utils.importing import HAS_TRITON -__all__ = ["HAS_TRITON"] \ No newline at end of file +__all__ = ["HAS_TRITON"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index a20700248c26..fa29efbf6b2d 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,17 +1,53 @@ # SPDX-License-Identifier: Apache-2.0 +import sys +import types from importlib.util import find_spec from vllm.logger import init_logger -from vllm.platforms import current_platform logger = init_logger(__name__) HAS_TRITON = ( find_spec("triton") is not None - and not current_platform.is_xpu() # Not compatible + or find_spec("pytorch-triton-xpu") is not None # Not compatible ) if not HAS_TRITON: logger.info("Triton not installed or not compatible; certain GPU-related" " functions will not be available.") + + class TritonPlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton") + self.jit = self._dummy_decorator("jit") + self.autotune = self._dummy_decorator("autotune") + self.heuristics = self._dummy_decorator("heuristics") + self.language = TritonLanguagePlaceholder() + logger.warning_once( + "Triton is not installed. Using dummy decorators. " + "Install it via `pip install triton` to enable kernel" + "compilation.") + + def _dummy_decorator(self, name): + + def decorator(func=None, **kwargs): + if func is None: + return lambda f: f + return func + + return decorator + + class TritonLanguagePlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.language") + self.constexpr = None + self.dtype = None + + sys.modules['triton'] = TritonPlaceholder() + sys.modules['triton.language'] = TritonLanguagePlaceholder() + +if 'triton' in sys.modules: + logger.info("Triton module has been replaced with a placeholder.") diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 2ee3f9104d19..67b834533b7d 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -168,12 +169,20 @@ def _report_usage_once(self, model_architecture: str, # Platform information from vllm.platforms import current_platform if current_platform.is_cuda_alike(): - device_property = torch.cuda.get_device_properties(0) - self.gpu_count = torch.cuda.device_count() - self.gpu_type = device_property.name - self.gpu_memory_per_device = device_property.total_memory + self.gpu_count = cuda_device_count_stateless() + self.gpu_type, self.gpu_memory_per_device = ( + cuda_get_device_properties(0, ("name", "total_memory"))) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda + if current_platform.is_tpu(): + try: + import torch_xla + self.gpu_count = torch_xla.runtime.world_size() + self.gpu_type = torch_xla.tpu.get_tpu_type() + self.gpu_memory_per_device = ( + torch_xla.core.xla_model.get_memory_info()["bytes_limit"]) + except Exception: + pass self.provider = _detect_cloud_provider() self.architecture = platform.machine() self.platform = platform.platform() diff --git a/vllm/utils.py b/vllm/utils.py index c2aad04941b8..73726bb9a346 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -38,11 +38,13 @@ from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, Iterable, Iterator, KeysView, Mapping) +from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, Tuple, Type, TypeVar, Union, cast, overload) + Optional, Sequence, Tuple, Type, TypeVar, Union, cast, + overload) from uuid import uuid4 import cachetools @@ -61,6 +63,9 @@ from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs +# NOTE: import triton_utils to make TritonPlaceholderModule work +# if triton is unavailable +import vllm.triton_utils # noqa: F401 from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: @@ -236,6 +241,12 @@ def hit_ratio(self) -> float: return self.hits / self.total + def __sub__(self, other: CacheInfo): + return CacheInfo( + hits=self.hits - other.hits, + total=self.total - other.total, + ) + class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): @@ -243,15 +254,26 @@ def __init__(self, capacity: float, getsizeof: Optional[Callable[[_V], float]] = None): super().__init__(capacity, getsizeof) + self.pinned_items = set[_K]() - self.capacity = capacity self._hits = 0 self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: + value = super().__getitem__(key) + + if update_info: + self._hits += 1 + self._total += 1 + + return value def __delitem__(self, key: _K) -> None: run_on_remove = key in self - value = self.__getitem__(key) + value = self.__getitem__(key, + update_info=False) # type: ignore[call-arg] super().__delitem__(key) if key in self.pinned_items: # Todo: add warning to inform that del pinned item @@ -271,11 +293,38 @@ def order(self) -> Mapping[_K, None]: """Return the internal order dictionary (read-only).""" return MappingProxyType(self._LRUCache__order) # type: ignore - def stat(self) -> CacheInfo: - return CacheInfo(hits=self._hits, total=self._total) + @property + def capacity(self) -> float: + return self.maxsize + + @property + def usage(self) -> float: + if self.maxsize == 0: + return 0 + + return self.currsize / self.maxsize + + def stat(self, *, delta: bool = False) -> CacheInfo: + """ + Gets the cumulative number of hits and queries against this cache. + + If :code:`delta=True`, instead gets these statistics + since the last call that also passed :code:`delta=True`. + """ + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info def touch(self, key: _K) -> None: - self._LRUCache__update(key) # type: ignore + try: + self._LRUCache__order.move_to_end(key) # type: ignore + except KeyError: + self._LRUCache__order[key] = None # type: ignore @overload def get(self, key: _K, /) -> Optional[_V]: @@ -292,7 +341,8 @@ def get(self, _T]] = None) -> Optional[Union[_V, _T]]: value: Optional[Union[_V, _T]] if key in self: - value = self.__getitem__(key) + value = self.__getitem__( + key, update_info=False) # type: ignore[call-arg] self._hits += 1 else: @@ -317,8 +367,9 @@ def pop(self, if key not in self: return default - value = self[key] - del self[key] + value = self.__getitem__(key, + update_info=False) # type: ignore[call-arg] + self.__delitem__(key) return value def put(self, key: _K, value: _V) -> None: @@ -353,10 +404,6 @@ def _remove_old_if_needed(self) -> None: while self.currsize > self.capacity: self.remove_oldest() - def clear(self) -> None: - while len(self) > 0: - self.remove_oldest(remove_pinned=True) - def popitem(self, remove_pinned: bool = False): """Remove and return the `(key, value)` pair least recently used.""" if not remove_pinned: @@ -372,6 +419,14 @@ def popitem(self, remove_pinned: bool = False): value = self.pop(cast(_K, lru_key)) return (lru_key, value) + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + class PyObjectCache: """Used to cache python objects to avoid object allocations @@ -578,12 +633,12 @@ def get_open_port() -> int: process. Currently it uses 2 ports. """ if "VLLM_DP_MASTER_PORT" in os.environ: - dp_port = envs.VLLM_DP_MASTER_PORT + dp_master_port = envs.VLLM_DP_MASTER_PORT + reserved_port_range = range(dp_master_port, dp_master_port + 10) while True: - port = _get_open_port() - if dp_port <= port < dp_port + 10: - continue - return port + candidate_port = _get_open_port() + if candidate_port not in reserved_port_range: + return candidate_port return _get_open_port() @@ -710,21 +765,28 @@ def create_kv_caches_with_random_flash( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = None, device: Optional[str] = "cuda", + cache_layout: Optional[str] = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, + 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] + for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_value_cache = torch.empty(size=key_value_cache_shape, + key_value_cache = torch.empty(size=kv_cache_allocation_shape, dtype=torch_dtype, - device=device) + device=device).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) elif cache_dtype == 'fp8': @@ -1185,6 +1247,22 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def cuda_get_device_properties(device, + names: Sequence[str], + init_cuda=False) -> tuple[Any, ...]: + """Get specified CUDA device property values without initializing CUDA in + the current process.""" + if init_cuda or cuda_is_initialized(): + props = torch.cuda.get_device_properties(device) + return tuple(getattr(props, name) for name in names) + + # Run in subprocess to avoid initializing CUDA as a side effect. + mp_ctx = multiprocessing.get_context("fork") + with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: + return executor.submit(cuda_get_device_properties, device, names, + True).result() + + def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b4c7708daab9..f5ad2334bd19 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,12 +10,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -23,7 +25,8 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_cuda(): - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) logger = init_logger(__name__) @@ -63,10 +66,6 @@ def get_kv_cache_shape( raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return use_cascade_attention(*args, **kwargs) - @dataclass class FlashAttentionMetadata: @@ -93,6 +92,10 @@ class FlashAttentionMetadata: prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -104,6 +107,7 @@ class LocalAttentionMetadata: local_block_table: torch.Tensor local_max_query_len: int local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] local_attn_metadata: Optional[LocalAttentionMetadata] = None @@ -274,10 +278,34 @@ def make_local_attention_virtual_batches( block_table_local +def _get_sliding_window_configs( + vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + """Get the set of all sliding window configs used in the model.""" + sliding_window_configs: set[Optional[tuple[int, int]]] = set() + layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, FlashAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + return sliding_window_configs + + class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): + model_config = runner.model_config + self.runner = runner + self.num_heads_q = model_config.get_num_attention_heads( + runner.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + runner.parallel_config) + self.headdim = model_config.get_head_size() + self.page_size = self.runner.block_size + + self.aot_schedule = (get_flash_attn_version() == 3) + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: Optional[tuple[int, int]] = None def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -296,6 +324,40 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if self.aot_schedule: + sliding_window_configs = _get_sliding_window_configs( + self.runner.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads_q, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + window_size=self.aot_sliding_window, + ) + return None + # for local attention local_attn_metadata = None if self.runner.attention_chunk_size is not None: @@ -307,18 +369,31 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, block_table, self.runner.block_size, ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + local_scheduler_metadata = schedule( + batch_size=local_query_start_loc.shape[0] - 1, + cu_query_lens=local_query_start_loc, + max_query_len=local_max_query_len, + seqlens=local_seqused_k, + max_seq_len=local_max_seq_len, + causal=True) + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=torch.from_numpy( - virt_q_cu_seqlens_np).to(self.runner.device, - non_blocking=True), - local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True), + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, local_block_table=virt_block_table, - local_max_query_len=seqlens_q_local_np.max(), - local_max_seq_len=virt_k_seqlens_np.max(), + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=local_scheduler_metadata, ) use_cascade = common_prefix_len > 0 + if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, @@ -330,10 +405,31 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False) + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - + common_prefix_len, + causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None + prefix_scheduler_metadata = None + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=True) attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -345,13 +441,18 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + class FlashAttentionImpl(AttentionImpl): @@ -491,12 +592,14 @@ def forward( max_seqlen_q = local_metadata.local_max_query_len max_seqlen_k = local_metadata.local_max_seq_len block_table = local_metadata.local_block_table + scheduler_metadata = local_metadata.local_scheduler_metadata else: cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -515,6 +618,7 @@ def forward( window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), @@ -543,6 +647,8 @@ def forward( block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, @@ -636,6 +742,8 @@ def cascade_attention( block_table: torch.Tensor, common_prefix_len: int, fa_version: int, + prefix_scheduler_metadata: Optional[torch.Tensor] = None, + suffix_scheduler_metadata: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -667,6 +775,7 @@ def cascade_attention( block_table=block_table[:1], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, @@ -693,6 +802,7 @@ def cascade_attention( block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py new file mode 100755 index 000000000000..bce446bd2b82 --- /dev/null +++ b/vllm/v1/attention/backends/flashinfer.py @@ -0,0 +1,638 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashInfer.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch +from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper) + +import vllm.envs as envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionType) +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_vllm_config) +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import use_cascade_attention + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 + +logger = init_logger(__name__) + + +class FlashInferBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128, 256] + + @staticmethod + def get_name() -> str: + return "FLASHINFER_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type[FlashInferImpl]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> type[FlashInferMetadata]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> type[FlashInferMetadataBuilder]: + return FlashInferMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +@dataclass +class FlashInferMetadata: + + num_actual_tokens: int # Number of tokens excluding padding. + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + qo_indptr: torch.Tensor + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: torch.Tensor + # The page indices of the paged kv cache + paged_kv_indices: torch.Tensor + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: torch.Tensor + # The number of query/output heads + num_qo_heads: int + # The number of key/value heads + num_kv_heads: int + # The dimension of the attention heads + head_dim: int + # Block size of vllm + page_size: int + # The data type of the paged kv cache + data_type: torch.dtype + # The data type of the query + q_data_type: torch.dtype + + slot_mapping: torch.Tensor + + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + # For cascade attention. + use_cascade: bool + shared_qo_indptr: Optional[torch.Tensor] = None + shared_kv_page_indptr: Optional[torch.Tensor] = None + shared_kv_page_indices: Optional[torch.Tensor] = None + shared_kv_last_page_len: Optional[torch.Tensor] = None + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + + @property + def query_start_loc(self): + # The GPUModelRunner expects to be able to access this property. + return self.qo_indptr + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + +class FlashInferMetadataBuilder: + + def __init__(self, runner: GPUModelRunner): + self.runner = runner + self._workspace_buffer = None + self._prefill_wrapper = None # Wrapper for prefill/append + self._decode_wrapper = None # Wrapper for decode + self._cascade_wrapper = None # Wrapper for cascade attention + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + + def reorder_batch(self, input_batch: InputBatch, + scheduler_output: SchedulerOutput) -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + def _get_cascade_wrapper(self): + if self._cascade_wrapper is None: + self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( + 2, self._get_workspace_buffer(), "NHD") + return self._cascade_wrapper + + def _plan(self, attn_metadata: FlashInferMetadata): + if self.global_hyperparameters is None: + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + if attn_metadata.use_cascade: + attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + attn_metadata.cascade_wrapper.plan( + [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], + [ + attn_metadata.shared_kv_page_indptr, + attn_metadata.paged_kv_indptr + ], + [ + attn_metadata.shared_kv_page_indices, + attn_metadata.paged_kv_indices + ], + [ + attn_metadata.shared_kv_last_page_len, + attn_metadata.paged_kv_last_page_len + ], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + ) + else: + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + if self._num_prefills > 0: + # Decodes are first so prefills start after the last decode + prefill_start = self._num_decodes + attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + assert attn_metadata.qo_indptr[prefill_start:].shape[ + 0] == self._num_prefills + 1 + assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ + 0] == self._num_prefills + 1 + assert attn_metadata.paged_kv_last_page_len[ + prefill_start:].shape[0] == self._num_prefills + # Since prefill_wrapper.run() will be called with + # query[num_decode_tokens:] we need to adjust the qo_indptr + # to be relative to the start of the prefill queries. + qo_indptr = attn_metadata.qo_indptr[ + prefill_start:] - attn_metadata.qo_indptr[prefill_start] + attn_metadata.prefill_wrapper.plan( + qo_indptr, + attn_metadata.paged_kv_indptr[prefill_start:], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[prefill_start:], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + causal=True, + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) + + if self._num_decodes > 0: + attn_metadata.decode_wrapper = self._get_decode_wrapper() + attn_metadata.decode_wrapper.plan( + attn_metadata.paged_kv_indptr[:self._num_decodes + 1], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[:self._num_decodes], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.data_type, + ) + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int): + assert self._num_decodes + self._num_prefills == num_reqs + assert (self._num_decode_tokens + + self._num_prefill_tokens == num_actual_tokens) + page_size = self.runner.block_size + device = self.runner.device + qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( + self.runner.device, non_blocking=True) + seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, + non_blocking=True) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + self.runner.device, non_blocking=True).long() + + block_table_bounds = (seq_lens + page_size - 1) // page_size + + use_cascade = common_prefix_len > 0 + if use_cascade: + # Grab the blocks of the shared prefix from the first request. + assert common_prefix_len % page_size == 0 + num_common_kv_blocks = common_prefix_len // page_size + shared_qo_indptr = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=device) + shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device=device) + shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_last_page_len = torch.tensor([page_size], + dtype=torch.int32, + device=device) + # Remove the blocks of the shared prefix from all requests. + block_table = block_table[:, num_common_kv_blocks:] + block_table_bounds -= num_common_kv_blocks + else: + shared_qo_indptr = None + shared_kv_page_indptr = None + shared_kv_page_indices = None + shared_kv_last_page_len = None + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, + dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + + attn_metadata = FlashInferMetadata( + num_actual_tokens=num_actual_tokens, + qo_indptr=qo_indptr, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.runner.num_query_heads, + num_kv_heads=self.runner.num_kv_heads, + head_dim=self.runner.head_size, + page_size=page_size, + data_type=self.runner.kv_cache_dtype, + q_data_type=self.runner.dtype, + slot_mapping=slot_mapping, + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + num_prefills=self._num_prefills, + num_prefill_tokens=self._num_prefill_tokens, + use_cascade=use_cascade, + shared_qo_indptr=shared_qo_indptr, + shared_kv_page_indptr=shared_kv_page_indptr, + shared_kv_page_indices=shared_kv_page_indices, + shared_kv_last_page_len=shared_kv_last_page_len, + ) + + self._plan(attn_metadata) + + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + if self.runner.kv_cache_dtype != self.runner.model_config.dtype: + # TODO: The cascade wrapper currently does not support setting + # kv cache dtype to something different from query dtype. + return False + return use_cascade_attention(*args, **kwargs) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashInfer. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + window_left = (self.sliding_window[0] + if self.sliding_window is not None else -1) + + # Inputs and outputs may be padded for CUDA graphs + query = query[:num_actual_tokens] + output_padded = output + output = output[:num_actual_tokens] + + if attn_metadata.use_cascade: + # Cascade attention (rare case). + assert attn_metadata.cascade_wrapper is not None + output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) + return output + + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + # Regular attention (common case). + # Decodes are at the front and prefills are at the back, + # according to reorder_batch() + if prefill_wrapper := attn_metadata.prefill_wrapper: + prefill_query = query[num_decode_tokens:] + assert prefill_query.shape[0] == num_prefill_tokens + assert prefill_wrapper is not None + assert prefill_wrapper._causal + assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( + prefill_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) + + if decode_wrapper := attn_metadata.decode_wrapper: + decode_query = query[:num_decode_tokens] + assert decode_query.shape[0] == num_decode_tokens + assert decode_wrapper is not None + assert decode_wrapper._window_left == window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert decode_wrapper._sm_scale == self.scale + decode_wrapper.run( + decode_query, + kv_cache, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) + + return output_padded diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8c7179ba0a8a..e6e483bae2bc 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,7 +195,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) +from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, @@ -203,13 +205,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -248,10 +251,6 @@ def get_kv_cache_shape( def get_supported_head_sizes() -> list[int]: return [576] - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - @dataclass class MLACommonPrefillMetadata: @@ -350,6 +349,14 @@ def __init__(self, model_config = runner.model_config cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.num_heads = model_config.get_num_attention_heads( + runner.parallel_config) + self.mla_dims = get_mla_dims(model_config) + self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + + # Dont try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -375,7 +382,6 @@ def __init__(self, dtype=model_config.dtype, device=runner.device, ) - self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -415,20 +421,18 @@ def reorder_batch(self, input_batch: "InputBatch", # the above loop num_decodes = len(decodes) num_prefills = len(prefills) - first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: break + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this @@ -466,7 +470,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(device, non_blocking=True) - max_query_len = seq_lens_cpu.max().item() prefill_metadata = None if self._num_prefills > 0: @@ -477,6 +480,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, num_computed_tokens_cpu_tensor[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ @@ -539,8 +544,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], - query_start_loc=query_start_loc[reqs_start:] - - query_start_loc[reqs_start], + query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, ) @@ -566,6 +570,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode=decode_metadata, ) + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ @@ -630,11 +637,56 @@ def __init__( # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -747,16 +799,11 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -803,15 +850,10 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = self.flash_attn_varlen_func( + output = self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=attn_metadata.prefill.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc, max_seqlen_q=attn_metadata.prefill.max_query_len, @@ -835,12 +877,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 3e8149a24ebf..05b97172bc6c 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -10,7 +10,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.utils import cdiv logger = init_logger(__name__) @@ -50,6 +52,19 @@ def swap_blocks( ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") + # In recent TPU generations, up to v6e, the SMEM size is 1MB. The + # block_tables within the PallasMetadata constitute almost the entire SMEM + # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here + # we simply make sure that the size is smaller than half of SMEM capacity. + @staticmethod + def get_min_page_size(vllm_config: VllmConfig) -> int: + max_num_page_per_req = (1024 * 1024 // 2 // + vllm_config.scheduler_config.max_num_seqs // 4) + min_page_size = cdiv(vllm_config.model_config.max_model_len, + max_num_page_per_req) + min_page_size = 1 << (min_page_size - 1).bit_length() + return min_page_size + @dataclass class PallasMetadata: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 33761cf7f9c0..0830d8433d89 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -25,7 +25,7 @@ def __init__( max_model_len: int, enable_caching: bool = True, caching_hash_algo: str = "builtin", - num_preallocate_tokens: int = 64, + use_eagle: bool = False, log_stats: bool = False, ) -> None: assert len(kv_cache_config.kv_cache_groups) == 1, ( @@ -39,24 +39,12 @@ def __init__( self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash - # FIXME: make prefix cache stats conditional on log_stats + self.use_eagle = use_eagle self.log_stats = log_stats - # NOTE(woosuk): To avoid frequent block allocation, we preallocate some - # blocks for each request. For example, when a request reaches the end - # of its block table, we preallocate N blocks in advance. This way, we - # reduce the overhead of updating free_block_ids and ref_cnts for each - # request every step (at the cost of some memory waste). - # NOTE(woosuk): This is different from the "lookahead" slots since this - # does not guarantee that the request always has N empty blocks. After - # the request gets N empty blocks, it starts to use the blocks without - # further allocation. When it uses up all the N empty blocks, it gets - # N new empty blocks. - self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, - self.block_size) + # FIXME: make prefix cache stats conditional on log_stats + self.prefix_cache_stats = PrefixCacheStats() if log_stats else None self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, @@ -79,7 +67,6 @@ def __init__( # This is only used to track the RUNNING requests, we do not track the # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.prefix_cache_stats = PrefixCacheStats() @property def usage(self) -> float: @@ -90,12 +77,14 @@ def usage(self) -> float: """ return self.block_pool.get_usage() - def make_prefix_cache_stats(self) -> PrefixCacheStats: + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: """Get (and reset) the prefix cache stats. Returns: - The current prefix caching stats. + The current prefix caching stats, or None if logging is disabled. """ + if not self.log_stats: + return None stats = self.prefix_cache_stats self.prefix_cache_stats = PrefixCacheStats() return stats @@ -125,7 +114,9 @@ def get_computed_blocks( self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes - self.prefix_cache_stats.requests += 1 + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: return [], 0 @@ -145,8 +136,18 @@ def get_computed_blocks( computed_blocks = ( self.specialized_manager.find_longest_cache_hit(block_hashes)) - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) + + if self.use_eagle and len(computed_blocks) > 0: + # Drop the last matched block if (1) eagle is enabled and + # (2) there is a cache hit. + # This is to recompute the last block to get the required + # hidden states for eagle drafting head. + computed_blocks.pop() + + if self.log_stats: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) if last_block_hash is not None: # Add back the last block hash if it was removed. @@ -171,8 +172,9 @@ def allocate_slots( Args: request: The request to allocate slots. - num_tokens: The number of tokens to allocate. Note that this does - not include the tokens that have already been computed. + num_tokens: The number of tokens to allocate, including external + tokens. Note that this does not include tokens that have + already been computed locally (i.e. new_computed_blocks). new_computed_blocks: A list of new computed blocks just hitting the prefix caching. num_lookahead_tokens: The number of speculative tokens to allocate. @@ -249,13 +251,9 @@ def allocate_slots( # No new block is needed. new_blocks = [] else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_preallocate_blocks = max( - 0, self.num_preallocate_blocks - - num_lookahead_tokens // self.block_size) + # Get new blocks from the free block pool. num_new_blocks = min( - num_new_blocks + num_preallocate_blocks, + num_new_blocks, self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape @@ -316,17 +314,19 @@ def free(self, request: Request) -> None: def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, + flows to invalidate prefix caching after the weights are updated, or used for resetting prefix caching status for benchmarking. Returns: bool: True if the prefix cache is successfully reset, False otherwise. """ - if self.block_pool.reset_prefix_cache(): + if not self.block_pool.reset_prefix_cache(): + return False + if self.log_stats: + assert self.prefix_cache_stats is not None self.prefix_cache_stats.reset = True - return True - return False + return True def get_num_common_prefix_blocks( self, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bd0e01d045d1..3026ecc1c968 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -43,19 +43,19 @@ class BlockHashType(NamedTuple): # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( - 'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED')) + 'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED')) class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the most recent N requests. + """Metrics for prefix caching with a hit rate of the max recent N requests. Args: - interval: The number of the most recent requests to aggregate. + max_recent_requests: The number of the max recent requests to aggregate. Defaults to 1000. """ - def __init__(self, interval: int = 1000): - self.interval = interval + def __init__(self, max_recent_requests: int = 1000): + self.max_recent_requests = max_recent_requests # The current aggregated values. self.aggregated_requests = 0 self.aggregated_query_total = 0 @@ -70,7 +70,7 @@ def observe(self, stats: PrefixCacheStats): are being scheduled and are looking for computed blocks. When there are more than `interval` requests, the oldest set of - requestsare removed from the metrics. + requests are removed from the metrics. Args: stats: The prefix cache stats. @@ -87,7 +87,7 @@ def observe(self, stats: PrefixCacheStats): self.aggregated_query_hit += stats.hits # Remove the oldest stats if the number of requests exceeds. - if self.aggregated_requests > self.interval: + if self.aggregated_requests > self.max_recent_requests: old_requests, old_queries, old_hits = self.query_queue.popleft() self.aggregated_requests -= old_requests self.aggregated_query_total -= old_queries diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index bfed44f9d58c..1de236d42f02 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -117,11 +117,6 @@ def has_requests(self) -> bool: not yet returned in SchedulerOutputs.""" return self.has_unfinished_requests() or self.has_finished_requests() - @abstractmethod - def get_num_unscheduled_requests(self) -> int: - """Number of requests that are not being processed by the executor.""" - raise NotImplementedError - @abstractmethod def reset_prefix_cache(self) -> bool: """Reset the prefix cache for KV cache. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dc0d2d59fea7..928fb231a1f2 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -9,6 +9,8 @@ import numpy as np import numpy.typing as npt + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata) from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -20,7 +22,6 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] - prompt: Optional[str] mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] @@ -38,7 +39,6 @@ def from_request( return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, mm_inputs=request.mm_inputs, mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, @@ -121,3 +121,6 @@ class SchedulerOutput: structured_output_request_ids: dict[str, int] # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a81574875a5c..60c6a0f00600 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -3,12 +3,14 @@ from __future__ import annotations import time -from collections import deque +from collections import defaultdict, deque from collections.abc import Iterable from typing import Optional, Union -from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -34,20 +36,17 @@ class Scheduler(SchedulerInterface): def __init__( self, - scheduler_config: SchedulerConfig, - model_config: ModelConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, - speculative_config: SpeculativeConfig = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - self.lora_config = lora_config + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config self.kv_cache_config = kv_cache_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager @@ -64,13 +63,17 @@ def __init__( self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len - # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=self.max_model_len, - enable_caching=cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, - log_stats=self.log_stats) + # Create KVConnector for the Scheduler. Note that each Worker + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None + if self.vllm_config.kv_transfer_config is not None: + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + num_gpu_blocks = self.cache_config.num_gpu_blocks + assert num_gpu_blocks is not None and num_gpu_blocks > 0 + self.block_size = self.cache_config.block_size # req_id -> Request @@ -78,9 +81,6 @@ def __init__( # Priority queues for requests. self.waiting: deque[Request] = deque() self.running: list[Request] = [] - # The requests that have been scheduled and are being executed - # by the executor. - self.scheduled_req_ids: set[str] = set() # The request IDs that are finished in between the previous and the # current steps. This is used to notify the workers about the finished @@ -90,8 +90,9 @@ def __init__( # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - # Request id -> CachedRequestData - self._cached_reqs_data: dict[str, CachedRequestData] = {} + # Request id -> deque of CachedRequestData + self._cached_reqs_data: dict[ + str, deque[CachedRequestData]] = defaultdict(deque) # Encoder-related. # Calculate encoder cache size if applicable @@ -99,8 +100,8 @@ def __init__( # This can be changed when we make encoder cache for embedding caching # across requests. encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, mm_registry=mm_registry, ) @@ -114,10 +115,24 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) - self.num_lookahead_tokens = 0 - if speculative_config and speculative_config.method == "eagle": - self.num_lookahead_tokens = \ - speculative_config.num_speculative_tokens + speculative_config = vllm_config.speculative_config + + self.use_eagle = False + self.num_spec_tokens = self.num_lookahead_tokens = 0 + if speculative_config: + self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.use_eagle(): + self.use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + use_eagle=self.use_eagle, + log_stats=self.log_stats) def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -160,10 +175,6 @@ def schedule(self) -> SchedulerOutput: req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - if request.request_id in self.scheduled_req_ids: - # This request has already been scheduled. - req_index += 1 - continue num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) @@ -172,26 +183,35 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) - assert num_new_tokens > 0 + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget) = self._try_schedule_encoder_inputs( request, request.num_computed_tokens, num_new_tokens, encoder_budget) - if num_new_tokens == 0: - # The request cannot be scheduled because the encoder budget - # or the encoder cache is exhausted. - # NOTE(woosuk): By using `continue` instead of `break` here, - # we intentionally relax the strict FCFS scheduling policy - # to allow lower-priority requests to be scheduled when a - # higher-priority request is blocked by encoder constraints. - req_index += 1 - continue - else: - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when PP>1 and + # we have already scheduled all prompt tokens but they are + # not finished yet. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue while True: new_blocks = self.kv_cache_manager.allocate_slots( @@ -225,7 +245,6 @@ def schedule(self) -> SchedulerOutput: # Schedule the request. scheduled_running_reqs.append(request) - self.scheduled_req_ids.add(request.request_id) if request.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. @@ -303,7 +322,18 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks(request) + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + num_external_tokens = ( + 0 if self.connector is None else + self.connector.get_num_new_matched_tokens( + request, num_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens += num_external_tokens + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -330,18 +360,30 @@ def schedule(self) -> SchedulerOutput: new_encoder_budget = encoder_budget new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, computed_blocks) + request, + num_new_tokens + num_external_tokens, + computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is None: # The request cannot be scheduled. break + # KVConnector: update internal state after allocation. + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + num_external_tokens, + ) + self.waiting.popleft() if request.use_structured_output: structured_output_request_ids[ request.request_id] = req_index req_index += 1 self.running.append(request) - self.scheduled_req_ids.add(request.request_id) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) @@ -443,6 +485,14 @@ def schedule(self) -> SchedulerOutput: grammar_bitmask=grammar_bitmask, ) + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the @@ -472,18 +522,21 @@ def _make_cached_request_data( num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens new_token_ids = request.all_token_ids[ num_computed_tokens:num_computed_tokens + num_regular_tokens] - req_data = self._cached_reqs_data.get(request.request_id) - if req_data is not None: + + req_data_queue = self._cached_reqs_data.get(request.request_id) + if req_data_queue: + req_data = req_data_queue.popleft() req_data.resumed_from_preemption = resumed_from_preemption req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: + # No cached request data, or all cached request data has been + # used by the scheduled requests. req_data = CachedRequestData.from_request(request, resumed_from_preemption, new_token_ids, new_block_ids) - self._cached_reqs_data[request.request_id] = req_data return req_data def _try_schedule_encoder_inputs( @@ -508,7 +561,12 @@ def _try_schedule_encoder_inputs( If an encoder input cannot be scheduled due to cache or budget limitations, the method adjusts `num_new_tokens` to schedule only the decoder tokens up to just before the unschedulable encoder input. + + Note that num_computed_tokens includes both locally cached + blocks and externally cached blocks (via KVConnector). """ + if num_new_tokens == 0 or not request.has_encoder_inputs: + return [], num_new_tokens, encoder_budget encoder_inputs_to_schedule: list[int] = [] mm_positions = request.mm_positions assert mm_positions is not None @@ -676,10 +734,16 @@ def update_from_output( # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) + # Return the cached request data to the queue so they can be reused. + for req_data in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): since we free stopped reqs above, adding stopped reqs + # to _cached_reqs_data will cause a memory leak. + if req_data.req_id not in self.finished_req_ids: + self._cached_reqs_data[req_data.req_id].append(req_data) + self.running = new_running engine_core_outputs = EngineCoreOutputs( outputs=outputs, @@ -722,7 +786,6 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) - self.scheduled_req_ids.discard(request.request_id) else: self.waiting.remove(request) request.status = finished_status @@ -743,10 +806,6 @@ def get_num_unfinished_requests(self) -> int: def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def get_num_unscheduled_requests(self) -> int: - """Number of requests that are not being processed by the executor.""" - return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) - def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() @@ -756,11 +815,13 @@ def make_stats( ) -> Optional[SchedulerStats]: if not self.log_stats: return None + prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() + assert prefix_cache_stats is not None return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(), + prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, ) @@ -773,7 +834,8 @@ def make_spec_decoding_stats( if not self.log_stats: return None if spec_decoding_stats is None: - spec_decoding_stats = SpecDecodingStats() - spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + spec_decoding_stats.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) return spec_decoding_stats diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1264e43c79d9..0474669610cd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -49,9 +49,6 @@ class EngineCoreRequest( # due to circular imports and typing we have in data.py request_id: str - # NOTE(ywang96): original text prompt is needed when a request is added to - # Detokenizer, but set to None when it is added to EngineCoreClient. - prompt: Optional[str] prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] @@ -61,6 +58,11 @@ class EngineCoreRequest( arrival_time: float lora_request: Optional[LoRARequest] + # Used in DP case to indicate which wave of requests this is expected to + # belong to, to cover a race condition where the request is sent before + # a wave finished notification is received. + current_wave: int = 0 + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -139,8 +141,12 @@ class EngineCoreOutputs( utility_output: Optional[UtilityOutput] = None finished_requests: Optional[set[str]] = None - # In DP case, used to signal that the engine is paused. - engine_paused: bool = False + # In DP case, used to signal that the current wave of requests + # has finished and the engines are paused. + wave_complete: Optional[int] = None + # In DP case, used to signal that a request was received for an + # "old" wave, so the next wave needs to be started in other engines. + start_wave: Optional[int] = None def __post_init__(self): if self.timestamp == 0.0: @@ -154,5 +160,7 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' - START_DP = b'\x02' + START_DP_WAVE = b'\x02' UTILITY = b'\x03' + # Sentinel used within EngineCoreProc. + EXECUTOR_FAILED = b'\x04' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b77a6824cddb..1334fb789aa4 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,8 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import logging -import os from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Optional, Union @@ -26,16 +23,17 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cdiv, kill_process_tree +from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, - StatLoggerBase) +from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, + setup_default_loggers) from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -53,7 +51,28 @@ def __init__( use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> None: + """ + Create an AsyncLLM. + + Args: + vllm_config: global configuration. + executor_class: an Executor impl, e.g. MultiprocExecutor. + log_stats: Whether to log stats. + usage_context: Usage context of the LLM. + mm_registry: Multi-modal registry. + use_cached_outputs: Whether to use cached outputs. + log_requests: Whether to log requests. + start_engine_loop: Whether to start the engine loop. + stat_loggers: customized stat loggers for the engine. + If not provided, default stat loggers will be used. + PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE + IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + + Returns: + None + """ if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " @@ -61,31 +80,24 @@ def __init__( "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - assert start_engine_loop - self.model_config = vllm_config.model_config - + self.vllm_config = vllm_config self.log_requests = log_requests self.log_stats = log_stats # Set up stat loggers; independent set for each DP rank. - self.stat_loggers: list[list[StatLoggerBase]] = [] - if self.log_stats: - for i in range(vllm_config.parallel_config.data_parallel_size): - loggers: list[StatLoggerBase] = [] - if logger.isEnabledFor(logging.INFO): - loggers.append(LoggingStatLogger(engine_index=i)) - loggers.append( - PrometheusStatLogger(vllm_config, engine_index=i)) - self.stat_loggers.append(loggers) + self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( + vllm_config=vllm_config, + log_stats=self.log_stats, + engine_num=vllm_config.parallel_config.data_parallel_size, + custom_stat_loggers=stat_loggers, + ) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) - self.tokenizer.ping() # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -99,15 +111,23 @@ def __init__( log_stats=self.log_stats) # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_client( - multiprocess_mode=True, - asyncio_mode=True, + core_client_class = AsyncMPClient if ( + vllm_config.parallel_config.data_parallel_size + == 1) else DPAsyncMPClient + + self.engine_core = core_client_class( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, ) self.output_handler: Optional[asyncio.Task] = None + try: + # Start output handler eagerly if we are in the asyncio eventloop. + asyncio.get_running_loop() + self._run_output_handler() + except RuntimeError: + pass @classmethod def from_vllm_config( @@ -115,7 +135,7 @@ def from_vllm_config( vllm_config: VllmConfig, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, ) -> "AsyncLLM": @@ -126,17 +146,12 @@ def from_vllm_config( "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - # FIXME(rob): refactor VllmConfig to include the StatLoggers - # include StatLogger in the Oracle decision. - if stat_loggers is not None: - raise ValueError("Custom StatLoggers are not yet supported on V1. " - "Explicitly set VLLM_USE_V1=0 to disable V1.") - # Create the LLMEngine. return cls( vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), start_engine_loop=start_engine_loop, + stat_loggers=stat_loggers, log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, @@ -148,6 +163,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" @@ -163,8 +179,12 @@ def from_engine_args( log_stats=not engine_args.disable_log_stats, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) + def __del__(self): + self.shutdown() + def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" @@ -187,6 +207,9 @@ async def add_request( ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" + if self.errored: + raise EngineDeadError() + assert isinstance(params, SamplingParams), \ "Pooling is not supported in V1" @@ -194,14 +217,12 @@ async def add_request( queue = RequestOutputCollector(output_kind=params.output_kind) # Convert Input --> Request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + prompt_str, request = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) if params.n == 1: - await self._add_request(request, None, 0, queue) + await self._add_request(request, prompt_str, None, 0, queue) return queue # Fan out child requests (for n>1). @@ -211,15 +232,18 @@ async def add_request( child_request = request if idx == params.n - 1 else copy(request) child_request.request_id = request_id child_request.sampling_params = params - await self._add_request(child_request, parent_request, idx, queue) + await self._add_request(child_request, prompt_str, parent_request, + idx, queue) return queue async def _add_request(self, request: EngineCoreRequest, + prompt: Optional[str], parent_req: Optional[ParentRequest], index: int, queue: RequestOutputCollector): # Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, parent_req, index, queue) + self.output_processor.add_request(request, prompt, parent_req, index, + queue) # Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(request) @@ -261,9 +285,7 @@ async def generate( # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. - if self.output_handler is None: - self.output_handler = asyncio.create_task( - self._run_output_handler()) + self._run_output_handler() q = await self.add_request( request_id, @@ -288,62 +310,96 @@ async def generate( finished = out.finished yield out - # If the request is disconnected by the client, the - # generate() task will be canceled. So, we abort the - # request if we end up here. + # If the request is disconnected by the client, generate() + # is cancelled. So, we abort the request if we end up here. except asyncio.CancelledError: await self.abort(request_id) + if self.log_requests: + logger.info("Request %s aborted.", request_id) raise - async def _run_output_handler(self): - """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + # Engine is dead. Do not abort since we shut down. + except EngineDeadError: + if self.log_requests: + logger.info("Request %s failed (engine dead).", request_id) + raise - try: - while True: - # 1) Pull EngineCoreOutputs from the EngineCore. - outputs = await self.engine_core.get_output_async() - num_outputs = len(outputs.outputs) - - iteration_stats = IterationStats() if ( - self.log_stats and num_outputs) else None - - # Split outputs into chunks of at most - # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the - # event loop for too long. - if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: - slices = (outputs.outputs, ) - else: - slices = np.array_split( - outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) - - for i, outputs_slice in enumerate(slices): - # 2) Process EngineCoreOutputs. - processed_outputs = self.output_processor.process_outputs( - outputs_slice, outputs.timestamp, iteration_stats) - # NOTE: RequestOutputs are pushed to their queues. - assert not processed_outputs.request_outputs - - # Allow other asyncio tasks to run between chunks - if i + 1 < len(slices): - await asyncio.sleep(0) - - # 3) Abort any reqs that finished due to stop strings. - await self.engine_core.abort_requests_async( - processed_outputs.reqs_to_abort) - - # 4) Logging. - # TODO(rob): make into a coroutine and launch it in - # background thread once Prometheus overhead is non-trivial. - self._record_stats( - engine_index=outputs.engine_index, - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - ) + # Request validation error. + except ValueError: + if self.log_requests: + logger.info("Request %s failed (bad request).", request_id) + raise + # Unexpected error in the generate() task (possibly recoverable). except Exception as e: - logger.exception("EngineCore output handler hit an error: %s", e) - kill_process_tree(os.getpid()) + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s failed.", request_id) + raise EngineGenerateError() from e + + def _run_output_handler(self): + """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + + if self.output_handler is not None: + return + + # Ensure that the task doesn't have a circular ref back to the AsyncLLM + # object, or else it won't be garbage collected and cleaned up properly. + engine_core = self.engine_core + output_processor = self.output_processor + log_stats = self.log_stats + stat_loggers = self.stat_loggers if log_stats else None + + async def output_handler(): + try: + while True: + # 1) Pull EngineCoreOutputs from the EngineCore. + outputs = await engine_core.get_output_async() + num_outputs = len(outputs.outputs) + + iteration_stats = IterationStats() if ( + log_stats and num_outputs) else None + + # Split outputs into chunks of at most + # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the + # event loop for too long. + if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: + slices = (outputs.outputs, ) + else: + slices = np.array_split( + outputs.outputs, + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + + for i, outputs_slice in enumerate(slices): + # 2) Process EngineCoreOutputs. + processed_outputs = output_processor.process_outputs( + outputs_slice, outputs.timestamp, iteration_stats) + # NOTE: RequestOutputs are pushed to their queues. + assert not processed_outputs.request_outputs + + # Allow other asyncio tasks to run between chunks + if i + 1 < len(slices): + await asyncio.sleep(0) + + # 3) Abort any reqs that finished due to stop strings. + await engine_core.abort_requests_async( + processed_outputs.reqs_to_abort) + + # 4) Logging. + # TODO(rob): make into a coroutine and launch it in + # background thread once Prometheus overhead is non-trivial. + if stat_loggers: + assert outputs.scheduler_stats is not None + AsyncLLM._record_stats( + stat_loggers[outputs.engine_index], + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + ) + except Exception as e: + logger.exception("AsyncLLM output_handler failed.") + output_processor.propagate_error(e) + + self.output_handler = asyncio.create_task(output_handler()) async def abort(self, request_id: str) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" @@ -354,17 +410,15 @@ async def abort(self, request_id: str) -> None: if self.log_requests: logger.info("Aborted request %s.", request_id) + @staticmethod def _record_stats( - self, - scheduler_stats: Optional[SchedulerStats], + stat_loggers: list[StatLoggerBase], + scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats], - engine_index: int = 0, ): - if not self.log_stats: - return - - assert scheduler_stats is not None - for stat_logger in self.stat_loggers[engine_index]: + """static so that it can be used from the output_handler task + without a circular ref to AsyncLLM.""" + for stat_logger in stat_loggers: stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) @@ -379,6 +433,9 @@ def encode( ): raise ValueError("Not Supported on V1 yet.") + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + async def get_model_config(self) -> ModelConfig: return self.model_config @@ -446,18 +503,30 @@ async def pin_lora(self, lora_id: int) -> bool: """Prevent an adapter from being evicted.""" return await self.engine_core.pin_lora_async(lora_id) + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """ + Perform a collective RPC call to the given path. + """ + return await self.engine_core.collective_rpc_async( + method, timeout, args, kwargs) + @property def is_running(self) -> bool: - return True + # Is None before the loop is started. + return self.output_handler is None or not self.output_handler.done() @property def is_stopped(self) -> bool: - return False + return self.errored @property def errored(self) -> bool: - return False + return self.engine_core.resources.engine_dead or not self.is_running @property def dead_error(self) -> BaseException: - return Exception() # TODO: implement + return EngineDeadError() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f642e51001a8..80807665e779 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,15 +5,14 @@ import sys import threading import time +from collections import deque from concurrent.futures import Future from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union import msgspec -import psutil import zmq -import zmq.asyncio from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group @@ -22,8 +21,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, - zmq_socket_ctx) +from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -50,12 +48,11 @@ class EngineCore: """Inner loop of vLLM's Engine.""" - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - ): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" logger.info("Initializing a V1 LLM engine (v%s) with config: %s", @@ -65,6 +62,9 @@ def __init__( # Setup Model. self.model_executor = executor_class(vllm_config) + if executor_fail_callback is not None: + self.model_executor.register_failure_callback( + executor_fail_callback) # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ @@ -93,12 +93,8 @@ def __init__( vllm_config.scheduler_config.scheduler_cls) self.scheduler: SchedulerInterface = Scheduler( - scheduler_config=vllm_config.scheduler_config, - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - lora_config=vllm_config.lora_config, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, - speculative_config=vllm_config.speculative_config, structured_output_manager=self.structured_output_manager, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, @@ -215,10 +211,10 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: Note that if nothing to output in this step, None is returned. The execution flow is as follows: - 1. Try to schedule a new batch if there are unscheduled requests - and the job queue is not full. If a new batch is scheduled, directly - return an empty engine core output. In other words, we won't check - and return model outputs before the batch queue is full. + 1. Try to schedule a new batch if the batch queue is not full. + If a new batch is scheduled, directly return an empty engine core + output. In other words, fulfilling the batch queue has a higher priority + than getting model outputs. 2. If there is no new scheduled batch, meaning that the batch queue is full or no other requests can be scheduled, we block until the first batch in the job queue is finished. @@ -228,10 +224,10 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: engine_core_outputs = None scheduler_output = None - # If there are unscheduled requests and the job queue - # is not full, schedule a new batch. Note that this is not blocking. - if (self.scheduler.get_num_unscheduled_requests() > 0 - and not self.batch_queue.full()): + # Try to schedule a new batch if the batch queue is not full, but + # the scheduler may return an empty batch if all requests are scheduled. + # Note that this is not blocking. + if not self.batch_queue.full(): scheduler_output = self.scheduler.schedule() if scheduler_output.total_num_scheduled_tokens > 0: future = self.model_executor.execute_model(scheduler_output) @@ -243,6 +239,10 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: # If no more requests can be scheduled and the job queue is not empty, # block until the first batch in the job queue is finished. + # TODO(comaniac): Ideally we should peek the first batch in the + # job queue to check if it's finished before scheduling a new batch, + # but peeking the first element in a queue is not thread-safe, + # so we need more work. if not scheduled_batch and not self.batch_queue.empty(): future, scheduler_output = self.batch_queue.get_nowait() # Blocking until the first result is available. @@ -254,7 +254,9 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: return engine_core_outputs def shutdown(self): - self.model_executor.shutdown() + self.structured_output_manager.clear_backend() + if self.model_executor: + self.model_executor.shutdown() def profile(self, is_start: bool = True): self.model_executor.profile(is_start) @@ -308,6 +310,8 @@ def collective_rpc(self, class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" + ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + def __init__( self, input_path: str, @@ -317,27 +321,33 @@ def __init__( log_stats: bool, engine_index: int = 0, ): - super().__init__(vllm_config, executor_class, log_stats) + input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() + + executor_fail_callback = lambda: input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + + super().__init__(vllm_config, executor_class, log_stats, + executor_fail_callback) self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) - - self.global_unfinished_reqs = False + self.engines_running = False # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[tuple[EngineCoreRequestType, - Any]] = queue.Queue() - self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() threading.Thread(target=self.process_input_socket, args=(input_path, engine_index), daemon=True).start() - threading.Thread(target=self.process_output_socket, - args=(output_path, engine_index), - daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_socket, + args=(output_path, engine_index), + daemon=True) + self.output_thread.start() @staticmethod def run_engine_core(*args, @@ -364,7 +374,6 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - parent_process = psutil.Process().parent() engine_core: Optional[EngineCoreProc] = None try: parallel_config: ParallelConfig = kwargs[ @@ -380,13 +389,15 @@ def signal_handler(signum, frame): engine_core.run_busy_loop() except SystemExit: - logger.debug("EngineCore interrupted.") - - except Exception: - traceback = get_exception_traceback() - logger.error("EngineCore hit an exception: %s", traceback) - parent_process.send_signal(signal.SIGUSR1) - + logger.debug("EngineCore exiting.") + raise + except Exception as e: + if engine_core is None: + logger.exception("EngineCore failed to start.") + else: + logger.exception("EngineCore encountered a fatal error.") + engine_core._send_engine_dead() + raise e finally: if engine_core is not None: engine_core.shutdown() @@ -405,8 +416,7 @@ def _process_input_queue(self): """Exits when an engine step needs to be performed.""" waited = False - while not self.global_unfinished_reqs and not ( - self.scheduler.has_requests()): + while not self.engines_running and not (self.scheduler.has_requests()): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -414,10 +424,7 @@ def _process_input_queue(self): self._handle_client_request(*req) if waited: - logger.debug( - "EngineCore loop active - local unfinished: %s, finished: %s.", - self.scheduler.has_unfinished_requests(), - self.scheduler.has_finished_requests()) + logger.debug("EngineCore loop active.") # Handle any more client requests. while not self.input_queue.empty(): @@ -441,10 +448,6 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) - elif request_type == EngineCoreRequestType.START_DP: - if not self.global_unfinished_reqs: - logger.debug("EngineCore starting idle loop.") - self.global_unfinished_reqs = True elif request_type == EngineCoreRequestType.UTILITY: call_id, method_name, args = request output = UtilityOutput(call_id) @@ -458,6 +461,11 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, f" failed: {str(e)}") self.output_queue.put_nowait( EngineCoreOutputs(utility_output=output)) + elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: + raise RuntimeError("Executor failed.") + else: + logger.error("Unrecognized input request type encountered: %s", + request_type) @staticmethod def _convert_msgspec_args(method, args): @@ -473,6 +481,18 @@ def _convert_msgspec_args(method, args): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) + def _send_engine_dead(self): + """Send EngineDead status to the EngineCoreClient.""" + + # Put ENGINE_CORE_DEAD in the queue. + self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD) + + # Wait until msg sent by the daemon before shutdown. + self.output_thread.join(timeout=5.0) + if self.output_thread.is_alive(): + logger.fatal("vLLM shutdown signal from EngineCore failed " + "to send. Please report this issue.") + def process_input_socket(self, input_path: str, engine_index: int): """Input socket IO thread.""" @@ -508,18 +528,40 @@ def process_output_socket(self, output_path: str, engine_index: int): # Msgpack serialization encoding. encoder = MsgpackEncoder() - # Reuse send buffer. - buffer = bytearray() - - with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: + # Send buffers to reuse. + reuse_buffers: list[bytearray] = [] + # Keep references to outputs and buffers until zmq is finished + # with them (outputs may contain tensors/np arrays whose + # backing buffers were extracted for zero-copy send). + pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]() + + # We must set linger to ensure the ENGINE_CORE_DEAD + # message is sent prior to closing the socket. + with zmq_socket_ctx(output_path, zmq.constants.PUSH, + linger=4000) as socket: while True: outputs = self.output_queue.get() + if outputs == EngineCoreProc.ENGINE_CORE_DEAD: + socket.send(outputs, copy=False) + break + assert not isinstance(outputs, bytes) outputs.engine_index = engine_index - buffers = encoder.encode_into(outputs, buffer) - socket.send_multipart(buffers, copy=False) + # Reclaim buffers that zmq is finished with. + while pending and pending[-1][0].done: + reuse_buffers.append(pending.pop()[2]) -ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True) + buffer = reuse_buffers.pop() if reuse_buffers else bytearray() + buffers = encoder.encode_into(outputs, buffer) + tracker = socket.send_multipart(buffers, + copy=False, + track=True) + if not tracker.done: + ref = outputs if len(buffers) > 1 else None + pending.appendleft((tracker, ref, buffer)) + elif len(reuse_buffers) < 2: + # Keep at most 2 buffers to reuse. + reuse_buffers.append(buffer) class DPEngineCoreProc(EngineCoreProc): @@ -558,7 +600,9 @@ def __init__( for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * tp_size)) + self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.current_wave = 0 # Initialize the engine after setting up environment. super().__init__(input_path, output_path, vllm_config, executor_class, @@ -573,6 +617,31 @@ def shutdown(self): if dp_group := getattr(self, "dp_group", None): stateless_destroy_torch_distributed_process_group(dp_group) + def add_request(self, request: EngineCoreRequest): + if request.current_wave != self.current_wave: + if request.current_wave > self.current_wave: + self.current_wave = request.current_wave + elif not self.engines_running: + # Request received for an already-completed wave, notify + # front-end that we need to start the next one. + self.output_queue.put_nowait( + EngineCoreOutputs(start_wave=self.current_wave)) + + super().add_request(request) + + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + if request_type == EngineCoreRequestType.START_DP_WAVE: + new_wave: int = request + if new_wave >= self.current_wave: + self.current_wave = new_wave + if not self.engines_running: + logger.debug("EngineCore starting idle loop for wave %d.", + new_wave) + self.engines_running = True + else: + super()._handle_client_request(request_type, request) + def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -599,7 +668,7 @@ def run_busy_loop(self): # up-to-date state is returned in the engine outputs. self._process_engine_step() - if not self.global_unfinished_reqs: + if not self.engines_running: # All engines are idle. continue @@ -608,18 +677,23 @@ def run_busy_loop(self): self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. - self.global_unfinished_reqs = self._has_global_unfinished_reqs( + self.engines_running = self._has_global_unfinished_reqs( local_unfinished_reqs) - if not self.global_unfinished_reqs: - # Notify client that we are pausing the loop. - self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS) + if not self.engines_running: + if self.local_dp_rank == 0: + # Notify client that we are pausing the loop. + logger.debug("Wave %d finished, pausing engine loop.", + self.current_wave) + self.output_queue.put_nowait( + EngineCoreOutputs(wave_complete=self.current_wave)) + self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 16 steps. + # Optimization - only perform finish-sync all-reduce every 24 steps. self.counter += 1 - if self.counter != 16: + if self.counter != 24: return True self.counter = 0 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a96ebc7edb53..dd5190996196 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import os +import contextlib import queue -import signal -import threading import uuid import weakref from abc import ABC, abstractmethod -from collections.abc import Awaitable +from collections import deque +from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass, field from threading import Thread @@ -21,10 +19,11 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - kill_process_tree, make_zmq_socket) + make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.utils import BackgroundProcHandle @@ -305,14 +304,23 @@ class BackgroundResources: core_engines: list[CoreEngine] = field(default_factory=list) output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + output_queue_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None + # Set if any of the engines are dead. Here so that the output + # processing threads can access it without holding a ref to the client. + engine_dead: bool = False + def __call__(self): """Clean up background resources.""" + self.engine_dead = True for core_engine in self.core_engines: core_engine.close() + if self.output_queue_task is not None: + self.output_queue_task.cancel() + # ZMQ context termination can hang if the sockets # aren't explicitly closed first. if self.output_socket is not None: @@ -327,6 +335,12 @@ def __call__(self): # Send shutdown signal. shutdown_sender.send(b'') + def validate_alive(self, frames: Sequence[zmq.Frame]): + if len(frames) == 1 and (frames[0].buffer + == EngineCoreProc.ENGINE_CORE_DEAD): + self.engine_dead = True + raise EngineDeadError() + class MPClient(EngineCoreClient): """ @@ -348,27 +362,6 @@ def __init__( executor_class: type[Executor], log_stats: bool, ): - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. We kill the process tree here so that the - # stack trace is very evident. - # TODO(rob): rather than killing the main process, we should - # figure out how to raise an AsyncEngineDeadError and - # handle at the API server level so we can return a better - # error code to the clients calling vLLM. - def sigusr1_handler(signum, frame): - logger.fatal("Got fatal signal from worker processes, shutting " - "down. See stack trace above for root cause issue.") - kill_process_tree(os.getpid()) - - if threading.current_thread() == threading.main_thread(): - signal.signal(signal.SIGUSR1, sigusr1_handler) - else: - logger.warning("SIGUSR1 handler not installed because we are not " - "running in the main thread. In this case the " - "forked engine process may not be killed when " - "an exception is raised, and you need to handle " - "the engine process shutdown manually.") - # Serialization setup. self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) @@ -378,32 +371,43 @@ def sigusr1_handler(signum, frame): self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed - # when the client is garbage collected, even if an + # when the client is garbage collected, even if an # exception is raised mid-construction. self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket - - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) - - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) - - # Wait for engine core process(es) to start. - self._wait_for_engine_startup() - - self.utility_results: dict[int, AnyFuture] = {} + success = False + try: + # Paths and sockets for IPC. + self.output_path = get_open_zmq_ipc_path() + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(self.ctx, + input_path, + zmq.ROUTER, + bind=True) + self.resources.input_socket = self.input_socket + + new_core_engine = lambda index, local_dp_rank=None: CoreEngine( + vllm_config, executor_class, log_stats, input_path, self. + output_path, index, local_dp_rank) + + # Start engine core process(es). + self._init_core_engines(vllm_config, new_core_engine, + self.resources.core_engines) + + # Wait for engine core process(es) to start. + self._wait_for_engine_startup() + + self.utility_results: dict[int, AnyFuture] = {} + + # Request objects which may contain pytorch-allocated tensors + # that we need to keep references to until zmq is done with the + # underlying data. + self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + + success = True + finally: + if not success: + self._finalizer() def _wait_for_engine_startup(self): # Get a sync handle to the socket which can be sync or async. @@ -443,16 +447,34 @@ def _init_core_engines( ) -> None: # Default case - single core engine. - dp_rank = vllm_config.parallel_config.data_parallel_rank - local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local core_engine = new_core_engine( - dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) + vllm_config.parallel_config.data_parallel_rank, + vllm_config.parallel_config.data_parallel_rank_local, + ) core_engines.append(core_engine) self.core_engine = core_engine def shutdown(self): + # Terminate background resources. self._finalizer() + def _format_exception(self, e: Exception) -> Exception: + """If errored, use EngineDeadError so root cause is clear.""" + return EngineDeadError( + suppress_context=True) if self.resources.engine_dead else e + + def ensure_alive(self): + if self.resources.engine_dead: + raise EngineDeadError() + + def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any): + if not tracker.done: + self.pending_messages.appendleft((tracker, msg)) + + def free_pending_messages(self): + while self.pending_messages and self.pending_messages[-1][0].done: + self.pending_messages.pop() + def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): @@ -476,7 +498,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats=log_stats, ) - self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. @@ -487,7 +509,8 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], outputs_queue = self.outputs_queue shutdown_path = get_open_zmq_inproc_path() - self.resources.shutdown_path = shutdown_path + resources = self.resources + resources.shutdown_path = shutdown_path def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) @@ -506,12 +529,15 @@ def process_outputs_socket(): break frames = out_socket.recv_multipart(copy=False) + resources.validate_alive(frames) outputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) + except Exception as e: + outputs_queue.put_nowait(e) finally: # Close sockets. shutdown_socket.close(linger=0) @@ -524,13 +550,28 @@ def process_outputs_socket(): self.output_queue_thread.start() def get_output(self) -> EngineCoreOutputs: - return self.outputs_queue.get() + # If an exception arises in process_outputs_socket task, + # it is forwarded to the outputs_queue so we can raise it + # from this (run_output_handler) task to shut down the server. + outputs = self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None + return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any): + self.ensure_alive() + self.free_pending_messages() # (Identity, RequestType, SerializedRequest) msg = (self.core_engine.identity, request_type.value, *self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) + + if len(msg) <= 3: + # No auxiliary buffers => no tensor backing buffers in request. + self.input_socket.send_multipart(msg, copy=False) + return + + tracker = self.input_socket.send_multipart(msg, copy=False, track=True) + self.add_pending_message(tracker, request) def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 @@ -542,13 +583,10 @@ def call_utility(self, method: str, *args) -> Any: return future.result() def add_request(self, request: EngineCoreRequest) -> None: - # NOTE: text prompt is not needed in the core engine as it has been - # tokenized. - request.prompt = None self._send_input(EngineCoreRequestType.ADD, request) def abort_requests(self, request_ids: list[str]) -> None: - if len(request_ids) > 0: + if request_ids and not self.resources.engine_dead: self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: @@ -608,71 +646,111 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats=log_stats, ) - self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None - self.queue_task: Optional[asyncio.Task] = None - - self.outputs_handler: Optional[Callable[ - [AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, + Exception]]() + try: + # If we are running in an asyncio event loop, start the queue task. + # Otherwise, it will be started lazily. If it is not started here, + # we could miss EXECUTOR_FAILED messages from engine core if they + # occur prior to any requests being sent. + asyncio.get_running_loop() + self._ensure_output_queue_task() + except RuntimeError: + pass def _ensure_output_queue_task(self): - if self.outputs_queue is not None: + resources = self.resources + if resources.output_queue_task is not None: return # Perform IO in separate task to parallelize as much as possible. # Avoid task having direct reference back to the client. - self.outputs_queue = asyncio.Queue() decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler = self.outputs_handler + output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], + Awaitable[None]]] = getattr( + self.__class__, + "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_path = self.output_path output_socket = make_zmq_socket(self.ctx, output_path, zmq.constants.PULL) - self.resources.output_socket = output_socket + resources.output_socket = output_socket async def process_outputs_socket(): - while True: - frames = await output_socket.recv_multipart(copy=False) - outputs: EngineCoreOutputs = decoder.decode(frames) - if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) - continue - - if output_handler is not None: - assert _self_ref is not None - _self = _self_ref() - if not _self: - # Client has been garbage collected, abort. - return - await output_handler(_self, outputs) - - if outputs.outputs or outputs.scheduler_stats: - outputs_queue.put_nowait(outputs) - - self.queue_task = asyncio.create_task(process_outputs_socket(), - name="EngineCoreOutputQueueTask") + try: + while True: + frames = await output_socket.recv_multipart(copy=False) + resources.validate_alive(frames) + outputs: EngineCoreOutputs = decoder.decode(frames) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + continue + + if output_handler is not None: + assert _self_ref is not None + _self = _self_ref() + if not _self: + # Client has been garbage collected, abort. + return + await output_handler(_self, outputs) + + if outputs.outputs or outputs.scheduler_stats: + outputs_queue.put_nowait(outputs) + except Exception as e: + outputs_queue.put_nowait(e) + + resources.output_queue_task = asyncio.create_task( + process_outputs_socket(), name="EngineCoreOutputQueueTask") async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() + # If an exception arises in process_outputs_socket task, + # it is forwarded to the outputs_queue so we can raise it + # from this (run_output_handler) task to shut down the server. assert self.outputs_queue is not None - return await self.outputs_queue.get() + outputs = await self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None + return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any, - engine: Optional[CoreEngine] = None) -> Awaitable[None]: + engine: Optional[CoreEngine] = None) -> Awaitable[Any]: + self.ensure_alive() if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) - return self._send_input_message(message, engine) + return self._send_input_message(message, engine, request) + + def _send_input_message(self, message: tuple[bytestr, + ...], engine: CoreEngine, + objects: Any) -> Awaitable[Any]: + """ + objects is a reference to retain until zmq is finished with the + buffers, in case they were extracted from tensors in the request. + """ + self.ensure_alive() + self.free_pending_messages() + + msg = (engine.identity, ) + message + if not objects or len(msg) <= 3: + # No auxiliary buffers => no tensor backing buffers in request. + return self.input_socket.send_multipart(msg, copy=False) + + future: asyncio.Future[zmq.MessageTracker] + future = self.input_socket.send_multipart(msg, copy=False, track=True) - def _send_input_message(self, message: tuple[bytestr, ...], - engine: CoreEngine) -> Awaitable[None]: - message = (engine.identity, ) + message - return self.input_socket.send_multipart(message, copy=False) + def add_pending(f: asyncio.Future[zmq.MessageTracker]): + with contextlib.suppress(BaseException): + self.add_pending_message(f.result(), objects) + + future.add_done_callback(add_pending) + return future async def call_utility_async(self, method: str, *args) -> Any: return await self._call_utility_async(method, @@ -686,19 +764,16 @@ async def _call_utility_async(self, method: str, *args, self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( (call_id, method, args))) - await self._send_input_message(message, engine) + await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: - # NOTE: text prompt is not needed in the core engine as it has been - # tokenized. - request.prompt = None await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() async def abort_requests_async(self, request_ids: list[str]) -> None: - if len(request_ids) > 0: + if request_ids and not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: @@ -754,18 +829,14 @@ class DPAsyncMPClient(AsyncMPClient): def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): - super().__init__(vllm_config, executor_class, log_stats) - - assert len(self.core_engines) > 1 - - # Control message used for triggering dp idle mode loop. - self.start_dp_msg = (EngineCoreRequestType.START_DP.value, - *self.encoder.encode(None)) - self.num_engines_running = 0 + self.current_wave = 0 + self.engines_running = False self.reqs_in_flight: dict[str, CoreEngine] = {} - self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] + super().__init__(vllm_config, executor_class, log_stats) + + assert len(self.core_engines) > 1 def _init_core_engines( self, @@ -790,26 +861,23 @@ async def call_utility_async(self, method: str, *args) -> Any: ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: - # NOTE: text prompt is not needed in the core engine as it has been - # tokenized. - request.prompt = None - - msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request)) + request.current_wave = self.current_wave chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine chosen_engine.num_reqs_in_flight += 1 - if self.num_engines_running >= len(self.core_engines): - await self._send_input_message(msg, chosen_engine) - else: + + to_await = self._send_input(EngineCoreRequestType.ADD, request, + chosen_engine) + if not self.engines_running: # Send request to chosen engine and dp start loop # control message to all other engines. - self.num_engines_running += len(self.core_engines) - await asyncio.gather(*[ - self._send_input_message( - msg if engine is chosen_engine else self.start_dp_msg, - engine) for engine in self.core_engines - ]) + self.engines_running = True + to_await = asyncio.gather( + to_await, # type: ignore[assignment] + *self._start_wave_coros(exclude_index=chosen_engine.index)) + + await to_await self._ensure_output_queue_task() @@ -824,21 +892,31 @@ async def process_engine_outputs(self: "DPAsyncMPClient", if engine := self.reqs_in_flight.pop(req_id, None): engine.num_reqs_in_flight -= 1 - if outputs.engine_paused: - assert self.num_engines_running >= 1 - self.num_engines_running -= 1 - if not self.num_engines_running and self.reqs_in_flight: - # If there are requests in flight here, they must have - # been sent after the engines paused. We must make - # sure to start the other engines: - self.num_engines_running = len(self.core_engines) - coros = [ - self._send_input_message(self.start_dp_msg, engine) - for engine in self.core_engines - if not engine.num_reqs_in_flight - ] - if coros: - await asyncio.gather(*coros) + if outputs.wave_complete is not None: + # Current wave is complete, move to next wave number + # and mark engines as paused. + if self.current_wave <= outputs.wave_complete: + self.current_wave = outputs.wave_complete + 1 + self.engines_running = False + + elif outputs.start_wave is not None and ( + outputs.start_wave > self.current_wave or + (outputs.start_wave == self.current_wave + and not self.engines_running)): + # Engine received request for a non-current wave so we must ensure + # that other engines progress to the next wave. + self.current_wave = outputs.start_wave + self.engines_running = True + await asyncio.gather(*self._start_wave_coros( + exclude_index=outputs.engine_index)) + + def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: + logger.debug("Sending start DP wave %d.", self.current_wave) + return [ + self._send_input(EngineCoreRequestType.START_DP_WAVE, + self.current_wave, engine) + for engine in self.core_engines if engine.index != exclude_index + ] async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: @@ -859,5 +937,6 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: async def _abort_requests(self, request_ids: list[str], engine: CoreEngine) -> None: - await self._send_input(EngineCoreRequestType.ABORT, request_ids, - engine) + if not self.resources.engine_dead: + await self._send_input(EngineCoreRequestType.ABORT, request_ids, + engine) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index bf06a17507b2..dca327cc5d07 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 - -from dataclasses import dataclass, field +from abc import ABC, abstractmethod from typing import Optional +import tokenizers +from packaging import version +from tokenizers import Tokenizer +from tokenizers.decoders import DecodeStream +from transformers import PreTrainedTokenizerFast + from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( @@ -12,39 +17,22 @@ logger = init_logger(__name__) -@dataclass class IncrementalDetokenizer: - # Generation data - token_ids: list[int] - output_text: str = "" - tokens: list[str] = field(default_factory=list) - prompt_len: int = 0 - - # Stop strings - stop: list[str] = field(default_factory=list) - include_stop_str_in_output: bool = False - - # Metadata for incremental detokenization - prefix_offset: int = 0 - read_offset: int = 0 - - # Parameters for detokenization - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - - # Tokenizer for this request, - # None if detokenization is disabled. - tokenizer: Optional[AnyTokenizer] = None - - # Accounting for stop string buffering - stop_buffer_length: int = 0 - _last_output_text_offset: int = 0 + def __init__(self): + self.token_ids: list[int] = [] @property def output_token_ids(self) -> list[int]: - return self.token_ids if not self.prompt_len else ( - self.token_ids[self.prompt_len:]) + return self.token_ids + + def update(self, new_token_ids: list[int], + stop_terminated: bool) -> Optional[str]: + self.token_ids.extend(new_token_ids) + return None + + def get_next_output_text(self, finished: bool, delta: bool) -> str: + return "" @classmethod def from_new_request( @@ -54,39 +42,39 @@ def from_new_request( ) -> "IncrementalDetokenizer": if tokenizer is None: - return cls(token_ids=[]) + # No tokenizer => skipping detokenization. + return IncrementalDetokenizer() + + if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse( + tokenizers.__version__) >= version.parse("0.21.1")): + # Fast tokenizer => use tokenizers library DecodeStream. + # And only tokenizers >= 0.21.1 supports Fast Detokenizer. + return FastIncrementalDetokenizer(tokenizer, request) + + # Fall back to slow python-based incremental detokenization. + return SlowIncrementalDetokenizer(tokenizer, request) + + +class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): + + def __init__(self, request: EngineCoreRequest): + super().__init__() - tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=request.prompt_token_ids, - skip_special_tokens=request.sampling_params.skip_special_tokens, - ) + # Stop strings + params = request.sampling_params + self.stop = stop = params.stop + self.include_stop_str_in_output = params.include_stop_str_in_output - stops = request.sampling_params.stop # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stops and not request.sampling_params.include_stop_str_in_output: - stop_buffer_length = max(len(s) for s in stops) - 1 + if stop and not self.include_stop_str_in_output: + self.stop_buffer_length = max(len(s) for s in stop) - 1 else: - stop_buffer_length = 0 - - return cls( - tokens=tokens, - # Detokenizer mutates this list, so need a unique copy. - # NOTE(Nick): could we take ownership of it though? - token_ids=request.prompt_token_ids.copy(), - stop=stops, - include_stop_str_in_output=request.sampling_params. - include_stop_str_in_output, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=request.sampling_params.skip_special_tokens, - spaces_between_special_tokens=request.sampling_params. - spaces_between_special_tokens, - prompt_len=len(request.prompt_token_ids), - tokenizer=tokenizer, - stop_buffer_length=stop_buffer_length, - ) + self.stop_buffer_length = 0 + self._last_output_text_offset: int = 0 + + # Generation data + self.output_text = "" def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: @@ -98,11 +86,7 @@ def update(self, new_token_ids: list[int], Return matched stop string or None. """ if not new_token_ids: - # Skip detokenization if no new token ids - return None - if self.tokenizer is None: - # Skip detokenization if no tokenizer - self.token_ids.extend(new_token_ids) + # Skip detokenization if no new token ids. return None if stop_terminated and not self.include_stop_str_in_output: @@ -116,34 +100,16 @@ def update(self, new_token_ids: list[int], # 1) Detokenize the new token ids incrementally. # TODO(woosuk): This method becomes very inefficient when the number of # new_token_ids is more than 1. We need to optimize this. - decoded_text = "" + offset_before = len(self.output_text) for new_token_id in new_token_ids: self.token_ids.append(new_token_id) - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=self.token_ids, - prev_tokens=self.tokens, - prefix_offset=self.prefix_offset, - read_offset=self.read_offset, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self. - spaces_between_special_tokens, - ) - - self.tokens.extend(new_tokens) - self.prefix_offset = prefix_offset - self.read_offset = read_offset - - decoded_text += new_decoded_token_text - - self.output_text += decoded_text + self.output_text += self.decode_next(new_token_id) if stop_terminated: if skipped_stop_token_id is not None: - # Cleanup after skipping detokenization + # Cleanup after skipping detokenization. self.token_ids.append(skipped_stop_token_id) - # Stop token triggered; skip stop string check + # Stop token triggered; skip stop string check. return None # 2) Evaluate stop strings. @@ -151,7 +117,7 @@ def update(self, new_token_ids: list[int], if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, - new_char_count=len(decoded_text), + new_char_count=len(self.output_text) - offset_before, stop=self.stop, include_in_output=self.include_stop_str_in_output, ) @@ -162,6 +128,10 @@ def update(self, new_token_ids: list[int], return stop_string + @abstractmethod + def decode_next(self, next_token_id: int) -> str: + raise NotImplementedError + def get_next_output_text(self, finished: bool, delta: bool) -> str: """If delta is True, only new text since the last call to this method is returned""" @@ -177,3 +147,114 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: self._last_output_text_offset = length return self.output_text[last_offset:length] return "" + + +class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, tokenizer: PreTrainedTokenizerFast, + request: EngineCoreRequest): + super().__init__(request) + + sampling_params = request.sampling_params + self.stream = DecodeStream( + skip_special_tokens=sampling_params.skip_special_tokens) + + self.tokenizer: Tokenizer = tokenizer._tokenizer + + # Find a safe place to start. + prompt_suffix = request.prompt_token_ids + prompt_len = len(prompt_suffix) + if prompt_len > 4: + for i in range(4, min(prompt_len + 1, 24)): + suffix = request.prompt_token_ids[-i:] + if '�' not in self.tokenizer.decode(suffix): + prompt_suffix = suffix + break + + # Prime the stream. + for tid in prompt_suffix: + self.stream.step(self.tokenizer, tid) + + self.spaces_between_special_tokens = ( + sampling_params.skip_special_tokens + or sampling_params.spaces_between_special_tokens) + + if not self.spaces_between_special_tokens: + # Store dict of added token ids so that we can suppress + # the spaces between them. + if (added_token_ids := getattr(self.tokenizer, "added_token_ids", + None)) is None: + self.tokenizer.added_token_ids = added_token_ids = { + tid: tok.content + for tid, tok in + self.tokenizer.get_added_tokens_decoder().items() + } + + if added_token_ids: + self.last_special = False + self.added_token_ids = added_token_ids + else: + # No added tokens. + self.spaces_between_special_tokens = True + + def decode_next(self, next_token_id: int) -> str: + token = self.stream.step(self.tokenizer, next_token_id) + + if not self.spaces_between_special_tokens: + special_token = self.added_token_ids.get(next_token_id) + is_special = special_token is not None + if is_special and self.last_special: + # Return raw token string without any prefixed spaces. + token = special_token + self.last_special = is_special + + return token or "" + + +class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): + super().__init__(request) + + self.tokenizer = tokenizer + + # Metadata for incremental detokenization. + self.tokens, self.prefix_offset, self.read_offset = ( + convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=request.sampling_params. + skip_special_tokens, + )) + + self.token_ids.extend(request.prompt_token_ids) + self.prompt_len = len(request.prompt_token_ids) + + params = request.sampling_params + self.skip_special_tokens = params.skip_special_tokens + self.spaces_between_special_tokens = ( + params.spaces_between_special_tokens) + + @property + def output_token_ids(self) -> list[int]: + return self.token_ids if not self.prompt_len else ( + self.token_ids[self.prompt_len:]) + + def decode_next(self, next_token_id: int) -> str: + new_tokens, decoded_text, prefix_offset, read_offset = ( + detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self. + spaces_between_special_tokens, + )) + + self.tokens.extend(new_tokens) + self.prefix_offset = prefix_offset + self.read_offset = read_offset + + return decoded_text diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py new file mode 100644 index 000000000000..97dd31d5e521 --- /dev/null +++ b/vllm/v1/engine/exceptions.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +class EngineGenerateError(Exception): + """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass + + +class EngineDeadError(Exception): + """Raised when the EngineCore dies. Unrecoverable.""" + + def __init__(self, *args, suppress_context: bool = False, **kwargs): + ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501 + + super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs) + # Make stack trace clearer when using with LLMEngine by + # silencing irrelevant ZMQError. + self.__suppress_context__ = suppress_context diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4c67186f7040..85da58451c78 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -10,7 +10,6 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -20,7 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, init_tokenizer_from_configs) + TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient @@ -28,10 +27,10 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import StatLoggerFactory logger = init_logger(__name__) -_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _R = TypeVar("_R", default=Any) @@ -44,7 +43,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, @@ -56,6 +55,11 @@ def __init__( "LLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") + if stat_loggers is not None: + raise NotImplementedError( + "Passing StatLoggers to LLMEngine in V1 is not yet supported. " + "Set VLLM_USE_V1=0 and file and issue on Github.") + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -73,9 +77,7 @@ def __init__( self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) - self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, @@ -104,14 +106,9 @@ def from_vllm_config( cls, vllm_config: VllmConfig, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - if stat_loggers is not None: - raise NotImplementedError( - "Passing StatLoggers to V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") - return cls(vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), log_stats=(not disable_log_stats), @@ -124,7 +121,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" @@ -183,17 +180,15 @@ def add_request( priority: int = 0, ) -> None: # Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + prompt_str, request = self.processor.process_inputs( + request_id, prompt, params, arrival_time, lora_request, + trace_headers, prompt_adapter_request, priority) n = params.n if isinstance(params, SamplingParams) else 1 if n == 1: # Make a new RequestState and queue. - self.output_processor.add_request(request, None, 0) + self.output_processor.add_request(request, prompt_str, None, 0) # Add the request to EngineCore. self.engine_core.add_request(request) return @@ -207,7 +202,8 @@ def add_request( child_request.sampling_params = params # Make a new RequestState and queue. - self.output_processor.add_request(child_request, parent_req, idx) + self.output_processor.add_request(child_request, prompt_str, + parent_req, idx) # Add the request to EngineCore. self.engine_core.add_request(child_request) @@ -230,6 +226,9 @@ def step(self) -> list[RequestOutput]: return processed_outputs.request_outputs + def get_vllm_config(self): + return self.vllm_config + def get_model_config(self): return self.model_config @@ -251,21 +250,12 @@ def wake_up(self, tags: Optional[list[str]] = None): def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() - def get_tokenizer_group( - self, - group_type: type[_G] = BaseTokenizerGroup, - ) -> _G: - tokenizer_group = self.tokenizer - - if tokenizer_group is None: + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - if not isinstance(tokenizer_group, group_type): - raise TypeError("Invalid type of tokenizer group. " - f"Expected type: {group_type}, but " - f"found type: {type(tokenizer_group)}") - return tokenizer_group + return self.tokenizer def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index ef5a2e5acb15..c765c1bbffcf 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -50,7 +50,7 @@ def get_and_update_p0( full_mm_inputs = list[Optional[MultiModalKwargs]]() for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if mm_hash in self.mm_cache: + if self.mm_cache.get(mm_hash) is not None: mm_input = None else: self.mm_cache[mm_hash] = mm_input diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 70f072d3c939..f76c44cb8bca 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,7 +8,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor @@ -28,32 +28,37 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[RequestOutput] = None + self.output: Optional[Union[RequestOutput, Exception]] = None self.ready = asyncio.Event() - def put(self, output: RequestOutput) -> None: - if self.output is None: + def put(self, output: Union[RequestOutput, Exception]) -> None: + """Non-blocking put operation.""" + if self.output is None or isinstance(output, Exception): self.output = output self.ready.set() - elif self.aggregate: - # Coalesce the outputs in delta case. - self.output.add(output) - else: - # Just replace latest in non-delta case. - self.output = output + elif isinstance(self.output, RequestOutput): + # This ensures that request outputs with different request indexes + # (if n > 1) do not override each other. + self.output.add(output, aggregate=self.aggregate) async def get(self) -> RequestOutput: + """Get operation blocks on put event.""" while (output := self.output) is None: await self.ready.wait() self.output = None self.ready.clear() + if isinstance(output, Exception): + raise output return output def get_nowait(self) -> Optional[RequestOutput]: + """Non-blocking get operation.""" output = self.output if output is not None: self.output = None self.ready.clear() + if isinstance(output, Exception): + raise output return output @@ -104,6 +109,7 @@ def from_new_request( cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, + prompt: Optional[str], parent_req: Optional[ParentRequest], request_index: int, queue: Optional[RequestOutputCollector], @@ -118,7 +124,7 @@ def from_new_request( lora_name=(request.lora_request.name if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, - prompt=request.prompt, + prompt=prompt, prompt_token_ids=request.prompt_token_ids, logprobs_processor=LogprobsProcessor.from_new_request( tokenizer=tokenizer, @@ -220,7 +226,7 @@ class OutputProcessor: def __init__( self, - tokenizer: BaseTokenizerGroup, + tokenizer: TokenizerGroup, log_stats: bool, ): self.log_stats = log_stats @@ -235,6 +241,13 @@ def get_num_unfinished_requests(self): def has_unfinished_requests(self) -> bool: return len(self.request_states) > 0 + def propagate_error(self, e: Exception): + """Propagate error to all generate() tasks.""" + + for _, state in self.request_states.items(): + assert state.queue is not None + state.queue.put(e) + def abort_requests( self, request_ids: Iterable[str], @@ -255,6 +268,7 @@ def abort_requests( def add_request( self, request: EngineCoreRequest, + prompt: Optional[str], parent_req: Optional[ParentRequest] = None, request_index: int = 0, queue: Optional[RequestOutputCollector] = None, @@ -266,6 +280,7 @@ def add_request( req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, + prompt=prompt, parent_req=parent_req, request_index=request_index, queue=queue, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6d3290f16565..fa334302e781 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,13 +17,13 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) -from vllm.v1.structured_output.utils import ( - validate_structured_output_request_xgrammar) +from vllm.v1.structured_output.backend_xgrammar import ( + validate_xgrammar_grammar) class Processor: @@ -31,7 +31,7 @@ class Processor: def __init__( self, vllm_config: VllmConfig, - tokenizer: BaseTokenizerGroup, + tokenizer: TokenizerGroup, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -145,48 +145,52 @@ def _validate_structured_output(self, params: SamplingParams) -> None: if not params.guided_decoding or not self.decoding_config: return - supported_backends = [ - "xgrammar", "xgrammar:disable-any-whitespace", "guidance", - "guidance:disable-any-whitespace", "auto" - ] engine_level_backend = self.decoding_config.guided_decoding_backend - if engine_level_backend not in supported_backends: - raise ValueError(f"Only {supported_backends} structured output is " - "supported in V1.") if params.guided_decoding.backend: - if params.guided_decoding.backend != engine_level_backend: - raise ValueError("Request-level structured output backend " - "must match engine-level backend. " - f"{params.guided_decoding.backend}" - f" != {engine_level_backend}") + # Request-level backend selection is not supported in V1. + # The values may differ if `params` is reused and was set + # to a specific backend based on `auto` behavior in a previous + # request. We remember that it was set as a result of `auto` + # using the `_auto` option set on the backend in the params. + if (params.guided_decoding.backend != engine_level_backend + and not (engine_level_backend == "auto" and "_auto" + in params.guided_decoding.backend_options())): + raise ValueError( + "Request-level structured output backend selection is no " + "longer supported. The request specified " + f"'{params.guided_decoding.backend}', but vLLM was " + f"initialised with '{engine_level_backend}'. This error " + "can be resolved by removing backend selection from the " + "request.") else: params.guided_decoding.backend = engine_level_backend # Request content validation if engine_level_backend.startswith("xgrammar"): # xgrammar with no fallback - validate_structured_output_request_xgrammar(params) - params.guided_decoding.backend = engine_level_backend - elif engine_level_backend == "auto": + validate_xgrammar_grammar(params) + elif engine_level_backend.startswith("guidance"): + # TODO: ideally we would have the LLTokenizer here as Lark syntax + # allows <|special_token|> and similar, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens + # Without tokenizer these are disallowed in grammars. + validate_guidance_grammar(params, tokenizer=None) + else: + # NOTE: engine_level_backend must be "auto" here, because we have + # checked supported_backends above. # "auto" is an opt-in to opinionated behavior where we try to # choose a backend based on request contents. This is not the # default as it is less predictable and subject to change # between releases as feature support changes. try: - validate_structured_output_request_xgrammar(params) + validate_xgrammar_grammar(params) params.guided_decoding.backend = "xgrammar" except ValueError: # The request includes some jsonschema feature(s) that # are not supported in xgrammar. Fall back to guidance. params.guided_decoding.backend = "guidance" - - if engine_level_backend.startswith("guidance"): - # TODO ideally we would have the LLTokenizer here as Lark syntax - # allows <|special_token|> and similar, see - # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens - # Without tokenizer these are disallowed in grammars. - validate_guidance_grammar(params, tokenizer=None) - params.guided_decoding.backend = engine_level_backend + # Remember that this backend was set automatically + params.guided_decoding.add_option("_auto") def process_inputs( self, @@ -198,16 +202,10 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> EngineCoreRequest: + ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - - from vllm.platforms import current_platform - current_platform.validate_request( - prompt=prompt, - params=params, - ) self._validate_lora(lora_request) self._validate_params(params) if priority != 0: @@ -231,6 +229,12 @@ def process_inputs( prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash, ) + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request) @@ -302,9 +306,8 @@ def process_inputs( else: sorted_mm_inputs = orig_sorted_mm_inputs - return EngineCoreRequest( + return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, - prompt=decoder_inputs.get("prompt"), prompt_token_ids=decoder_inputs["prompt_token_ids"], mm_inputs=sorted_mm_inputs, mm_hashes=sorted_mm_hashes, @@ -351,7 +354,7 @@ def _validate_model_input( raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) >= max_prompt_len: + if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index e3a4cd98c1f8..3b9feb0d3298 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from concurrent.futures import Future -from typing import Union +from typing import Callable, Union import torch import torch.distributed as dist @@ -15,6 +15,8 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput +FailureCallback = Callable[[], None] + class Executor(ExecutorBase): """ @@ -62,6 +64,13 @@ def initialize_from_config(self, args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") + def register_failure_callback(self, callback: FailureCallback): + """ + Register a function to be called if the executor enters a permanent + failed state. + """ + pass + def determine_available_memory(self) -> list[int]: # in bytes output = self.collective_rpc("determine_available_memory") return output diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e854c2a44ff9..cb125bf4bf17 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,21 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 - +import multiprocessing import os import pickle import signal import sys +import threading import time import traceback import weakref +from concurrent.futures import Future from dataclasses import dataclass from enum import Enum, auto from functools import partial +from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess -from typing import Any, Callable, Optional, Union +from threading import Thread +from typing import Any, Callable, Optional, Union, cast import cloudpickle -import psutil -import zmq from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, @@ -26,8 +28,9 @@ _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, - get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) -from vllm.v1.executor.abstract import Executor + get_open_port) +from vllm.v1.executor.abstract import Executor, FailureCallback +from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -35,6 +38,8 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +EXECUTE_MODEL_TIMEOUT_S = 40 + class MultiprocExecutor(Executor): @@ -42,19 +47,9 @@ def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) - - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. - def sigusr1_handler(signum, frame): - logger.fatal( - "MulitprocExecutor got fatal signal from worker processes, " - "shutting down. See stack trace above for root cause issue.") - # Propagate error up to parent process. - parent_process = psutil.Process().parent() - parent_process.send_signal(signal.SIGUSR1) - self.shutdown() - - signal.signal(signal.SIGUSR1, sigusr1_handler) + self.is_failed = False + self.shutdown_event = threading.Event() + self.failure_callback: Optional[FailureCallback] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -78,26 +73,92 @@ def sigusr1_handler(signum, frame): scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers - self.workers: list[WorkerProcHandle] = [] - for rank in range(self.world_size): - worker = WorkerProc.make_worker_process(self.vllm_config, rank, - rank, - distributed_init_method, - scheduler_output_handle) - self.workers.append(worker) - - # Ensure message queues are ready. Will deadlock if re-ordered - # Must be kept consistent with the WorkerProc - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() + unready_workers: list[UnreadyWorkerProcHandle] = [] + success = False + try: + for rank in range(self.world_size): + unready_workers.append( + WorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + )) + + # Workers must be created before wait_for_ready to avoid + # deadlock, since worker.init_device() does a device sync. + self.workers = WorkerProc.wait_for_ready(unready_workers) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc. + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + self.start_worker_monitor() + success = True + finally: + if not success: + # Clean up the worker procs if there was a failure. + self._ensure_worker_termination( + [w.proc for w in unready_workers]) + + def start_worker_monitor(self): + workers = self.workers + self_ref = weakref.ref(self) + + # Monitors worker process liveness. If any die unexpectedly, + # logs an error, shuts down the executor and invokes the failure + # callback to inform the engine. + def monitor_workers(): + sentinels = [h.proc.sentinel for h in workers] + died = multiprocessing.connection.wait(sentinels) + _self = self_ref() + if not _self or getattr(_self, 'shutting_down', False): + return + _self.is_failed = True + proc_name = next(h.proc.name for h in workers + if h.proc.sentinel == died[0]) + logger.error( + "Worker proc %s died unexpectedly, " + "shutting down executor.", proc_name) + _self.shutdown() + callback = _self.failure_callback + if callback is not None: + _self.failure_callback = None + callback() + + Thread(target=monitor_workers, + daemon=True, + name="MultiprocWorkerMonitor").start() + + def register_failure_callback(self, callback: FailureCallback): + if self.is_failed: + callback() + else: + self.failure_callback = callback + + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + (output, ) = self.collective_rpc("execute_model", + args=(scheduler_output, ), + rank0_reply_only=True, + timeout=EXECUTE_MODEL_TIMEOUT_S) + return output def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, args: tuple = (), - kwargs: Optional[dict] = None) -> list[Any]: - start_time = time.monotonic() + kwargs: Optional[dict] = None, + rank0_reply_only: bool = False) -> list[Any]: + if self.is_failed: + raise RuntimeError("Executor failed.") + + deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} # NOTE: If the args are heterogeneous, then we pack them into a list, @@ -109,30 +170,30 @@ def collective_rpc(self, else: send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) - self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) - - responses = [None] * self.world_size - for w in self.workers: - dequeue_timeout = timeout - (time.monotonic() - start_time - ) if timeout is not None else None + self.rpc_broadcast_mq.enqueue( + (send_method, args, kwargs, rank0_reply_only)) + + workers = (self.workers[0], ) if rank0_reply_only else self.workers + responses = [None] * len(workers) + for w in workers: + dequeue_timeout = None if deadline is None else ( + deadline - time.monotonic()) status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout) + timeout=dequeue_timeout, cancel=self.shutdown_event) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( - "Worker failed with error %s, please check the" - " stack trace above for the root cause", result) + f"Worker failed with error '{result}', please check the" + " stack trace above for the root cause") responses[w.rank] = result return responses except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e - except Exception as e: - # Re-raise any other exceptions - raise e - def _ensure_worker_termination(self): + @staticmethod + def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have received termination requests. Waits for processing, then sends termination and kill signals if needed.""" @@ -150,7 +211,7 @@ def wait_for_termination(procs, timeout): return False # Send SIGTERM if still running - active_procs = [w.proc for w in self.workers if w.proc.is_alive()] + active_procs = [proc for proc in worker_procs if proc.is_alive()] for p in active_procs: p.terminate() if not wait_for_termination(active_procs, 4): @@ -159,22 +220,14 @@ def wait_for_termination(procs, timeout): for p in active_procs: p.kill() - self._cleanup_sockets() - - def _cleanup_sockets(self): - for w in self.workers: - # Remove the zmq ipc socket file - socket_path = w.ready_path.replace("ipc://", "") - if os and os.path.exists(socket_path): - os.remove(socket_path) - def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, 'shutting_down', False): self.shutting_down = True + self.shutdown_event.set() for w in self.workers: w.worker_response_mq = None - self._ensure_worker_termination() + self._ensure_worker_termination([w.proc for w in self.workers]) self.rpc_broadcast_mq = None @@ -183,13 +236,30 @@ def check_health(self) -> None: return +@dataclass +class UnreadyWorkerProcHandle: + """WorkerProcess handle before READY.""" + proc: BaseProcess + rank: int + ready_pipe: Connection + + @dataclass class WorkerProcHandle: proc: BaseProcess rank: int - ready_path: str worker_response_mq: MessageQueue # The worker process writes to this MQ + @classmethod + def from_unready_handle( + cls, unready_handle: UnreadyWorkerProcHandle, + worker_response_mq: MessageQueue) -> "WorkerProcHandle": + return cls( + proc=unready_handle.proc, + rank=unready_handle.rank, + worker_response_mq=worker_response_mq, + ) + class WorkerProc: """Wrapper that runs one Worker in a separate process.""" @@ -203,7 +273,6 @@ def __init__( rank: int, distributed_init_method: str, input_shm_handle: Handle, - ready_path: str, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -231,18 +300,8 @@ def __init__( # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) - worker_response_mq_handle = self.worker_response_mq.export_handle() - - # Send Readiness signal to EngineCore process. - # Set linger here because we want to ensure the message has - # been sent before the context is closed. - with zmq_socket_ctx(ready_path, zmq.constants.PUSH, - linger=10000) as ready_socket: - payload = pickle.dumps(worker_response_mq_handle, - protocol=pickle.HIGHEST_PROTOCOL) - ready_socket.send_string(WorkerProc.READY_STR) - ready_socket.send(payload) + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -253,12 +312,10 @@ def make_worker_process( rank: int, distributed_init_method: str, input_shm_handle, # Receive SchedulerOutput - ) -> WorkerProcHandle: + ) -> UnreadyWorkerProcHandle: context = get_mp_context() - - # ZMQ path for worker to send ready message and shm_broadcast handle - # back to core process. - ready_path = get_open_zmq_ipc_path() + # (reader, writer) + reader, writer = context.Pipe(duplex=False) process_kwargs = { "vllm_config": vllm_config, @@ -266,24 +323,57 @@ def make_worker_process( "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, - "ready_path": ready_path, + "ready_pipe": (reader, writer), } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, kwargs=process_kwargs, + name=f"VllmWorker-{rank}", daemon=True) - with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket: - proc.start() - - # Wait for startup - worker_response_mq_handle = WorkerProc.wait_for_startup( - proc, ready_socket) - - worker_response_mq = MessageQueue.create_from_handle( - worker_response_mq_handle, 0) + proc.start() + writer.close() + return UnreadyWorkerProcHandle(proc, rank, reader) - return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) + @staticmethod + def wait_for_ready( + unready_proc_handles: list[UnreadyWorkerProcHandle] + ) -> list[WorkerProcHandle]: + + e = Exception("WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause.") + + pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} + ready_proc_handles: list[Optional[WorkerProcHandle]] = ( + [None] * len(unready_proc_handles)) + while pipes: + ready = multiprocessing.connection.wait(pipes.keys()) + for pipe in ready: + assert isinstance(pipe, Connection) + try: + # Wait until the WorkerProc is ready. + unready_proc_handle = pipes.pop(pipe) + response: dict[str, Any] = pipe.recv() + if response["status"] != "READY": + raise e + + # Extract the message queue handle. + worker_response_mq = MessageQueue.create_from_handle( + response["handle"], 0) + ready_proc_handles[unready_proc_handle.rank] = ( + WorkerProcHandle.from_unready_handle( + unready_proc_handle, worker_response_mq)) + + except EOFError: + e.__suppress_context__ = True + raise e from None + + finally: + # Close connection. + pipe.close() + + return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): self.rpc_broadcast_mq = None @@ -312,51 +402,51 @@ def signal_handler(signum, frame): signal.signal(signal.SIGINT, signal_handler) worker = None + # tuple[Connection, Connection] + reader, ready_writer = kwargs.pop("ready_pipe") try: + reader.close() worker = WorkerProc(*args, **kwargs) + # Send READY once we know everything is loaded + ready_writer.send({ + "status": + WorkerProc.READY_STR, + "handle": + worker.worker_response_mq.export_handle(), + }) + # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() + ready_writer.close() + ready_writer = None worker.worker_busy_loop() - except SystemExit: - logger.debug("Worker interrupted.") - except Exception: - # worker_busy_loop sends exceptions to Executor - # for shutdown, but if there is an error in startup or an - # error with IPC itself, we need to alert the parent. - psutil.Process().parent().send_signal(signal.SIGUSR1) - raise + # NOTE: if an Exception arises in busy_loop, we send + # a FAILURE message over the MQ RPC to notify the Executor, + # which triggers system shutdown. + # TODO(rob): handle case where the MQ itself breaks. + + if ready_writer is not None: + logger.exception("WorkerProc failed to start.") + else: + logger.exception("WorkerProc failed.") + + # The parent sends a SIGTERM to all worker processes if + # any worker dies. Set this value so we don't re-throw + # SystemExit() to avoid zmq exceptions in __del__. + shutdown_requested = True finally: + if ready_writer is not None: + ready_writer.close() # Clean up once worker exits busy loop if worker is not None: worker.shutdown() - worker = None - - @staticmethod - def wait_for_startup( - proc: BaseProcess, - ready_socket: zmq.Socket, - ) -> Optional[Handle]: - """Wait until the Worker is ready.""" - - # Wait for Worker to send READY. - while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for WorkerProc to startup.") - - if not proc.is_alive(): - raise RuntimeError("WorkerProc failed to start.") - - message = ready_socket.recv_string() - assert message == WorkerProc.READY_STR - handle_frame = ready_socket.recv(copy=False) - handle = pickle.loads(handle_frame.buffer) - return handle class ResponseStatus(Enum): SUCCESS = auto() @@ -365,7 +455,7 @@ class ResponseStatus(Enum): def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs = self.rpc_broadcast_mq.dequeue() + method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): @@ -377,12 +467,14 @@ def worker_busy_loop(self): # Notes have been introduced in python 3.11 if hasattr(e, "add_note"): e.add_note(traceback.format_exc()) - logger.exception("WorkerProc hit an exception: %s", exc_info=e) + logger.exception("WorkerProc hit an exception.") # exception might not be serializable, so we convert it to # string, only for logging purpose. - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + if not rank0_only or self.rank == 0: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.FAILURE, str(e))) continue - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + if not rank0_only or self.rank == 0: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3959be40b725..7051c681b1a0 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +import logging import time from abc import ABC, abstractmethod -from typing import Optional +from typing import Callable, Optional import numpy as np import prometheus_client @@ -12,14 +13,26 @@ from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, SchedulerStats -from vllm.v1.spec_decode.metrics import SpecDecodingMetrics +from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5.0 +StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] + class StatLoggerBase(ABC): + """Interface for logging metrics. + + API users may define custom loggers that implement this interface. + However, note that the `SchedulerStats` and `IterationStats` classes + are not considered stable interfaces and may change in future versions. + """ + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + ... @abstractmethod def record(self, scheduler_stats: SchedulerStats, @@ -32,14 +45,16 @@ def log(self): # noqa class LoggingStatLogger(StatLoggerBase): - def __init__(self, engine_index: int = 0): + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() # Prefix cache metrics. This cannot be reset. # TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() - self.spec_decoding_metrics = SpecDecodingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() + self.last_prompt_throughput: float = 0.0 + self.last_generation_throughput: float = 0.0 def _reset(self, now): self.last_log_time = now @@ -68,7 +83,7 @@ def record(self, scheduler_stats: SchedulerStats, self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_metrics.observe( + self.spec_decoding_logging.observe( scheduler_stats.spec_decoding_stats) self.last_scheduler_stats = scheduler_stats @@ -83,8 +98,17 @@ def log(self): scheduler_stats = self.last_scheduler_stats + log_fn = logger.info + if not any( + (prompt_throughput, generation_throughput, + self.last_prompt_throughput, self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + self.last_generation_throughput = generation_throughput + self.last_prompt_throughput = prompt_throughput + # Format and print output. - logger.info( + log_fn( "Engine %03d: " "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " @@ -101,7 +125,7 @@ def log(self): ) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_metrics.log() + self.spec_decoding_logging.log(log_fn=log_fn) class PrometheusStatLogger(StatLoggerBase): @@ -122,6 +146,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): max_model_len = vllm_config.model_config.max_model_len + self.spec_decoding_prom = SpecDecodingProm( + vllm_config.speculative_config, labelnames, labelvalues) + # # Scheduler state # @@ -205,7 +232,10 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): prometheus_client.Histogram( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", - buckets=build_cudagraph_buckets(vllm_config), + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, + 16384 + ], labelnames=labelnames).labels(*labelvalues) self.histogram_max_num_generation_tokens_request = \ @@ -312,24 +342,6 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.labelname_running_lora_adapters, ]) - # - # Speculative Decoding metrics - # The acceptance rate can be calculated using a PromQL query: - # - # rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / - # rate(vllm:spec_decode_num_draft_tokens_total[$interval]) - # - self.counter_spec_decode_num_draft_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_draft_tokens_total", - documentation="Number of draft tokens.", - labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_accepted_tokens = \ - prometheus_client.Counter( - name="vllm:spec_decode_num_accepted_tokens_total", - documentation="Number of accepted tokens.", - labelnames=labelnames).labels(*labelvalues) - # # Cache config info metric # @@ -367,10 +379,8 @@ def record(self, scheduler_stats: SchedulerStats, scheduler_stats.prefix_cache_stats.hits) if scheduler_stats.spec_decoding_stats is not None: - self.counter_spec_decode_num_draft_tokens.inc( - scheduler_stats.spec_decoding_stats.num_draft_tokens) - self.counter_spec_decode_num_accepted_tokens.inc( - scheduler_stats.spec_decoding_stats.num_accepted_tokens) + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return @@ -460,11 +470,29 @@ def build_1_2_5_buckets(max_value: int) -> list[int]: return build_buckets([1, 2, 5], max_value) -def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]: - if not vllm_config.model_config.enforce_eager: - buckets = vllm_config.compilation_config.\ - cudagraph_capture_sizes.copy() - buckets.sort() - return buckets +def setup_default_loggers( + vllm_config: VllmConfig, + log_stats: bool, + engine_num: int, + custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, +) -> list[list[StatLoggerBase]]: + """Setup logging and prometheus metrics.""" + if not log_stats: + return [] + + factories: list[StatLoggerFactory] + if custom_stat_loggers is not None: + factories = custom_stat_loggers else: - return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] + factories = [PrometheusStatLogger] + if logger.isEnabledFor(logging.INFO): + factories.append(LoggingStatLogger) + + stat_loggers: list[list[StatLoggerBase]] = [] + for i in range(engine_num): + per_engine_stat_loggers: list[StatLoggerBase] = [] + for logger_factory in factories: + per_engine_stat_loggers.append(logger_factory(vllm_config, i)) + stat_loggers.append(per_engine_stat_loggers) + + return stat_loggers diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde5..3b9b666f936a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -20,7 +20,6 @@ class Request: def __init__( self, request_id: str, - prompt: Optional[str], prompt_token_ids: list[int], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], @@ -46,7 +45,6 @@ def __init__( assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] @@ -81,7 +79,6 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, - prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index f69623edd632..745b81ded3f1 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -72,14 +72,7 @@ def __init__(self): "best performance, please install FlashInfer.") self.forward = self.forward_native elif current_platform.is_tpu(): - if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: - logger.warning( - "TPU-specific optimization for top-k & top-p sampling are " - "disabled, falling back to PyTorch-native implementation " - "which could be very slow.") - self.forward = self.forward_native - else: - self.forward = self.forward_tpu + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -146,12 +139,22 @@ def apply_top_k_top_p_tpu( chance of being chosen during final sampling, so we can consider the tie being broken then. """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + if k is not None: - logits = apply_top_k_only(logits, k) + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) if p is not None: - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) cumprob = torch.cumsum(probs_sort, dim=-1) top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) top_p_mask[:, -1] = False # at least one @@ -224,7 +227,7 @@ def apply_top_k_only( max_top_k = k.max() # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). - k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) + k_index = k.sub_(1).unsqueeze(1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3cf7fde5cd0e..9061a64db57c 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -226,7 +226,7 @@ def rejection_sample( is_greedy, max_spec_len, vocab_size, - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids @@ -423,7 +423,7 @@ def sample_recovered_tokens( q, vocab_size, triton.next_power_of_2(vocab_size), - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, ) return recovered_token_ids @@ -490,7 +490,7 @@ def rejection_random_sample_kernel( is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) @@ -509,7 +509,7 @@ def rejection_random_sample_kernel( for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + @@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel( q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: @@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel( return vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) @@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel( recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - if IS_NGRAM: + if NO_DRAFT_PROBS: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 3950fda3e5ea..d4ea8c2dee07 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -10,8 +10,8 @@ temperature=-1.0, min_p=0.0, # strictly disabled for now - # top_k=-1, - # top_p=0.0, + top_k=0, + top_p=1.0, # frequency_penalties=0.0, # presence_penalties=0.0, # repetition_penalties=0.0, @@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata: temperature: torch.Tensor = None min_p: torch.Tensor = None - # Still too slow on forward_native! top_k: torch.Tensor = None top_p: torch.Tensor = None - # Greedy sampling flag for compiling single xla graph. all_greedy: bool = True # unsupported, you need to return an extra tensor of static size BxV @@ -99,11 +97,12 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: fill_slice(input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"]) - # TODO Temporarily disabled until sampling options are enabled - # fill_slice(input_batch.top_p_cpu_tensor) - # fill_slice(input_batch.top_k_cpu_tensor) fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_k"]) + fill_slice(input_batch.top_p_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( @@ -111,7 +110,9 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: to(xla_device), all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values - top_p=None, # input_batch.top_p[:padded_num_reqs], - top_k=None, # input_batch.top_k[:padded_num_reqs], + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( + xla_device), + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( + xla_device), min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( xla_device)) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3af6793fde74..a3ad8cb92096 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import dataclasses import pickle from collections.abc import Sequence from inspect import isclass @@ -12,12 +13,26 @@ import zmq from msgspec import msgpack +from vllm import envs +from vllm.multimodal.inputs import (BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, MultiModalFieldElem, + MultiModalFlatField, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField, NestedTensors) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 -# TODO calibrate this size -MIN_NOCOPY_BUF_SIZE = 512 +# MultiModalField class serialization type map. +# These need to list all possible field types and match them +# to factory methods in `MultiModalFieldConfig`. +MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = { + MultiModalFlatField: "flat", + MultiModalSharedField: "shared", + MultiModalBatchedField: "batched", +} bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] @@ -27,14 +42,20 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + By default, arrays below 256B are serialized inline Larger will get sent + via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self): + def __init__(self, size_threshold: Optional[int] = None): + if size_threshold is None: + size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None + self.size_threshold = size_threshold def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -59,12 +80,31 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]: def enc_hook(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): - return self._encode_ndarray(obj.numpy()) + return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, MultiModalKwargs): + mm: MultiModalKwargs = obj + if not mm.modalities: + # just return the main dict if there are no modalities. + return dict(mm) + + # ignore the main dict, it will be re-indexed. + # Encode a list of MultiModalKwargsItems as plain dicts + # + special handling for .field. + # Any tensors *not* indexed by modality will be ignored. + return [[{ + "modality": elem.modality, + "key": elem.key, + "data": self._encode_nested_tensors(elem.data), + "field": self._encode_mm_field(elem.field), + } for elem in item.values()] + for itemlist in mm._items_by_modality.values() + for item in itemlist] + if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. @@ -77,8 +117,9 @@ def _encode_ndarray( self, obj: np.ndarray ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None + # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() - if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE: + if not obj.shape or obj.nbytes < self.size_threshold: # Encode small arrays and scalars inline. Using this extension type # ensures we can avoid copying when decoding. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) @@ -92,6 +133,44 @@ def _encode_ndarray( # backing buffers that we've stashed in `aux_buffers`. return obj.dtype.str, obj.shape, data + def _encode_tensor( + self, obj: torch.Tensor + ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + assert self.aux_buffers is not None + # this creates a copy of the tensor if it's not already contiguous + obj = obj.contiguous() + # view the tensor as a 1D array of bytes + arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() + if obj.nbytes < self.size_threshold: + # Smaller tensors are encoded inline, just like ndarrays. + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + else: + # Otherwise encode index of backing buffer to avoid copy. + data = len(self.aux_buffers) + self.aux_buffers.append(arr.data) + dtype = str(obj.dtype)[6:] # remove 'torch.' prefix + return dtype, obj.shape, data + + def _encode_nested_tensors(self, nt: NestedTensors) -> Any: + if isinstance(nt, torch.Tensor): + return self._encode_tensor(nt) + if isinstance(nt, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return nt + return [self._encode_nested_tensors(x) for x in nt] + + def _encode_mm_field(self, field: BaseMultiModalField): + # Figure out the factory name for the field type. + name = MMF_CLASS_TO_FACTORY.get(field.__class__) + if not name: + raise TypeError(f"Unsupported field type: {field.__class__}") + # We just need to copy all of the field values in order + # which will be then used to reconstruct the field. + field_values = (getattr(field, f.name) + for f in dataclasses.fields(field)) + return name, *field_values + class MsgpackDecoder: """Decoder with custom torch tensor and numpy array serialization. @@ -125,14 +204,64 @@ def dec_hook(self, t: type, obj: Any) -> Any: if issubclass(t, np.ndarray): return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): - return torch.from_numpy(self._decode_ndarray(obj)) + return self._decode_tensor(obj) + if issubclass(t, MultiModalKwargs): + if isinstance(obj, list): + return MultiModalKwargs.from_items( + self._decode_mm_items(obj)) + return MultiModalKwargs({ + k: self._decode_nested_tensors(v) + for k, v in obj.items() + }) return obj def _decode_ndarray(self, arr: Any) -> np.ndarray: dtype, shape, data = arr + # zero-copy decode. We assume the ndarray will not be kept around, + # as it now locks the whole received message buffer in memory. buffer = self.aux_buffers[data] if isinstance(data, int) else data return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) + def _decode_tensor(self, arr: Any) -> torch.Tensor: + dtype, shape, data = arr + # Copy from inline representation, to decouple the memory storage + # of the message from the original buffer. And also make Torch + # not complain about a readonly memoryview. + buffer = self.aux_buffers[data] if isinstance(data, int) \ + else bytearray(data) + # Create numpy wrapper around the bytes + arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), )) + torch_dtype = getattr(torch, dtype) + assert isinstance(torch_dtype, torch.dtype) + # Convert back to proper shape & type + return torch.from_numpy(arr).view(torch_dtype).view(shape) + + def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: + decoded_items = [] + for item in obj: + elems = [] + for v in item: + v["data"] = self._decode_nested_tensors(v["data"]) + # Reconstruct the field processor using MultiModalFieldConfig + factory_meth_name, *field_args = v["field"] + factory_meth = getattr(MultiModalFieldConfig, + factory_meth_name) + v["field"] = factory_meth(None, *field_args).field + elems.append(MultiModalFieldElem(**v)) + decoded_items.append(MultiModalKwargsItem.from_elems(elems)) + return decoded_items + + def _decode_nested_tensors(self, obj: Any) -> NestedTensors: + if isinstance(obj, (int, float)): + # Although it violates NestedTensors type, MultiModalKwargs + # values are sometimes floats. + return obj + if not isinstance(obj, list): + raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") + if obj and isinstance(obj[0], str): + return self._decode_tensor(obj) + return [self._decode_nested_tensors(x) for x in obj] + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2322463c0713..1de14584d396 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -6,12 +6,18 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context +from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +logger = init_logger(__name__) + +PADDING_SLOT_ID = -1 + class EagleProposer: @@ -23,6 +29,7 @@ def __init__( self.vllm_config = vllm_config self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) + self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. @@ -48,7 +55,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 @@ -84,27 +91,25 @@ def propose( ) with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states_fwd = self.model( input_ids=input_ids, hidden_states=target_hidden_states, positions=target_positions, ) - sample_hidden_states = hidden_states[last_token_indices] + sample_hidden_states = hidden_states_logits[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] and [batch_size, 1, vocab_size] - return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + # [batch_size, 1] + return draft_token_ids.view(-1, 1) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - draft_probs_list = [draft_probs] positions = target_positions[last_token_indices] - hidden_states = sample_hidden_states + hidden_states = hidden_states_fwd[last_token_indices] attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -112,34 +117,56 @@ def propose( # Update the inputs. input_ids = draft_token_ids_list[-1] positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + + # Increment the sequence lengths. attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + # Compute the slot mapping. - block_numbers = positions // self.block_size + block_numbers = clamped_positions // self.block_size block_ids = block_table.gather(dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + - positions % self.block_size) + clamped_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) # Run the model. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states = self.model( input_ids=input_ids, hidden_states=hidden_states, - positions=positions, + positions=clamped_positions, ) - logits = self.model.compute_logits(hidden_states, None) - draft_token_ids, probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + logits = self.model.compute_logits(hidden_states_logits, None) + draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) - draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # [batch_size, num_speculative_tokens, vocab_size] - draft_probs = torch.stack(draft_probs_list, dim=1) - return draft_token_ids, draft_probs + return draft_token_ids @staticmethod def prepare_inputs( @@ -198,17 +225,34 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) - - self.model.load_weights( + if self.vllm_config.speculative_config.method == "eagle": + self.model = EagleLlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + else: + assert self.vllm_config.speculative_config.method == "eagle3" + self.model = Eagle3LlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + + loaded_weights = self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - self.model.lm_head = target_model.lm_head - - + if self.vllm_config.speculative_config.method == "eagle3": + if "model.embed_tokens.weight" not in loaded_weights: + logger.info( + "Loading EAGLE embedding weights from the target model.") + self.model.model.embed_tokens = target_model.model.embed_tokens + else: + logger.info("Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_model.lm_head + + +# NOTE(woosuk): Currently, the below code is not used and we always use argmax +# to sample the draft tokens. We will use this after we find a way to manage +# the draft prob tensor. +# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token( @@ -235,7 +279,9 @@ def compute_probs_and_sample_next_token( # TODO(woosuk): Consider seeds. q = torch.empty_like(probs) q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs + # will be used later for rejection sampling. + next_token_ids = probs.div(q).argmax(dim=-1).view(-1) if not sampling_metadata.all_random: greedy_token_ids = probs.argmax(dim=-1) next_token_ids = torch.where( diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 7bb3c209d1dc..33ce98284e20 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional import numpy as np +import prometheus_client +from vllm.config import SpeculativeConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -11,52 +14,151 @@ @dataclass class SpecDecodingStats: + """Per-step iteration decoding stats from scheduler. + + Each scheduler step, statistics on spec decoding performance are + aggregated across requests by the scheduler and returned to the + frontend in EngineCoreOutputs->SchedulerStats. + """ + + num_spec_tokens: int + num_drafts: int = 0 num_draft_tokens: int = 0 num_accepted_tokens: int = 0 + num_accepted_tokens_per_pos: list[int] = field(default_factory=list) - def take(self): - copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens) - self.reset() - return copied - - def reset(self): - self.num_draft_tokens = 0 - self.num_accepted_tokens = 0 + @classmethod + def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": + return cls(num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens) - def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): + self.num_drafts += 1 self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens + assert num_accepted_tokens <= self.num_spec_tokens + for i in range(num_accepted_tokens): + self.num_accepted_tokens_per_pos[i] += 1 + +class SpecDecodingLogging: + """Aggregate and log spec decoding metrics. -class SpecDecodingMetrics: + LoggingStatLogger aggregates per-iteration metrics over a set + time interval using observe() and then logs them using log() + before resetting to zero. + """ def __init__(self): self.reset() def reset(self): + self.num_drafts: list[int] = [] self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] + self.accepted_tokens_per_pos_lists: list[list[int]] = [] def observe(self, spec_decoding_stats: SpecDecodingStats): + self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) self.num_accepted_tokens.append( spec_decoding_stats.num_accepted_tokens) + self.accepted_tokens_per_pos_lists.append( + spec_decoding_stats.num_accepted_tokens_per_pos) - def log(self): + def log(self, log_fn=logger.info): + num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) + mean_acceptance_length = (num_accepted_tokens / num_drafts) - logger.info( + pos_matrix = np.array(self.accepted_tokens_per_pos_lists) + acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts + rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates) + + log_fn( "SpecDecoding metrics: " "Draft acceptance rate: %.1f%%, " + "Mean acceptance length: %.2f, " "Accepted: %d tokens, " - "Drafted: %d tokens", + "Drafted: %d tokens, " + "Per-position acceptance rate: %s", draft_acceptance_rate, + mean_acceptance_length, num_accepted_tokens, num_draft_tokens, + rates_str, ) self.reset() + + +class SpecDecodingProm: + """Record spec decoding metrics in Prometheus. + + The acceptance rate can be calculated using a PromQL query: + + rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + rate(vllm:spec_decode_num_draft_tokens_total[$interval]) + + The mean acceptance length can be calculated using: + + rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / + rate(vllm:spec_decode_num_drafts[$interval]) + + A per-position acceptance rate vector can be computed using + + vllm:spec_decode_num_accepted_tokens_per_pos[$interval] / + vllm:spec_decode_num_drafts[$interval] + """ + + def __init__(self, speculative_config: Optional[SpeculativeConfig], + labelnames: list[str], labelvalues: list[str]): + self.spec_decoding_enabled = speculative_config is not None + if not self.spec_decoding_enabled: + return + + self.counter_spec_decode_num_drafts = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_drafts_total", + documentation="Number of spec decoding drafts.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_draft_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames).labels(*labelvalues) + self.counter_spec_decode_num_accepted_tokens = \ + prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames).labels(*labelvalues) + + assert speculative_config is not None + num_spec_tokens = (speculative_config.num_speculative_tokens + if self.spec_decoding_enabled else 0) + pos_labelnames = labelnames + ["position"] + base_counter = prometheus_client.Counter( + name="vllm:spec_decode_num_accepted_tokens_per_pos", + documentation="Accepted tokens per draft position.", + labelnames=pos_labelnames) + self.counter_spec_decode_num_accepted_tokens_per_pos: \ + list[prometheus_client.Counter] = [] + for pos in range(num_spec_tokens): + pos_labelvalues = labelvalues + [str(pos)] + self.counter_spec_decode_num_accepted_tokens_per_pos.append( + base_counter.labels(*pos_labelvalues)) + + def observe(self, spec_decoding_stats: SpecDecodingStats): + if not self.spec_decoding_enabled: + return + self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts) + self.counter_spec_decode_num_draft_tokens.inc( + spec_decoding_stats.num_draft_tokens) + self.counter_spec_decode_num_accepted_tokens.inc( + spec_decoding_stats.num_accepted_tokens) + for pos, counter in enumerate( + self.counter_spec_decode_num_accepted_tokens_per_pos): + counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7e548bb48b57..704153d43a2b 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -18,6 +18,9 @@ def __init__(self, vllm_config: VllmConfig): # tokens follow the match, we will return the maximum amount of # tokens until the end. self.k = vllm_config.speculative_config.num_speculative_tokens + # Maximum length of the model. + self.max_model_len = vllm_config.model_config.max_model_len + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) @@ -50,9 +53,14 @@ def propose( followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ + # Do not generate draft tokens beyond the max model length. + k = min(self.k, self.max_model_len - context_token_ids.shape[0]) + if k <= 0: + return None + # TODO(woosuk): Optimize this. for n in range(self.max_n, self.min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, self.k) + result = _find_subarray_kmp(context_token_ids, n, k) if result is not None: return result return None diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 218af43deb67..0fd66c072960 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -107,3 +107,7 @@ def grammar_bitmask( # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + + def clear_backend(self) -> None: + if self.backend is not None: + self.backend.destroy() diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 9150a28570bd..1453e284b013 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +import copy +import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -29,6 +31,29 @@ logger = init_logger(__name__) +def _walk_json_for_additional_properties(data: object): + if isinstance(data, dict): + for value in data.values(): + _walk_json_for_additional_properties(value) + if 'additionalProperties' not in data and \ + ('properties' in data or 'patternProperties' in data): + data['additionalProperties'] = False + elif isinstance(data, list): + for item in data: + _walk_json_for_additional_properties(item) + + +def process_for_additional_properties( + guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + if isinstance(guide_json, str): + guide_json_obj = json.loads(guide_json) + else: + # copy for modifications + guide_json_obj = copy.deepcopy(guide_json) + _walk_json_for_additional_properties(guide_json_obj) + return guide_json_obj + + class GuidanceBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): @@ -36,14 +61,23 @@ def __init__(self, vllm_config: VllmConfig): tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() - self.disable_any_whitespace = ( - "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) + + self.disable_any_whitespace = False + self.no_additional_properties = False + backend_options = GuidedDecodingParams( + backend=vllm_config.decoding_config.guided_decoding_backend + ).backend_options() + for option in backend_options: + if option == "disable-any-whitespace": + self.disable_any_whitespace = True + elif option == "no-additional-properties": + self.no_additional_properties = True + else: + raise ValueError( + f"Unsupported option for the guidance backend: {option}") tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( @@ -52,7 +86,8 @@ def __init__(self, vllm_config: VllmConfig): def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec, self.disable_any_whitespace) + request_type, grammar_spec, self.disable_any_whitespace, + self.no_additional_properties) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -73,6 +108,9 @@ def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( max_num_seqs, self.ll_tokenizer.vocab_size) + def destroy(self): + pass + @dataclass class GuidanceGrammar(StructuredOutputGrammar): @@ -129,10 +167,15 @@ def reset(self): self.ll_matcher.reset() -def serialize_guidance_grammar(request_type: StructuredOutputOptions, - grammar_spec: str, - disable_any_whitespace: bool = False) -> str: +def serialize_guidance_grammar( + request_type: StructuredOutputOptions, + grammar_spec: Union[str, dict[str, Any]], + disable_any_whitespace: bool = False, + no_additional_properties: bool = False, +) -> str: if request_type == StructuredOutputOptions.JSON: + if no_additional_properties: + grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ @@ -151,6 +194,9 @@ def serialize_guidance_grammar(request_type: StructuredOutputOptions, tp = "grammar" elif request_type == StructuredOutputOptions.CHOICE: tp = "choice" + elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: + raise ValueError("Structural tag is not supported " + "for guidance backend yet") else: logger.error("Validation should have already occurred. " "Please file an issue.") diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 6dc2a92411de..6330bcbf20c3 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum): REGEX = enum.auto() GRAMMAR = enum.auto() CHOICE = enum.auto() + STRUCTURAL_TAG = enum.auto() StructuredOutputKey = tuple[StructuredOutputOptions, str] @@ -87,3 +88,9 @@ def allocate_token_bitmask(self, max_num_seqs: int): max_num_seqs (int): The maximum number of sequences for which to allocate the bitmask. """ + + @abstractmethod + def destroy(self): + """ + Backend-specific cleanup. + """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 83f2c6436ed2..ecaeb6e4ee80 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,19 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 +import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch import vllm.envs from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, StructuredOutputOptions) +from vllm.v1.structured_output.utils import (choice_as_grammar, + convert_lark_to_ebnf, + grammar_is_likely_lark) if TYPE_CHECKING: import xgrammar as xgr @@ -27,15 +32,21 @@ class XgrammarBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config - self.disable_any_whitespace = ( - "disable-any-whitespace" - in vllm_config.decoding_config.guided_decoding_backend) tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, - parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() + + self.disable_any_whitespace = False + backend_options = GuidedDecodingParams( + backend=vllm_config.decoding_config.guided_decoding_backend + ).backend_options() + for option in backend_options: + if option == "disable-any-whitespace": + self.disable_any_whitespace = True + else: + raise ValueError( + f"Unsupported option for the xgrammar backend: {option}") tokenizer = tokenizer_group.get_lora_tokenizer(None) self.vocab_size = vllm_config.model_config.get_vocab_size() @@ -97,6 +108,16 @@ def compile_grammar(self, request_type: StructuredOutputOptions, ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: ctx = self.compiler.compile_regex(grammar_spec) + elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: + s_tag = json.loads(grammar_spec) + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) for s in s_tag["structures"] + ] + ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) else: logger.error( "Validation should have already occurred. Please file an issue." @@ -113,6 +134,9 @@ def compile_grammar(self, request_type: StructuredOutputOptions, def allocate_token_bitmask(self, max_num_seqs: int): return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size) + def destroy(self): + del self.compiler + @dataclass class XgrammarGrammar(StructuredOutputGrammar): @@ -156,3 +180,120 @@ def is_terminated(self) -> bool: def reset(self): self.num_processed_tokens = 0 self.matcher.reset() + + +def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict[str, Any]) -> bool: + if not isinstance(obj, dict): + return False + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and ("multipleOf" in obj): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any( + key in obj + for key in ("uniqueItems", "contains", "minContains", + "maxContains", "minItems", "maxItems")): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and "format" in obj: + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any( + key in obj for key in ("minProperties", "maxProperties", + "propertyNames", "patternProperties")): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: + """Validate that the request is supported by structured output. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + + if gd_params.regex: + try: + xgr.Grammar.from_regex(gd_params.regex) + except Exception as err: + raise ValueError("Failed to transform regex into a grammar: " + f"{err}") from err + + if gd_params.choice: + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgr.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") + return + + if gd_params.grammar: + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark + try: + # parse the grammar, but we aren't compiling it. + xgr.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e + return + + if gd_params.structural_tag: + try: + s_tag = json.loads(gd_params.structural_tag) + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) for s in s_tag["structures"] + ] + xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + except Exception as e: + raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 9e54b8bf028d..6ef472eb896c 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -78,5 +78,7 @@ def get_structured_output_key( return (StructuredOutputOptions.CHOICE, json_str) elif params.grammar is not None: return (StructuredOutputOptions.GRAMMAR, params.grammar) + elif params.structural_tag is not None: + return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag) else: raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 56eed95944e2..f33f4972e103 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -2,67 +2,7 @@ from __future__ import annotations -import json import re -from typing import TYPE_CHECKING, Any - -from vllm.sampling_params import SamplingParams -from vllm.utils import LazyLoader - -if TYPE_CHECKING: - import xgrammar as xgr -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - - -def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: - """Check if JSON schema contains features unsupported by xgrammar.""" - - def check_object(obj: dict[str, Any]) -> bool: - if not isinstance(obj, dict): - return False - - # Check for pattern restrictions - if "pattern" in obj: - return True - - # Check for numeric ranges - if obj.get("type") in ("integer", "number") and any( - key in obj - for key in ("minimum", "maximum", "exclusiveMinimum", - "exclusiveMaximum", "multipleOf")): - return True - - # Check for array unsupported keywords - if obj.get("type") == "array" and any( - key in obj - for key in ("uniqueItems", "contains", "minContains", - "maxContains", "minItems", "maxItems")): - return True - - # Unsupported keywords for strings - if obj.get("type") == "string" and "format" in obj: - return True - - # Unsupported keywords for objects - if obj.get("type") == "object" and any( - key in obj for key in ("minProperties", "maxProperties", - "propertyNames", "patternProperties")): - return True - - # Recursively check all nested objects and arrays - for value in obj.values(): - if isinstance(value, dict): - if check_object(value): - return True - elif isinstance(value, list): - for item in value: - if isinstance(item, dict) and check_object(item): - return True - - return False - - return check_object(schema) def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -232,63 +172,3 @@ def escape_ebnf_string(s: str) -> str: escaped_choices = (escape_ebnf_string(c) for c in choice) grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) return grammar - - -def validate_structured_output_request_xgrammar( - sampling_params: SamplingParams) -> None: - """Validate that the request is supported by structured output. - - Raises ValueError if the request is not supported. - """ - if sampling_params.guided_decoding is None: - return - - gd_params = sampling_params.guided_decoding - - if gd_params.regex: - try: - xgr.Grammar.from_regex(gd_params.regex) - except Exception as err: - raise ValueError("Failed to transform regex into a grammar: " - f"{err}") from err - - if gd_params.choice: - choice_grammar = choice_as_grammar(gd_params.choice) - try: - xgr.Grammar.from_ebnf(choice_grammar) - except Exception as err: - raise ValueError("Failed to transform choices into a grammar: " - "{err}") from err - gd_params.choice = None - gd_params.grammar = choice_grammar - return - - if gd_params.json: - if isinstance(gd_params.json, str): - try: - schema = json.loads(gd_params.json) - except json.JSONDecodeError as e: - raise ValueError("Invalid JSON grammar specification.") from e - else: - schema = gd_params.json - - if has_xgrammar_unsupported_json_features(schema): - raise ValueError("The provided JSON schema contains features not " - "supported by xgrammar.") - return - - if gd_params.grammar: - if grammar_is_likely_lark(gd_params.grammar): - # xgrammar supports EBNF grammars only - try: - gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) - except ValueError as e: - raise ValueError( - "Failed to convert the grammar from Lark to EBNF. ") from e - - # Test parsing EBNF grammar, possibly already converted from Lark - try: - # parse the grammar, but we aren't compiling it. - xgr.Grammar.from_ebnf(gd_params.grammar) - except Exception as e: - raise ValueError("Invalid grammar specification.") from e diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 32d8101f681d..9c238c3aad8e 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) from vllm.utils import get_mp_context, kill_process_tree if TYPE_CHECKING: @@ -134,8 +136,8 @@ def shutdown(proc: Process, input_path: str, output_path: str): proc.terminate() proc.join(5) - if proc.is_alive(): - kill_process_tree(proc.pid) + if proc.is_alive() and (pid := proc.pid) is not None: + kill_process_tree(pid) # Remove zmq ipc socket files. ipc_sockets = [output_path, input_path] @@ -201,3 +203,47 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, Returns the sliced target tensor. """ return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) + + +def report_usage_stats( + vllm_config, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None: + """Report usage statistics if enabled.""" + + if not is_usage_stats_enabled(): + return + + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(vllm_config.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(vllm_config.model_config.dtype), + "tensor_parallel_size": + vllm_config.parallel_config.tensor_parallel_size, + "block_size": + vllm_config.cache_config.block_size, + "gpu_memory_utilization": + vllm_config.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + vllm_config.model_config.quantization, + "kv_cache_dtype": + str(vllm_config.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(vllm_config.lora_config), + "enable_prompt_adapter": + bool(vllm_config.prompt_adapter_config), + "enable_prefix_caching": + vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": + vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": + vllm_config.parallel_config.disable_custom_all_reduce, + }) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a64cb97e0123..c00424dfea73 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -24,7 +24,6 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - prompt: Optional[str] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 70e8bd75ec94..e3d8b94fe9d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,11 +12,13 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY @@ -36,6 +38,7 @@ ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -151,6 +154,9 @@ def __init__( self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size + # Sampler + self.sampler = Sampler() + # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] @@ -159,14 +165,17 @@ def __init__( # Set up speculative decoding. self.use_spec_decode = False + self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -239,10 +248,11 @@ def __init__( device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. + # Keep in int64 to avoid overflow with long context self.arange_np = np.arange(max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), - dtype=np.int32) + dtype=np.int64) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -337,7 +347,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -353,6 +362,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False for mm_input in self.requests[req_id].mm_inputs: if mm_input.get("image_grid_thw") is not None: image_grid_thw.extend( @@ -363,6 +374,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if mm_input.get("second_per_grid_ts") is not None: second_per_grid_ts.extend( mm_input["second_per_grid_ts"]) + if mm_input.get("audio_feature_lengths") is not None: + audio_feature_lengths.extend( + mm_input["audio_feature_lengths"]) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True hf_config = self.model_config.hf_config @@ -374,6 +390,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) req_ids_to_add.append(req_id) @@ -443,7 +461,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) + removed_req_indices.sort(reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: @@ -458,7 +476,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - if batch_changed: + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + batch_reordered = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + + if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() def _prepare_inputs( @@ -471,14 +495,6 @@ def _prepare_inputs( num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - modified_batch = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) - if modified_batch: - self.input_batch.refresh_sampling_metadata() - # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) @@ -540,9 +556,6 @@ def _prepare_inputs( # because M (max_model_len) is not necessarily divisible by block_size. block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size @@ -690,7 +703,7 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) - use_cascade = self.attn_backend.use_cascade_attention( + use_cascade = self.attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, @@ -992,18 +1005,16 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + get_kv_transfer_group().bind_connector_metadata( + scheduler_output.kv_connector_metadata) + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( self._prepare_inputs(scheduler_output)) @@ -1016,9 +1027,26 @@ def execute_model( num_scheduled_tokens) else: # Eager mode. - num_input_tokens = num_scheduled_tokens + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + from vllm.utils import round_up + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens attn_metadata.num_input_tokens = num_input_tokens + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + else: + mm_embeds = [] + if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -1061,12 +1089,18 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = output + else: + hidden_states = output + if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states @@ -1082,7 +1116,7 @@ def execute_model( # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - sampler_output = self.model.sample( + sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) @@ -1092,7 +1126,7 @@ def execute_model( # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.model.sample( + sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, ) @@ -1164,7 +1198,7 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] @@ -1192,7 +1226,12 @@ def execute_model( # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[:num_scheduled_tokens] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1213,10 +1252,17 @@ def execute_model( ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[token_indices] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - draft_token_ids, draft_probs = self.drafter.propose( + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat(target_hidden_states, dim=-1) + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1227,9 +1273,10 @@ def execute_model( sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - # TODO(woosuk): Cache draft_probs and use it for rejection sampling - # in the next step. - del draft_probs + + # Clear KVConnector state after all KVs are generated. + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1254,7 +1301,8 @@ def generate_draft_token_ids( draft_token_ids.append([]) continue - # Skip requests that require top-p, top-k, etc. + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. req_id = self.input_batch.req_ids[i] if not is_spec_decode_supported(req_id, self.input_batch): draft_token_ids.append([]) @@ -1263,6 +1311,11 @@ def generate_draft_token_ids( # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids + if end_idx >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx]) @@ -1286,6 +1339,9 @@ def load_model(self) -> None: if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) + if self.use_aux_hidden_state_outputs: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -1362,8 +1418,8 @@ def _get_prompt_logprobs_dict( tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] # Compute prompt logprobs. - logprobs = self.model.sampler.compute_logprobs(logits) - token_ids, logprobs, ranks = self.model.sampler.gather_logprobs( + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( logprobs, num_prompt_logprobs, tgt_token_ids) # Transfer GPU->CPU async. @@ -1438,12 +1494,16 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - hidden_states = model( + outputs = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] @@ -1481,8 +1541,8 @@ def _dummy_sampler_run( bad_words_token_ids={}, ) try: - sampler_output = self.model.sample( - logits=logits, sampling_metadata=dummy_metadata) + sampler_output = self.sampler(logits=logits, + sampling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): raise RuntimeError( @@ -1681,17 +1741,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2972e0ffb3ba..68c4e94fcd73 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -9,11 +9,12 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -22,6 +23,7 @@ from vllm.utils import GiB_bytes from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -53,6 +55,9 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + # Buffers saved before sleep + self._sleep_saved_buffers: dict[str, torch.Tensor] = {} + # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: @@ -72,6 +77,15 @@ def __init__( def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() + for name, buffer in model.named_buffers() + } + allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() @@ -87,6 +101,14 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: allocator = CuMemAllocator.get_instance() allocator.wake_up(tags) + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + def init_device(self): if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -110,7 +132,7 @@ def init_device(self): raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -120,6 +142,10 @@ def init_device(self): self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device) + if self.rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: @@ -285,12 +311,13 @@ def save_sharded_state( def init_worker_distributed_environment( - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, @@ -299,6 +326,8 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a8a19e0e6206..3cbab840e969 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -28,20 +28,16 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig, scheduler_config: SchedulerConfig, lora_config: LoRAConfig, device: str) -> nn.Module: - assert supports_lora( - model), f"{model.__class__.__name__} does not support LoRA yet." + if not supports_lora(model): + raise ValueError( + f"{model.__class__.__name__} does not support LoRA yet.") if supports_multimodal(model): logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") - # It's necessary to distinguish between the max_position_embeddings - # of VLMs and LLMs. - if hasattr(model.config, "max_position_embeddings"): - max_pos_embeddings = model.config.max_position_embeddings - else: - max_pos_embeddings = ( - model.config.text_config.max_position_embeddings) + # Use get_text_config() in case of multimodal models + text_config = model_config.hf_config.get_text_config() # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( @@ -52,7 +48,7 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig, device, model.embedding_modules, model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config.max_position_embeddings, ) return self.lora_manager.create_lora_manager(model) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 69251d8bbb31..67f8af29db0e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import bisect +import gc import time from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch @@ -16,20 +17,22 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, + PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata @@ -37,8 +40,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import sanity_check_mm_encoder_outputs if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -53,6 +55,41 @@ MIN_NUM_SEQS = 8 +######################################################### +# Ways to avoid recompilation +######################################################### +# +# The model executor has two primary components: +# 1. preparing the model and sampler inputs +# 2. executing the model and sampler. +# The core idea is to avoid any TPU computation during input preparation. For +# better compilation tracking and increased flexibility, the model execution and +# sampler are divided into several distinct components. +# +# Below are the detailed steps: +# +# Step 1 +# It is recommended to avoid TPU operations when preparing the model and sampler +# inputs. CPU tensors can be prepared and transferred to the XLA device using +# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids +# compilation. +# +# Step 2 +# The TPU execution should be decomposed into subgraphs (4 at the moment): +# 1. the main model +# 2. selecting hidden states for each request +# 3. sampler +# 4. encoder. +# Each subgraph should be decorated in a torch.compile. This is used to make +# sure that we have the same subgraph topology in both dummy_run and +# xecute_model. The results from these subgraphs should either be passed to +# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for +# subsequent processing on the CPU. +# +# Step 3 +# The dummy_run should be comprehensive, ensuring all potential input shapes and +# branch predictions are included as subgraph inputs to facilitate +# pre-compilation. class TPUModelRunner: def __init__( @@ -93,10 +130,16 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) + self.num_tokens_paddings = _get_token_paddings( + min_token_size=16, + max_token_size=scheduler_config.max_num_batched_tokens, + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + # In case `max_num_tokens < max(num_tokens_paddings)` use the actual + # padded max value to pre-allocate data structures and pre-compile. + self.max_num_tokens = self.num_tokens_paddings[-1] # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -106,6 +149,7 @@ def __init__( self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + self.vocab_size = model_config.get_vocab_size() # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY @@ -136,7 +180,7 @@ def __init__( max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), + vocab_size=self.vocab_size, ) # Cached torch/numpy tensor @@ -157,7 +201,7 @@ def __init__( device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.block_table_cpu = torch.zeros( - (self.max_num_tokens, self.max_num_blocks_per_req), + (self.max_num_reqs, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") @@ -175,14 +219,56 @@ def __init__( # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens - self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) - self.num_tokens_paddings = _get_token_paddings( - min_token_size=16, - max_token_size=self.max_num_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) self.num_reqs_paddings = _get_req_paddings( min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + # tensors for structured decoding + self.grammar_bitmask_cpu = torch.zeros( + (self.max_num_reqs, cdiv(self.vocab_size, 32)), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.require_structured_out_cpu = torch.zeros( + (self.max_num_reqs, 1), + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory) + self.structured_decode_arange = torch.arange( + 0, 32, device="cpu", pin_memory=self.pin_memory) + + # Get maximum number of mm items per modality (batch size). + self.max_num_mm_items_by_modality = dict() + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + max_tokens_by_modality_dict = ( + MULTIMODAL_REGISTRY. + get_max_tokens_per_item_by_nonzero_modality(self.model_config)) + for modality, max_tokens in max_tokens_by_modality_dict.items(): + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + max_num_mm_items_encoder_budget = cdiv(encoder_budget, + max_tokens) + + # Check how many items of this modality can be supported by + # the decoder budget. + max_mm_items_per_req = self.mm_registry.\ + get_mm_limits_per_prompt(self.model_config)[modality] + + # NOTE: We do not consider max_num_batched_tokens on purpose + # because the multimodal embeddings can be generated in advance + # and chunked prefilled. + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req + + max_num_mm_items = min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget) + self.max_num_mm_items_by_modality[modality] = max_num_mm_items + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -270,7 +356,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, @@ -344,11 +429,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( @@ -569,29 +653,36 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. + xm.mark_step() curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) + xm.mark_step() sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=len(grouped_mm_inputs), ) - for output in curr_group_outputs: - encoder_outputs.append(output) + if isinstance(curr_group_outputs, torch.Tensor): + encoder_outputs.append(curr_group_outputs) + else: + assert isinstance(curr_group_outputs, (list, tuple)) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. + # NOTE (NickLucche) here we diverge from logic in other runners, as we + # assume to only have whole mm items to process. Hence we avoid the + # intrinsic dynamism that `scatter_mm_placeholders` introduces. for (req_id, input_id, pos_info), output in zip( req_ids_pos, encoder_outputs, ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + assert pos_info.is_embed is None, "Expected all positions to be"\ + " contiguous and embeddings." + self.encoder_cache[req_id][input_id] = output def _gather_mm_embeddings( self, @@ -604,6 +695,10 @@ def _gather_mm_embeddings( req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + # TODO unroll loop and assume/enforce --disable_chunked_mm_input + # NOTE (NickLucche) here we diverge from logic in other runners, as + # we assume to only have whole mm items to process. Hence we avoid + # the intrinsic dynamism that `gather_mm_placeholders` introduces. for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -620,25 +715,33 @@ def _gather_mm_embeddings( # in the decoder's KV cache. continue - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] + assert pos_info.is_embed is None, "Expected all positions to"\ + " be contiguous and embeddings." encoder_output = self.encoder_cache[req_id][i] - - if (is_embed := pos_info.is_embed) is not None: - is_embed = is_embed[start_idx:end_idx] - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds.append(mm_embeds_item) + mm_embeds.append(encoder_output) return mm_embeds + def _get_model_inputs(self, input_ids: torch.Tensor, + mm_embeds: list[torch.Tensor]): + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + return None, inputs_embeds + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + return input_ids, None + @torch.no_grad() def execute_model( self, @@ -657,27 +760,13 @@ def execute_model( mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + xm.mark_step() # Prepare inputs attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( scheduler_output) - if self.is_multimodal_model: - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - self.input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(self.input_ids) - input_ids = None - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids - inputs_embeds = None + input_ids, inputs_embeds = self._get_model_inputs( + self.input_ids, mm_embeds) + xm.mark_step() num_reqs = self.input_batch.num_reqs # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): @@ -688,9 +777,16 @@ def execute_model( ) hidden_states = self.select_hidden_states(hidden_states, logits_indices) + logits = self.compute_logits(hidden_states) tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, padded_num_reqs, self.device) - selected_token_ids = self.sample_from_hidden(hidden_states, + if scheduler_output.grammar_bitmask is not None: + require_struct_decoding, grammar_bitmask_padded, arange = \ + self.prepare_structured_decoding_input(logits, scheduler_output) + logits = self.structured_decode(require_struct_decoding, + grammar_bitmask_padded, logits, + arange) + selected_token_ids = self.sample_from_logits(logits, tpu_sampling_metadata) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] @@ -853,16 +949,77 @@ def _dummy_run(self, num_tokens: int) -> None: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype + def _precompile_mm_encoder(self) -> None: + # Pre-compile MM encoder for all supported data modalities. + hf_config = self.vllm_config.model_config.hf_config + for mode, max_items_by_mode in \ + self.max_num_mm_items_by_modality.items(): + logger.info( + "Compiling Multimodal %s Encoder with different input" + " shapes.", mode) + start = time.perf_counter() + # No padding for MM encoder just yet. + for num_items in range(1, max_items_by_mode + 1): + logger.info(" -- mode: %s items: %d", mode, num_items) + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + mode, num_items) + # Run multimodal encoder. + xm.mark_step() + mm_embeds = self.model.\ + get_multimodal_embeddings(**batched_dummy_mm_inputs) + xm.mark_step() + num_patches = mm_embeds[0].shape[0] + items_size = num_patches * num_items + + # NOTE (NickLucche) pre-compile `get_input_embeddings` when mm + # embeddings are present. We assume `--disable-mm-chunked`, + # hence only whole items can be scheduled. This implies we just + # need to compile when `num_items` fit the (padded) `input_ids` + for num_tokens in self.num_tokens_paddings: + if num_tokens >= items_size: + # XLA Workaround: if torch.zeros(..device) is used, XLA + # compiles a scalar+expansion op, which won't match + # the graph generated at runtime. CPU->TPU must be used + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + # Align placeholders and actual num mm_embeddings. + placeholders_ids[:items_size] = \ + hf_config.image_token_index + + placeholders_ids = placeholders_ids.to(self.device) + # Assign outputs or the graph will be cut short. + a, b = self._get_model_inputs(placeholders_ids, + [mm_embeds]) + assert a is None + xm.mark_step() + + # Pre-compile `get_input_embeddings` when mm_embeddings are not + # present. Chunk is only made of text, no mm_placeholders. + for num_tokens in self.num_tokens_paddings: + placeholders_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device="cpu") + placeholders_ids = placeholders_ids.to(self.device) + a, b = self._get_model_inputs(placeholders_ids, []) + assert a is None + xm.mark_step() + + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal %s Encoder compilation finished in in %.2f " + "[secs].", mode, end - start) + def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") - start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) self._dummy_run(num_tokens) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) + logger.info("Compilation finished in %.2f [secs].", end - start) self._update_num_xla_graphs("model backbone") def _precompile_select_hidden_states(self) -> None: @@ -883,22 +1040,67 @@ def _precompile_select_hidden_states(self) -> None: device=self.device) torch._dynamo.mark_dynamic(indices, 0) self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d", num_tokens) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, + num_reqs) + # Requests can't be more than tokens. But do compile for the + # next bigger value in case num_tokens uses bucketed padding. + if num_reqs >= min(num_tokens, self.max_num_reqs): + break xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) + logger.info("Compilation finished in %.2f [secs].", end - start) self._update_num_xla_graphs("select_hidden_states") - def _precompile_sample_from_hidden(self) -> None: - logger.info("Compiling sampling with different input shapes.") + def _precompile_compute_logits(self) -> None: + logger.info("Compiling compute_logits with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_reqs in self.num_reqs_paddings: dummy_hidden = torch.zeros((num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype) - # The first dimension of dummy_hidden cannot be mark_dynamic because - # some operations in the sampler require it to be static. + torch._dynamo.mark_dynamic(dummy_hidden, 0) + self.compute_logits(dummy_hidden) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("compute_logits") + + def _precompile_structured_decoding(self) -> None: + logger.info( + "Compiling structured_decoding with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + dummy_require_struct_decoding = \ + self.require_structured_out_cpu[:num_reqs].to(self.device) + dummy_grammar_bitmask = \ + self.grammar_bitmask_cpu[:num_reqs].to(self.device) + # The first dimension of the above 3 dummy tensors cannot be + # mark_dynamic because some operations in structured_decode require + # them to be static. + arange = self.structured_decode_arange.to(self.device) + self.structured_decode(dummy_require_struct_decoding, + dummy_grammar_bitmask, dummy_logits, arange) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("structured_decoding") + + def _precompile_sample_from_logits(self) -> None: + logger.info( + "Compiling sample_from_logits with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + # The first dimension of dummy_logits cannot be mark_dynamic + # because some operations in the sampler require it to be static. for all_greedy in [False, True]: generate_params_if_all_greedy = not all_greedy sampling_metadata = ( @@ -909,21 +1111,82 @@ def _precompile_sample_from_hidden(self) -> None: generate_params_if_all_greedy, )) sampling_metadata.all_greedy = all_greedy - self.sample_from_hidden(dummy_hidden, sampling_metadata) + self.sample_from_logits(dummy_logits, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) - self._update_num_xla_graphs("sampling") + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("sample_from_logits") def capture_model(self) -> None: """ Precompile all the subgraphs with possible input shapes. """ - # TODO: precompile encoder + self._precompile_mm_encoder() self._precompile_backbone() self._precompile_select_hidden_states() - self._precompile_sample_from_hidden() + self._precompile_compute_logits() + self._precompile_structured_decoding() + self._precompile_sample_from_logits() + + def profile_run( + self, + num_tokens: int, + ) -> None: + # Profile with multimodal encoder & encoder cache. + # TODO: handle encoder-decoder models once we support them. + if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 + and self.encoder_cache_size > 0): + + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_data_modality, max_num_mm_items = max( + self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) + + encoder_budget = min(self.max_num_encoder_input_tokens, + self.encoder_cache_size) + + logger.info( + "Encoder cache will be initialized with a budget of %d tokens," + " and profiled with %s %s items of the maximum feature size.", + encoder_budget, max_num_mm_items, dummy_data_modality) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_data_modality, max_num_mm_items) + + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in in %.2f [secs].", + end - start) + + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + + # Trigger compilation for general shape. + self._dummy_run(num_tokens) + + xm.mark_step() + xm.wait_device_ops() + self.encoder_cache.clear() + gc.collect() def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -945,7 +1208,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, FullAttentionSpec): + if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -980,16 +1243,14 @@ def select_hidden_states(self, hidden_states, indices_do_sample): return hidden_states[indices_do_sample] @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def sample_from_hidden( - self, - sample_hidden_states: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata, - ) -> torch.Tensor: - """ - Sample with xla-friendly function. This function is to be traced - separately from `forward` for lighter compilation overhead. - """ - logits = self.model.compute_logits(sample_hidden_states, None) + def compute_logits(self, + sample_hidden_states: torch.Tensor) -> torch.Tensor: + return self.model.compute_logits(sample_hidden_states, None) + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def sample_from_logits( + self, logits: torch.Tensor, + sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: if sampling_metadata.all_greedy: out_tokens = torch.argmax(logits, dim=-1, keepdim=True) else: @@ -997,12 +1258,101 @@ def sample_from_hidden( sampling_metadata).sampled_token_ids return out_tokens + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def structured_decode(self, require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, logits: torch.Tensor, + arange: torch.Tensor) -> torch.Tensor: + return torch.where( + require_struct_decoding, + self.apply_grammar_bitmask(logits, grammar_bitmask, arange), + logits) + + def apply_grammar_bitmask(self, logits: torch.Tensor, + grammar_bitmask: torch.Tensor, + arange: torch.Tensor): + assert (logits.shape[0] == grammar_bitmask.shape[0]) + logits_cloned = logits.clone() + for i in range(logits.shape[0]): + unpacked_bitmask = (torch.bitwise_right_shift( + grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 + unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] + logits_cloned[i] = logits_cloned[i].masked_fill( + unpacked_bitmask, -float("inf")) + return logits_cloned + def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs) def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) + def prepare_structured_decoding_input( + self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + grammar_bitmask = scheduler_output.grammar_bitmask + assert grammar_bitmask is not None + num_reqs, _ = logits.shape + + # Reset pre-allocated tensors + self.grammar_bitmask_cpu.zero_() + self.require_structured_out_cpu.zero_() + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the tpu runner is + # ordering the requests in the batch. We need to match the order of + # bitmask with the order of requests + struct_out_indices: list[int] = [] + mask_indices: list[int] = [] + for req_id in self.input_batch.req_ids: + mask_index = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_index is None: + continue + batch_index = self.input_batch.req_id_to_index[req_id] + struct_out_indices.append(batch_index) + mask_indices.append(mask_index) + self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( + grammar_bitmask[mask_indices]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) + self.require_structured_out_cpu[struct_out_indices] = True + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ + self.structured_decode_arange.to(logits.device) + + def _get_mm_dummy_batch(self, modality: str, + batch_size: int) -> BatchedTensorInputs: + # Dummy data for pre-compiling multimodal models. + dummy_request_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + assert isinstance(dummy_mm_data, MultiModalKwargs), ( + "Expected dummy multimodal data to be of type " + f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. " + "This is most likely due to the model not having a merged " + "processor.") + + # When models have a merged processor, their dummy data is + # already batched `MultiModalKwargs`, therefore we take the first + # `MultiModalKwargsItem` from the desired modality to profile on. + dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + + batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * + batch_size) + return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, + device=self.device) + def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: logger.info("Preparing request paddings:") @@ -1040,11 +1390,12 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, if padding_gap == 0: logger.info("Using exponential token paddings:") - while num <= max_token_size: + while True: logger.info(" %d", num) paddings.append(num) + if num >= max_token_size: + break num *= 2 - else: logger.info("Using incremental token paddings:") while num <= padding_gap: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 73c43969b87b..de676541effa 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache +from vllm.v1.utils import bind_kv_cache, report_usage_stats from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) @@ -133,6 +133,10 @@ def init_device(self): # Init ModelRunner here, so that we have access to self.device. self.model_runner = TPUModelRunner(self.vllm_config, self.device) + if rank == 0: + # If usage stat is enabled, collect relevant info. + report_usage_stats(self.vllm_config) + def determine_available_memory(self) -> int: kv_caches: dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() @@ -156,8 +160,8 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner._dummy_run( - self.scheduler_config.max_num_batched_tokens) + # `max_num_tokens >= max_num_batched_tokens` due to padding. + self.model_runner.profile_run(self.model_runner.max_num_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 85ebe8121e52..d48a6957c5dd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -71,19 +71,32 @@ def _allocate_kv_cache( device: str, ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] + try: + kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( + ) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape))) + + # The allocation respects the backend-defined stride order to ensure + # the semantic remains consistent for each backend. We first obtain the + # generic kv cache shape and then permute it according to the stride + # order which could result in a non-contiguous tensor. + kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i] + for i in kv_cache_stride_order) for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device) + layer_kv_cache = torch.zeros( + kv_cache_allocation_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device).permute(*kv_cache_stride_order) # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # when entry_shape is higher than 1D diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index ac7c93e48395..c2120c035175 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -316,7 +316,7 @@ def execute_model( return [] # Sample the next token. - output = self.model.sample( + output = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 9f4b18869bdf..710ca1a13b0c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -19,11 +19,11 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap) +from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs, + MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.worker.model_runner_base import ( @@ -154,7 +154,6 @@ def __init__(self, self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None if self.runner.attn_backend is not None: # spec decode (e.g. Medusa) does not have atten backend @@ -359,22 +358,14 @@ def _compute_multi_modal_input(self, computed_len = seq_data.get_num_computed_tokens() seq_len = self.input_data.seq_lens[-1] - # NOTE: mm_data only includes the subset of multi-modal items that + # NOTE: mm_kwargs only includes the subset of multi-modal items that # intersect with the current prefill positions. - mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(computed_len, seq_len)) - if not mm_data: + if not mm_kwargs: return - if self.runner.mm_registry.has_processor(self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - # special processing for mrope position deltas. if self.runner.model_config.uses_mrope: assert not self.chunked_prefill, \ @@ -382,11 +373,17 @@ def _compute_multi_modal_input(self, image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") + audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", + None) + assert ( + image_grid_thw is not None or video_grid_thw is not None + or audio_feature_lengths is not None), ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw' or " + "'audio_feature_lengths'.") second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) hf_config = self.runner.model_config.hf_config token_ids = seq_data.get_token_ids() @@ -398,6 +395,8 @@ def _compute_multi_modal_input(self, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, context_len=computed_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) seq_data.mrope_position_delta = mrope_position_delta @@ -472,16 +471,11 @@ def __init__( use_mla=self.model_config.use_mla, ) if needs_attn_backend else None - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) - # Lazy initialization. self.model: nn.Module # Set after init_Model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.sampler = get_sampler() if hasattr(self, "_builder_cls"): # multi-step model runner does not have `_builder_cls` @@ -499,13 +493,8 @@ def load_model(self) -> None: logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") - # It's necessary to distinguish between the max_position_embeddings - # of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = self.model.config.max_position_embeddings - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -515,7 +504,7 @@ def load_model(self) -> None: self.device, self.model.embedding_modules, self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config.max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) @@ -537,11 +526,6 @@ def _prepare_model_input_tensors( return self.builder.build() # type: ignore - # sampler property will be used by spec_decode_worker - @property - def sampler(self): - return self.model.sampler - @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @@ -669,7 +653,7 @@ def execute_model( return [] # Sample the next token. - output = self.model.sample( + output = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 72ff9d66a689..4df192a8727c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -100,6 +100,8 @@ def __init__( vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, + input_registry=input_registry, + mm_registry=mm_registry, ) # Crash for unsupported encoder/scenarios @@ -205,7 +207,7 @@ def execute_model( model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7a346b34cef5..e25864349e28 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,11 +11,9 @@ import gc import itertools import math -import operator import os import time from array import array -from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -24,8 +22,9 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch import torch.nn as nn +import vllm_hpu_extension.environment as environment +from vllm_hpu_extension.bucketing.common import get_bucketing_context from vllm_hpu_extension.ops import LoraMask as LoraMask -from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) @@ -41,13 +40,12 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SequenceGroupToSample -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, @@ -74,24 +72,7 @@ LORA_WARMUP_RANK = 8 - -class Singleton(type): - _instances: Dict[type, object] = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] - - -@dataclass -class HPUBucketingGlobalState(metaclass=Singleton): - prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_buckets: List[Tuple[int, int]] = field(init=False) - decode_buckets: List[Tuple[int, int]] = field(init=False) +DUMMY_TOKEN_ID = -1 def subtuple(obj: object, @@ -113,134 +94,10 @@ def subtuple(obj: object, return _TYPE_CACHE[typename](**values) -def read_bucket_settings(phase: str, dim: str, **defaults): - """Read bucketing configuration from env variables. - - phase is either 'prompt' or 'decode' - dim is either 'bs', 'seq' or 'block' - param is either 'min', 'step' or 'max' - example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 - """ - params = ['min', 'step', 'max'] - env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] - default_values = [defaults[p] for p in params] - values = [ - int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) - ] - for e, v, d in zip(env_vars, values, default_values): - logger.info('%s=%s (default:%s)', e, v, d) - return values - - -def warmup_range(config: Tuple[int, int, int]): - """Generate a warmup range. - - Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you - reach bmax. - - Example: - bmin = 2, bstep = 32, bmax = 64 - => ramp_up = (2, 4, 8, 16) - => stable = (32, 64) - => return ramp_up + stable => (2, 4, 8, 16, 32, 64) - """ - bmin, bstep, bmax = config - assert bmin <= bmax, ("Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true") - base = itertools.repeat(2) - ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) - stable = range(bstep, bmax + 1, bstep) - buckets = list(ramp_up_tw) + list(stable) - return list(filter(lambda bucket: bucket >= bmin, buckets)) - - -def generate_prompt_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): - buckets = list( - itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config))) - if len(buckets) == 0: - msg = ("No buckets could be captured with following config " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config}") - raise ValueError(msg) - - filtered_buckets = buckets - if max_num_batched_tokens is not None: - # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets)) - - if len(filtered_buckets) == 0: - # we can handle this if we ignore max_num_batched_tokens - min_bucket_bs, min_bucket_seq = min(buckets, - key=lambda b: (b[0] * b[1])) - min_reqd_budget = min_bucket_bs * min_bucket_seq - msg = ( - "The current bucketing configuration " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config} cannot be used with specified " - f"max_num_batched_tokens ({max_num_batched_tokens}), as the " - f"smallest bucket ({min_reqd_budget}) would exceed token " - "budget. Please increase max_num_batched_tokens or decrease " - "bucket minimum Ignoring max_num_batched_tokens at risk of " - "out-of-memory errors.") - logger.error(msg) - return list( - sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] - - captured_buckets = list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - omitted_buckets = list( - sorted([x for x in buckets if x not in filtered_buckets])) - return captured_buckets, omitted_buckets - - -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, - max_blocks): - buckets = [] - bs_buckets = warmup_range(bs_bucket_config) - block_buckets = warmup_range(blocks_bucket_config) - bmin, bstep, bmax = blocks_bucket_config - last_bucket = round_up(max_blocks, bstep) - for bs in bs_buckets: - for blocks in block_buckets: - if blocks < bs: - continue - if blocks > last_bucket: - break - buckets.append((bs, blocks)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - - -def next_pow2(value: int, base: int): - res = base - while value > 1: - value = (value + 1) // 2 - res *= 2 - return res - - def round_up(value: int, k: int): return (value + k - 1) // k * k -def find_bucket(value: int, config: Tuple[int, int, int]): - bmin, bstep, _ = config - next_step = round_up(value, bstep) - next_pow = next_pow2(value, bmin) - return max(bmin, min(next_step, next_pow)) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -314,6 +171,7 @@ class HpuModelAdapter: def __init__(self, model, vllm_config): self.model = model + self.sampler = get_sampler() self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] self.vllm_config = vllm_config @@ -403,16 +261,6 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): attn_bias=attn_bias) return metadata - def _set_block_scales(self, metadata, device): - block_mapping = metadata.block_mapping - ones = torch.ones((block_mapping.size(0), ), - device=device, - dtype=block_mapping.dtype) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) - metadata = metadata._replace(block_scales=block_scales) - return metadata - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: @@ -423,7 +271,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, meta = attn_metadata attn_metadata = self._set_block_mapping(meta, batch_size, device, dtype) - attn_metadata = self._set_block_scales(attn_metadata, device) return attn_metadata def forward(self, *args, **kwargs): @@ -452,7 +299,7 @@ def compute_logits(self, *args, **kwargs): return self.model.compute_logits(*args, **kwargs) def sample(self, *args, **kwargs): - return self.model.sample(*args, **kwargs) + return self.sampler(*args, **kwargs) class PreparePromptMetadata(NamedTuple): @@ -622,6 +469,7 @@ def __init__( return_hidden_states: bool = False, ): ModelRunnerBase.__init__(self, vllm_config=vllm_config) + environment.set_model_config(self.model_config) self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states @@ -661,13 +509,21 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self.bucketing_global_state = HPUBucketingGlobalState() - self._setup_buckets() + HPUBucketingContext = get_bucketing_context() + self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, + self.max_num_prefill_seqs, + self.block_size, + self.max_num_batched_tokens, + False, self.max_model_len) + self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = [] + # For delayed sampling + self.cached_step_inputs: List[ + ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -688,10 +544,6 @@ def _set_gc_threshold(self) -> None: ] gc.set_threshold(*requested_gc_thrs) - # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) - self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' @@ -718,14 +570,9 @@ def load_model(self) -> None: "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." - # It's necessary to distinguish between the - # max_position_embeddings of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = ( - self.model.config.max_position_embeddings) - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -735,7 +582,8 @@ def load_model(self) -> None: self.device, self.model.embedding_modules, self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config. + max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) @@ -771,6 +619,27 @@ def load_model(self) -> None: msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) + def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): + real_batch_size = len(seq_group_metadata_list) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, is_prompt) + batch_size_padding = batch_size_padded - real_batch_size + + seq_group_metadata_list = seq_group_metadata_list.copy() + + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) + return seq_group_metadata_list, real_batch_size, batch_size_padded + + def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( + *args, **kwargs) + def get_model(self) -> nn.Module: return self.model @@ -784,46 +653,6 @@ def _use_graphs(self, batch_size, seq_len, is_prompt): def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens - def _setup_buckets(self) -> None: - align_bs = lambda x: min(self.max_num_seqs, x) - #FIXME: The default values should be max_model_len - max_prompt_seq = 1024 - max_decode_seq = 2048 - self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_prefill_seqs) - self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( - 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) - self.bucketing_global_state.prompt_seq_bucket_cfg = \ - read_bucket_settings( - 'prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.bucketing_global_state.decode_block_bucket_cfg = \ - read_bucket_settings( - 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) - self.graphed_buckets: Set[Any] = set() - - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " - f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") - logger.info(msg) - - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " - f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") - logger.info(msg) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -897,9 +726,8 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = seq_group_metadata.multi_modal_data + if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: @@ -939,8 +767,7 @@ def _prepare_prompt( assert max_query_len > 0 max_prompt_len = max( - find_bucket(max(seq_lens), - self.bucketing_global_state.prompt_seq_bucket_cfg), + self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), self.block_size) lora_ids: List[int] = [] @@ -989,7 +816,6 @@ def _prepare_prompt( block_usage=None, block_indices=block_indices, block_offsets=block_offsets, - block_scales=None, block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, @@ -1116,9 +942,8 @@ def _prepare_decode( padding_fn = None if self.use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - block_bucket_size = find_bucket( - block_bucket_size, - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -1126,9 +951,9 @@ def _prepare_decode( padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) else: - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = \ + self.bucketing_ctx.get_padded_decode_num_blocks( + len(block_list)) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) @@ -1159,7 +984,6 @@ def _prepare_decode( block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=None, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, @@ -1202,17 +1026,8 @@ def prepare_input_tensors( base_event_name = 'prompt' if is_prompt else 'decode' self.profiler.start('internal', base_event_name) - real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ - if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg - batch_size_padded = find_bucket(real_batch_size, bucket_cfg) - batch_size_padding = batch_size_padded - real_batch_size - seq_group_metadata_list = seq_group_metadata_list.copy() - if batch_size_padding > 0: - dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - 0, 0, is_prompt) - seq_group_metadata_list.extend(dummy_seq_group_metadata - for _ in range(batch_size_padding)) + seq_group_metadata_list, real_batch_size, batch_size_padded = ( + self._add_dummy_seq(seq_group_metadata_list, is_prompt)) prefill_reqs = [] decode_reqs = [] @@ -1374,7 +1189,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales', 'block_groups' + 'block_offsets', 'block_groups' ]) return attention_metadata @@ -1412,16 +1227,18 @@ def profile_run(self) -> None: bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, [kv_caches]) - max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] - max_batch_size = min(self.max_num_batched_tokens // max_seq_len, - self.scheduler_config.max_num_seqs) - self.warmup_scenario(max_batch_size, max_seq_len, True, False, True) + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_batch_size = min(self.max_num_seqs, + self.max_num_batched_tokens // max_seq_len) + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, + False, True) return def warmup_scenario(self, batch_size, seq_len, is_prompt, + kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) @@ -1557,16 +1374,17 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt): + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, buckets, is_prompt, + kv_caches, available_mem, starting_mem=0, total_batch_seq=0.001): @@ -1598,7 +1416,7 @@ def warmup_graphs(self, self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem @@ -1622,50 +1440,21 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_decode_buckets(max_blocks) if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, True) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) raise AssertionError("Finished profiling") - if self.skip_warmup: - logger.info("Skipping warmup...") - return - self.profiler.start('internal', 'warmup') - max_blocks = kv_caches[0][0].size(0) - - self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ - generate_prompt_buckets( - self.bucketing_global_state.prompt_bs_bucket_cfg, - self.bucketing_global_state.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " - f"prompt buckets [bs, seq]: \ - {list(sorted(self.bucketing_global_state.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) - - self.bucketing_global_state.decode_buckets = generate_decode_buckets( - self.bucketing_global_state.decode_bs_bucket_cfg, - self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.bucketing_global_state.decode_buckets), - list(sorted(self.bucketing_global_state.decode_buckets))) - if not htorch.utils.internal.is_lazy() and not self.enforce_eager: cache_size_limit = 1 + 3 * ( - len(self.bucketing_global_state.prompt_buckets) + - len(self.bucketing_global_state.decode_buckets)) + len(self.bucketing_ctx.prompt_buckets) + + len(self.bucketing_ctx.decode_buckets)) torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1673,7 +1462,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: torch._dynamo.config.accumulated_cache_size_limit = max( cache_size_limit * 8, torch._dynamo.config.accumulated_cache_size_limit) - + if self.skip_warmup: + logger.info("Skipping warmup...") + return + self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -1692,10 +1484,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, - True) - self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, - False) + print("aa") + self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, + kv_caches) + print("bb") + self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, + kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1725,12 +1519,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.bucketing_global_state.prompt_buckets, - True, prompt_available_memory) + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.bucketing_global_state.decode_buckets, - False, decode_available_memory) + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1739,8 +1533,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, - self.bucketing_global_state.prompt_buckets, True, + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1751,17 +1545,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, - self.bucketing_global_state.decode_buckets, False, + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary( - self.bucketing_global_state.prompt_buckets, True, - mem_post_prompt) + self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) self.log_graph_warmup_summary( - self.bucketing_global_state.decode_buckets, False, - mem_post_decode) + self.bucketing_ctx.decode_buckets, False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() @@ -2020,6 +1812,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], return lora_mask, lora_logits_mask + def _get_seq_ids(self, model_input): + return ([ + sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups + ]) + + def _pad_to_max_num_seqs(self, tensor, value): + padding_needed = self.max_num_seqs - tensor.size(0) + if padding_needed: + padding = torch.full((padding_needed, *tensor.shape[1:]), + value, + device=tensor.device, + dtype=tensor.dtype) + tensor = torch.cat([tensor, padding]) + return tensor + @torch.inference_mode() def execute_model( self, @@ -2030,6 +1837,37 @@ def execute_model( warmup_mode=False, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING + use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode + assert not (use_delayed_sampling and num_steps != 1), \ + 'Delayed sampling is not compatible with MSS!' + assert model_input.input_tokens is not None + if use_delayed_sampling and not model_input.is_prompt and \ + self.is_driver_worker: + num_cached = len(self.cached_step_outputs) + assert num_cached > 0 + cur_seq_ids = self._get_seq_ids(model_input) + cur_seq_id_pos = { + sid: idx + for idx, sid in enumerate(cur_seq_ids) if sid >= 0 + } + htorch.core.mark_step() + for i in range(num_cached): + prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) + target_indices = [ + cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids + ] + padding = self.cached_step_outputs[i].size(0) - len( + target_indices) + target_indices.extend([-1] * padding) + target_indices = torch.tensor( + target_indices, + device=model_input.input_tokens.device, + dtype=model_input.input_tokens.dtype) + model_input.input_tokens.index_copy_( + 0, target_indices, self.cached_step_outputs[i]) + htorch.core.mark_step() + if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step @@ -2045,7 +1883,21 @@ def execute_model( assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - input_tokens = model_input.input_tokens + # Rank!=0 workers has is_prompt==None + if use_delayed_sampling and not model_input.is_prompt and \ + model_input.input_tokens.size(1) == 1: + if self.is_driver_worker: + model_kwargs_broadcast_data = { + "input_tokens": model_input.input_tokens + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + input_tokens = model_input.input_tokens + + else: + model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) + input_tokens = model_kwargs_broadcast_data["input_tokens"] + else: + input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata sampling_metadata = model_input.sampling_metadata @@ -2092,11 +1944,11 @@ def execute_model( f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - if num_steps > 1: + if num_steps > 1 or use_delayed_sampling: # in case of multi-step scheduling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True - self.model.model.sampler.include_gpu_probs_tensor = True + self.model.sampler.include_gpu_probs_tensor = True cache_orig_output_tokens_len: List[Dict] = [] def try_revert_dummy_output_tokens(): @@ -2152,9 +2004,9 @@ def try_revert_dummy_output_tokens(): if not self.is_driver_worker: continue - if model_input.async_callback is not None: - model_input.async_callback() - # Sample the next token. + if use_delayed_sampling: + fake_output = self._delayed_sampler_outputs(model_input) + with self.profiler.record_event( 'internal', ('sample_' f'{"prompt" if is_prompt else "decode"}_' @@ -2166,9 +2018,16 @@ def try_revert_dummy_output_tokens(): ) if num_steps > 1: output = output.sampled_token_ids - self.cached_step_outputs.append( - output.detach().clone()) + self.cached_step_outputs.append(output) + if use_delayed_sampling and self.is_driver_worker: + self._patch_prev_output() + output = self._pad_to_max_num_seqs( + output.sampled_token_ids, DUMMY_TOKEN_ID) + self.cached_step_outputs.append(output) + self.cached_step_inputs.append(model_input) htorch.core.mark_step() + if model_input.async_callback is not None: + model_input.async_callback() if i < num_steps - 1: if i == 0: if model_input.async_callback is not None: @@ -2241,11 +2100,30 @@ def try_revert_dummy_output_tokens(): is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states + if use_delayed_sampling: + if self.is_driver_worker: + return [fake_output] + else: + return [] + return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] + def _delayed_sampler_outputs(self, model_input): + next_token_ids = [[DUMMY_TOKEN_ID]] * len( + model_input.sampling_metadata.seq_groups) + sampler_output = self._make_decode_output( + next_token_ids, model_input.sampling_metadata.seq_groups) + return sampler_output + def _decode_sampler_outputs(self, model_input): use_async_out_proc = model_input.async_callback is not None sampler_outputs = [] @@ -2312,3 +2190,32 @@ def shutdown_inc(self): def __del__(self): self.shutdown_inc() + + def _patch_prev_output(self): + assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ + f'''Inputs and outputs are out of sync! + {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}''' + if len(self.cached_step_inputs) == 0: + return + model_input = self.cached_step_inputs.pop(0) + delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( + -1).tolist() + ctx = model_input.async_callback.keywords["ctx"] # type: ignore + # If there's no output to patch with, which is usually the case when + # we're starting a new request after all requests are completed. + if len(ctx.output_queue) == 0: + return + assert len( + ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' + output_data = ctx.output_queue[0] + assert len(output_data.outputs) == 1 + for fake_out, real_out in zip(output_data.outputs[0], delayed_output): + fake_out.samples[0].output_token = real_out + for sg, real_out in zip(output_data.seq_group_metadata_list, + delayed_output): + assert len(sg.seq_data) == 1 + seq_data = list(sg.seq_data.values())[0] + # This is a hack. Assigning output_token_ids triggers + # a cache recomputation and we only need to update the last token + seq_data.output_token_ids_array[-1] = real_out + seq_data._cached_all_token_ids[-1] = real_out diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index ccb175d88fd3..8d7d5d7adc10 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -245,6 +245,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: cache_block_size) num_hpu_blocks = max(num_hpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks if self.model_runner.lora_manager: self.model_runner.remove_all_loras() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9524a69f6b3a..73e0eff9a8b7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,8 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_kv_transfer_group, get_pp_group +from vllm.distributed import get_pp_group +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, graph_capture) from vllm.forward_context import get_forward_context, set_forward_context @@ -34,7 +35,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -456,7 +457,6 @@ def __init__(self, self.enable_lora = self.runner.lora_config is not None self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper # Attention metadata inputs. if self.attn_backend is not None: @@ -674,23 +674,15 @@ def _compute_prompt_adapter_input( def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" - # NOTE: mm_data only includes the subset of multi-modal items that + # NOTE: mm_kwargs only includes the subset of multi-modal items that # intersect with the current prefill positions. positions = inter_data.input_positions[0] - mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(positions[0], positions[0] + len(positions))) - if not mm_data: + if not mm_kwargs: return - if self.runner.mm_registry.has_processor(self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps @@ -698,11 +690,17 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if self.runner.model_config.uses_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") + audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", + None) + assert ( + image_grid_thw is not None or video_grid_thw is not None + or audio_feature_lengths is not None), ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw' or " + "'audio_feature_lengths'.") second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) hf_config = self.runner.model_config.hf_config inter_data.mrope_input_positions = [None] * inter_data.n_seqs @@ -720,6 +718,8 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, second_per_grid_ts=second_per_grid_ts, context_len=inter_data.context_lens[seq_idx], seq_len=inter_data.seq_lens[seq_idx], + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) seq_data.mrope_position_delta = mrope_position_delta @@ -1076,15 +1076,13 @@ def __init__( # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None + self.sampler = get_sampler() set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -1120,14 +1118,9 @@ def load_model(self) -> None: logger.warning( "Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") - # It's necessary to distinguish between the - # max_position_embeddings of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = ( - self.model.config.max_position_embeddings) - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) + + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, @@ -1137,7 +1130,8 @@ def load_model(self) -> None: self.device, self.model.embedding_modules, self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, + max_position_embeddings=text_config. + max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) time_after_load = time.perf_counter() @@ -1321,8 +1315,8 @@ def _dummy_run(self, dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) + seq_len, + self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), @@ -1823,7 +1817,7 @@ def execute_model( model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 7ddf382079c6..a6f5ec825635 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -488,8 +488,7 @@ def execute_model( device="cpu", pin_memory=True) - self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( - True) + self._base_model_runner.sampler.include_gpu_probs_tensor = True if frozen_model_input.sampling_metadata: frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( True) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f2093fc42ad1..e046ebc449de 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -15,8 +15,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -69,11 +68,6 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -149,16 +143,8 @@ def _prepare_prompt( assert len(block_table) == 1 input_block_ids.append(block_table[0]) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - if self.mm_registry.has_processor(self.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - + mm_kwargs = seq_group_metadata.multi_modal_data + if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 71b4b38fb9d6..bbcc4d59ae1c 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -163,8 +163,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size_bytes = (dtype_btyes * self.cache_config.block_size * + dtype_bytes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_bytes * self.cache_config.block_size * num_layers * 2 * head_size * num_kv_heads) num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d59f20f49996..78ea990de820 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,10 +10,10 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator -from vllm.distributed import (ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -95,6 +95,9 @@ def __init__( self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + # Buffers saved before sleep + self._sleep_saved_buffers: Dict[str, torch.Tensor] = {} + # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: @@ -124,6 +127,15 @@ def stop_profile(self): def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() + for name, buffer in model.named_buffers() + } + allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() @@ -139,6 +151,14 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: allocator = CuMemAllocator.get_instance() allocator.wake_up(tags=tags) + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9d49b4385dca..7042b575aa78 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadataCache -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, @@ -188,20 +188,11 @@ def _prepare_prompt( input_positions.extend(list(positions_range)) if seq_group_metadata.multi_modal_data: - # NOTE: mm_data only includes the subset of multi-modal items + # NOTE: mm_kwargs only includes the subset of multi-modal items # that intersect with the current prefill positions. - mm_data, placeholder_maps = MultiModalPlaceholderMap \ + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - if self.runner.mm_registry.has_processor( - self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.runner.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): @@ -404,12 +395,10 @@ def __init__( # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model + self.sampler = get_sampler() self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() \ @@ -596,7 +585,7 @@ def execute_model( model_input.async_callback() # Sample the next token. - output: SamplerOutput = self.model.sample( + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, )