Skip to content

Commit 4caf540

Browse files
dhuangnmdsikka
andauthored
Allow e2e tests to run vllm in a separate python environment (#1802)
SUMMARY: Currently the e2e tests assume vllm is installed into the same python environment as llmcompressor and run vllm validation after model is optimized. This can be problematic due to following factors: - vllm might have conflicting dependencies from llmcompressor and they cannot co-install into the same python env. - in the upcoming RHAIIS release, llmcompressor will contain its own image and only llmcompressor and its dependencies will be installed in the image originally. This PR is to address the issues above and allow the e2e tests to run the vllm validation using a different python env. This is achieved by packaging the vllm code into a separate run_vllm.py file, and using an env variable VLLM_PYTHON_ENV to control vllm code to run in the same or a different python env through subprocess.Popen(). By default, we still assume vllm and llmcompressor are installed into the same python env, and there is no changes to the way to set up or run the e2e tests, i.e.: ``` # install llmcompressor and vllm in current python env and run tests bash tests/e2e/vLLM/run_tests.sh -c tests/e2e/vLLM/configs -t tests/e2e/vLLM/test_vllm.py ``` However if vllm is installed into a separate python env (e.g. through virtualenv, uv venv etc), an env variable VLLM_PYTHON_ENV needs to be set to the path of this separate python env, i.e.: ``` export VLLM_PYTHON_ENV=<path of the python env where vllm is installed separately> # run tests in the llmcompressor python env bash tests/e2e/vLLM/run_tests.sh -c tests/e2e/vLLM/configs -t tests/e2e/vLLM/test_vllm.py ``` Here is an example to set up and run the e2e tests using a separate vllm python env: ``` # create vllm python env and get its path uv venv vllm-venv --python 3.12 source vllm-venv/bin/activate uv pip install vllm which python # get the path from this command and use the output for the VLLM_PYTHON_ENV env var deactivate # set up llmcompressor python env to run tests uv venv llmcompressor-venv --python 3.12 source llmcompressor-venv/bin/activate uv pip install llmcompressor[dev] cd llm-compressor export VLLM_PYTHON_ENV=<path from the `which python` command above> bash tests/e2e/vLLM/run_tests.sh -c tests/e2e/vLLM/configs -t tests/e2e/vLLM/test_vllm.py ``` This PR also removed the skip check for vllm in the tests so tests will fail if there is any issue with vllm, no matter it's due to installation or runtime issue. This allows us to track vllm setup issues rather than just skipping and not returning any errors. TEST PLAN: Run the e2e tests with the same and different python env for vllm and make sure all tests pass. - Run with vllm in the same python env: https://github.com/neuralmagic/llm-compressor-testing/actions/runs/17650195538 - Run with vllm in a separate python env: https://github.com/neuralmagic/llm-compressor-testing/actions/runs/17647525082 --------- Signed-off-by: Dan Huang <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 3466ce8 commit 4caf540

File tree

2 files changed

+89
-33
lines changed

2 files changed

+89
-33
lines changed

tests/e2e/vLLM/run_vllm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import json
2+
import sys
3+
4+
import torch
5+
from vllm import LLM, SamplingParams
6+
7+
8+
def parse_args():
9+
"""Parse JSON arguments passed via command line."""
10+
if len(sys.argv) < 4:
11+
msg = "Usage: python script.py '<scheme>' '<llm_kwargs_json>' '<prompts_json>'"
12+
raise ValueError(msg)
13+
14+
try:
15+
scheme = json.loads(sys.argv[1])
16+
llm_kwargs = json.loads(sys.argv[2])
17+
prompts = json.loads(sys.argv[3])
18+
except json.JSONDecodeError as e:
19+
raise ValueError(f"Invalid JSON input: {e}")
20+
21+
if "W4A16_2of4" in scheme:
22+
# required by the kernel
23+
llm_kwargs["dtype"] = torch.float16
24+
25+
return llm_kwargs, prompts
26+
27+
28+
def run_vllm(llm_kwargs: dict, prompts: list[str]) -> None:
29+
"""Run vLLM with given kwargs and prompts, then print outputs."""
30+
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
31+
32+
llm = LLM(**llm_kwargs)
33+
outputs = llm.generate(prompts, sampling_params)
34+
35+
print("================= vLLM GENERATION =================")
36+
for output in outputs:
37+
if not output or not output.outputs:
38+
print("[Warning] Empty output for prompt:", output.prompt)
39+
continue
40+
41+
print(f"\nPROMPT:\n{output.prompt}")
42+
print(f"GENERATED TEXT:\n{output.outputs[0].text}")
43+
44+
45+
def main():
46+
llm_kwargs, prompts = parse_args()
47+
run_vllm(llm_kwargs, prompts)
48+
49+
50+
if __name__ == "__main__":
51+
main()

tests/e2e/vLLM/test_vllm.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import re
33
import shutil
4+
import sys
45
from pathlib import Path
56

67
import pandas as pd
@@ -14,21 +15,14 @@
1415
from tests.examples.utils import requires_gpu_count
1516
from tests.test_timer.timer_utils import get_singleton_manager, log_time
1617

17-
try:
18-
from vllm import LLM, SamplingParams
19-
20-
vllm_installed = True
21-
except ImportError:
22-
vllm_installed = False
23-
logger.warning("vllm is not installed. This test will be skipped")
24-
25-
2618
HF_MODEL_HUB_NAME = "nm-testing"
2719

2820
TEST_DATA_FILE = os.environ.get(
2921
"TEST_DATA_FILE", "tests/e2e/vLLM/configs/int8_dynamic_per_token.yaml"
3022
)
3123
SKIP_HF_UPLOAD = os.environ.get("SKIP_HF_UPLOAD", "")
24+
# vllm python environment
25+
VLLM_PYTHON_ENV = os.environ.get("VLLM_PYTHON_ENV", "same")
3226
TIMINGS_DIR = os.environ.get("TIMINGS_DIR", "timings/e2e-test_vllm")
3327
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
3428
EXPECTED_SAVED_FILES = [
@@ -45,7 +39,6 @@
4539
@pytest.mark.parametrize(
4640
"test_data_file", [pytest.param(TEST_DATA_FILE, id=TEST_DATA_FILE)]
4741
)
48-
@pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test")
4942
class TestvLLM:
5043
"""
5144
The following test quantizes a model using a preset scheme or recipe,
@@ -83,6 +76,12 @@ def set_up(self, test_data_file: str):
8376
self.max_seq_length = eval_config.get("max_seq_length", 2048)
8477
# GPU memory utilization - only set if explicitly provided in config
8578
self.gpu_memory_utilization = eval_config.get("gpu_memory_utilization")
79+
# vllm python env - if same, use the current python env, otherwise use
80+
# the python passed in VLLM_PYTHON_ENV
81+
if VLLM_PYTHON_ENV.lower() != "same":
82+
self.vllm_env = VLLM_PYTHON_ENV
83+
else:
84+
self.vllm_env = sys.executable
8685

8786
if not self.save_dir:
8887
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
@@ -152,20 +151,12 @@ def test_vllm(self, test_data_file: str):
152151
folder_path=self.save_dir,
153152
)
154153

155-
logger.info("================= RUNNING vLLM =========================")
154+
if VLLM_PYTHON_ENV.lower() == "same":
155+
logger.info("========== RUNNING vLLM in the same python env ==========")
156+
else:
157+
logger.info("========== RUNNING vLLM in a separate python env ==========")
156158

157-
outputs = self._run_vllm()
158-
159-
logger.info("================= vLLM GENERATION ======================")
160-
for output in outputs:
161-
assert output
162-
prompt = output.prompt
163-
generated_text = output.outputs[0].text
164-
165-
logger.info("PROMPT")
166-
logger.info(prompt)
167-
logger.info("GENERATED TEXT")
168-
logger.info(generated_text)
159+
self._run_vllm(logger)
169160

170161
self.tear_down()
171162

@@ -193,22 +184,36 @@ def _save_compressed_model(self, oneshot_model, tokenizer):
193184
tokenizer.save_pretrained(self.save_dir)
194185

195186
@log_time
196-
def _run_vllm(self):
197-
import torch
187+
def _run_vllm(self, logger):
188+
import json
189+
import subprocess
198190

199-
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)
200191
llm_kwargs = {"model": self.save_dir}
201192

202-
if "W4A16_2of4" in self.scheme:
203-
# required by the kernel
204-
llm_kwargs["dtype"] = torch.float16
205-
206193
if self.gpu_memory_utilization is not None:
207194
llm_kwargs["gpu_memory_utilization"] = self.gpu_memory_utilization
208195

209-
llm = LLM(**llm_kwargs)
210-
outputs = llm.generate(self.prompts, sampling_params)
211-
return outputs
196+
json_scheme = json.dumps(self.scheme)
197+
json_llm_kwargs = json.dumps(llm_kwargs)
198+
json_prompts = json.dumps(self.prompts)
199+
200+
test_file_dir = os.path.dirname(os.path.abspath(__file__))
201+
run_file_path = os.path.join(test_file_dir, "run_vllm.py")
202+
203+
logger.info("Run vllm in subprocess.Popen() using python env:")
204+
logger.info(self.vllm_env)
205+
206+
result = subprocess.Popen(
207+
[self.vllm_env, run_file_path, json_scheme, json_llm_kwargs, json_prompts],
208+
stdout=subprocess.PIPE,
209+
stderr=subprocess.PIPE,
210+
text=True,
211+
)
212+
stdout, stderr = result.communicate()
213+
logger.info(stdout)
214+
215+
error_msg = f"ERROR: vLLM failed with exit code {result.returncode}: {stderr}"
216+
assert result.returncode == 0, error_msg
212217

213218
def _check_session_contains_recipe(self) -> None:
214219
session = active_session()

0 commit comments

Comments
 (0)