1- from typing import Any , Optional
1+ from typing import Any , Optional , Union
22
33import timm
44import torch
@@ -15,17 +15,17 @@ class TimmViTEncoder(nn.Module):
1515 - Ensures consistent multi-level feature extraction across all ViT models.
1616 """
1717
18- _is_torch_scriptable = True
18+ _is_torch_scriptable = False
1919 _is_torch_exportable = True
20- _is_torch_compilable = True
20+ _is_torch_compilable = False
2121
2222 def __init__ (
2323 self ,
2424 name : str ,
2525 pretrained : bool = True ,
2626 in_channels : int = 3 ,
2727 depth : int = 4 ,
28- output_indices : Optional [list [int ] | int ] = None ,
28+ output_indices : Optional [Union [ list [int ], int ] ] = None ,
2929 ** kwargs : dict [str , Any ],
3030 ):
3131 """
@@ -49,16 +49,14 @@ def __init__(
4949 super ().__init__ ()
5050 self .name = name
5151
52- output_stride = kwargs .pop ("output_stride" ,None )
52+ output_stride = kwargs .pop ("output_stride" , None )
5353 if output_stride is not None :
54- raise ValueError (
55- "Dilated mode not supported, set output stride to None"
56- )
54+ raise ValueError ("Dilated mode not supported, set output stride to None" )
5755
5856 # Default model configuration for feature extraction
5957 common_kwargs = dict (
6058 in_chans = in_channels ,
61- features_only = True ,
59+ features_only = False ,
6260 pretrained = pretrained ,
6361 out_indices = tuple (range (depth )),
6462 )
@@ -76,6 +74,23 @@ def __init__(
7674 feature_info = tmp_model .feature_info
7775 model_num_blocks = len (feature_info )
7876
77+ if output_indices is not None :
78+ if isinstance (output_indices , int ):
79+ output_indices = list (output_indices )
80+
81+ for output_index in output_indices :
82+ if output_indices < 0 or output_indices > model_num_blocks :
83+ raise ValueError (
84+ f"Output indices for feature extraction should be greater than 0 and less \
85+ than the number of blocks in the model ({ model_num_blocks } ), got { output_index } "
86+ )
87+
88+ if len (output_indices ) != depth :
89+ raise ValueError (
90+ f"Length of output indices for feature extraction should be equal to the depth of the encoder\
91+ architecture, got output indices length - { len (output_indices )} , encoder depth - { depth } "
92+ )
93+
7994 if depth > model_num_blocks :
8095 raise ValueError (
8196 f"Depth of the encoder cannot exceed the number of blocks in the model \
@@ -87,9 +102,6 @@ def __init__(
87102 int ((model_num_blocks / 4 ) * index ) - 1 for index in range (1 , depth + 1 )
88103 ]
89104
90- if isinstance (output_indices ,int ):
91- output_indices = list (output_indices )
92-
93105 common_kwargs ["out_indices" ] = self .out_indices = output_indices
94106 feature_info_obj = timm .models .FeatureInfo (
95107 feature_info = feature_info , out_indices = output_indices
@@ -109,18 +121,16 @@ def __init__(
109121 self ._output_stride = reduction_scales [0 ]
110122
111123 if (
112- int (self ._output_stride ).bit_count ( ) != 1
124+ bin (self ._output_stride ).count ( "1" ) != 1
113125 and not allow_output_stride_not_power_of_two
114126 ):
115127 raise ValueError (
116128 f"Models with stride which is not a power of 2 are not supported, \
117129 got output stride { self ._output_stride } "
118130 )
119131
120- self .prefix_token_supported = getattr (tmp_model , "has_class_token" , False )
132+ self .cls_token_supported = getattr (tmp_model , "has_class_token" , False )
121133 self .num_prefix_tokens = getattr (tmp_model , "num_prefix_tokens" , 0 )
122- if self .prefix_token_supported :
123- common_kwargs ["features_only" ] = False
124134
125135 self .model = timm .create_model (
126136 name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
@@ -131,47 +141,40 @@ def __init__(
131141 self ._depth = depth
132142 self ._embed_dim = tmp_model .embed_dim
133143
134- def forward (self , x : torch .Tensor ) -> list [list [torch .Tensor ], list [torch .Tensor ]]:
144+ def forward (self , x : torch .Tensor ) -> tuple [list [torch .Tensor ], list [torch .Tensor ]]:
135145 """
136146 Forward pass to extract multi-stage features.
137147
138148 Args:
139149 x (torch.Tensor): Input tensor of shape (B, C, H, W).
140150
141151 Returns:
142- list[torch.Tensor]: List of feature maps at different scales.
152+ tuple[ list[torch.Tensor], list[torch.Tensor]]: Tuple of feature maps and cls tokens (if supported) at different scales.
143153 """
144- if self .prefix_token_supported :
145- intermediate_outputs = self .model .forward_intermediates (
146- x ,
147- indices = self .out_indices ,
148- return_prefix_tokens = True ,
149- intermediates_only = True ,
150- )
151- features , cls_tokens = zip (* intermediate_outputs )
152-
153- # Convert NHWC to NCHW if needed
154- if self ._is_channel_last :
155- features = [
156- feature .permute (0 , 3 , 1 , 2 ).contiguous () for feature in features
157- ]
158-
159- if self .num_prefix_tokens > 1 :
160- cls_tokens = [cls_token [:, 0 , :] for cls_token in cls_tokens ]
154+ intermediate_outputs = self .model .forward_intermediates (
155+ x ,
156+ indices = self .out_indices ,
157+ return_prefix_tokens = True ,
158+ intermediates_only = True ,
159+ )
161160
162- return [ features , cls_tokens ]
161+ cls_tokens = [ None ] * len ( self . out_indices )
163162
164- features = self .model (x )
163+ if self .num_prefix_tokens > 0 :
164+ features , prefix_tokens = zip (* intermediate_outputs )
165+ if self .cls_token_supported :
166+ if self .num_prefix_tokens == 1 :
167+ cls_tokens = prefix_tokens
165168
166- # Convert NHWC to NCHW if needed
167- if self ._is_channel_last :
168- features = [
169- feature .permute (0 , 3 , 1 , 2 ).contiguous () for feature in features
170- ]
169+ elif self .num_prefix_tokens > 1 :
170+ cls_tokens = [
171+ prefix_token [:, 0 , :] for prefix_token in prefix_tokens
172+ ]
171173
172- cls_tokens = [None ] * len (features )
174+ else :
175+ features = intermediate_outputs
173176
174- return [ features , cls_tokens ]
177+ return features , cls_tokens
175178
176179 @property
177180 def embed_dim (self ) -> int :
0 commit comments