Skip to content

Commit e8b171f

Browse files
committed
fix tests.
1 parent 47e7162 commit e8b171f

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

crates/cubecl-std/src/tests/tensor/permute.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub fn test_permute_2d_transpose<R: Runtime, C: Float + CubeElement>(
2020
let handle = client.create_from_slice(C::as_bytes(&input_data));
2121
let dtype = C::as_type_native().unwrap();
2222
let input = TensorHandle::<R>::new_contiguous(vec![height, width], handle, dtype);
23-
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
23+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[1, 0]);
2424

2525
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
2626
&output.shape,
@@ -63,7 +63,7 @@ pub fn test_permute_3d_batch_transpose<R: Runtime, C: Float + CubeElement>(
6363
let handle = client.create_from_slice(C::as_bytes(&input_data));
6464
let dtype = C::as_type_native().unwrap();
6565
let input = TensorHandle::<R>::new_contiguous(vec![batch, height, width], handle, dtype);
66-
let output = tensor::permute::launch_alloc(&client, &input, &[0, 2, 1]);
66+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[0, 2, 1]);
6767

6868
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
6969
&output.shape,
@@ -109,7 +109,7 @@ pub fn test_permute_3d_complex<R: Runtime, C: Float + CubeElement>(
109109
let handle = client.create_from_slice(C::as_bytes(&input_data));
110110
let dtype = C::as_type_native().unwrap();
111111
let input = TensorHandle::<R>::new_contiguous(vec![dim0, dim1, dim2], handle, dtype);
112-
let output = tensor::permute::launch_alloc(&client, &input, &[2, 0, 1]);
112+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[2, 0, 1]);
113113

114114
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
115115
&output.shape,
@@ -158,7 +158,7 @@ pub fn test_permute_single_element<R: Runtime, C: Float + CubeElement>(device: &
158158
let handle = client.create_from_slice(C::as_bytes(&[C::from(42.0).unwrap()]));
159159
let dtype = C::as_type_native().unwrap();
160160
let input = TensorHandle::<R>::new_contiguous(vec![1, 1], handle, dtype);
161-
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
161+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[1, 0]);
162162

163163
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
164164
&output.shape,
@@ -188,7 +188,7 @@ pub fn test_permute_4d_last_two_transpose<R: Runtime, C: Float + CubeElement>(
188188
let dtype = C::as_type_native().unwrap();
189189
let input =
190190
TensorHandle::<R>::new_contiguous(vec![batch, channels, height, width], handle, dtype);
191-
let output = tensor::permute::launch_alloc(&client, &input, &[0, 1, 3, 2]);
191+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[0, 1, 3, 2]);
192192

193193
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
194194
&output.shape,
@@ -245,7 +245,7 @@ pub fn test_permute_4d_complex<R: Runtime, C: Float + CubeElement>(
245245
let dtype = C::as_type_native().unwrap();
246246
let input =
247247
TensorHandle::<R>::new_contiguous(vec![batch, channels, height, width], handle, dtype);
248-
let output = tensor::permute::launch_alloc(&client, &input, &[0, 3, 1, 2]);
248+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[0, 3, 1, 2]);
249249

250250
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
251251
&output.shape,
@@ -300,7 +300,7 @@ pub fn test_permute_channel_shuffle<R: Runtime, C: Float + CubeElement>(
300300
let dtype = C::as_type_native().unwrap();
301301
let input =
302302
TensorHandle::<R>::new_contiguous(vec![batch, channels, height, width], handle, dtype);
303-
let output = tensor::permute::launch_alloc(&client, &input, &[0, 2, 3, 1]);
303+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[0, 2, 3, 1]);
304304

305305
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
306306
&output.shape,
@@ -355,7 +355,7 @@ pub fn test_permute_attention_transpose<R: Runtime, C: Float + CubeElement>(
355355
let dtype = C::as_type_native().unwrap();
356356
let input =
357357
TensorHandle::<R>::new_contiguous(vec![batch, heads, seq_len, head_dim], handle, dtype);
358-
let output = tensor::permute::launch_alloc(&client, &input, &[0, 2, 1, 3]);
358+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[0, 2, 1, 3]);
359359

360360
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
361361
&output.shape,
@@ -408,7 +408,7 @@ pub fn test_permute_small_transpose<R: Runtime, C: Float + CubeElement>(
408408
let handle = client.create_from_slice(C::as_bytes(&input_data));
409409
let dtype = C::as_type_native().unwrap();
410410
let input = TensorHandle::<R>::new_contiguous(vec![size, size], handle, dtype);
411-
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
411+
let output = tensor::permute::launch_alloc::<R, C>(&client, &input, &[1, 0]);
412412

413413
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
414414
&output.shape,

0 commit comments

Comments
 (0)