|
1 | | -from typing import Any, List, Optional, Union |
| 1 | +from typing import Any, Optional, Union, Tuple, Callable |
2 | 2 |
|
3 | 3 | from segmentation_models_pytorch.base import ( |
4 | 4 | ClassificationHead, |
5 | 5 | SegmentationHead, |
6 | 6 | SegmentationModel, |
7 | 7 | ) |
8 | 8 | from segmentation_models_pytorch.encoders import get_encoder |
| 9 | +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading |
9 | 10 |
|
10 | 11 | from .decoder import UnetDecoder |
11 | 12 |
|
@@ -55,17 +56,18 @@ class Unet(SegmentationModel): |
55 | 56 |
|
56 | 57 | """ |
57 | 58 |
|
| 59 | + @supports_config_loading |
58 | 60 | def __init__( |
59 | 61 | self, |
60 | 62 | encoder_name: str = "resnet34", |
61 | 63 | encoder_depth: int = 5, |
62 | 64 | encoder_weights: Optional[str] = "imagenet", |
63 | 65 | decoder_use_batchnorm: bool = True, |
64 | | - decoder_channels: List[int] = (256, 128, 64, 32, 16), |
| 66 | + decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), |
65 | 67 | decoder_attention_type: Optional[str] = None, |
66 | 68 | in_channels: int = 3, |
67 | 69 | classes: int = 1, |
68 | | - activation: Optional[Union[str, callable]] = None, |
| 70 | + activation: Optional[Union[str, Callable]] = None, |
69 | 71 | aux_params: Optional[dict] = None, |
70 | 72 | **kwargs: dict[str, Any], |
71 | 73 | ): |
|
0 commit comments