Skip to content

Commit c84d160

Browse files
committed
Fix type errors in utils package
Signed-off-by: Jared O'Connell <[email protected]>
1 parent 11d585c commit c84d160

File tree

7 files changed

+181
-133
lines changed

7 files changed

+181
-133
lines changed

src/guidellm/utils/console.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def print_update_details(self, details: Any | None):
155155
block = Padding(
156156
Text.from_markup(str(details)),
157157
(0, 0, 0, 2),
158-
style=StatusStyles.get("debug"),
158+
style=StatusStyles.get("debug", "dim"),
159159
)
160160
self.print(block)
161161

src/guidellm/utils/encoding.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import json
1414
from collections.abc import Mapping
15-
from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar
15+
from typing import Annotated, Any, cast, ClassVar, Generic, Literal, Optional, TypeVar
1616

1717
try:
18-
import msgpack
18+
import msgpack # type: ignore[import-untyped] # Optional dependency
1919
from msgpack import Packer, Unpacker
2020

2121
HAS_MSGPACK = True
@@ -24,16 +24,16 @@
2424
HAS_MSGPACK = False
2525

2626
try:
27-
from msgspec.msgpack import Decoder as MsgspecDecoder
28-
from msgspec.msgpack import Encoder as MsgspecEncoder
27+
from msgspec.msgpack import Decoder as MsgspecDecoder # type: ignore[import-not-found] # Optional dependency
28+
from msgspec.msgpack import Encoder as MsgspecEncoder # type: ignore[import-not-found] # Optional dependency
2929

3030
HAS_MSGSPEC = True
3131
except ImportError:
3232
MsgspecDecoder = MsgspecEncoder = None
3333
HAS_MSGSPEC = False
3434

3535
try:
36-
import orjson
36+
import orjson # type: ignore[import-not-found] # Optional dependency
3737

3838
HAS_ORJSON = True
3939
except ImportError:
@@ -116,7 +116,7 @@ def encode_message(
116116
"""
117117
serialized = serializer.serialize(obj) if serializer else obj
118118

119-
return encoder.encode(serialized) if encoder else serialized
119+
return cast(MsgT, encoder.encode(serialized) if encoder else serialized)
120120

121121
@classmethod
122122
def decode_message(
@@ -137,7 +137,7 @@ def decode_message(
137137
"""
138138
serialized = encoder.decode(message) if encoder else message
139139

140-
return serializer.deserialize(serialized) if serializer else serialized
140+
return cast(ObjT, serializer.deserialize(serialized) if serializer else serialized)
141141

142142
def __init__(
143143
self,
@@ -296,6 +296,8 @@ def _get_available_encoder_decoder(
296296
return None, None, None
297297

298298

299+
PayloadType = Literal['pydantic', 'python', 'collection_tuple', 'collection_sequence', 'collection_mapping']
300+
299301
class Serializer:
300302
"""
301303
Object serialization with specialized Pydantic model support.
@@ -474,6 +476,7 @@ def to_sequence(self, obj: Any) -> str | Any:
474476
:param obj: Object to serialize to sequence format
475477
:return: Serialized sequence string or bytes
476478
"""
479+
payload_type: PayloadType
477480
if isinstance(obj, BaseModel):
478481
payload_type = "pydantic"
479482
payload = self.to_sequence_pydantic(obj)
@@ -515,7 +518,7 @@ def to_sequence(self, obj: Any) -> str | Any:
515518
payload_type = "python"
516519
payload = self.to_sequence_python(obj)
517520

518-
return self.pack_next_sequence(payload_type, payload, None)
521+
return self.pack_next_sequence(payload_type, payload if payload is not None else "", None)
519522

520523
def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912
521524
"""
@@ -529,6 +532,7 @@ def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912
529532
:raises ValueError: If sequence format is invalid or contains multiple
530533
packed sequences
531534
"""
535+
payload: str | bytes | None
532536
type_, payload, remaining = self.unpack_next_sequence(data)
533537
if remaining is not None:
534538
raise ValueError("Data contains multiple packed sequences; expected one.")
@@ -540,16 +544,16 @@ def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912
540544
return self.from_sequence_python(payload)
541545

542546
if type_ in {"collection_sequence", "collection_tuple"}:
543-
items = []
547+
c_items = []
544548
while payload:
545549
type_, item_payload, payload = self.unpack_next_sequence(payload)
546550
if type_ == "pydantic":
547-
items.append(self.from_sequence_pydantic(item_payload))
551+
c_items.append(self.from_sequence_pydantic(item_payload))
548552
elif type_ == "python":
549-
items.append(self.from_sequence_python(item_payload))
553+
c_items.append(self.from_sequence_python(item_payload))
550554
else:
551555
raise ValueError("Invalid type in collection sequence")
552-
return items
556+
return c_items
553557

554558
if type_ != "collection_mapping":
555559
raise ValueError(f"Invalid type for mapping sequence: {type_}")
@@ -604,6 +608,7 @@ def from_sequence_pydantic(self, data: str | bytes) -> BaseModel:
604608
:param data: Sequence data containing class metadata and JSON
605609
:return: Reconstructed Pydantic model instance
606610
"""
611+
json_data: str | bytes | bytearray
607612
if isinstance(data, bytes):
608613
class_name_end = data.index(b"|")
609614
class_name = data[:class_name_end].decode()
@@ -647,13 +652,7 @@ def from_sequence_python(self, data: str | bytes) -> Any:
647652

648653
def pack_next_sequence( # noqa: C901, PLR0912
649654
self,
650-
type_: Literal[
651-
"pydantic",
652-
"python",
653-
"collection_tuple",
654-
"collection_sequence",
655-
"collection_mapping",
656-
],
655+
type_: PayloadType,
657656
payload: str | bytes,
658657
current: str | bytes | None,
659658
) -> str | bytes:
@@ -672,9 +671,11 @@ def pack_next_sequence( # noqa: C901, PLR0912
672671
raise ValueError("Payload and current must be of the same type")
673672

674673
payload_len = len(payload)
675-
674+
payload_len_output: str | bytes
675+
payload_type: str | bytes
676+
delimiter: str | bytes
676677
if isinstance(payload, bytes):
677-
payload_len = payload_len.to_bytes(
678+
payload_len_output = payload_len.to_bytes(
678679
length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1,
679680
byteorder="big",
680681
)
@@ -692,7 +693,7 @@ def pack_next_sequence( # noqa: C901, PLR0912
692693
raise ValueError(f"Unknown type for packing: {type_}")
693694
delimiter = b"|"
694695
else:
695-
payload_len = str(payload_len)
696+
payload_len_output = str(payload_len)
696697
if type_ == "pydantic":
697698
payload_type = "P"
698699
elif type_ == "python":
@@ -707,20 +708,14 @@ def pack_next_sequence( # noqa: C901, PLR0912
707708
raise ValueError(f"Unknown type for packing: {type_}")
708709
delimiter = "|"
709710

710-
next_sequence = payload_type + delimiter + payload_len + delimiter + payload
711-
712-
return current + next_sequence if current else next_sequence
711+
# Type ignores because types are enforced at runtime
712+
next_sequence = payload_type + delimiter + payload_len_output + delimiter + payload # type: ignore[operator]
713+
return current + next_sequence if current else next_sequence # type: ignore[operator]
713714

714715
def unpack_next_sequence( # noqa: C901, PLR0912
715716
self, data: str | bytes
716717
) -> tuple[
717-
Literal[
718-
"pydantic",
719-
"python",
720-
"collection_tuple",
721-
"collection_sequence",
722-
"collection_mapping",
723-
],
718+
PayloadType,
724719
str | bytes,
725720
str | bytes | None,
726721
]:
@@ -731,57 +726,58 @@ def unpack_next_sequence( # noqa: C901, PLR0912
731726
:return: Tuple of (type, payload, remaining_data)
732727
:raises ValueError: If sequence format is invalid or unknown type character
733728
"""
729+
type_: PayloadType
734730
if isinstance(data, bytes):
735731
if len(data) < len(b"T|N") or data[1:2] != b"|":
736732
raise ValueError("Invalid packed data format")
737733

738-
type_char = data[0:1]
739-
if type_char == b"P":
734+
type_char_b = data[0:1]
735+
if type_char_b == b"P":
740736
type_ = "pydantic"
741-
elif type_char == b"p":
737+
elif type_char_b == b"p":
742738
type_ = "python"
743-
elif type_char == b"T":
739+
elif type_char_b == b"T":
744740
type_ = "collection_tuple"
745-
elif type_char == b"S":
741+
elif type_char_b == b"S":
746742
type_ = "collection_sequence"
747-
elif type_char == b"M":
743+
elif type_char_b == b"M":
748744
type_ = "collection_mapping"
749745
else:
750746
raise ValueError("Unknown type character in packed data")
751747

752748
len_end = data.index(b"|", 2)
753749
payload_len = int.from_bytes(data[2:len_end], "big")
754-
payload = data[len_end + 1 : len_end + 1 + payload_len]
755-
remaining = (
750+
payload_b = data[len_end + 1 : len_end + 1 + payload_len]
751+
remaining_b = (
756752
data[len_end + 1 + payload_len :]
757753
if len_end + 1 + payload_len < len(data)
758754
else None
759755
)
760756

761-
return type_, payload, remaining
757+
return type_, payload_b, remaining_b
762758

763759
if len(data) < len("T|N") or data[1] != "|":
764760
raise ValueError("Invalid packed data format")
765761

766-
type_char = data[0]
767-
if type_char == "P":
762+
type_char_s = data[0]
763+
if type_char_s == "P":
768764
type_ = "pydantic"
769-
elif type_char == "p":
765+
elif type_char_s == "p":
770766
type_ = "python"
771-
elif type_char == "S":
767+
elif type_char_s == "S":
772768
type_ = "collection_sequence"
773-
elif type_char == "M":
769+
elif type_char_s == "M":
774770
type_ = "collection_mapping"
775771
else:
776772
raise ValueError("Unknown type character in packed data")
777773

778774
len_end = data.index("|", 2)
779775
payload_len = int(data[2:len_end])
780-
payload = data[len_end + 1 : len_end + 1 + payload_len]
781-
remaining = (
776+
payload_s = data[len_end + 1 : len_end + 1 + payload_len]
777+
remaining_s = (
782778
data[len_end + 1 + payload_len :]
783779
if len_end + 1 + payload_len < len(data)
784780
else None
785781
)
786782

787-
return type_, payload, remaining
783+
return type_, payload_s, remaining_s

src/guidellm/utils/functions.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,20 @@ def safe_add(
9696
if not values:
9797
return default
9898

99-
values = list(values)
99+
values_list = list(values)
100100

101101
if signs is None:
102-
signs = [1] * len(values)
102+
signs = [1] * len(values_list)
103103

104-
if len(signs) != len(values):
104+
if len(signs) != len(values_list):
105105
raise ValueError("Length of signs must match length of values")
106106

107-
result = values[0] if values[0] is not None else default
107+
result = values_list[0] if values_list[0] is not None else default
108108

109-
for ind in range(1, len(values)):
110-
val = values[ind] if values[ind] is not None else default
111-
result += signs[ind] * val
109+
for ind in range(1, len(values_list)):
110+
value = values_list[ind]
111+
checked_value = value if value is not None else default
112+
result += signs[ind] * checked_value
112113

113114
return result
114115

0 commit comments

Comments
 (0)