11# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22
3+ import logging
34import math
45import unittest
6+ from typing import cast
57
68import executorch .backends .cadence .aot .ops_registrations # noqa
79import torch
@@ -110,7 +112,121 @@ def forward(self, x):
110112
111113
112114class TestMemTransform (unittest .TestCase ):
113- def test_optimize_cat (self ):
115+ def _verify_cat_nop_memory_alloc (self , node : torch .fx .Node ) -> None :
116+ spec = node .meta .get ("spec" , None )
117+ self .assertIsNotNone (spec )
118+ dim : int = cast (int , node .args [1 ]) if len (node .args ) > 1 else 0
119+ outer_size = math .prod (spec .shape [:dim ])
120+ self .assertEqual (
121+ outer_size ,
122+ 1 ,
123+ f"{ node = } has wrong outer size: { outer_size = } , expected 1." ,
124+ )
125+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
126+ dim_offset = 0
127+ for arg in cast (list [torch .fx .Node ], node .args [0 ]):
128+ arg_spec = arg .meta .get ("spec" , None )
129+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
130+ self .assertEqual (
131+ arg_spec .mem_offset ,
132+ spec .mem_offset + dim_offset * inner_dim_elements ,
133+ f"{ arg = } for node { node = } has wrong memory offset: { arg_spec .mem_offset = } { dim_offset = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
134+ )
135+ dim_offset += arg_spec .shape [dim ]
136+
137+ def _verify_slice_nop_memory_alloc (self , node : torch .fx .Node ) -> None :
138+ spec = node .meta .get ("spec" , None )
139+ self .assertIsNotNone (spec )
140+ dim : int = cast (int , node .args [1 ]) if len (node .args ) > 1 else 0
141+ outer_size = math .prod (spec .shape [:dim ])
142+ self .assertEqual (
143+ outer_size ,
144+ 1 ,
145+ f"{ node = } has wrong outer size: { outer_size = } , expected 1." ,
146+ )
147+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
148+ start : int = (
149+ cast (int , node .args [2 ])
150+ if (len (node .args ) > 2 and node .args [2 ] is not None )
151+ else 0
152+ )
153+ arg = cast (torch .fx .Node , node .args [0 ])
154+ arg_spec = arg .meta .get ("spec" , None )
155+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
156+ self .assertEqual (
157+ spec .mem_offset ,
158+ arg_spec .mem_offset + start * inner_dim_elements ,
159+ f"{ arg = } for node { node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for slice on { dim = } , but output has { spec .mem_offset = } " ,
160+ )
161+
162+ def _verify_select_nop_memory_alloc (self , node : torch .fx .Node ) -> None :
163+ spec = node .meta .get ("spec" , None )
164+ self .assertIsNotNone (spec )
165+ dim : int = cast (int , node .args [1 ]) if len (node .args ) > 1 else 0
166+ outer_size = math .prod (spec .shape [:dim ])
167+ self .assertEqual (
168+ outer_size ,
169+ 1 ,
170+ f"{ node = } has wrong outer size: { outer_size = } , expected 1." ,
171+ )
172+ inner_dim_elements = math .prod (spec .shape [dim :]) * spec .dtype .itemsize
173+ index : int = (
174+ cast (int , node .args [2 ])
175+ if (len (node .args ) > 2 and node .args [2 ] is not None )
176+ else 0
177+ )
178+ arg = cast (torch .fx .Node , node .args [0 ])
179+ arg_spec = arg .meta .get ("spec" , None )
180+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
181+ self .assertEqual (
182+ spec .mem_offset ,
183+ arg_spec .mem_offset + index * inner_dim_elements ,
184+ f"{ arg = } for node { node = } has wrong memory offset: { arg_spec .mem_offset = } for select on { dim = } { index = } , "
185+ f"but output has { spec .mem_offset = } "
186+ f"{ spec = } { arg_spec = } " ,
187+ )
188+
189+ def verify_nop_memory_alloc (self , graph_module ):
190+ for node in graph_module .graph .find_nodes (
191+ op = "call_function" , target = torch .ops .aten ._cat_nop .out
192+ ):
193+ self ._verify_cat_nop_memory_alloc (node )
194+
195+ for node in graph_module .graph .find_nodes (
196+ op = "call_function" , target = torch .ops .aten ._slice_copy_nop .Tensor_out
197+ ):
198+ self ._verify_slice_nop_memory_alloc (node )
199+
200+ for node in graph_module .graph .find_nodes (
201+ op = "call_function" , target = torch .ops .aten ._select_copy_nop .int_out
202+ ):
203+ self ._verify_select_nop_memory_alloc (node )
204+
205+ def test_optimize_cat_on_placeholders (self ):
206+ class Cat (torch .nn .Module ):
207+ def forward (self , x , y ):
208+ return torch .ops .aten .cat ((x , y ))
209+
210+ x = torch .ones (3 , 6 )
211+ y = torch .ones (2 , 6 )
212+ # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
213+ # pass to run:
214+ graph_module = (
215+ compiler .export_to_executorch_gen_etrecord (
216+ Cat (), (x , y ), opt_level = 2 , mem_algo = 1
217+ )
218+ .exported_program ()
219+ .graph_module
220+ )
221+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
222+ graph_module .graph .eliminate_dead_code ()
223+ # Assert that cat op is optimized away
224+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
225+ # Assert that cat op is replaced by its nop version post optimization
226+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
227+ self .verify_nop_memory_alloc (graph_module )
228+
229+ def test_optimize_cat_outermost (self ):
114230 class OptimizeCatFeasible1 (torch .nn .Module ):
115231 def forward (self , x , y ):
116232 x1 = torch .add (x , 2.4 , 3.1 )
@@ -135,7 +251,9 @@ def forward(self, x, y):
135251 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
136252 # Assert that cat op is replaced by its nop version post optimization
137253 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
254+ self .verify_nop_memory_alloc (graph_module )
138255
256+ def test_optimize_cat_non_outermost (self ):
139257 class OptimizeCatFeasible2 (torch .nn .Module ):
140258 def forward (self , x , y ):
141259 x1 = torch .add (x , 2.4 , 3.1 )
@@ -160,7 +278,9 @@ def forward(self, x, y):
160278 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
161279 # Assert that cat op is replaced by its nop version post optimization
162280 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
281+ self .verify_nop_memory_alloc (graph_module )
163282
283+ def test_no_optimize_cat_non_outermost (self ):
164284 class OptimizeCatInfeasible1 (torch .nn .Module ):
165285 def forward (self , x , y ):
166286 x1 = torch .add (x , 2.4 , 3.1 )
@@ -184,7 +304,9 @@ def forward(self, x, y):
184304 # Assert that cat op is not optimized away, since the concat is not
185305 # along the outermost dim
186306 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
307+ self .verify_nop_memory_alloc (graph_module )
187308
309+ def test_no_optimize_cat_non_outermost1 (self ):
188310 class OptimizeCatInfeasible2 (torch .nn .Module ):
189311 def forward (self , x , y ):
190312 x1 = torch .add (x , 2.4 , 3.1 )
@@ -209,6 +331,7 @@ def forward(self, x, y):
209331 # offsets are not multiple of 8 bytes, and the cat is not the output
210332 # of the graph.
211333 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
334+ self .verify_nop_memory_alloc (graph_module )
212335
213336 def test_optimize_cat_with_slice (self ):
214337 class OptimizeCatSliceFeasible (torch .nn .Module ):
@@ -237,6 +360,7 @@ def forward(self, x):
237360 graph_module .graph .eliminate_dead_code ()
238361 # Assert that cat op is optimized away
239362 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
363+ self .verify_nop_memory_alloc (graph_module )
240364
241365 def test_optimize_cat_with_slice_infeasible (self ):
242366 class OptimizeCatSliceInfeasible (torch .nn .Module ):
@@ -262,6 +386,7 @@ def forward(self, x, y):
262386 graph_module .graph .eliminate_dead_code ()
263387 # Assert that cat op is not optimized away
264388 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
389+ self .verify_nop_memory_alloc (graph_module )
265390
266391 def test_optimize_slice_Tensor (self ):
267392 class SliceTensor (torch .nn .Module ):
@@ -323,6 +448,7 @@ def forward(self, x, y, z):
323448 self .assertEqual (
324449 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
325450 )
451+ self .verify_nop_memory_alloc (graph_module )
326452
327453 def test_optimize_select_Tensor (self ):
328454 class SelectTensor (torch .nn .Module ):
@@ -387,6 +513,7 @@ def forward(self, x, y, z):
387513 self .assertEqual (
388514 count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
389515 )
516+ self .verify_nop_memory_alloc (graph_module )
390517
391518 # TODO: Test fails due to memory planning
392519 @unittest .expectedFailure
@@ -416,6 +543,32 @@ def forward(self, x, y):
416543 graph_module .graph .eliminate_dead_code ()
417544 # Assert that cat op is not optimized away
418545 self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
546+ self .verify_nop_memory_alloc (graph_module )
547+
548+ def test_optimize_cat_then_slice_on_mutable_buffer (self ):
549+ class CatWithPadding (torch .nn .Module ):
550+ def __init__ (self , padding_shape ):
551+ super ().__init__ ()
552+ zeros = torch .zeros (padding_shape )
553+ self .register_buffer ("padding" , zeros )
554+
555+ def forward (self , x , y ):
556+ x = x .view (3 , 5 )
557+ cat = torch .ops .aten .cat ((x , self .padding .clone ()))
558+ slice_copy = torch .ops .aten .slice (cat , dim = 0 , start = x .shape [0 ])
559+ self .padding .copy_ (slice_copy )
560+ return cat .view (- 1 ) + y
561+
562+ x = torch .ones (15 )
563+ y = torch .ones (1 )
564+ et_prog_manager = compiler .export_to_executorch_gen_etrecord (
565+ CatWithPadding ((1 , 5 )), (x , y ), opt_level = 3
566+ )
567+ graph_module = et_prog_manager .exported_program ().graph_module
568+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
569+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
570+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
571+ self .verify_nop_memory_alloc (graph_module )
419572
420573 def test_optimize_cat_with_view (self ):
421574 class CatViewFeasible (torch .nn .Module ):
@@ -442,6 +595,7 @@ def forward(self, x, y):
442595 # Assert that cat op is optimized away
443596 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
444597 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
598+ self .verify_nop_memory_alloc (graph_module )
445599
446600 def test_no_optimize_cat_with_repeated_args (self ):
447601 class CatViewInfeasible (torch .nn .Module ):
@@ -465,6 +619,7 @@ def forward(self, x):
465619 # Assert that cat op is not optimized away
466620 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
467621 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
622+ self .verify_nop_memory_alloc (graph_module )
468623
469624 def test_no_optimize_cat_with_placeholder (self ):
470625 class CatViewInfeasible (torch .nn .Module ):
@@ -492,6 +647,7 @@ def forward(self, x, y):
492647 # Assert that cat op is not optimized away
493648 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
494649 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
650+ self .verify_nop_memory_alloc (graph_module )
495651
496652 def test_no_optimize_cat (self ) -> None :
497653 class Model (torch .nn .Module ):
@@ -522,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
522678 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
523679 )
524680 self .assertEqual (count_node (graph_module , memory .view ), 2 )
681+ self .verify_nop_memory_alloc (graph_module )
525682
526683 def test_optimize_slice_copy (self ) -> None :
527684 class Model (torch .nn .Module ):
@@ -553,6 +710,7 @@ def forward(self, x) -> torch.Tensor:
553710 count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 0
554711 )
555712 self .assertEqual (count_node (graph_module , memory .view ), 2 )
713+ self .verify_nop_memory_alloc (graph_module )
556714
557715 def test_cat_then_cat (self ) -> None :
558716 class Model (torch .nn .Module ):
@@ -579,6 +737,7 @@ def forward(self, x) -> torch.Tensor:
579737 graph_module .print_readable ()
580738 self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 2 )
581739 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
740+ self .verify_nop_memory_alloc (graph_module )
582741
583742 def test_view_for_unallocated_output (self ):
584743 class Model (torch .nn .Module ):
@@ -602,3 +761,4 @@ def forward(self, x, y):
602761 .graph_module
603762 )
604763 self .assertEqual (count_node (graph_module , memory .view ), 1 )
764+ self .verify_nop_memory_alloc (graph_module )
0 commit comments