Skip to content

Commit 330e6e5

Browse files
committed
Fix ruff style
1 parent 8b0fece commit 330e6e5

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
"""
2-
TimmUniversalEncoder provides a unified feature extraction interface built on the
3-
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
2+
TimmUniversalEncoder provides a unified feature extraction interface built on the
3+
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
44
models (e.g., Swin Transformer, ConvNeXt).
55
6-
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7-
It allows configuring the number of feature extraction stages (`depth`) and adjusting
6+
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7+
It allows configuring the number of feature extraction stages (`depth`) and adjusting
88
`output_stride` when supported.
99
1010
Key Features:
1111
- Flexible model selection using `timm.create_model`.
12-
- Unified multi-level output across different model hierarchies.
12+
- Unified multi-level output across different model hierarchies.
1313
- Automatic alignment for inconsistent feature scales:
14-
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15-
- VGG-style models (include scale-1 features): Align outputs for compatibility.
14+
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15+
- VGG-style models (include scale-1 features): Align outputs for compatibility.
1616
- Easy access to feature scale information via the `reduction` property.
1717
1818
Feature Scale Differences:
19-
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20-
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
19+
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20+
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
2121
- VGG-style models: Include scale-1 features (input resolution).
2222
2323
Notes:
24-
- `output_stride` is unsupported in some models, especially transformer-based architectures.
25-
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
26-
- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs.
24+
- `output_stride` is unsupported in some models, especially transformer-based architectures.
25+
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
26+
- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs.
2727
"""
2828

2929
from typing import Any
@@ -35,7 +35,7 @@
3535

3636
class TimmUniversalEncoder(nn.Module):
3737
"""
38-
A universal encoder leveraging the `timm` library for feature extraction from
38+
A universal encoder leveraging the `timm` library for feature extraction from
3939
various model architectures, including traditional-style and transformer-style models.
4040
4141
Features:
@@ -92,15 +92,15 @@ def __init__(
9292
if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
9393
# Transformer-style downsampling: scales (4, 8, 16, 32)
9494
self._is_transformer_style = True
95-
self._is_skip_first = False
95+
self._is_vgg_style = False
9696
elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]:
9797
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
9898
self._is_transformer_style = False
99-
self._is_skip_first = False
100-
elif reduction_scales == [2 ** i for i in range(encoder_stage)]:
101-
# Models including scale 1: scales (1, 2, 4, 8, 16, 32)
99+
self._is_vgg_style = False
100+
elif reduction_scales == [2**i for i in range(encoder_stage)]:
101+
# Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32)
102102
self._is_transformer_style = False
103-
self._is_skip_first = True
103+
self._is_vgg_style = True
104104
else:
105105
raise ValueError("Unsupported model downsampling pattern.")
106106

@@ -125,14 +125,14 @@ def __init__(
125125
if "dla" in name:
126126
# For 'dla' models, out_indices starts at 0 and matches the input size.
127127
common_kwargs["out_indices"] = tuple(range(1, depth + 1))
128-
if self._is_skip_first:
128+
if self._is_vgg_style:
129129
common_kwargs["out_indices"] = tuple(range(depth + 1))
130130

131131
self.model = timm.create_model(
132132
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
133133
)
134134

135-
if self._is_skip_first:
135+
if self._is_vgg_style:
136136
self._out_channels = self.model.feature_info.channels()
137137
else:
138138
self._out_channels = [in_channels] + self.model.feature_info.channels()
@@ -164,9 +164,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
164164
B, _, H, W = x.shape
165165
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
166166
features = [dummy] + features
167-
168-
# Add input tensor as scale 1 feature if `self._is_skip_first` is False
169-
if not self._is_skip_first:
167+
168+
# Add input tensor as scale 1 feature if `self._is_vgg_style` is False
169+
if not self._is_vgg_style:
170170
features = [x] + features
171171

172172
return features

0 commit comments

Comments
 (0)