11
11
from typing import Any , Dict , Optional , Tuple
12
12
13
13
import torch
14
+ from torch import nn
14
15
from torchtnt .utils .module_summary import (
15
16
_get_human_readable_count ,
16
17
get_module_summary ,
17
18
ModuleSummary ,
18
19
prune_module_summary ,
19
20
)
20
- from torchvision import models
21
- from torchvision .models .alexnet import AlexNet_Weights
22
- from torchvision .models .resnet import ResNet18_Weights
23
21
24
22
25
23
def get_summary_and_prune (
@@ -156,54 +154,6 @@ def test_lazy_tensor_flops(self) -> None:
156
154
self .assertEqual (ms .flops_backward , "?" )
157
155
self .assertEqual (ms .flops_forward , "?" )
158
156
159
- def test_resnet_max_depth (self ) -> None :
160
- """Test the behavior of max_depth on a layered model like ResNet"""
161
- pretrained_model = models .resnet .resnet18 (
162
- weights = ResNet18_Weights .IMAGENET1K_V1
163
- )
164
-
165
- # max_depth = None
166
- ms1 = get_module_summary (pretrained_model )
167
-
168
- self .assertEqual (len (ms1 .submodule_summaries ), 10 )
169
- self .assertEqual (len (ms1 .submodule_summaries ["layer2" ].submodule_summaries ), 2 )
170
- self .assertEqual (
171
- len (
172
- ms1 .submodule_summaries ["layer2" ]
173
- .submodule_summaries ["layer2.0" ]
174
- .submodule_summaries
175
- ),
176
- 6 ,
177
- )
178
-
179
- ms2 = get_summary_and_prune (pretrained_model , max_depth = 1 )
180
- self .assertEqual (len (ms2 .submodule_summaries ), 0 )
181
- self .assertNotIn ("layer2" , ms2 .submodule_summaries )
182
-
183
- ms3 = get_summary_and_prune (pretrained_model , max_depth = 2 )
184
- self .assertEqual (len (ms3 .submodule_summaries ), 10 )
185
- self .assertEqual (len (ms1 .submodule_summaries ["layer2" ].submodule_summaries ), 2 )
186
- self .assertNotIn (
187
- "layer2.0" , ms3 .submodule_summaries ["layer2" ].submodule_summaries
188
- )
189
- inp = torch .randn (1 , 3 , 224 , 224 )
190
- ms4 = get_summary_and_prune (pretrained_model , max_depth = 2 , module_args = (inp ,))
191
-
192
- self .assertEqual (len (ms4 .submodule_summaries ), 10 )
193
- self .assertEqual (ms4 .flops_forward , 1814073344 )
194
- self .assertEqual (ms4 .flops_backward , 3510132736 )
195
- self .assertEqual (ms4 .submodule_summaries ["layer2" ].flops_forward , 411041792 )
196
- self .assertEqual (ms4 .submodule_summaries ["layer2" ].flops_backward , 822083584 )
197
-
198
- # These should stay constant for all max_depth values
199
- ms_list = [ms1 , ms2 , ms3 , ms4 ]
200
- for ms in ms_list :
201
- self .assertEqual (ms .module_name , "" )
202
- self .assertEqual (ms .module_type , "ResNet" )
203
- self .assertEqual (ms .num_parameters , 11689512 )
204
- self .assertEqual (ms .num_trainable_parameters , 11689512 )
205
- self .assertFalse (ms .has_uninitialized_param )
206
-
207
157
def test_module_summary_layer_print (self ) -> None :
208
158
model = torch .nn .Conv2d (3 , 8 , 3 )
209
159
ms1 = get_module_summary (model )
@@ -215,69 +165,6 @@ def test_module_summary_layer_print(self) -> None:
215
165
"""
216
166
self ._test_module_summary_text (summary_table , str (ms1 ))
217
167
218
- def test_alexnet_print (self ) -> None :
219
- pretrained_model = models .alexnet (weights = AlexNet_Weights .IMAGENET1K_V1 )
220
- ms1 = get_summary_and_prune (pretrained_model , max_depth = 1 )
221
- ms2 = get_summary_and_prune (pretrained_model , max_depth = 2 )
222
- ms3 = get_summary_and_prune (pretrained_model , max_depth = 3 )
223
- ms4 = get_module_summary (pretrained_model )
224
-
225
- summary_table1 = """
226
- Name | Type | # Parameters | # Trainable Parameters | Size (bytes) | Contains Uninitialized Parameters?
227
- ----------------------------------------------------------------------------------------------------------
228
- | AlexNet | 61.1 M | 61.1 M | 244 M | No
229
- """
230
- summary_table2 = """
231
- Name | Type | # Parameters | # Trainable Parameters | Size (bytes) | Contains Uninitialized Parameters?
232
- --------------------------------------------------------------------------------------------------------------------------
233
- | AlexNet | 61.1 M | 61.1 M | 244 M | No
234
- features | Sequential | 2.5 M | 2.5 M | 9.9 M | No
235
- avgpool | AdaptiveAvgPool2d | 0 | 0 | 0 | No
236
- classifier | Sequential | 58.6 M | 58.6 M | 234 M | No
237
- """
238
-
239
- self ._test_module_summary_text (summary_table1 , str (ms1 ))
240
- self ._test_module_summary_text (summary_table2 , str (ms2 ))
241
- self .assertEqual (str (ms3 ), str (ms4 ))
242
-
243
- def test_alexnet_with_input_tensor (self ) -> None :
244
- pretrained_model = models .alexnet (weights = AlexNet_Weights .IMAGENET1K_V1 )
245
- inp = torch .randn (1 , 3 , 224 , 224 )
246
- ms1 = get_summary_and_prune (pretrained_model , max_depth = 1 , module_args = (inp ,))
247
- ms2 = get_summary_and_prune (pretrained_model , max_depth = 2 , module_args = (inp ,))
248
-
249
- self .assertEqual (ms1 .module_type , "AlexNet" )
250
- self .assertEqual (ms1 .num_parameters , 61100840 )
251
- self .assertFalse (ms1 .has_uninitialized_param )
252
- self .assertEqual (ms1 .flops_forward , 714188480 )
253
- self .assertEqual (ms1 .flops_backward , 1358100160 )
254
- self .assertEqual (ms1 .in_size , [1 , 3 , 224 , 224 ])
255
- self .assertEqual (ms1 .out_size , [1 , 1000 ])
256
-
257
- ms_features = ms2 .submodule_summaries ["features" ]
258
- self .assertEqual (ms_features .module_type , "Sequential" )
259
- self .assertFalse (ms_features .has_uninitialized_param )
260
- self .assertEqual (ms_features .flops_forward , 655566528 )
261
- self .assertEqual (ms_features .flops_backward , 1240856256 )
262
- self .assertEqual (ms_features .in_size , [1 , 3 , 224 , 224 ])
263
- self .assertEqual (ms_features .out_size , [1 , 256 , 6 , 6 ])
264
-
265
- ms_avgpool = ms2 .submodule_summaries ["avgpool" ]
266
- self .assertEqual (ms_avgpool .module_type , "AdaptiveAvgPool2d" )
267
- self .assertFalse (ms_avgpool .has_uninitialized_param )
268
- self .assertEqual (ms_avgpool .flops_forward , 0 )
269
- self .assertEqual (ms_avgpool .flops_backward , 0 )
270
- self .assertEqual (ms_avgpool .in_size , [1 , 256 , 6 , 6 ])
271
- self .assertEqual (ms_avgpool .out_size , [1 , 256 , 6 , 6 ])
272
-
273
- ms_classifier = ms2 .submodule_summaries ["classifier" ]
274
- self .assertEqual (ms_classifier .module_type , "Sequential" )
275
- self .assertFalse (ms_classifier .has_uninitialized_param )
276
- self .assertEqual (ms_classifier .flops_forward , 58621952 )
277
- self .assertEqual (ms_classifier .flops_backward , 117243904 )
278
- self .assertEqual (ms_classifier .in_size , [1 , 9216 ])
279
- self .assertEqual (ms_classifier .out_size , [1 , 1000 ])
280
-
281
168
def test_get_human_readable_count (self ) -> None :
282
169
with self .assertRaisesRegex (ValueError , "received -1" ):
283
170
_get_human_readable_count (- 1 )
@@ -350,7 +237,9 @@ def forward(self, x, y, offset=1):
350
237
self .assertEqual (ms_classifier .out_size , [1 , 1 , 224 , 224 ])
351
238
352
239
def test_forward_elapsed_time (self ) -> None :
353
- pretrained_model = models .alexnet (weights = AlexNet_Weights .IMAGENET1K_V1 )
240
+ pretrained_model = nn .Sequential (
241
+ nn .Conv2d (3 , 20 , 5 ), nn .ReLU (), nn .Conv2d (20 , 64 , 5 ), nn .ReLU ()
242
+ )
354
243
inp = torch .randn (1 , 3 , 224 , 224 )
355
244
ms1 = get_summary_and_prune (pretrained_model , module_args = (inp ,), max_depth = 4 )
356
245
stack = [ms1 ] + [
0 commit comments