1414
1515class  TestCat (unittest .TestCase ):
1616    class  Cat (torch .nn .Module ):
17+         def  __init__ (self , dim = 0 ):
18+             super ().__init__ ()
19+             self .dim  =  dim 
20+ 
1721        def  forward (self , * args ):
1822            xs  =  [* args ]
19-             x  =  torch .cat (xs )
23+             x  =  torch .cat (xs ,  dim = self . dim )
2024            return  x  +  x   # Quantize by propagation. 
2125
2226    def  _test_cat (self , module , inputs , cat_num = 1 , quant = False , quant_ops = 2 ):
@@ -27,7 +31,6 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
2731                tester .quantize ()
2832
2933            tester .export ().check_count ({"torch.ops.aten.cat" : 1 })
30-             tester .dump_artifact ()
3134
3235            if  quant :
3336                # Expect multiple quantize ops - one per input, cat, and add. 
@@ -93,6 +96,29 @@ def test_fp16_cat4(self):
9396        )
9497        self ._test_cat (self .Cat (), inputs )
9598
99+     def  test_fp16_cat5 (self ):
100+         """ 
101+         Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. 
102+         """ 
103+         inputs  =  (
104+             torch .randn (1 , 2 , 3 ).to (torch .float16 ),
105+             torch .randn (3 , 2 , 3 ).to (torch .float16 ),
106+             torch .randn (2 , 2 , 3 ).to (torch .float16 ),
107+             torch .randn (5 , 2 , 3 ).to (torch .float16 ),
108+             torch .randn (5 , 2 , 3 ).to (torch .float16 ),
109+         )
110+         self ._test_cat (self .Cat (), inputs )
111+ 
112+     def  test_fp16_cat_gt_5 (self ):
113+         """ 
114+         Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. 
115+         """ 
116+         for  num_inputs  in  range (6 , 10 ):
117+             inputs  =  []
118+             for  _  in  range (num_inputs ):
119+                 inputs .append (torch .randn (1 , 2 , 3 ).to (torch .float16 ))
120+             self ._test_cat (self .Cat (), tuple (inputs ))
121+ 
96122    def  test_fp32_cat2 (self ):
97123        inputs  =  (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
98124        self ._test_cat (self .Cat (), inputs )
@@ -120,6 +146,13 @@ def test_fp32_cat5(self):
120146        )
121147        self ._test_cat (self .Cat (), inputs )
122148
149+     def  test_fp32_cat_gt_5 (self ):
150+         for  num_inputs  in  range (6 , 10 ):
151+             inputs  =  []
152+             for  _  in  range (num_inputs ):
153+                 inputs .append (torch .randn (1 , 2 , 3 ))
154+             self ._test_cat (self .Cat (), tuple (inputs ))
155+ 
123156    def  test_qs8_cat2 (self ):
124157        inputs  =  (torch .randn (1 , 2 , 3 ), torch .randn (3 , 2 , 3 ))
125158        self ._test_cat (self .Cat (), inputs , cat_num = 2 , quant = True )
@@ -137,46 +170,22 @@ def test_qs8_cat4(self):
137170        )
138171        self ._test_cat (self .Cat (), inputs , cat_num = 4 , quant = True )
139172
140-     def  test_fp32_cat_unsupported (self ):
141-         """ 
142-         XNNPACK only supports concatenating up to 4 values, so it should not delegate here. 
143-         """ 
173+     def  test_qs8_cat5 (self ):
144174        inputs  =  (
145175            torch .randn (1 , 2 , 3 ),
146176            torch .randn (3 , 2 , 3 ),
147177            torch .randn (2 , 2 , 3 ),
148178            torch .randn (5 , 2 , 3 ),
149-             torch .randn (1 , 2 , 3 ),
150-             torch .randn (2 , 2 , 3 ),
151-         )
152-         (
153-             Tester (self .Cat (), inputs )
154-             .export ()
155-             .check_count ({"torch.ops.aten.cat" : 1 })
156-             .to_edge_transform_and_lower ()
157-             .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
158-         )
159- 
160-     def  test_fp32_cat_unsupported_legacy_mode (self ):
161-         """ 
162-         XNNPACK only supports concatenating up to 5 values, so it should not delegate here. 
163-         """ 
164-         inputs  =  (
165-             torch .randn (1 , 2 , 3 ),
166-             torch .randn (3 , 2 , 3 ),
167-             torch .randn (2 , 2 , 3 ),
168179            torch .randn (5 , 2 , 3 ),
169-             torch .randn (1 , 2 , 3 ),
170-             torch .randn (6 , 2 , 3 ),
171-         )
172-         (
173-             Tester (self .Cat (), inputs )
174-             .export ()
175-             .check_count ({"torch.ops.aten.cat" : 1 })
176-             .to_edge ()
177-             .partition ()
178-             .check_count ({"executorch_exir_dialects_edge__ops_aten_cat" : 1 })
179180        )
181+         self ._test_cat (self .Cat (), inputs , cat_num = 5 , quant = True )
182+ 
183+     def  test_qs8_cat_gt_5 (self ):
184+         for  num_inputs  in  range (6 , 10 ):
185+             inputs  =  []
186+             for  _  in  range (num_inputs ):
187+                 inputs .append (torch .randn (1 , 2 , 3 ))
188+             self ._test_cat (self .Cat (), tuple (inputs ), cat_num = num_inputs , quant = True )
180189
181190    class  CatNegativeDim (torch .nn .Module ):
182191        def  __init__ (self ):
0 commit comments