File tree Expand file tree Collapse file tree 6 files changed +91
-1
lines changed
segmentation_models_pytorch/base Expand file tree Collapse file tree 6 files changed +91
-1
lines changed Original file line number Diff line number Diff line change 9090 - name : Test with PyTest
9191 run : uv run pytest -v -rsx -n 2 -m "compile"
9292
93+ test_torch_export :
94+ runs-on : ubuntu-latest
95+ steps :
96+ - uses : actions/checkout@v4
97+ - name : Set up Python
98+ uses : astral-sh/setup-uv@v5
99+ with :
100+ python-version : " 3.10"
101+ - name : Install dependencies
102+ run : uv pip install -r requirements/required.txt -r requirements/test.txt
103+ - name : Show installed packages
104+ run : uv pip list
105+ - name : Test with PyTest
106+ run : uv run pytest -v -rsx -n 2 -m "torch_export"
107+
93108 minimum :
94109 runs-on : ubuntu-latest
95110 steps :
Original file line number Diff line number Diff line change @@ -65,6 +65,7 @@ include = ['segmentation_models_pytorch*']
6565markers = [
6666 " logits_match" ,
6767 " compile" ,
68+ " torch_export" ,
6869]
6970
7071[tool .coverage .run ]
Original file line number Diff line number Diff line change 33
44from . import initialization as init
55from .hub_mixin import SMPHubMixin
6+ from .utils import is_torch_compiling
67
78T = TypeVar ("T" , bound = "SegmentationModel" )
89
@@ -50,7 +51,11 @@ def check_input_shape(self, x):
5051 def forward (self , x ):
5152 """Sequentially pass `x` trough model`s encoder, decoder and heads"""
5253
53- if not torch .jit .is_tracing () and self .requires_divisible_input_shape :
54+ if (
55+ not torch .jit .is_tracing ()
56+ and not is_torch_compiling ()
57+ and self .requires_divisible_input_shape
58+ ):
5459 self .check_input_shape (x )
5560
5661 features = self .encoder (x )
Original file line number Diff line number Diff line change 1+ import torch
2+
3+
4+ def is_torch_compiling ():
5+ try :
6+ return torch .compiler .is_compiling ()
7+ except Exception :
8+ try :
9+ import torch ._dynamo as dynamo # noqa: F401
10+
11+ return dynamo .is_compiling ()
12+ except Exception :
13+ return False
Original file line number Diff line number Diff line change @@ -231,3 +231,31 @@ def test_compile(self):
231231
232232 with torch .inference_mode ():
233233 compiled_encoder (sample )
234+
235+ @pytest .mark .torch_export
236+ def test_torch_export (self ):
237+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
238+ self .skipTest ("No diff and not on `main`." )
239+
240+ sample = self ._get_sample (
241+ batch_size = self .default_batch_size ,
242+ num_channels = self .default_num_channels ,
243+ height = self .default_height ,
244+ width = self .default_width ,
245+ ).to (default_device )
246+
247+ encoder = self .get_tiny_encoder ()
248+ encoder = encoder .eval ().to (default_device )
249+
250+ exported_encoder = torch .export .export (
251+ encoder ,
252+ args = (sample ,),
253+ strict = True ,
254+ )
255+
256+ with torch .inference_mode ():
257+ eager_output = encoder (sample )
258+ exported_output = exported_encoder .module ().forward (sample )
259+
260+ for eager_feature , exported_feature in zip (eager_output , exported_output ):
261+ torch .testing .assert_close (eager_feature , exported_feature )
Original file line number Diff line number Diff line change @@ -254,3 +254,31 @@ def test_compile(self):
254254
255255 with torch .inference_mode ():
256256 compiled_model (sample )
257+
258+ @pytest .mark .torch_export
259+ def test_torch_export (self ):
260+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
261+ self .skipTest ("No diff and not on `main`." )
262+
263+ sample = self ._get_sample (
264+ batch_size = self .default_batch_size ,
265+ num_channels = self .default_num_channels ,
266+ height = self .default_height ,
267+ width = self .default_width ,
268+ ).to (default_device )
269+
270+ model = self .get_default_model ()
271+ model .eval ()
272+
273+ exported_model = torch .export .export (
274+ model ,
275+ args = (sample ,),
276+ strict = True ,
277+ )
278+
279+ with torch .inference_mode ():
280+ eager_output = model (sample )
281+ exported_output = exported_model .module ().forward (sample )
282+
283+ self .assertEqual (eager_output .shape , exported_output .shape )
284+ torch .testing .assert_close (eager_output , exported_output )
You can’t perform that action at this time.
0 commit comments