From 742430ff44a7be9c22eb7626be831a2c629285b1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 22 Nov 2024 15:48:35 +0000 Subject: [PATCH 1/7] Safe indexing in wgsl --- crates/cubecl-core/src/ir/operation.rs | 39 +- crates/cubecl-core/src/ir/variable.rs | 16 + crates/cubecl-cpp/src/shared/base.rs | 54 +-- crates/cubecl-wgpu/src/compiler/base.rs | 1 - .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 338 ++++++++++-------- .../src/compiler/wgsl/instructions.rs | 111 +++--- crates/cubecl-wgpu/src/compute/server.rs | 2 +- 7 files changed, 318 insertions(+), 243 deletions(-) diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index f2e604008..0cf511fd6 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -1,8 +1,11 @@ use std::fmt::Display; -use crate::prelude::AtomicOp; - -use super::{Branch, CoopMma, Item, Plane, Select, Synchronization, Variable}; +use super::{Branch, CoopMma, Item, Plane, Scope, Select, Synchronization, Variable}; +use crate::{ + cpa, + ir::{Elem, UIntKind}, + prelude::AtomicOp, +}; use serde::{Deserialize, Serialize}; /// All operations that can be used in a GPU compute shader. @@ -358,6 +361,36 @@ pub struct FmaOperator { pub c: Variable, } +#[allow(missing_docs)] +pub struct CheckedIndexAssign { + pub lhs: Variable, + pub rhs: Variable, + pub out: Variable, +} + +impl CheckedIndexAssign { + #[allow(missing_docs)] + pub fn expand(self, scope: &mut Scope) { + let lhs = self.lhs; + let rhs = self.rhs; + let out = self.out; + let array_len = scope.create_local(Item::new(Elem::UInt(UIntKind::U32))); + let inside_bound = scope.create_local(Item::new(Elem::Bool)); + + if out.has_buffer_length() { + cpa!(scope, array_len = buffer_len(out)); + } else { + cpa!(scope, array_len = len(out)); + } + + cpa!(scope, inside_bound = lhs < array_len); + + cpa!(scope, if(inside_bound).then(|scope| { + cpa!(scope, unchecked(out[lhs]) = rhs); + })); + } +} + impl From for Operation { fn from(val: Operator) -> Self { Operation::Operator(val) diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index e3b1f5b65..17e0e7efc 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -114,6 +114,22 @@ impl Variable { ) } + pub fn has_length(&self) -> bool { + matches!( + self.kind, + VariableKind::GlobalInputArray { .. } + | VariableKind::GlobalOutputArray { .. } + | VariableKind::Slice { .. } + ) + } + + pub fn has_buffer_length(&self) -> bool { + matches!( + self.kind, + VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. } + ) + } + /// Determines if the value is a constant with the specified value (converted if necessary) pub fn is_constant(&self, value: i64) -> bool { match self.kind { diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 706913afc..af4cc739c 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1,8 +1,8 @@ use std::hash::Hash; use std::{collections::HashSet, fmt::Debug, num::NonZero}; +use cubecl_core::ir::CheckedIndexAssign; use cubecl_core::{ - cpa, ir::{self as gpu}, prelude::CubePrimitive, Compiler, Feature, @@ -522,14 +522,14 @@ impl CppCompiler { out: self.compile_variable(out), }), gpu::Operator::Index(op) => { - if matches!(self.strategy, ExecutionMode::Checked) && has_length(&op.lhs) { + if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() { let lhs = op.lhs; let rhs = op.rhs; let array_len = scope.create_local(gpu::Item::new(u32::as_elem())); instructions.extend(self.compile_scope(scope)); - let length = match has_buffer_length(&lhs) { + let length = match lhs.has_buffer_length() { true => gpu::Metadata::BufferLength { var: lhs }, false => gpu::Metadata::Length { var: lhs }, }; @@ -550,7 +550,7 @@ impl CppCompiler { } gpu::Operator::IndexAssign(op) => { if let ExecutionMode::Checked = self.strategy { - if has_length(&out) { + if out.has_length() { CheckedIndexAssign { lhs: op.lhs, rhs: op.rhs, @@ -972,52 +972,6 @@ impl CppCompiler { } } -#[allow(missing_docs)] -struct CheckedIndexAssign { - pub lhs: gpu::Variable, - pub rhs: gpu::Variable, - pub out: gpu::Variable, -} - -impl CheckedIndexAssign { - #[allow(missing_docs)] - fn expand(self, scope: &mut gpu::Scope) { - let lhs = self.lhs; - let rhs = self.rhs; - let out = self.out; - let array_len = scope.create_local(gpu::Item::new(u32::as_elem())); - let inside_bound = scope.create_local(gpu::Item::new(gpu::Elem::Bool)); - - if has_buffer_length(&out) { - cpa!(scope, array_len = buffer_len(out)); - } else { - cpa!(scope, array_len = len(out)); - } - - cpa!(scope, inside_bound = lhs < array_len); - - cpa!(scope, if(inside_bound).then(|scope| { - cpa!(scope, unchecked(out[lhs]) = rhs); - })); - } -} - -fn has_length(var: &gpu::Variable) -> bool { - matches!( - var.kind, - gpu::VariableKind::GlobalInputArray { .. } - | gpu::VariableKind::GlobalOutputArray { .. } - | gpu::VariableKind::Slice { .. } - ) -} - -fn has_buffer_length(var: &gpu::Variable) -> bool { - matches!( - var.kind, - gpu::VariableKind::GlobalInputArray { .. } | gpu::VariableKind::GlobalOutputArray { .. } - ) -} - pub fn register_supported_types(props: &mut DeviceProperties) { let supported_types = [ gpu::Elem::UInt(gpu::UIntKind::U8), diff --git a/crates/cubecl-wgpu/src/compiler/base.rs b/crates/cubecl-wgpu/src/compiler/base.rs index e040190de..0089bcfc4 100644 --- a/crates/cubecl-wgpu/src/compiler/base.rs +++ b/crates/cubecl-wgpu/src/compiler/base.rs @@ -18,7 +18,6 @@ pub trait WgpuCompiler: Compiler { fn create_pipeline( server: &mut WgpuServer, kernel: CompiledKernel, - mode: ExecutionMode, ) -> Arc; #[allow(async_fn_in_trait)] diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index a7355ec75..910742ed2 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -1,14 +1,18 @@ use std::{borrow::Cow, sync::Arc}; -use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory}; -use super::{LocalArray, Subgroup}; +use super::Subgroup; +use super::{shader::ComputeShader, ConstantArray}; +use super::{Item, LocalArray, SharedMemory}; use crate::{ compiler::{base::WgpuCompiler, wgsl}, WgpuServer, }; + +use cubecl_core::ir::CheckedIndexAssign; use cubecl_core::{ ir::{self as cube, HybridAllocator, UIntKind}, prelude::CompiledKernel, + prelude::CubePrimitive, server::ComputeServer, Feature, Metadata, }; @@ -45,6 +49,7 @@ pub struct WgslCompiler { local_arrays: Vec, #[allow(dead_code)] compilation_options: CompilationOptions, + strategy: ExecutionMode, } impl core::fmt::Debug for WgslCompiler { @@ -60,7 +65,6 @@ impl cubecl_core::Compiler for WgslCompiler { fn compile( shader: cube::KernelDefinition, compilation_options: &Self::CompilationOptions, - _mode: ExecutionMode, ) -> Self::Representation { let mut compiler = Self { compilation_options: compilation_options.clone(), @@ -86,22 +90,22 @@ impl WgpuCompiler for WgslCompiler { fn create_pipeline( server: &mut WgpuServer, kernel: CompiledKernel, - mode: ExecutionMode, ) -> Arc { let source = &kernel.source; - let module = match mode { - ExecutionMode::Checked => server.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }), - ExecutionMode::Unchecked => unsafe { - server - .device - .create_shader_module_unchecked(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }) - }, + // Cube always in principle uses unchecked modules. Certain operations like + // indexing are instead checked by cube. The WebGPU specification only makes + // incredibly loose gaurantees that Cube can't rely on. Additionally, kernels + // can opt in/out per operation whether checks should be performed which can be faster. + // + // SAFETY: Cube gaurantees OOB safety when launching in checked mode. Launching in unchecked mode + // is only availble through the use of unsafe code. + let module = unsafe { + server + .device + .create_shader_module_unchecked(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }) }; let layout = kernel.repr.map(|repr| { @@ -225,7 +229,13 @@ fn register_types(props: &mut DeviceProperties) { } impl WgslCompiler { - fn compile_shader(&mut self, mut value: cube::KernelDefinition) -> wgsl::ComputeShader { + fn compile_shader( + &mut self, + mut value: cube::KernelDefinition, + mode: ExecutionMode, + ) -> wgsl::ComputeShader { + self.strategy = mode; + self.num_inputs = value.inputs.len(); self.num_outputs = value.outputs.len(); let num_meta = value.inputs.len() + value.outputs.len(); @@ -475,10 +485,10 @@ impl WgslCompiler { } } - fn compile_scope(&mut self, value: &mut cube::Scope) -> Vec { + fn compile_scope(&mut self, scope: &mut cube::Scope) -> Vec { let mut instructions = Vec::new(); - let const_arrays = value + let const_arrays = scope .const_arrays .drain(..) .map(|(var, values)| ConstantArray { @@ -493,7 +503,7 @@ impl WgslCompiler { .collect::>(); self.const_arrays.extend(const_arrays); - let processing = value.process(); + let processing = scope.process(); for var in processing.variables { // We don't declare slices. @@ -509,7 +519,7 @@ impl WgslCompiler { processing .operations .into_iter() - .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out)); + .for_each(|op| self.compile_operation(&mut instructions, op.operation, op.out, scope)); instructions } @@ -519,13 +529,14 @@ impl WgslCompiler { instructions: &mut Vec, operation: cube::Operation, out: Option, + scope: &mut cube::Scope, ) { match operation { cube::Operation::Copy(variable) => instructions.push(wgsl::Instruction::Assign { input: self.compile_variable(variable), out: self.compile_variable(out.unwrap()), }), - cube::Operation::Operator(op) => instructions.push(self.compile_instruction(op, out)), + cube::Operation::Operator(op) => self.compile_instruction(op, out, instructions, scope), cube::Operation::Atomic(op) => instructions.push(self.compile_atomic(op, out)), cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)), cube::Operation::Branch(val) => self.compile_branch(instructions, val), @@ -720,267 +731,310 @@ impl WgslCompiler { &mut self, value: cube::Operator, out: Option, - ) -> wgsl::Instruction { + instructions: &mut Vec, + scope: &mut cube::Scope, + ) { let out = out.unwrap(); match value { - cube::Operator::Max(op) => wgsl::Instruction::Max { + cube::Operator::Max(op) => instructions.push(wgsl::Instruction::Max { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Min(op) => wgsl::Instruction::Min { + }), + cube::Operator::Min(op) => instructions.push(wgsl::Instruction::Min { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Add(op) => wgsl::Instruction::Add { + }), + cube::Operator::Add(op) => instructions.push(wgsl::Instruction::Add { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Fma(op) => wgsl::Instruction::Fma { + }), + cube::Operator::Fma(op) => instructions.push(wgsl::Instruction::Fma { a: self.compile_variable(op.a), b: self.compile_variable(op.b), c: self.compile_variable(op.c), out: self.compile_variable(out), - }, - cube::Operator::Index(op) => wgsl::Instruction::Index { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(out), - }, - cube::Operator::UncheckedIndex(op) => wgsl::Instruction::Index { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(out), - }, - cube::Operator::Modulo(op) => wgsl::Instruction::Modulo { + }), + cube::Operator::Modulo(op) => instructions.push(wgsl::Instruction::Modulo { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Sub(op) => wgsl::Instruction::Sub { + }), + cube::Operator::Sub(op) => instructions.push(wgsl::Instruction::Sub { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Mul(op) => wgsl::Instruction::Mul { + }), + cube::Operator::Mul(op) => instructions.push(wgsl::Instruction::Mul { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Div(op) => wgsl::Instruction::Div { + }), + cube::Operator::Div(op) => instructions.push(wgsl::Instruction::Div { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Abs(op) => wgsl::Instruction::Abs { + }), + cube::Operator::Abs(op) => instructions.push(wgsl::Instruction::Abs { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Exp(op) => wgsl::Instruction::Exp { + }), + cube::Operator::Exp(op) => instructions.push(wgsl::Instruction::Exp { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Log(op) => wgsl::Instruction::Log { + }), + cube::Operator::Log(op) => instructions.push(wgsl::Instruction::Log { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Log1p(op) => wgsl::Instruction::Log1p { + }), + cube::Operator::Log1p(op) => instructions.push(wgsl::Instruction::Log1p { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Cos(op) => wgsl::Instruction::Cos { + }), + cube::Operator::Cos(op) => instructions.push(wgsl::Instruction::Cos { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Sin(op) => wgsl::Instruction::Sin { + }), + cube::Operator::Sin(op) => instructions.push(wgsl::Instruction::Sin { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Tanh(op) => wgsl::Instruction::Tanh { + }), + cube::Operator::Tanh(op) => instructions.push(wgsl::Instruction::Tanh { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Powf(op) => wgsl::Instruction::Powf { + }), + cube::Operator::Powf(op) => instructions.push(wgsl::Instruction::Powf { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Sqrt(op) => wgsl::Instruction::Sqrt { + }), + cube::Operator::Sqrt(op) => instructions.push(wgsl::Instruction::Sqrt { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Round(op) => wgsl::Instruction::Round { + }), + cube::Operator::Round(op) => instructions.push(wgsl::Instruction::Round { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Floor(op) => wgsl::Instruction::Floor { + }), + cube::Operator::Floor(op) => instructions.push(wgsl::Instruction::Floor { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Ceil(op) => wgsl::Instruction::Ceil { + }), + cube::Operator::Ceil(op) => instructions.push(wgsl::Instruction::Ceil { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Erf(op) => wgsl::Instruction::Erf { + }), + cube::Operator::Erf(op) => instructions.push(wgsl::Instruction::Erf { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Recip(op) => wgsl::Instruction::Recip { + }), + cube::Operator::Recip(op) => instructions.push(wgsl::Instruction::Recip { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Equal(op) => wgsl::Instruction::Equal { + }), + cube::Operator::Equal(op) => instructions.push(wgsl::Instruction::Equal { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Lower(op) => wgsl::Instruction::Lower { + }), + cube::Operator::Lower(op) => instructions.push(wgsl::Instruction::Lower { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Clamp(op) => wgsl::Instruction::Clamp { + }), + cube::Operator::Clamp(op) => instructions.push(wgsl::Instruction::Clamp { input: self.compile_variable(op.input), min_value: self.compile_variable(op.min_value), max_value: self.compile_variable(op.max_value), out: self.compile_variable(out), - }, - cube::Operator::Greater(op) => wgsl::Instruction::Greater { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(out), - }, - cube::Operator::LowerEqual(op) => wgsl::Instruction::LowerEqual { + }), + cube::Operator::Greater(op) => instructions.push(wgsl::Instruction::Greater { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::GreaterEqual(op) => wgsl::Instruction::GreaterEqual { + }), + cube::Operator::LowerEqual(op) => instructions.push(wgsl::Instruction::LowerEqual { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::NotEqual(op) => wgsl::Instruction::NotEqual { + }), + cube::Operator::GreaterEqual(op) => { + instructions.push(wgsl::Instruction::GreaterEqual { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }) + } + cube::Operator::NotEqual(op) => instructions.push(wgsl::Instruction::NotEqual { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Cast(op) => wgsl::Instruction::Assign { + }), + cube::Operator::Cast(op) => instructions.push(wgsl::Instruction::Assign { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::IndexAssign(op) => wgsl::Instruction::IndexAssign { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(out), - }, - cube::Operator::UncheckedIndexAssign(op) => wgsl::Instruction::IndexAssign { + }), + cube::Operator::Index(op) => { + if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() { + let lhs = op.lhs; + let rhs = op.rhs; + let array_len = scope.create_local(cube::Item::new(u32::as_elem())); + + instructions.extend(self.compile_scope(scope)); + + let length = match lhs.has_buffer_length() { + true => cube::Metadata::BufferLength { var: lhs }, + false => cube::Metadata::Length { var: lhs }, + }; + + instructions.push(self.compile_metadata(length, Some(array_len))); + instructions.push(wgsl::Instruction::CheckedIndex { + len: self.compile_variable(array_len), + lhs: self.compile_variable(lhs), + rhs: self.compile_variable(rhs), + out: self.compile_variable(out), + }); + } else { + instructions.push(wgsl::Instruction::Index { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }); + } + } + cube::Operator::UncheckedIndex(op) => instructions.push(wgsl::Instruction::Index { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::And(op) => wgsl::Instruction::And { + }), + cube::Operator::IndexAssign(op) => { + if let ExecutionMode::Checked = self.strategy { + if out.has_length() { + CheckedIndexAssign { + lhs: op.lhs, + rhs: op.rhs, + out, + } + .expand(scope); + instructions.extend(self.compile_scope(scope)); + return; + } + }; + instructions.push(wgsl::Instruction::IndexAssign { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }) + } + cube::Operator::UncheckedIndexAssign(op) => { + instructions.push(wgsl::Instruction::IndexAssign { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }) + } + cube::Operator::And(op) => instructions.push(wgsl::Instruction::And { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Or(op) => wgsl::Instruction::Or { + }), + cube::Operator::Or(op) => instructions.push(wgsl::Instruction::Or { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Not(op) => wgsl::Instruction::Not { + }), + cube::Operator::Not(op) => instructions.push(wgsl::Instruction::Not { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::BitwiseOr(op) => wgsl::Instruction::BitwiseOr { + }), + cube::Operator::BitwiseOr(op) => instructions.push(wgsl::Instruction::BitwiseOr { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::BitwiseAnd(op) => wgsl::Instruction::BitwiseAnd { + }), + cube::Operator::BitwiseAnd(op) => instructions.push(wgsl::Instruction::BitwiseAnd { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::BitwiseXor(op) => wgsl::Instruction::BitwiseXor { + }), + cube::Operator::BitwiseXor(op) => instructions.push(wgsl::Instruction::BitwiseXor { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::ShiftLeft(op) => wgsl::Instruction::ShiftLeft { + }), + cube::Operator::ShiftLeft(op) => instructions.push(wgsl::Instruction::ShiftLeft { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::ShiftRight(op) => wgsl::Instruction::ShiftRight { + }), + cube::Operator::ShiftRight(op) => instructions.push(wgsl::Instruction::ShiftRight { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Remainder(op) => wgsl::Instruction::Remainder { + }), + cube::Operator::Remainder(op) => instructions.push(wgsl::Instruction::Remainder { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::Slice(op) => wgsl::Instruction::Slice { + }), + cube::Operator::Slice(op) => instructions.push(wgsl::Instruction::Slice { input: self.compile_variable(op.input), start: self.compile_variable(op.start), end: self.compile_variable(op.end), out: self.compile_variable(out), - }, + }), - cube::Operator::Bitcast(op) => wgsl::Instruction::Bitcast { + cube::Operator::Bitcast(op) => instructions.push(wgsl::Instruction::Bitcast { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, + }), - cube::Operator::Neg(op) => wgsl::Instruction::Negate { + cube::Operator::Neg(op) => instructions.push(wgsl::Instruction::Negate { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Magnitude(op) => wgsl::Instruction::Magnitude { + }), + cube::Operator::Magnitude(op) => instructions.push(wgsl::Instruction::Magnitude { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Normalize(op) => wgsl::Instruction::Normalize { + }), + cube::Operator::Normalize(op) => instructions.push(wgsl::Instruction::Normalize { input: self.compile_variable(op.input), out: self.compile_variable(out), - }, - cube::Operator::Dot(op) => wgsl::Instruction::Dot { + }), + cube::Operator::Dot(op) => instructions.push(wgsl::Instruction::Dot { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), - }, - cube::Operator::InitLine(op) => wgsl::Instruction::VecInit { + }), + cube::Operator::InitLine(op) => instructions.push(wgsl::Instruction::VecInit { inputs: op .inputs .into_iter() .map(|var| self.compile_variable(var)) .collect(), out: self.compile_variable(out), - }, - cube::Operator::CopyMemory(op) => wgsl::Instruction::Copy { + }), + cube::Operator::CopyMemory(op) => instructions.push(wgsl::Instruction::Copy { input: self.compile_variable(op.input), in_index: self.compile_variable(op.in_index), out: self.compile_variable(out), out_index: self.compile_variable(op.out_index), - }, - cube::Operator::CopyMemoryBulk(op) => wgsl::Instruction::CopyBulk { + }), + cube::Operator::CopyMemoryBulk(op) => instructions.push(wgsl::Instruction::CopyBulk { input: self.compile_variable(op.input), in_index: self.compile_variable(op.in_index), out: self.compile_variable(out), out_index: self.compile_variable(op.out_index), len: op.len, - }, - cube::Operator::Select(op) => wgsl::Instruction::Select { + }), + cube::Operator::Select(op) => instructions.push(wgsl::Instruction::Select { cond: self.compile_variable(op.cond), then: self.compile_variable(op.then), or_else: self.compile_variable(op.or_else), out: self.compile_variable(out), - }, + }), } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index efbcbcfb6..6be961b6e 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -62,6 +62,13 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + // Index handles casting to correct local variable. + CheckedIndex { + len: Variable, + lhs: Variable, + rhs: Variable, + out: Variable, + }, // Index assign handles casting to correct output variable. IndexAssign { lhs: Variable, @@ -437,10 +444,28 @@ impl Display for Instruction { item: *item, is_array: true, }; - index(f, &lhs, rhs, out, Some(offset)) + index(f, &lhs, rhs, out, Some(offset), None) } - _ => index(f, lhs, rhs, out, None), + _ => index(f, lhs, rhs, out, None, None), }, + Instruction::IndexAssign { lhs, rhs, out } => { + if let Variable::Slice { item, .. } = out { + let offset = Variable::Named { + name: format!("{out}_offset"), + item: Item::Scalar(Elem::U32), + is_array: false, + }; + let out = Variable::Named { + name: format!("(*{out}_ptr)"), + item: *item, + is_array: true, + }; + + index_assign(f, lhs, rhs, &out, Some(offset)) + } else { + index_assign(f, lhs, rhs, out, None) + } + } Instruction::Copy { input, in_index, @@ -654,24 +679,22 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ f.write_str("}\n") } - Instruction::IndexAssign { lhs, rhs, out } => { - if let Variable::Slice { item, .. } = out { + Instruction::CheckedIndex { len, lhs, rhs, out } => match lhs { + Variable::Slice { item, .. } => { let offset = Variable::Named { - name: format!("{out}_offset"), + name: format!("{lhs}_offset"), item: Item::Scalar(Elem::U32), is_array: false, }; - let out = Variable::Named { - name: format!("(*{out}_ptr)"), + let lhs = Variable::Named { + name: format!("(*{lhs}_ptr)"), item: *item, is_array: true, }; - - index_assign(f, lhs, rhs, &out, Some(offset)) - } else { - index_assign(f, lhs, rhs, out, None) + index(f, &lhs, rhs, out, Some(offset), Some(len)) } - } + _ => index(f, lhs, rhs, out, None, Some(len)), + }, Instruction::If { cond, instructions } => { writeln!(f, "if {cond} {{")?; for i in instructions { @@ -1012,6 +1035,7 @@ fn index( rhs: &Variable, out: &Variable, offset: Option, + len: Option<&Variable>, ) -> core::fmt::Result { let is_scalar = match lhs { Variable::Local { item, .. } => item.vectorization_factor() == 1, @@ -1019,43 +1043,38 @@ fn index( _ => false, }; - if out.item().elem().is_atomic() { - match offset { - Some(offset) => writeln!(f, "let {out} = &{lhs}[{rhs} + {offset}];"), - None => writeln!(f, "let {out} = &{lhs}[{rhs}];"), + let atomic_ref = if out.item().elem().is_atomic() { + "&" + } else { + "" + }; + + let read_exp = if is_scalar { + format!("{atomic_ref}{lhs}") + } else { + let index = if let Some(offset) = offset { + format!("{rhs} + {offset}") + } else { + format!("{rhs}") + }; + + let item = lhs.item(); + + if let Some(len) = len { + format!("select({atomic_ref}{lhs}[{index}], {item}(0), {index} < {len}") + } else { + format!("{atomic_ref}{lhs}[{index}]") } - } else if lhs.elem() != out.elem() { + }; + + let out_bind = out.fmt_left(); + + if lhs.elem() != out.elem() { let item = out.item(); - let out = out.fmt_left(); - match offset { - Some(offset) => { - let value = lhs - .item() - .fmt_cast_to(item, format!("{lhs}[{rhs}+{offset}]")); - writeln!(f, "{out} = {value};") - } - None => { - if is_scalar { - let value = lhs.item().fmt_cast_to(item, format!("{lhs}")); - writeln!(f, "{out} = {value};") - } else { - let value = lhs.item().fmt_cast_to(item, format!("{lhs}[{rhs}]")); - writeln!(f, "{out} = {value};") - } - } - } + let value = lhs.item().fmt_cast_to(item, read_exp); + writeln!(f, "{out_bind} = {value};") } else { - let out = out.fmt_left(); - match offset { - Some(offset) => writeln!(f, "{out} = {lhs}[{rhs} + {offset}];"), - None => { - if is_scalar { - writeln!(f, "{out} = {lhs};") - } else { - writeln!(f, "{out} = {lhs}[{rhs}];") - } - } - } + writeln!(f, "{out_bind} = {read_exp};") } } diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 8777913e0..a3981a813 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -85,7 +85,7 @@ impl WgpuServer { } let compile = self.logger.debug(compile); - let pipeline = C::create_pipeline(self, compile, mode); + let pipeline = C::create_pipeline(self, compile); self.pipelines.insert(kernel_id.clone(), pipeline.clone()); From 072474b1a0c6bf3e110dec27c02ec16021016aee Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 22 Nov 2024 16:53:46 +0000 Subject: [PATCH 2/7] Fix indexing op --- .../src/compiler/wgsl/instructions.rs | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 6be961b6e..ce874d277 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -1043,39 +1043,39 @@ fn index( _ => false, }; - let atomic_ref = if out.item().elem().is_atomic() { - "&" - } else { - "" - }; - - let read_exp = if is_scalar { - format!("{atomic_ref}{lhs}") + let (mut value, index) = if is_scalar { + (format!("{lhs}"), None) } else { let index = if let Some(offset) = offset { - format!("{rhs} + {offset}") + format!("{rhs}+{offset}") } else { format!("{rhs}") }; - let item = lhs.item(); - - if let Some(len) = len { - format!("select({atomic_ref}{lhs}[{index}], {item}(0), {index} < {len}") - } else { - format!("{atomic_ref}{lhs}[{index}]") - } + (format!("{lhs}[{index}]"), Some(index)) }; - let out_bind = out.fmt_left(); + if out.item().elem().is_atomic() { + value = format!("atomicLoad(&{value})") + }; if lhs.elem() != out.elem() { - let item = out.item(); - let value = lhs.item().fmt_cast_to(item, read_exp); - writeln!(f, "{out_bind} = {value};") - } else { - writeln!(f, "{out_bind} = {read_exp};") + value = lhs.item().fmt_cast_to(out.item(), value) + }; + + if let Some(ind) = index { + if let Some(len) = len { + // Note: This is technically not 100% allowed. According to the WebGPU specification, + // indexing OOB is a "dynamic error" which allows "many possible outcomes". In practice, + // both wgpu and Dawn handle this by either returning 0s or clamping the index + // to valid bounds. In practice this means it's harmless to use in a select. + let out_item = out.item(); + value = format!("select({out_item}(0), {value}, {ind} < {len})"); + }; } + + let out = out.fmt_left(); + writeln!(f, "{out} = {value};") } fn index_assign( From b26b183881f4132b8918200c1b7355e62eb36824 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 22 Nov 2024 18:02:17 +0000 Subject: [PATCH 3/7] Spirv shader, update comment --- crates/cubecl-wgpu/src/compiler/spirv.rs | 31 +++++++++---------- .../src/compiler/wgsl/instructions.rs | 6 ++-- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/crates/cubecl-wgpu/src/compiler/spirv.rs b/crates/cubecl-wgpu/src/compiler/spirv.rs index 92d8d819c..8de503a44 100644 --- a/crates/cubecl-wgpu/src/compiler/spirv.rs +++ b/crates/cubecl-wgpu/src/compiler/spirv.rs @@ -102,23 +102,20 @@ impl WgpuCompiler for SpirvCompiler { }) .unwrap_or_else(|| { let source = &kernel.source; - let module = match mode { - ExecutionMode::Checked => { - server - .device - .create_shader_module(wgpu::ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }) - } - ExecutionMode::Unchecked => unsafe { - server - .device - .create_shader_module_unchecked(wgpu::ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }) - }, + // Cube always in principle uses unchecked modules. Certain operations like + // indexing are instead checked by cube. The WebGPU specification only makes + // incredibly loose gaurantees that Cube can't rely on. Additionally, kernels + // can opt in/out per operation whether checks should be performed which can be faster. + // + // SAFETY: Cube gaurantees OOB safety when launching in checked mode. Launching in unchecked mode + // is only availble through the use of unsafe code. + let module = unsafe { + server + .device + .create_shader_module_unchecked(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }) }; (module, None) }); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index ce874d277..4dedf3830 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -1066,9 +1066,9 @@ fn index( if let Some(ind) = index { if let Some(len) = len { // Note: This is technically not 100% allowed. According to the WebGPU specification, - // indexing OOB is a "dynamic error" which allows "many possible outcomes". In practice, - // both wgpu and Dawn handle this by either returning 0s or clamping the index - // to valid bounds. In practice this means it's harmless to use in a select. + // any OOB access is a "dynamic error" which allows "many possible outcomes". In practice, + // both wgpu and Dawn handle this by either returning dummy data or clamping the index + // to valid bounds. This means it's harmless to use in a select. let out_item = out.item(); value = format!("select({out_item}(0), {value}, {ind} < {len})"); }; From 1e29f780459dce909b667d2ab8f3e31cf166b8ce Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 22 Nov 2024 18:18:22 +0000 Subject: [PATCH 4/7] Merge fix --- crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 910742ed2..1be50e5c5 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -65,12 +65,13 @@ impl cubecl_core::Compiler for WgslCompiler { fn compile( shader: cube::KernelDefinition, compilation_options: &Self::CompilationOptions, + mode: ExecutionMode, ) -> Self::Representation { let mut compiler = Self { compilation_options: compilation_options.clone(), ..Self::default() }; - compiler.compile_shader(shader) + compiler.compile_shader(shader, mode) } fn elem_size(elem: cube::Elem) -> usize { From eb327a712c48f9b3a3280a347c97ecac0b985af3 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Fri, 22 Nov 2024 23:06:53 +0000 Subject: [PATCH 5/7] Fix atomics, fix spirv --- crates/cubecl-wgpu/src/compiler/spirv.rs | 1 - .../src/compiler/wgsl/instructions.rs | 38 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/crates/cubecl-wgpu/src/compiler/spirv.rs b/crates/cubecl-wgpu/src/compiler/spirv.rs index 8de503a44..55d8abe42 100644 --- a/crates/cubecl-wgpu/src/compiler/spirv.rs +++ b/crates/cubecl-wgpu/src/compiler/spirv.rs @@ -46,7 +46,6 @@ impl WgpuCompiler for SpirvCompiler { fn create_pipeline( server: &mut WgpuServer, kernel: CompiledKernel, - mode: ExecutionMode, ) -> Arc { let (module, layout) = kernel .repr diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 4dedf3830..a43717ad3 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -1056,26 +1056,30 @@ fn index( }; if out.item().elem().is_atomic() { - value = format!("atomicLoad(&{value})") - }; + // Atomic values don't support casting or bound checking - we just assign the reference. + value = format!("&{value}"); + writeln!(f, "let {out} = {value};") + } else { + // Check for casting + if lhs.elem() != out.elem() { + value = lhs.item().fmt_cast_to(out.item(), value) + }; - if lhs.elem() != out.elem() { - value = lhs.item().fmt_cast_to(out.item(), value) - }; + // Check for bounds. + if let Some(ind) = index { + if let Some(len) = len { + // Note: This is technically not 100% allowed. According to the WebGPU specification, + // any OOB access is a "dynamic error" which allows "many possible outcomes". In practice, + // both wgpu and Dawn handle this by either returning dummy data or clamping the index + // to valid bounds. This means it's harmless to use in a select. + let out_item = out.item(); + value = format!("select({out_item}(0), {value}, {ind} < {len})"); + }; + } - if let Some(ind) = index { - if let Some(len) = len { - // Note: This is technically not 100% allowed. According to the WebGPU specification, - // any OOB access is a "dynamic error" which allows "many possible outcomes". In practice, - // both wgpu and Dawn handle this by either returning dummy data or clamping the index - // to valid bounds. This means it's harmless to use in a select. - let out_item = out.item(); - value = format!("select({out_item}(0), {value}, {ind} < {len})"); - }; + let out = out.fmt_left(); + writeln!(f, "{out} = {value};") } - - let out = out.fmt_left(); - writeln!(f, "{out} = {value};") } fn index_assign( From 7b80c7a8d4166e727d4f0a88983cd9eb8c531ce2 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Fri, 22 Nov 2024 23:26:00 +0000 Subject: [PATCH 6/7] Update references --- crates/cubecl-wgpu/tests/constant_array.wgsl | 8 +++- crates/cubecl-wgpu/tests/plane_elect.wgsl | 6 +++ crates/cubecl-wgpu/tests/plane_sum.wgsl | 12 +++++- .../cubecl-wgpu/tests/sequence_for_loop.wgsl | 22 +++++++++-- crates/cubecl-wgpu/tests/slice_assign.wgsl | 12 +++++- crates/cubecl-wgpu/tests/unary_bench.wgsl | 38 +++++++++++++++---- 6 files changed, 83 insertions(+), 15 deletions(-) diff --git a/crates/cubecl-wgpu/tests/constant_array.wgsl b/crates/cubecl-wgpu/tests/constant_array.wgsl index 2baa153c3..15d99a9a3 100644 --- a/crates/cubecl-wgpu/tests/constant_array.wgsl +++ b/crates/cubecl-wgpu/tests/constant_array.wgsl @@ -23,6 +23,12 @@ let _0 = info[1u]; let _1 = id < _0; if _1 { let _2 = arrays_0[id]; +var l_1_0: u32; +var l_1_1: bool; +l_1_0 = info[0u]; +l_1_1 = id < l_1_0; +if l_1_1 { output_0_global[id] = _2; } -} \ No newline at end of file +} +} diff --git a/crates/cubecl-wgpu/tests/plane_elect.wgsl b/crates/cubecl-wgpu/tests/plane_elect.wgsl index ac1c83792..ee899265f 100644 --- a/crates/cubecl-wgpu/tests/plane_elect.wgsl +++ b/crates/cubecl-wgpu/tests/plane_elect.wgsl @@ -17,5 +17,11 @@ fn kernel_elect( ) { let _0 = subgroupElect(); let _1 = u32(_0); +var l_0_0: u32; +var l_0_1: bool; +l_0_0 = info[0u]; +l_0_1 = local_idx < l_0_0; +if l_0_1 { output_0_global[local_idx] = _1; } +} diff --git a/crates/cubecl-wgpu/tests/plane_sum.wgsl b/crates/cubecl-wgpu/tests/plane_sum.wgsl index 92aa4a734..5077b3fb3 100644 --- a/crates/cubecl-wgpu/tests/plane_sum.wgsl +++ b/crates/cubecl-wgpu/tests/plane_sum.wgsl @@ -15,10 +15,18 @@ const WORKGROUP_SIZE_Z = 1u; fn kernel_sum( @builtin(local_invocation_index) local_idx: u32, ) { -let _0 = output_0_global[local_idx]; +var l_0_0: u32; +l_0_0 = info[0u]; +let _0 = select(f32(0), output_0_global[local_idx], local_idx < l_0_0); let _1 = subgroupAdd(_0); let _2 = local_idx == 0u; if _2 { +var l_1_0: u32; +var l_1_1: bool; +l_1_0 = info[0u]; +l_1_1 = 0u < l_1_0; +if l_1_1 { output_0_global[0u] = _1; } -} \ No newline at end of file +} +} diff --git a/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl b/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl index ffcf0b661..dc037bdb8 100644 --- a/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl +++ b/crates/cubecl-wgpu/tests/sequence_for_loop.wgsl @@ -19,10 +19,26 @@ let _0 = local_idx != 0u; if _0 { return; } -let _1 = output_0_global[0u]; +var l_0_0: u32; +l_0_0 = info[0u]; +let _1 = select(f32(0), output_0_global[0u], 0u < l_0_0); let _2 = _1 + 1f; +var l_0_1: u32; +var l_0_2: bool; +l_0_1 = info[0u]; +l_0_2 = 0u < l_0_1; +if l_0_2 { output_0_global[0u] = _2; -let _3 = output_0_global[0u]; +} +var l_0_3: u32; +l_0_3 = info[0u]; +let _3 = select(f32(0), output_0_global[0u], 0u < l_0_3); let _4 = _3 + 4f; +var l_0_4: u32; +var l_0_5: bool; +l_0_4 = info[0u]; +l_0_5 = 0u < l_0_4; +if l_0_5 { output_0_global[0u] = _4; -} \ No newline at end of file +} +} diff --git a/crates/cubecl-wgpu/tests/slice_assign.wgsl b/crates/cubecl-wgpu/tests/slice_assign.wgsl index e39f03978..3a19078c0 100644 --- a/crates/cubecl-wgpu/tests/slice_assign.wgsl +++ b/crates/cubecl-wgpu/tests/slice_assign.wgsl @@ -24,7 +24,15 @@ if _0 { let slice_1_0_offset = 2u; let slice_1_0_length = 3u - 2u; let slice_1_0_ptr = &output_0_global; -let _1 = input_0_global[0u]; +var l_1_0: u32; +l_1_0 = info[0u]; +let _1 = select(f32(0), input_0_global[0u], 0u < l_1_0); +var l_1_1: u32; +var l_1_2: bool; +l_1_1 = slice_1_0_length; +l_1_2 = 0u < l_1_1; +if l_1_2 { (*slice_1_0_ptr)[0u + slice_1_0_offset] = _1; } -} \ No newline at end of file +} +} diff --git a/crates/cubecl-wgpu/tests/unary_bench.wgsl b/crates/cubecl-wgpu/tests/unary_bench.wgsl index 4ff827305..182f1463d 100644 --- a/crates/cubecl-wgpu/tests/unary_bench.wgsl +++ b/crates/cubecl-wgpu/tests/unary_bench.wgsl @@ -33,22 +33,46 @@ for (var l_2_2: u32 = 0u; l_2_2 < 256u; l_2_2++) { let _3 = l_2_2 % 2u; let _4 = _3 == 0u; if _4 { -let _5 = input_0_global[id]; -let _6 = input_1_global[id]; +var l_3_0: u32; +l_3_0 = info[0u]; +let _5 = select(vec4(0), input_0_global[id], id < l_3_0); +var l_3_1: u32; +l_3_1 = info[1u]; +let _6 = select(vec4(0), input_1_global[id], id < l_3_1); let _7 = _5 * _6; let _8 = cos(_7); -let _9 = output_0_global[id]; +var l_3_2: u32; +l_3_2 = info[2u]; +let _9 = select(vec4(0), output_0_global[id], id < l_3_2); let _10 = _9 - _8; +var l_3_3: u32; +var l_3_4: bool; +l_3_3 = info[2u]; +l_3_4 = id < l_3_3; +if l_3_4 { output_0_global[id] = _10; +} } else { -let _11 = input_0_global[id]; -let _12 = input_1_global[id]; +var l_3_0: u32; +l_3_0 = info[0u]; +let _11 = select(vec4(0), input_0_global[id], id < l_3_0); +var l_3_1: u32; +l_3_1 = info[1u]; +let _12 = select(vec4(0), input_1_global[id], id < l_3_1); let _13 = _11 * _12; let _14 = cos(_13); -let _15 = output_0_global[id]; +var l_3_2: u32; +l_3_2 = info[2u]; +let _15 = select(vec4(0), output_0_global[id], id < l_3_2); let _16 = _15 + _14; +var l_3_3: u32; +var l_3_4: bool; +l_3_3 = info[2u]; +l_3_4 = id < l_3_3; +if l_3_4 { output_0_global[id] = _16; } } } -} \ No newline at end of file +} +} From b022db7e918d6ecf33bc7eb7eadae32eb39bd015 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Fri, 22 Nov 2024 23:50:19 +0000 Subject: [PATCH 7/7] Add tests --- crates/cubecl-core/src/runtime_tests/index.rs | 56 +++++++++++++++++++ crates/cubecl-core/src/runtime_tests/mod.rs | 2 + 2 files changed, 58 insertions(+) create mode 100644 crates/cubecl-core/src/runtime_tests/index.rs diff --git a/crates/cubecl-core/src/runtime_tests/index.rs b/crates/cubecl-core/src/runtime_tests/index.rs new file mode 100644 index 000000000..a1c62f3c3 --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/index.rs @@ -0,0 +1,56 @@ +use crate as cubecl; + +use cubecl::prelude::*; + +#[cube(launch)] +pub fn kernel_assign(output: &mut Array) { + if UNIT_POS == 0 { + let item = F::new(5.0); + // Assign normally. + output[0] = item; + + // out of bounds write should not show up in the array. + output[2] = F::new(10.0); + + // out of bounds read should be read as 0. + output[1] = output[2]; + } +} + +pub fn test_kernel_index_scalar( + client: ComputeClient, +) { + let handle = client.create(F::as_bytes(&[F::new(0.0), F::new(1.0), F::new(123.0)])); + let handle_slice = handle.clone().offset_end(1); + let vectorization = 1; + + kernel_assign::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::default(), + unsafe { ArrayArg::from_raw_parts::(&handle_slice, 3, vectorization) }, + ); + + let actual = client.read_one(handle.binding()); + let actual = F::from_bytes(&actual); + + assert_eq!(actual[0], F::new(5.0)); + assert_eq!(actual[1], F::new(0.0)); + assert_eq!(actual[2], F::new(123.0)); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_index { + () => { + use super::*; + + #[test] + fn test_assign_index() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::index::test_kernel_index_scalar::( + client, + ); + } + }; +} diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 867b40aae..b96ca8fdf 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -5,6 +5,7 @@ pub mod cmma; pub mod const_match; pub mod constants; pub mod different_rank; +pub mod index; pub mod launch; pub mod metadata; pub mod plane; @@ -23,6 +24,7 @@ macro_rules! testgen_all { type IntType = i32; type UintType = u32; + cubecl_core::testgen_index!(); cubecl_core::testgen_assign!(); cubecl_core::testgen_branch!(); cubecl_core::testgen_const_match!();