@@ -18,7 +18,8 @@ pub fn test_permute_2d_transpose<R: Runtime, C: Float + CubeElement>(
1818 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
1919
2020 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
21- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ height, width] , handle) ;
21+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
22+ let input = TensorHandle :: < R > :: new_contiguous ( vec ! [ height, width] , handle, dtype) ;
2223 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 1 , 0 ] ) ;
2324
2425 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -60,7 +61,8 @@ pub fn test_permute_3d_batch_transpose<R: Runtime, C: Float + CubeElement>(
6061 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
6162
6263 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
63- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ batch, height, width] , handle) ;
64+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
65+ let input = TensorHandle :: < R > :: new_contiguous ( vec ! [ batch, height, width] , handle, dtype) ;
6466 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 0 , 2 , 1 ] ) ;
6567
6668 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -105,7 +107,8 @@ pub fn test_permute_3d_complex<R: Runtime, C: Float + CubeElement>(
105107 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
106108
107109 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
108- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ dim0, dim1, dim2] , handle) ;
110+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
111+ let input = TensorHandle :: < R > :: new_contiguous ( vec ! [ dim0, dim1, dim2] , handle, dtype) ;
109112 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 2 , 0 , 1 ] ) ;
110113
111114 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -141,7 +144,8 @@ pub fn test_permute_3d_complex<R: Runtime, C: Float + CubeElement>(
141144pub fn test_permute_empty < R : Runtime , C : Float + CubeElement > ( device : & R :: Device ) {
142145 let client = R :: client ( device) ;
143146
144- let input = TensorHandle :: < R , C > :: empty ( & client, vec ! [ 0 , 5 ] ) ;
147+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
148+ let input = TensorHandle :: < R > :: empty ( & client, vec ! [ 0 , 5 ] , dtype) ;
145149 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 1 , 0 ] ) ;
146150
147151 assert_eq ! ( output. shape, vec![ 5 , 0 ] ) ;
@@ -152,7 +156,8 @@ pub fn test_permute_single_element<R: Runtime, C: Float + CubeElement>(device: &
152156 let client = R :: client ( device) ;
153157
154158 let handle = client. create ( C :: as_bytes ( & [ C :: from ( 42.0 ) . unwrap ( ) ] ) ) ;
155- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ 1 , 1 ] , handle) ;
159+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
160+ let input = TensorHandle :: < R > :: new_contiguous ( vec ! [ 1 , 1 ] , handle, dtype) ;
156161 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 1 , 0 ] ) ;
157162
158163 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -180,7 +185,9 @@ pub fn test_permute_4d_last_two_transpose<R: Runtime, C: Float + CubeElement>(
180185 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
181186
182187 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
183- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ batch, channels, height, width] , handle) ;
188+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
189+ let input =
190+ TensorHandle :: < R > :: new_contiguous ( vec ! [ batch, channels, height, width] , handle, dtype) ;
184191 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 0 , 1 , 3 , 2 ] ) ;
185192
186193 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -235,7 +242,9 @@ pub fn test_permute_4d_complex<R: Runtime, C: Float + CubeElement>(
235242 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
236243
237244 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
238- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ batch, channels, height, width] , handle) ;
245+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
246+ let input =
247+ TensorHandle :: < R > :: new_contiguous ( vec ! [ batch, channels, height, width] , handle, dtype) ;
239248 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 0 , 3 , 1 , 2 ] ) ;
240249
241250 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -288,7 +297,9 @@ pub fn test_permute_channel_shuffle<R: Runtime, C: Float + CubeElement>(
288297 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
289298
290299 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
291- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ batch, channels, height, width] , handle) ;
300+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
301+ let input =
302+ TensorHandle :: < R > :: new_contiguous ( vec ! [ batch, channels, height, width] , handle, dtype) ;
292303 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 0 , 2 , 3 , 1 ] ) ;
293304
294305 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -341,7 +352,9 @@ pub fn test_permute_attention_transpose<R: Runtime, C: Float + CubeElement>(
341352 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
342353
343354 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
344- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ batch, heads, seq_len, head_dim] , handle) ;
355+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
356+ let input =
357+ TensorHandle :: < R > :: new_contiguous ( vec ! [ batch, heads, seq_len, head_dim] , handle, dtype) ;
345358 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 0 , 2 , 1 , 3 ] ) ;
346359
347360 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
@@ -393,7 +406,8 @@ pub fn test_permute_small_transpose<R: Runtime, C: Float + CubeElement>(
393406 let input_data: Vec < C > = ( 0 ..numel) . map ( |i| C :: from ( i as f32 ) . unwrap ( ) ) . collect ( ) ;
394407
395408 let handle = client. create ( C :: as_bytes ( & input_data) ) ;
396- let input = TensorHandle :: < R , C > :: new_contiguous ( vec ! [ size, size] , handle) ;
409+ let dtype = C :: as_type_native ( ) . unwrap ( ) ;
410+ let input = TensorHandle :: < R > :: new_contiguous ( vec ! [ size, size] , handle, dtype) ;
397411 let output = tensor:: permute:: launch_alloc ( & client, & input, & [ 1 , 0 ] ) ;
398412
399413 let actual = client. read_one_tensor ( output. handle . clone ( ) . copy_descriptor (
0 commit comments