Skip to content

Commit 5e6db7e

Browse files
committed
add typing and fix ruff style
1 parent e8a1825 commit 5e6db7e

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

segmentation_models_pytorch/decoders/pan/decoder.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
14
import torch
25
import torch.nn as nn
36
import torch.nn.functional as F
@@ -44,7 +47,9 @@ def forward(self, x):
4447

4548

4649
class FPABlock(nn.Module):
47-
def __init__(self, in_channels, out_channels, upscale_mode="bilinear"):
50+
def __init__(
51+
self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear"
52+
):
4853
super(FPABlock, self).__init__()
4954

5055
self.upscale_mode = upscale_mode
@@ -118,7 +123,9 @@ def forward(self, x):
118123
mid = self.mid(x)
119124
x1 = self.down1(x)
120125
x2 = self.down2(x1)
126+
print(x2.shape)
121127
x3 = self.down3(x2)
128+
print(x3.shape)
122129
x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters)
123130

124131
x2 = self.conv2(x2)
@@ -176,9 +183,9 @@ def forward(self, x, y):
176183
class PANDecoder(nn.Module):
177184
def __init__(
178185
self,
179-
encoder_channels,
180-
encoder_depth,
181-
decoder_channels,
186+
encoder_channels: Sequence[int],
187+
encoder_depth: Literal[3, 4, 5],
188+
decoder_channels: int,
182189
upscale_mode: str = "bilinear",
183190
):
184191
super().__init__()
@@ -197,11 +204,14 @@ def __init__(
197204
)
198205

199206
for i in range(1, len(encoder_channels)):
200-
self.add_module(f"gau{len(encoder_channels)-i}", GAUBlock(
201-
in_channels=encoder_channels[i],
202-
out_channels=decoder_channels,
203-
upscale_mode=upscale_mode,
204-
))
207+
self.add_module(
208+
f"gau{len(encoder_channels)-i}",
209+
GAUBlock(
210+
in_channels=encoder_channels[i],
211+
out_channels=decoder_channels,
212+
upscale_mode=upscale_mode,
213+
),
214+
)
205215

206216
def forward(self, *features):
207217
features = features[2:] # remove first and second skip

segmentation_models_pytorch/decoders/pan/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Callable, Literal, Optional, Union
22

33
from segmentation_models_pytorch.base import (
44
ClassificationHead,
@@ -56,13 +56,13 @@ class PAN(SegmentationModel):
5656
def __init__(
5757
self,
5858
encoder_name: str = "resnet34",
59-
encoder_depth: int = 5,
59+
encoder_depth: Literal[3, 4, 5] = 5,
6060
encoder_weights: Optional[str] = "imagenet",
61-
encoder_output_stride: int = 16,
61+
encoder_output_stride: Literal[16, 32] = 16,
6262
decoder_channels: int = 32,
6363
in_channels: int = 3,
6464
classes: int = 1,
65-
activation: Optional[Union[str, callable]] = None,
65+
activation: Optional[Union[str, Callable]] = None,
6666
upsampling: int = 4,
6767
aux_params: Optional[dict] = None,
6868
**kwargs: dict[str, Any],

0 commit comments

Comments
 (0)