Skip to content

Commit 1d3f06e

Browse files
committed
fix lints
1 parent a44cbe1 commit 1d3f06e

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

crates/cubecl-std/src/tests/tensor/permute.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ pub fn test_permute_2d_transpose<R: Runtime, C: Float + CubeElement>(
1818
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
1919

2020
let handle = client.create(C::as_bytes(&input_data));
21-
let input = TensorHandle::<R, C>::new_contiguous(vec![height, width], handle);
21+
let dtype = C::as_type_native().unwrap();
22+
let input = TensorHandle::<R>::new_contiguous(vec![height, width], handle, dtype);
2223
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
2324

2425
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -60,7 +61,8 @@ pub fn test_permute_3d_batch_transpose<R: Runtime, C: Float + CubeElement>(
6061
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
6162

6263
let handle = client.create(C::as_bytes(&input_data));
63-
let input = TensorHandle::<R, C>::new_contiguous(vec![batch, height, width], handle);
64+
let dtype = C::as_type_native().unwrap();
65+
let input = TensorHandle::<R>::new_contiguous(vec![batch, height, width], handle, dtype);
6466
let output = tensor::permute::launch_alloc(&client, &input, &[0, 2, 1]);
6567

6668
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -105,7 +107,8 @@ pub fn test_permute_3d_complex<R: Runtime, C: Float + CubeElement>(
105107
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
106108

107109
let handle = client.create(C::as_bytes(&input_data));
108-
let input = TensorHandle::<R, C>::new_contiguous(vec![dim0, dim1, dim2], handle);
110+
let dtype = C::as_type_native().unwrap();
111+
let input = TensorHandle::<R>::new_contiguous(vec![dim0, dim1, dim2], handle, dtype);
109112
let output = tensor::permute::launch_alloc(&client, &input, &[2, 0, 1]);
110113

111114
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -141,7 +144,8 @@ pub fn test_permute_3d_complex<R: Runtime, C: Float + CubeElement>(
141144
pub fn test_permute_empty<R: Runtime, C: Float + CubeElement>(device: &R::Device) {
142145
let client = R::client(device);
143146

144-
let input = TensorHandle::<R, C>::empty(&client, vec![0, 5]);
147+
let dtype = C::as_type_native().unwrap();
148+
let input = TensorHandle::<R>::empty(&client, vec![0, 5], dtype);
145149
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
146150

147151
assert_eq!(output.shape, vec![5, 0]);
@@ -152,7 +156,8 @@ pub fn test_permute_single_element<R: Runtime, C: Float + CubeElement>(device: &
152156
let client = R::client(device);
153157

154158
let handle = client.create(C::as_bytes(&[C::from(42.0).unwrap()]));
155-
let input = TensorHandle::<R, C>::new_contiguous(vec![1, 1], handle);
159+
let dtype = C::as_type_native().unwrap();
160+
let input = TensorHandle::<R>::new_contiguous(vec![1, 1], handle, dtype);
156161
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
157162

158163
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -180,7 +185,9 @@ pub fn test_permute_4d_last_two_transpose<R: Runtime, C: Float + CubeElement>(
180185
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
181186

182187
let handle = client.create(C::as_bytes(&input_data));
183-
let input = TensorHandle::<R, C>::new_contiguous(vec![batch, channels, height, width], handle);
188+
let dtype = C::as_type_native().unwrap();
189+
let input =
190+
TensorHandle::<R>::new_contiguous(vec![batch, channels, height, width], handle, dtype);
184191
let output = tensor::permute::launch_alloc(&client, &input, &[0, 1, 3, 2]);
185192

186193
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -235,7 +242,9 @@ pub fn test_permute_4d_complex<R: Runtime, C: Float + CubeElement>(
235242
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
236243

237244
let handle = client.create(C::as_bytes(&input_data));
238-
let input = TensorHandle::<R, C>::new_contiguous(vec![batch, channels, height, width], handle);
245+
let dtype = C::as_type_native().unwrap();
246+
let input =
247+
TensorHandle::<R>::new_contiguous(vec![batch, channels, height, width], handle, dtype);
239248
let output = tensor::permute::launch_alloc(&client, &input, &[0, 3, 1, 2]);
240249

241250
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -288,7 +297,9 @@ pub fn test_permute_channel_shuffle<R: Runtime, C: Float + CubeElement>(
288297
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
289298

290299
let handle = client.create(C::as_bytes(&input_data));
291-
let input = TensorHandle::<R, C>::new_contiguous(vec![batch, channels, height, width], handle);
300+
let dtype = C::as_type_native().unwrap();
301+
let input =
302+
TensorHandle::<R>::new_contiguous(vec![batch, channels, height, width], handle, dtype);
292303
let output = tensor::permute::launch_alloc(&client, &input, &[0, 2, 3, 1]);
293304

294305
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -341,7 +352,9 @@ pub fn test_permute_attention_transpose<R: Runtime, C: Float + CubeElement>(
341352
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
342353

343354
let handle = client.create(C::as_bytes(&input_data));
344-
let input = TensorHandle::<R, C>::new_contiguous(vec![batch, heads, seq_len, head_dim], handle);
355+
let dtype = C::as_type_native().unwrap();
356+
let input =
357+
TensorHandle::<R>::new_contiguous(vec![batch, heads, seq_len, head_dim], handle, dtype);
345358
let output = tensor::permute::launch_alloc(&client, &input, &[0, 2, 1, 3]);
346359

347360
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(
@@ -393,7 +406,8 @@ pub fn test_permute_small_transpose<R: Runtime, C: Float + CubeElement>(
393406
let input_data: Vec<C> = (0..numel).map(|i| C::from(i as f32).unwrap()).collect();
394407

395408
let handle = client.create(C::as_bytes(&input_data));
396-
let input = TensorHandle::<R, C>::new_contiguous(vec![size, size], handle);
409+
let dtype = C::as_type_native().unwrap();
410+
let input = TensorHandle::<R>::new_contiguous(vec![size, size], handle, dtype);
397411
let output = tensor::permute::launch_alloc(&client, &input, &[1, 0]);
398412

399413
let actual = client.read_one_tensor(output.handle.clone().copy_descriptor(

0 commit comments

Comments
 (0)