@@ -33,10 +33,6 @@ class BaseEncoderTester(unittest.TestCase):
3333 depth_to_test = [3 , 4 , 5 ]
3434 strides_to_test = [8 , 16 ] # 32 is a default one
3535
36- # enable/disable tests
37- do_test_torch_compile = True
38- do_test_torch_export = True
39-
4036 def get_tiny_encoder (self ):
4137 return smp .encoders .get_encoder (self .encoder_names [0 ], encoder_weights = None )
4238
@@ -208,28 +204,25 @@ def test_dilated(self):
208204
209205 @pytest .mark .compile
210206 def test_compile (self ):
211- if not self .do_test_torch_compile :
212- self .skipTest (
213- f"torch_compile test is disabled for { self .encoder_names [0 ]} ."
214- )
215-
216207 if not check_run_test_on_diff_or_main (self .files_for_diff ):
217208 self .skipTest ("No diff and not on `main`." )
218209
219210 sample = self ._get_sample ().to (default_device )
220211
221- encoder = self .get_tiny_encoder ().eval ().to (default_device )
212+ encoder = self .get_tiny_encoder ()
213+ encoder = encoder .eval ().to (default_device )
214+
222215 compiled_encoder = torch .compile (encoder , fullgraph = True , dynamic = True )
223216
224- with torch . inference_mode () :
217+ if encoder . _is_torch_compilable :
225218 compiled_encoder (sample )
219+ else :
220+ with self .assertRaises (Exception ):
221+ compiled_encoder (sample )
226222
227223 @pytest .mark .torch_export
228224 @requires_torch_greater_or_equal ("2.4.0" )
229225 def test_torch_export (self ):
230- if not self .do_test_torch_export :
231- self .skipTest (f"torch_export test is disabled for { self .encoder_names [0 ]} ." )
232-
233226 if not check_run_test_on_diff_or_main (self .files_for_diff ):
234227 self .skipTest ("No diff and not on `main`." )
235228
0 commit comments