Skip to content

Commit f07e107

Browse files
committed
Fix ruff style and typing
1 parent 363a361 commit f07e107

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
TimmUniversalEncoder provides a unified feature extraction interface built on the
33
`timm` library, supporting various backbone architectures, including traditional
4-
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
4+
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
55
(e.g., Swin Transformer, ConvNeXt).
66
77
This encoder produces standardized multi-level feature maps, facilitating integration
@@ -22,16 +22,16 @@
2222
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
2323
and 1/32 scales.
2424
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25-
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
25+
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
2626
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
2727
multi-stage output.
2828
2929
Notes:
30-
- Not all models support modifying `output_stride` (especially transformer-based or
30+
- Not all models support modifying `output_stride` (especially transformer-based or
3131
transformer-like models).
3232
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
3333
feature indexing.
34-
- Most `timm` models output features in (B, C, H, W) format. However, some
34+
- Most `timm` models output features in (B, C, H, W) format. However, some
3535
(e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
3636
currently unsupported.
3737
"""
@@ -46,7 +46,7 @@
4646
class TimmUniversalEncoder(nn.Module):
4747
"""
4848
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
49-
model architectures, including both traditional CNNs and those that follow a
49+
model architectures, including both traditional CNNs and those that follow a
5050
transformer-like hierarchy.
5151
5252
Features:
@@ -94,10 +94,8 @@ def __init__(
9494
# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
9595
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
9696
if len(self.model.feature_info.channels()) == 5:
97-
# This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
9897
self._is_transformer_style = False
9998
else:
100-
# This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
10199
self._is_transformer_style = True
102100

103101
if self._is_transformer_style:
@@ -138,7 +136,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
138136
x (torch.Tensor): Input tensor of shape (B, C, H, W).
139137
140138
Returns:
141-
List[torch.Tensor]: A list of feature maps extracted at various scales.
139+
list[torch.Tensor]: A list of feature maps extracted at various scales.
142140
"""
143141
features = self.model(x)
144142

@@ -158,7 +156,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
158156
def out_channels(self) -> list[int]:
159157
"""
160158
Returns:
161-
List[int]: A list of output channels for each stage of the encoder,
159+
list[int]: A list of output channels for each stage of the encoder,
162160
including the input channels at the first stage.
163161
"""
164162
return self._out_channels

0 commit comments

Comments
 (0)