Skip to content

Commit 538c194

Browse files
[GuideLLM Refactor] fix util package types (#393)
## Summary This PR fixes all type errors in the utils package. Only a few were ignored. ## Details - A lot of these changes are reflecting that values can be None, and the associated None checks. - Others are incorrect type annotations - Others are asserting with cast that we know for certain that the type is correct. - Plus other minor changes ## Test Plan Run the tests and look through the changes to make sure the logic is equivalent or better to the original code. --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents 0a64090 + bde3ae8 commit 538c194

File tree

11 files changed

+251
-143
lines changed

11 files changed

+251
-143
lines changed

src/guidellm/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
"EndlessTextCreator",
8282
"InfoMixin",
8383
"IntegerRangeSampler",
84-
"camelize_str",
8584
"InterProcessMessaging",
8685
"InterProcessMessagingManagerQueue",
8786
"InterProcessMessagingPipe",
@@ -107,14 +106,15 @@
107106
"ThreadSafeSingletonMixin",
108107
"TimeRunningStats",
109108
"all_defined",
109+
"camelize_str",
110110
"check_load_processor",
111111
"clean_text",
112112
"filter_text",
113113
"format_value_display",
114114
"get_literal_vals",
115115
"is_punctuation",
116-
"recursive_key_update",
117116
"load_text",
117+
"recursive_key_update",
118118
"safe_add",
119119
"safe_divide",
120120
"safe_format_timestamp",

src/guidellm/utils/console.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def print_update_details(self, details: Any | None):
155155
block = Padding(
156156
Text.from_markup(str(details)),
157157
(0, 0, 0, 2),
158-
style=StatusStyles.get("debug"),
158+
style=StatusStyles.get("debug", "dim"),
159159
)
160160
self.print(block)
161161

src/guidellm/utils/encoding.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import json
1414
from collections.abc import Mapping
15-
from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar
15+
from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, cast
1616

1717
try:
18-
import msgpack
18+
import msgpack # type: ignore[import-untyped] # Optional dependency
1919
from msgpack import Packer, Unpacker
2020

2121
HAS_MSGPACK = True
@@ -24,16 +24,20 @@
2424
HAS_MSGPACK = False
2525

2626
try:
27-
from msgspec.msgpack import Decoder as MsgspecDecoder
28-
from msgspec.msgpack import Encoder as MsgspecEncoder
27+
from msgspec.msgpack import ( # type: ignore[import-not-found] # Optional dependency
28+
Decoder as MsgspecDecoder,
29+
)
30+
from msgspec.msgpack import ( # type: ignore[import-not-found] # Optional dependency
31+
Encoder as MsgspecEncoder,
32+
)
2933

3034
HAS_MSGSPEC = True
3135
except ImportError:
3236
MsgspecDecoder = MsgspecEncoder = None
3337
HAS_MSGSPEC = False
3438

3539
try:
36-
import orjson
40+
import orjson # type: ignore[import-not-found] # Optional dependency
3741

3842
HAS_ORJSON = True
3943
except ImportError:
@@ -116,7 +120,7 @@ def encode_message(
116120
"""
117121
serialized = serializer.serialize(obj) if serializer else obj
118122

119-
return encoder.encode(serialized) if encoder else serialized
123+
return cast("MsgT", encoder.encode(serialized) if encoder else serialized)
120124

121125
@classmethod
122126
def decode_message(
@@ -137,7 +141,9 @@ def decode_message(
137141
"""
138142
serialized = encoder.decode(message) if encoder else message
139143

140-
return serializer.deserialize(serialized) if serializer else serialized
144+
return cast(
145+
"ObjT", serializer.deserialize(serialized) if serializer else serialized
146+
)
141147

142148
def __init__(
143149
self,
@@ -296,6 +302,15 @@ def _get_available_encoder_decoder(
296302
return None, None, None
297303

298304

305+
PayloadType = Literal[
306+
"pydantic",
307+
"python",
308+
"collection_tuple",
309+
"collection_sequence",
310+
"collection_mapping",
311+
]
312+
313+
299314
class Serializer:
300315
"""
301316
Object serialization with specialized Pydantic model support.
@@ -474,6 +489,7 @@ def to_sequence(self, obj: Any) -> str | Any:
474489
:param obj: Object to serialize to sequence format
475490
:return: Serialized sequence string or bytes
476491
"""
492+
payload_type: PayloadType
477493
if isinstance(obj, BaseModel):
478494
payload_type = "pydantic"
479495
payload = self.to_sequence_pydantic(obj)
@@ -515,7 +531,9 @@ def to_sequence(self, obj: Any) -> str | Any:
515531
payload_type = "python"
516532
payload = self.to_sequence_python(obj)
517533

518-
return self.pack_next_sequence(payload_type, payload, None)
534+
return self.pack_next_sequence(
535+
payload_type, payload if payload is not None else "", None
536+
)
519537

520538
def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912
521539
"""
@@ -529,6 +547,7 @@ def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912
529547
:raises ValueError: If sequence format is invalid or contains multiple
530548
packed sequences
531549
"""
550+
payload: str | bytes | None
532551
type_, payload, remaining = self.unpack_next_sequence(data)
533552
if remaining is not None:
534553
raise ValueError("Data contains multiple packed sequences; expected one.")
@@ -540,16 +559,16 @@ def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912
540559
return self.from_sequence_python(payload)
541560

542561
if type_ in {"collection_sequence", "collection_tuple"}:
543-
items = []
562+
c_items = []
544563
while payload:
545564
type_, item_payload, payload = self.unpack_next_sequence(payload)
546565
if type_ == "pydantic":
547-
items.append(self.from_sequence_pydantic(item_payload))
566+
c_items.append(self.from_sequence_pydantic(item_payload))
548567
elif type_ == "python":
549-
items.append(self.from_sequence_python(item_payload))
568+
c_items.append(self.from_sequence_python(item_payload))
550569
else:
551570
raise ValueError("Invalid type in collection sequence")
552-
return items
571+
return c_items
553572

554573
if type_ != "collection_mapping":
555574
raise ValueError(f"Invalid type for mapping sequence: {type_}")
@@ -604,6 +623,7 @@ def from_sequence_pydantic(self, data: str | bytes) -> BaseModel:
604623
:param data: Sequence data containing class metadata and JSON
605624
:return: Reconstructed Pydantic model instance
606625
"""
626+
json_data: str | bytes | bytearray
607627
if isinstance(data, bytes):
608628
class_name_end = data.index(b"|")
609629
class_name = data[:class_name_end].decode()
@@ -647,13 +667,7 @@ def from_sequence_python(self, data: str | bytes) -> Any:
647667

648668
def pack_next_sequence( # noqa: C901, PLR0912
649669
self,
650-
type_: Literal[
651-
"pydantic",
652-
"python",
653-
"collection_tuple",
654-
"collection_sequence",
655-
"collection_mapping",
656-
],
670+
type_: PayloadType,
657671
payload: str | bytes,
658672
current: str | bytes | None,
659673
) -> str | bytes:
@@ -672,9 +686,11 @@ def pack_next_sequence( # noqa: C901, PLR0912
672686
raise ValueError("Payload and current must be of the same type")
673687

674688
payload_len = len(payload)
675-
689+
payload_len_output: str | bytes
690+
payload_type: str | bytes
691+
delimiter: str | bytes
676692
if isinstance(payload, bytes):
677-
payload_len = payload_len.to_bytes(
693+
payload_len_output = payload_len.to_bytes(
678694
length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1,
679695
byteorder="big",
680696
)
@@ -692,7 +708,7 @@ def pack_next_sequence( # noqa: C901, PLR0912
692708
raise ValueError(f"Unknown type for packing: {type_}")
693709
delimiter = b"|"
694710
else:
695-
payload_len = str(payload_len)
711+
payload_len_output = str(payload_len)
696712
if type_ == "pydantic":
697713
payload_type = "P"
698714
elif type_ == "python":
@@ -707,20 +723,16 @@ def pack_next_sequence( # noqa: C901, PLR0912
707723
raise ValueError(f"Unknown type for packing: {type_}")
708724
delimiter = "|"
709725

710-
next_sequence = payload_type + delimiter + payload_len + delimiter + payload
711-
712-
return current + next_sequence if current else next_sequence
726+
# Type ignores because types are enforced at runtime
727+
next_sequence = (
728+
payload_type + delimiter + payload_len_output + delimiter + payload # type: ignore[operator]
729+
)
730+
return current + next_sequence if current else next_sequence # type: ignore[operator]
713731

714732
def unpack_next_sequence( # noqa: C901, PLR0912
715733
self, data: str | bytes
716734
) -> tuple[
717-
Literal[
718-
"pydantic",
719-
"python",
720-
"collection_tuple",
721-
"collection_sequence",
722-
"collection_mapping",
723-
],
735+
PayloadType,
724736
str | bytes,
725737
str | bytes | None,
726738
]:
@@ -731,57 +743,58 @@ def unpack_next_sequence( # noqa: C901, PLR0912
731743
:return: Tuple of (type, payload, remaining_data)
732744
:raises ValueError: If sequence format is invalid or unknown type character
733745
"""
746+
type_: PayloadType
734747
if isinstance(data, bytes):
735748
if len(data) < len(b"T|N") or data[1:2] != b"|":
736749
raise ValueError("Invalid packed data format")
737750

738-
type_char = data[0:1]
739-
if type_char == b"P":
751+
type_char_b = data[0:1]
752+
if type_char_b == b"P":
740753
type_ = "pydantic"
741-
elif type_char == b"p":
754+
elif type_char_b == b"p":
742755
type_ = "python"
743-
elif type_char == b"T":
756+
elif type_char_b == b"T":
744757
type_ = "collection_tuple"
745-
elif type_char == b"S":
758+
elif type_char_b == b"S":
746759
type_ = "collection_sequence"
747-
elif type_char == b"M":
760+
elif type_char_b == b"M":
748761
type_ = "collection_mapping"
749762
else:
750763
raise ValueError("Unknown type character in packed data")
751764

752765
len_end = data.index(b"|", 2)
753766
payload_len = int.from_bytes(data[2:len_end], "big")
754-
payload = data[len_end + 1 : len_end + 1 + payload_len]
755-
remaining = (
767+
payload_b = data[len_end + 1 : len_end + 1 + payload_len]
768+
remaining_b = (
756769
data[len_end + 1 + payload_len :]
757770
if len_end + 1 + payload_len < len(data)
758771
else None
759772
)
760773

761-
return type_, payload, remaining
774+
return type_, payload_b, remaining_b
762775

763776
if len(data) < len("T|N") or data[1] != "|":
764777
raise ValueError("Invalid packed data format")
765778

766-
type_char = data[0]
767-
if type_char == "P":
779+
type_char_s = data[0]
780+
if type_char_s == "P":
768781
type_ = "pydantic"
769-
elif type_char == "p":
782+
elif type_char_s == "p":
770783
type_ = "python"
771-
elif type_char == "S":
784+
elif type_char_s == "S":
772785
type_ = "collection_sequence"
773-
elif type_char == "M":
786+
elif type_char_s == "M":
774787
type_ = "collection_mapping"
775788
else:
776789
raise ValueError("Unknown type character in packed data")
777790

778791
len_end = data.index("|", 2)
779792
payload_len = int(data[2:len_end])
780-
payload = data[len_end + 1 : len_end + 1 + payload_len]
781-
remaining = (
793+
payload_s = data[len_end + 1 : len_end + 1 + payload_len]
794+
remaining_s = (
782795
data[len_end + 1 + payload_len :]
783796
if len_end + 1 + payload_len < len(data)
784797
else None
785798
)
786799

787-
return type_, payload, remaining
800+
return type_, payload_s, remaining_s

src/guidellm/utils/functions.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,20 @@ def safe_add(
9696
if not values:
9797
return default
9898

99-
values = list(values)
99+
values_list = list(values)
100100

101101
if signs is None:
102-
signs = [1] * len(values)
102+
signs = [1] * len(values_list)
103103

104-
if len(signs) != len(values):
104+
if len(signs) != len(values_list):
105105
raise ValueError("Length of signs must match length of values")
106106

107-
result = values[0] if values[0] is not None else default
107+
result = values_list[0] if values_list[0] is not None else default
108108

109-
for ind in range(1, len(values)):
110-
val = values[ind] if values[ind] is not None else default
111-
result += signs[ind] * val
109+
for ind in range(1, len(values_list)):
110+
value = values_list[ind]
111+
checked_value = value if value is not None else default
112+
result += signs[ind] * checked_value
112113

113114
return result
114115

0 commit comments

Comments
 (0)