Skip to content

Commit ef36af1

Browse files
committed
Fixes from review
1 parent 24f2ca3 commit ef36af1

File tree

11 files changed

+109
-147
lines changed

11 files changed

+109
-147
lines changed

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,8 @@ dependencies = [
7878

7979
[project.optional-dependencies]
8080
perf = ["orjson", "msgpack", "msgspec", "uvloop"]
81-
recommended = [
82-
"tiktoken>=0.11.0", # For OpenAI tokenizer
83-
"blobfile>=3.1.0", # For OpenAI tokenizer
84-
]
81+
openai = ["tiktoken>=0.11.0", "blobfile>=3.1.0"]
82+
recommended = ["guidellm[perf,openai]"]
8583
dev = [
8684
# build
8785
"build>=1.0.0",

src/guidellm/backends/response_handlers.py

Lines changed: 63 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,10 @@
1010

1111
from __future__ import annotations
1212

13-
import json
14-
from typing import Any, Protocol, cast
13+
from typing import Any, Protocol
1514

1615
from guidellm.schemas import GenerationRequest, GenerationResponse, UsageMetrics
17-
from guidellm.utils import RegistryMixin
18-
19-
try:
20-
import orjson
21-
except ImportError:
22-
orjson = None # type: ignore[assignment]
16+
from guidellm.utils import RegistryMixin, json
2317

2418
__all__ = [
2519
"AudioResponseHandler",
@@ -115,8 +109,7 @@ def compile_non_streaming(
115109
:param response: Complete API response containing choices and usage data
116110
:return: Standardized GenerationResponse with extracted text and metrics
117111
"""
118-
choices = cast("list[dict]", response.get("choices", []))
119-
usage = cast("dict[str, int | dict[str, int]]", response.get("usage", {}))
112+
choices, usage = self.extract_choices_and_usage(response)
120113
input_metrics, output_metrics = self.extract_metrics(usage)
121114

122115
return GenerationResponse(
@@ -139,26 +132,17 @@ def add_streaming_line(self, line: str) -> int | None:
139132
:param line: Raw SSE line from the streaming response
140133
:return: 1 if text content was extracted, 0 if line ignored, None if done
141134
"""
142-
if line == "data: [DONE]":
143-
return None
135+
if not (data := self.extract_line_data(line)):
136+
return None if data is None else 0
144137

145-
if not line or not (line := line.strip()) or not line.startswith("data:"):
146-
return 0
147-
148-
line = line[len("data:") :].strip()
149-
data = cast(
150-
"dict[str, Any]",
151-
json.loads(line) if orjson is None else orjson.loads(line),
152-
)
153138
updated = False
139+
choices, usage = self.extract_choices_and_usage(data)
154140

155-
if (choices := cast("list[dict]", data.get("choices"))) and (
156-
text := choices[0].get("text")
157-
):
141+
if text := choices[0].get("text"):
158142
self.streaming_texts.append(text)
159143
updated = True
160144

161-
if usage := cast("dict[str, int | dict[str, int]]", data.get("usage")):
145+
if usage:
162146
self.streaming_usage = usage
163147

164148
return 1 if updated else 0
@@ -182,6 +166,34 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse:
182166
output_metrics=output_metrics,
183167
)
184168

169+
def extract_line_data(self, line: str) -> dict[str, Any] | None:
170+
"""
171+
Extract JSON data from a streaming response line.
172+
173+
:param line: Raw line from the streaming response
174+
:return: Parsed JSON data as a dictionary, or None if line is invalid
175+
"""
176+
if line == "data: [DONE]":
177+
return None
178+
179+
if not line or not (line := line.strip()) or not line.startswith("data:"):
180+
return {}
181+
182+
line = line[len("data:") :].strip()
183+
184+
return json.loads(line)
185+
186+
def extract_choices_and_usage(
187+
self, response: dict
188+
) -> tuple[list[dict], dict[str, int | dict[str, int]]]:
189+
"""
190+
Extract choices and usage data from the API response.
191+
192+
:param response: Complete API response containing choices and usage data
193+
:return: Tuple of (choices list, usage dictionary)
194+
"""
195+
return response.get("choices", []), response.get("usage", {})
196+
185197
def extract_metrics(
186198
self, usage: dict[str, int | dict[str, int]] | None
187199
) -> tuple[UsageMetrics, UsageMetrics]:
@@ -194,15 +206,14 @@ def extract_metrics(
194206
if not usage:
195207
return UsageMetrics(), UsageMetrics()
196208

197-
input_details = cast("dict[str, int]", usage.get("prompt_tokens_details", {}))
198-
output_details = cast(
199-
"dict[str, int]", usage.get("completion_tokens_details", {})
209+
input_details: dict[str, int] = usage.get("prompt_tokens_details", {}) or {}
210+
output_details: dict[str, int] = (
211+
usage.get("completion_tokens_details", {}) or {}
200212
)
201213

202214
return UsageMetrics(
203215
text_tokens=(
204-
input_details.get("prompt_tokens")
205-
or cast("int", usage.get("prompt_tokens"))
216+
input_details.get("prompt_tokens") or usage.get("prompt_tokens")
206217
),
207218
image_tokens=input_details.get("image_tokens"),
208219
video_tokens=input_details.get("video_tokens"),
@@ -211,7 +222,7 @@ def extract_metrics(
211222
), UsageMetrics(
212223
text_tokens=(
213224
output_details.get("completion_tokens")
214-
or cast("int", usage.get("completion_tokens"))
225+
or usage.get("completion_tokens")
215226
),
216227
image_tokens=output_details.get("image_tokens"),
217228
video_tokens=output_details.get("video_tokens"),
@@ -243,18 +254,15 @@ def compile_non_streaming(
243254
:param response: Complete API response containing choices and usage data
244255
:return: Standardized GenerationResponse with extracted content and metrics
245256
"""
246-
choices = cast("list[dict]", response.get("choices", []))
247-
usage = cast("dict[str, int | dict[str, int]]", response.get("usage", {}))
257+
choices, usage = self.extract_choices_and_usage(response)
248258
input_metrics, output_metrics = self.extract_metrics(usage)
249259

250260
return GenerationResponse(
251261
request_id=request.request_id,
252262
request_args=str(
253263
request.arguments.model_dump() if request.arguments else None
254264
),
255-
text=cast("dict", choices[0].get("message", {})).get("content", "")
256-
if choices
257-
else "",
265+
text=(choices[0].get("message", {}).get("content", "") if choices else ""),
258266
input_metrics=input_metrics,
259267
output_metrics=output_metrics,
260268
)
@@ -269,27 +277,17 @@ def add_streaming_line(self, line: str) -> int | None:
269277
:param line: Raw SSE line from the streaming response
270278
:return: 1 if content was extracted, 0 if line ignored, None if done
271279
"""
272-
if line == "data: [DONE]":
273-
return None
280+
if not (data := self.extract_line_data(line)):
281+
return None if data is None else 0
274282

275-
if not line or not (line := line.strip()) or not line.startswith("data:"):
276-
return 0
277-
278-
line = line[len("data:") :].strip()
279-
data = cast(
280-
"dict[str, Any]",
281-
json.loads(line) if orjson is None else orjson.loads(line),
282-
)
283283
updated = False
284+
choices, usage = self.extract_choices_and_usage(data)
284285

285-
# Extract delta content for chat completion chunks
286-
if choices := cast("list[dict]", data.get("choices")):
287-
delta = choices[0].get("delta", {})
288-
if content := delta.get("content"):
289-
self.streaming_texts.append(content)
286+
if choices and (content := choices[0].get("delta", {}).get("content")):
287+
self.streaming_texts.append(content)
290288
updated = True
291289

292-
if usage := cast("dict[str, int | dict[str, int]]", data.get("usage")):
290+
if usage:
293291
self.streaming_usage = usage
294292

295293
return 1 if updated else 0
@@ -355,10 +353,10 @@ def compile_non_streaming(
355353
:param response: Complete API response containing text and usage data
356354
:return: Standardized GenerationResponse with extracted text and metrics
357355
"""
358-
usage = cast("dict[str, int]", response.get("usage", {}))
359-
input_details = cast("dict[str, int]", usage.get("input_token_details", {}))
360-
output_details = cast("dict[str, int]", usage.get("output_token_details", {}))
361-
text = response.get("text", "")
356+
usage: dict[str, int | dict[str, int]] = response.get("usage", {})
357+
input_details: dict[str, int] = usage.get("input_token_details", {}) or {}
358+
output_details: dict[str, int] = usage.get("output_token_details", {}) or {}
359+
text: str = response.get("text", "")
362360

363361
return GenerationResponse(
364362
request_id=request.request_id,
@@ -396,17 +394,16 @@ def add_streaming_line(self, line: str) -> int | None:
396394
if not line or not (line := line.strip()) or not line.startswith("{"):
397395
return 0
398396

399-
data = cast(
400-
"dict[str, Any]",
401-
json.loads(line) if orjson is None else orjson.loads(line),
402-
)
397+
data: dict[str, Any] = json.loads(line)
398+
text: str
399+
usage: dict[str, int | dict[str, int]]
403400
updated = False
404401

405402
if text := data.get("text"):
406403
self.streaming_texts.append(text)
407404
updated = True
408405

409-
if usage := cast("dict[str, int | dict[str, int]]", data.get("usage")):
406+
if usage := data.get("usage"):
410407
self.streaming_usage = usage
411408

412409
return 1 if updated else 0
@@ -445,22 +442,15 @@ def extract_metrics(
445442
if not usage:
446443
return UsageMetrics(), UsageMetrics()
447444

448-
input_details = cast("dict[str, int]", usage.get("input_token_details", {}))
449-
output_details = cast("dict[str, int]", usage.get("output_token_details", {}))
445+
input_details: dict[str, int] = usage.get("input_token_details", {}) or {}
446+
output_details: dict[str, int] = usage.get("output_token_details", {}) or {}
450447

451448
return UsageMetrics(
452-
text_tokens=(
453-
input_details.get("text_tokens")
454-
or cast("int", usage.get("input_tokens"))
455-
),
449+
text_tokens=(input_details.get("text_tokens") or usage.get("input_tokens")),
456450
audio_tokens=(
457-
input_details.get("audio_tokens")
458-
or cast("int", usage.get("audio_tokens"))
459-
),
460-
audio_seconds=(
461-
input_details.get("seconds") or cast("int", usage.get("seconds"))
451+
input_details.get("audio_tokens") or usage.get("audio_tokens")
462452
),
453+
audio_seconds=(input_details.get("seconds") or usage.get("seconds")),
463454
), UsageMetrics(
464-
text_tokens=output_details.get("text_tokens")
465-
or cast("int", usage.get("output_tokens")),
455+
text_tokens=output_details.get("text_tokens") or usage.get("output_tokens"),
466456
)

src/guidellm/benchmark/entrypoints.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
from typing import Any, Literal
66

77
from torch.utils.data import Sampler
8-
from transformers import PreTrainedTokenizerBase
9-
from typing_extensions import TypeAliasType
108

119
from guidellm.backends import Backend, BackendType
1210
from guidellm.benchmark.benchmarker import Benchmarker
1311
from guidellm.benchmark.output import GenerativeBenchmarkerOutput
1412
from guidellm.benchmark.profile import Profile, ProfileType
15-
from guidellm.benchmark.progress import BenchmarkerProgress, BenchmarkerProgressGroup
13+
from guidellm.benchmark.progress import BenchmarkerProgressGroup
1614
from guidellm.benchmark.schemas import GenerativeBenchmark, GenerativeBenchmarksReport
15+
from guidellm.benchmark.types import OutputFormatT, ProcessorInputT, ProgressInputT
1716
from guidellm.data import (
1817
DataLoader,
1918
DatasetPreprocessor,
@@ -40,20 +39,6 @@
4039

4140
_CURRENT_WORKING_DIR = Path.cwd()
4241

43-
OutputFormatT = TypeAliasType(
44-
"OutputFormatT",
45-
tuple[str, ...]
46-
| list[str]
47-
| dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput]
48-
| None,
49-
)
50-
51-
ProcessorInputT = TypeAliasType("ProcessorInputT", str | Path | PreTrainedTokenizerBase)
52-
53-
ProgressInputT = TypeAliasType(
54-
"ProgressInputT", tuple[str, ...] | list[str] | list[BenchmarkerProgress]
55-
)
56-
5742

5843
# Helper Functions
5944

src/guidellm/benchmark/types.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,7 @@
99
from guidellm.benchmark.output import GenerativeBenchmarkerOutput
1010
from guidellm.benchmark.progress import BenchmarkerProgress
1111

12-
__all__ = [
13-
"AggregatorInputT",
14-
"DataInputT",
15-
"OutputFormatT",
16-
"ProcessorInputT",
17-
"ProgressInputT",
18-
]
12+
__all__ = ["OutputFormatT", "ProcessorInputT", "ProgressInputT"]
1913

2014

2115
OutputFormatT = TypeAliasType(

src/guidellm/data/deserializers/synthetic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
209209

210210
# Create prefix list maintaining the correct distribution
211211
prefixes = []
212-
for bucket, weight in zip(self.config.prefix_buckets, unnorm_weights, strict=False):
212+
for bucket, weight in zip(
213+
self.config.prefix_buckets, unnorm_weights, strict=False
214+
):
213215
bucket_prefixes = [
214216
self._create_prompt(bucket.prefix_tokens, faker)
215217
for _ in range(bucket.prefix_count)

src/guidellm/data/preprocessors/mappers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,16 @@ def datasets_mappings(
120120
for index, dataset in enumerate(datasets)
121121
}
122122

123+
# Parse out user mappings that were passed in and validate them
124+
# Must be in the format of:
125+
# {<column_type>: [<column_names>]}
126+
# where <column_names> can be a single string or list of strings
127+
# and each string can be any of:
128+
# - a column name (assumes the first dataset was intended)
129+
# - <int>.<column_name> where <int> is the dataset index
130+
# - <str>.<column_name> where <str> is the dataset name
123131
for column_type, names in input_mappings.items():
124132
mappings[column_type] = []
125-
126133
for name in names if isinstance(names, list) else [names]:
127134
if "." in name:
128135
dataset, column_name = name.split(".", 1)

src/guidellm/scheduler/worker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,18 @@ async def _processing_startup(self):
232232
self.backend_started = True
233233
await self.backend.validate()
234234

235-
# Wait for all processes to be ready
236-
await wait_for_sync_barrier(
237-
self.startup_barrier,
238-
poll_interval=self.messaging.poll_interval,
239-
)
240-
241235
# Get messaging system ready
242236
await self.messaging.start(
243237
receive_stop_criteria=[self.requests_generated_event]
244238
)
245239
self.messaging_started = True
246240

241+
# Wait for all processes to be ready
242+
await wait_for_sync_barrier(
243+
self.startup_barrier,
244+
poll_interval=self.messaging.poll_interval,
245+
)
246+
247247
self.startup_completed = True
248248

249249
async def _processing_shutdown(self):

src/guidellm/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class LoggingSettings(BaseModel):
4646

4747
disabled: bool = False
4848
clear_loggers: bool = True
49-
console_log_level: str = "DEBUG"
49+
console_log_level: str = "WARNING"
5050
log_file: str | None = None
5151
log_file_level: str | None = None
5252

@@ -145,7 +145,7 @@ class Settings(BaseSettings):
145145
mp_max_pending_buffer_percent: float = 0.5
146146
mp_max_worker_buffer_percent: float = 0.2
147147
max_concurrency: int = 512
148-
max_worker_processes: int = 2
148+
max_worker_processes: int = 10
149149
scheduler_start_delay_non_distributed: float = 1.0
150150
constraint_error_window_size: float = 30
151151
constraint_error_min_processed: float = 30

0 commit comments

Comments
 (0)