Skip to content

Commit 42e0c1d

Browse files
authored
[Quality] Add CI for formatting (#343)
1 parent e41f067 commit 42e0c1d

File tree

6 files changed

+113
-31
lines changed

6 files changed

+113
-31
lines changed

.github/workflows/pylint.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: pylint
2+
3+
on:
4+
# Trigger the workflow on push or pull request,
5+
# but only for the main branch
6+
push:
7+
branches:
8+
- main
9+
pull_request:
10+
branches:
11+
- main
12+
13+
jobs:
14+
pylint:
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: ["3.10"]
19+
steps:
20+
- uses: actions/checkout@v2
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v2
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
- name: Install dependencies
26+
run: |
27+
python -m pip install --upgrade pip
28+
pip install pylint==2.8.2
29+
- name: Analysing the code with pylint
30+
run: |
31+
pylint vllm

.github/workflows/yapf.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: yapf
2+
3+
on:
4+
# Trigger the workflow on push or pull request,
5+
# but only for the main branch
6+
push:
7+
branches:
8+
- main
9+
pull_request:
10+
branches:
11+
- main
12+
jobs:
13+
yapf:
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
python-version: ["3.10"]
18+
steps:
19+
- uses: actions/checkout@v2
20+
- name: Set up Python ${{ matrix.python-version }}
21+
uses: actions/setup-python@v2
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install yapf==0.32.0
28+
pip install toml==0.10.2
29+
- name: Running yapf
30+
run: |
31+
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'

vllm/engine/async_llm_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from typing import Dict, List, Optional
44

5+
from vllm.config import ModelConfig
56
from vllm.engine.arg_utils import AsyncEngineArgs
67
from vllm.engine.llm_engine import LLMEngine
78
from vllm.engine.ray_utils import initialize_cluster, ray
@@ -206,6 +207,13 @@ async def abort(self, request_id: str) -> None:
206207
self.is_engine_running = False
207208
self.kicking_request_id = None
208209

210+
async def get_model_config(self) -> ModelConfig:
211+
"""Get the model configuration of the vLLM engine."""
212+
if self.engine_use_ray:
213+
return await self.engine.get_model_config.remote()
214+
else:
215+
return self.engine.get_model_config()
216+
209217
@classmethod
210218
def from_engine_args(cls,
211219
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":

vllm/engine/llm_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ def abort_request(self, request_id: str) -> None:
210210
"""
211211
self.scheduler.abort_seq_group(request_id)
212212

213+
def get_model_config(self) -> ModelConfig:
214+
"""Gets the model configuration."""
215+
return self.model_config
216+
213217
def get_num_unfinished_requests(self) -> int:
214218
"""Gets the number of unfinished requests."""
215219
return self.scheduler.get_num_unfinished_seq_groups()

vllm/entrypoints/openai/api_server.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,30 @@
22
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
33

44
import argparse
5+
import asyncio
56
from http import HTTPStatus
67
import json
78
import time
8-
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
9+
from typing import AsyncGenerator, Dict, List, Optional
910

1011
import fastapi
1112
from fastapi import BackgroundTasks, Request
1213
from fastapi.exceptions import RequestValidationError
1314
from fastapi.middleware.cors import CORSMiddleware
1415
from fastapi.responses import JSONResponse, StreamingResponse
16+
from fastchat.conversation import (Conversation, SeparatorStyle,
17+
get_conv_template)
1518
import uvicorn
1619

1720
from vllm.engine.arg_utils import AsyncEngineArgs
1821
from vllm.engine.async_llm_engine import AsyncLLMEngine
1922
from vllm.entrypoints.openai.protocol import (
2023
CompletionRequest, CompletionResponse, CompletionResponseChoice,
2124
CompletionResponseStreamChoice, CompletionStreamResponse,
22-
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
23-
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
24-
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
25-
ModelCard, ModelList, ModelPermission, UsageInfo)
26-
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
25+
ChatCompletionRequest, ChatCompletionResponse,
26+
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
27+
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
28+
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
2729
from vllm.logger import init_logger
2830
from vllm.outputs import RequestOutput
2931
from vllm.sampling_params import SamplingParams
@@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
9597
return prompt
9698

9799

98-
async def check_length(request, prompt, engine):
99-
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
100-
context_len = engine.engine.model_config.hf_config.max_sequence_length
101-
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
102-
context_len = engine.engine.model_config.hf_config.seq_length
103-
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
104-
context_len = engine.engine.model_config.hf_config.max_position_embeddings
105-
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
106-
context_len = engine.engine.model_config.hf_config.seq_length
100+
async def check_length(request, prompt, model_config):
101+
if hasattr(model_config.hf_config, "max_sequence_length"):
102+
context_len = model_config.hf_config.max_sequence_length
103+
elif hasattr(model_config.hf_config, "seq_length"):
104+
context_len = model_config.hf_config.seq_length
105+
elif hasattr(model_config.hf_config, "max_position_embeddings"):
106+
context_len = model_config.hf_config.max_position_embeddings
107+
elif hasattr(model_config.hf_config, "seq_length"):
108+
context_len = model_config.hf_config.seq_length
107109
else:
108110
context_len = 2048
109111

@@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
182184
"logit_bias is not currently supported")
183185

184186
prompt = await get_gen_prompt(request)
185-
error_check_ret = await check_length(request, prompt, engine)
187+
error_check_ret = await check_length(request, prompt, engine_model_config)
186188
if error_check_ret is not None:
187189
return error_check_ret
188190

@@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
206208
except ValueError as e:
207209
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
208210

209-
result_generator = engine.generate(prompt, sampling_params,
210-
request_id)
211+
result_generator = engine.generate(prompt, sampling_params, request_id)
211212

212213
async def abort_request() -> None:
213214
await engine.abort(request_id)
214215

215-
def create_stream_response_json(index: int,
216-
text: str,
217-
finish_reason: Optional[str] = None) -> str:
216+
def create_stream_response_json(
217+
index: int,
218+
text: str,
219+
finish_reason: Optional[str] = None,
220+
) -> str:
218221
choice_data = ChatCompletionResponseStreamChoice(
219222
index=index,
220223
delta=DeltaMessage(content=text),
@@ -238,10 +241,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
238241
delta=DeltaMessage(role="assistant"),
239242
finish_reason=None,
240243
)
241-
chunk = ChatCompletionStreamResponse(
242-
id=request_id, choices=[choice_data], model=model_name
243-
)
244-
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
244+
chunk = ChatCompletionStreamResponse(id=request_id,
245+
choices=[choice_data],
246+
model=model_name)
247+
data = chunk.json(exclude_unset=True, ensure_ascii=False)
248+
yield f"data: {data}\n\n"
245249

246250
previous_texts = [""] * request.n
247251
previous_num_tokens = [0] * request.n
@@ -295,8 +299,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
295299
choices.append(choice_data)
296300

297301
num_prompt_tokens = len(final_res.prompt_token_ids)
298-
num_generated_tokens = sum(len(output.token_ids)
299-
for output in final_res.outputs)
302+
num_generated_tokens = sum(
303+
len(output.token_ids) for output in final_res.outputs)
300304
usage = UsageInfo(
301305
prompt_tokens=num_prompt_tokens,
302306
completion_tokens=num_generated_tokens,
@@ -314,9 +318,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
314318
# When user requests streaming but we don't stream, we still need to
315319
# return a streaming response with a single event.
316320
response_json = response.json(ensure_ascii=False)
321+
317322
async def fake_stream_generator() -> AsyncGenerator[str, None]:
318323
yield f"data: {response_json}\n\n"
319324
yield "data: [DONE]\n\n"
325+
320326
return StreamingResponse(fake_stream_generator(),
321327
media_type="text/event-stream")
322328

@@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
367373
return create_error_response(HTTPStatus.BAD_REQUEST,
368374
"please provide at least one prompt")
369375
if len(request.prompt) > 1:
370-
return create_error_response(HTTPStatus.BAD_REQUEST,
371-
"multiple prompts in a batch is not "
372-
"currently supported")
376+
return create_error_response(
377+
HTTPStatus.BAD_REQUEST,
378+
"multiple prompts in a batch is not currently supported")
373379
prompt = request.prompt[0]
374380
else:
375381
prompt = request.prompt
@@ -571,6 +577,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
571577

572578
engine_args = AsyncEngineArgs.from_cli_args(args)
573579
engine = AsyncLLMEngine.from_engine_args(engine_args)
580+
engine_model_config = asyncio.run(engine.get_model_config())
574581

575582
# A separate tokenizer to map token IDs to strings.
576583
tokenizer = get_tokenizer(engine_args.tokenizer,

vllm/model_executor/models/bloom.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# coding=utf-8
2-
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
2+
# Adapted from
3+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
34
# Copyright 2023 The CacheFlow team.
45
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
56
#

0 commit comments

Comments
 (0)