Skip to content

Commit 74f28d9

Browse files
authored
fix: Align shift right behaviour to align with most other compilers (#894)
1 parent 0e8ee73 commit 74f28d9

File tree

39 files changed

+74
-77
lines changed

39 files changed

+74
-77
lines changed

crates/cubecl-attention/src/components/tile/dummy/flash_matmul/accelerated/setup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl FlashMatmulFamily for AcceleratedFlashMatmul {
3333
1,
3434
line_sizes.query as u32,
3535
line_sizes.key as u32,
36-
problem.seq_kv as u32 % selection.attention_tile_size.seq_kv != 0,
36+
!(problem.seq_kv as u32).is_multiple_of(selection.attention_tile_size.seq_kv),
3737
)
3838
}
3939
}

crates/cubecl-attention/src/components/tile/dummy/flash_matmul/dummy_register/setup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ impl FlashMatmulFamily for DummyRegisterFlashMatmul {
3535
1,
3636
line_sizes.query as u32,
3737
line_sizes.key as u32,
38-
problem.seq_kv as u32 % selection.attention_tile_size.seq_kv != 0,
38+
!(problem.seq_kv as u32).is_multiple_of(selection.attention_tile_size.seq_kv),
3939
)
4040
}
4141
}

crates/cubecl-attention/src/components/tile/dummy/writer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl<EO: Numeric> DummyWriter<EO> {
4848

4949
let unit_step = config.plane_dim() * output_line_size;
5050
let num_unit_writes = comptime!(div_ceil(tile_size, unit_step));
51-
let balanced_workload = comptime!(tile_size % unit_step == 0);
51+
let balanced_workload = comptime!(tile_size.is_multiple_of(unit_step));
5252

5353
#[unroll(num_unit_writes == 1)]
5454
for i in 0..num_unit_writes {

crates/cubecl-core/src/runtime_tests/cmma.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,8 +597,8 @@ pub fn test_simple_tf32<R: Runtime>(
597597
return;
598598
}
599599

600-
let lhs: Vec<f32> = (0..128).map(|i| (i as f32)).collect();
601-
let rhs: Vec<f32> = (0..128).map(|i| ((i % 8) as f32)).collect();
600+
let lhs: Vec<f32> = (0..128).map(|i| i as f32).collect();
601+
let rhs: Vec<f32> = (0..128).map(|i| (i % 8) as f32).collect();
602602

603603
let lhs = client.create(f32::as_bytes(&lhs));
604604
let rhs = client.create(f32::as_bytes(&rhs));

crates/cubecl-cpp/src/shared/item.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl<D: Dialect> Item<D> {
5151
}
5252

5353
pub fn optimized(&self) -> Item<D> {
54-
if !self.can_be_optimized() || self.vectorization % 2 != 0 {
54+
if !self.can_be_optimized() || !self.vectorization.is_multiple_of(2) {
5555
return *self;
5656
}
5757

crates/cubecl-cuda/src/compute/server.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ use cudarc::driver::sys::{
4444
use cudarc::driver::sys::{CUfunc_st, CUtensorMapInterleave};
4545
#[cfg(feature = "cuda-12080")]
4646
use cudarc::driver::sys::{CUtensorMapIm2ColWideMode, cuTensorMapEncodeIm2colWide};
47-
use serde::{Deserialize, Serialize};
4847
use std::collections::HashMap;
4948
use std::ffi::c_char;
5049
use std::path::PathBuf;
@@ -78,7 +77,8 @@ pub(crate) struct CudaContext {
7877
compilation_options: CompilationOptions,
7978
}
8079

81-
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
80+
#[cfg(feature = "compilation-cache")]
81+
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq, Clone)]
8282
pub struct PtxCacheEntry {
8383
entrypoint_name: String,
8484
cube_dim: (u32, u32, u32),
@@ -336,7 +336,7 @@ impl ComputeServer for CudaServer {
336336
.expect("Failed to find resource");
337337
let device_ptr = resource.ptr as *mut c_void;
338338
debug_assert!(
339-
device_ptr as usize % 16 == 0,
339+
(device_ptr as usize).is_multiple_of(16),
340340
"Tensor pointer must be 16 byte aligned"
341341
);
342342
let mut map_ptr = MaybeUninit::zeroed();

crates/cubecl-matmul/src/components/batch/partitioned_matmul/hypercube/base.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ impl HypercubeConfig {
7676
match self.global_order {
7777
RowMajor | ColMajor => Ok(()),
7878

79-
SwizzleRowMajor(w) if m_cubes % w != 0 => {
79+
SwizzleRowMajor(w) if !m_cubes.is_multiple_of(w) => {
8080
Err(MatmulSetupError::InvalidConfig(Box::new(format!(
8181
"In swizzle row major, number of cubes in m {m_cubes:?} must be divisible by swizzle step length {w:?}."
8282
))))
8383
}
8484

85-
SwizzleColMajor(w) if n_cubes % w != 0 => {
85+
SwizzleColMajor(w) if !n_cubes.is_multiple_of(w) => {
8686
Err(MatmulSetupError::InvalidConfig(Box::new(format!(
8787
"In swizzle col major, number of cubes in n {n_cubes:?} must be divisible by swizzle step length {w:?}."
8888
))))

crates/cubecl-matmul/src/components/batch/partitioned_matmul/hypercube/sm_allocation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl SmAllocation {
5151
let mut i = 1;
5252

5353
while i * i <= n {
54-
if n % i == 0 {
54+
if n.is_multiple_of(i) {
5555
divs.push(i);
5656
if i != n / i {
5757
divs.push(n / i);

crates/cubecl-matmul/src/components/global/load/strategy/async_full_cyclic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ impl<T: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<T> {
2626
let num_slices = config.tiling_scheme().elements_in_tile_row(ident)
2727
* config.tiling_scheme().tiles_in_stage(ident);
2828

29-
if num_slices >= total_units && num_slices % total_units != 0 {
29+
if num_slices >= total_units && !num_slices.is_multiple_of(total_units) {
3030
return Err(Box::new(format!(
3131
"Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
3232
)));

crates/cubecl-matmul/src/components/global/load/strategy/async_full_maximize_slice_length.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ impl<IP: InputPrecision> AsyncLoadingJob<IP, StridedTilingLayout>
8383
let nth_slice = this.unit_count * task_id + UNIT_POS;
8484

8585
#[allow(clippy::collapsible_else_if)]
86-
if comptime!(this.num_slices % this.unit_count == 0) {
86+
if comptime!(this.num_slices.is_multiple_of(this.unit_count)) {
8787
load_nth_slice::<IP::Global, IP::Stage, CM, G>(
8888
nth_slice,
8989
tensor_reader,

0 commit comments

Comments
 (0)