@@ -36,39 +36,45 @@ def forward(self, arg1, arg2, arg3, arg4, arg5):
3636 return x + x # Quantize by propagation.
3737
3838 def _test_cat (self , module , inputs , cat_num = 1 , quant = False , quant_ops = 2 ):
39- tester = Tester (module , inputs )
40-
41- if quant :
42- tester .quantize ()
43-
44- tester .export ().check_count ({"torch.ops.aten.cat" : 1 })
45- tester .dump_artifact ()
46-
47- if quant :
48- # Expect multiple quantize ops - one per input, cat, and add.
49- tester .check_node_count (
50- {
51- # Q/DQ pair for each input and quantized op. For most tests, there are
52- # two quantized ops - cat and add.
53- torch .ops .quantized_decomposed .quantize_per_tensor .default : (
54- cat_num + quant_ops
55- )
56- }
39+ for legacy_mode in (True , False ):
40+ tester = Tester (module , inputs )
41+
42+ if quant :
43+ tester .quantize ()
44+
45+ tester .export ().check_count ({"torch.ops.aten.cat" : 1 })
46+ tester .dump_artifact ()
47+
48+ if quant :
49+ # Expect multiple quantize ops - one per input, cat, and add.
50+ tester .check_node_count (
51+ {
52+ # Q/DQ pair for each input and quantized op. For most tests, there are
53+ # two quantized ops - cat and add.
54+ torch .ops .quantized_decomposed .quantize_per_tensor .default : (
55+ cat_num + quant_ops
56+ )
57+ }
58+ )
59+
60+
61+ if legacy_mode :
62+ tester .to_edge ()
63+ tester .partition ()
64+ else :
65+ tester .to_edge_transform_and_lower ()
66+
67+ if quant :
68+ tester .check_not (["torch.ops.quantized_decomposed" ])
69+
70+ (
71+ tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
72+ .check_not (["executorch_exir_dialects_edge__ops_aten_cat" ])
73+ .to_executorch ()
74+ .serialize ()
75+ .run_method_and_compare_outputs ()
5776 )
5877
59- tester .to_edge_transform_and_lower ()
60-
61- if quant :
62- tester .check_not (["torch.ops.quantized_decomposed" ])
63-
64- (
65- tester .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
66- .check_not (["executorch_exir_dialects_edge__ops_aten_cat" ])
67- .to_executorch ()
68- .serialize ()
69- .run_method_and_compare_outputs ()
70- )
71-
7278 def test_fp16_cat2 (self ):
7379 """
7480 Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
@@ -155,6 +161,26 @@ def test_fp32_cat_unsupported(self):
155161 .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
156162 )
157163
164+ def test_fp32_cat_unsupported_legacy_mode (self ):
165+ """
166+ XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
167+ """
168+ inputs = (
169+ torch .randn (1 , 2 , 3 ),
170+ torch .randn (3 , 2 , 3 ),
171+ torch .randn (2 , 2 , 3 ),
172+ torch .randn (5 , 2 , 3 ),
173+ torch .randn (1 , 2 , 3 ),
174+ )
175+ (
176+ Tester (self .Cat5 (), inputs )
177+ .export ()
178+ .check_count ({"torch.ops.aten.cat" : 1 })
179+ .to_edge ()
180+ .partition ()
181+ .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
182+ )
183+
158184 class CatNegativeDim (torch .nn .Module ):
159185 def __init__ (self ):
160186 super ().__init__ ()
0 commit comments