99from typing import Tuple
1010
1111import torch
12-
13- from executorch .backends .arm .test import common
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1417
1518from executorch .backends .arm .test .tester .test_pipeline import (
1619 EthosU55PipelineINT ,
1922 TosaPipelineINT ,
2023 VgfPipeline ,
2124)
22- from torchvision .ops import Permute
25+ from executorch .backends .arm .tosa import TosaSpecification
26+ from executorch .backends .xnnpack .test .tester import Quantize
2327
2428input_t1 = Tuple [torch .Tensor ] # Input x
2529
@@ -42,10 +46,10 @@ class SimplePermute(torch.nn.Module):
4246 def __init__ (self , dims : list [int ]):
4347 super ().__init__ ()
4448
45- self .permute = Permute ( dims = dims )
49+ self .dims = dims
4650
4751 def forward (self , x ):
48- return self .permute (x )
52+ return torch .permute (x , self . dims )
4953
5054
5155@common .parametrize ("test_data" , test_data_suite )
@@ -128,3 +132,98 @@ def test_permute_vgf_INT(test_data):
128132 tosa_version = "TOSA-1.0+INT" ,
129133 )
130134 pipeline .run ()
135+
136+
137+ def get_symmetric_a16w8_permute_quantizer (
138+ u55_config = False , per_channel_quantization = False
139+ ):
140+ tosa_version = conftest .get_option ("tosa_version" )
141+ tosa_profiles = {
142+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
143+ }
144+
145+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
146+ quantizer .set_global (
147+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
148+ )
149+
150+ return Quantize (
151+ quantizer ,
152+ get_symmetric_a16w8_quantization_config (
153+ is_per_channel = per_channel_quantization
154+ ),
155+ )
156+
157+
158+ @common .parametrize ("test_data" , test_data_suite )
159+ def test_permute_int16_tosa_INT (test_data : torch .Tensor ):
160+ """Test permute operation with int16 quantization"""
161+ test_data , dims = test_data ()
162+ pipeline = TosaPipelineINT [input_t1 ](
163+ SimplePermute (dims = dims ),
164+ (test_data ,),
165+ aten_op ,
166+ exir_op = [],
167+ per_channel_quantization = False ,
168+ use_to_edge_transform_and_lower = True ,
169+ tosa_extensions = ["int16" ],
170+ )
171+
172+ pipeline .change_args (
173+ "quantize" ,
174+ get_symmetric_a16w8_permute_quantizer (per_channel_quantization = False ),
175+ )
176+ # Run the pipeline
177+ pipeline .run ()
178+
179+
180+ test_data_suite_exact = {
181+ x : test_data_suite [x ] for x in test_data_suite if x != "rank_4_3"
182+ }
183+
184+
185+ @common .parametrize ("test_data" , test_data_suite_exact )
186+ @common .XfailIfNoCorstone300
187+ def test_permute_int16_u55_INT16 (test_data : torch .Tensor ):
188+ """Test permute operation with int16 quantization on U55"""
189+ test_data , dims = test_data ()
190+ pipeline = EthosU55PipelineINT [input_t1 ](
191+ SimplePermute (dims = dims ),
192+ (test_data ,),
193+ aten_op ,
194+ exir_ops = [],
195+ per_channel_quantization = True ,
196+ use_to_edge_transform_and_lower = True ,
197+ atol = 1e-02 ,
198+ rtol = 1e-02 ,
199+ run_on_fvp = True ,
200+ )
201+
202+ pipeline .change_args (
203+ "quantize" ,
204+ get_symmetric_a16w8_permute_quantizer (per_channel_quantization = False ),
205+ )
206+ pipeline .run ()
207+
208+
209+ @common .parametrize ("test_data" , test_data_suite )
210+ @common .XfailIfNoCorstone320
211+ def test_permute_int16_u85_INT16 (test_data : torch .Tensor ):
212+ """Test permute operation with int16 quantization on U85"""
213+ test_data , dims = test_data ()
214+ pipeline = EthosU85PipelineINT [input_t1 ](
215+ SimplePermute (dims = dims ),
216+ (test_data ,),
217+ aten_op ,
218+ exir_ops = [],
219+ use_to_edge_transform_and_lower = True ,
220+ atol = 1e-03 ,
221+ rtol = 1e-03 ,
222+ run_on_fvp = True ,
223+ )
224+
225+ pipeline .change_args (
226+ "quantize" ,
227+ get_symmetric_a16w8_permute_quantizer (per_channel_quantization = False ),
228+ )
229+ pipeline .run ()
0 commit comments