Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,28 @@ webgpu = [ "khal/webgpu" ]
cpu = ["khal/cpu", "vortx-shaders/cpu"]
cpu-parallel = ["cpu", "vortx-shaders/cpu-parallel"]
cuda = ["khal/cuda", "khal-builder/cuda", "vortx-shaders/cuda"]
metal = ["khal/metal"]
push_constants = ["khal/push_constants", "vortx-shaders/push_constants"]
subgroup_ops = ["khal/subgroup_ops", "vortx-shaders/subgroup_ops"]

[workspace.package]
version = "0.1.1"

[workspace.dependencies]
khal-std = "0.1"
khal = { version = "0.1", features = ["derive"]}
khal-std = "0.2"
khal = { version = "0.2", features = ["derive"]}

[dependencies]
bytemuck = "1"
include_dir = "0.7"
nalgebra = "0.34"
nalgebra = "0.35"
khal = { workspace = true }
khal-std = { workspace = true }
# Shader crate provides both GPU shader code and generated ShaderArgs via spirv_bindgen
vortx-shaders = { version = "0.1", path = "vortx-shaders" }

[dev-dependencies]
nalgebra = { version = "0.34", features = ["rand"] }
nalgebra = { version = "0.35", features = ["rand"] }
futures-test = "0.3"
serial_test = "3"
approx = "0.5"
Expand All @@ -51,12 +52,13 @@ anyhow = "1"
wgpu = "29"

[build-dependencies]
khal-builder = "0.1.1"
khal-builder = "0.2"
# To build the shader from the dependency instead of local path.
vortx-shaders = { version = "0.1.1", path = "./vortx-shaders" }

#[patch.crates-io]
[patch.crates-io]
#khal-builder = { path = "../khal/crates/khal-builder" }
#khal = { path = "../khal/crates/khal" }
#khal-std = { path = "../khal/crates/khal-std" }
#khal-derive = { path = "../khal/crates/khal-derive" }
#glamx = { git = "https://github.com/dimforge/glamx", branch = "bytemuck" }
8 changes: 8 additions & 0 deletions src/linalg/contiguous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ mod test {
gpu_contiguous_generic(&cuda).await;
}

#[cfg(feature = "metal")]
#[futures_test::test]
#[serial_test::serial]
async fn gpu_contiguous_metal() {
let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap());
gpu_contiguous_generic(&metal).await;
}

async fn gpu_contiguous_generic(backend: &GpuBackend) {
let contiguous = super::Contiguous::from_backend(backend).unwrap();

Expand Down
8 changes: 8 additions & 0 deletions src/linalg/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ mod test {
gpu_gemm_generic(&cuda).await;
}

#[cfg(feature = "metal")]
#[futures_test::test]
#[serial_test::serial]
async fn gpu_gemm_metal() {
let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap());
gpu_gemm_generic(&metal).await;
}

async fn gpu_gemm_generic(backend: &GpuBackend) {
let gemm = super::Gemm::from_backend(backend).unwrap();

Expand Down
8 changes: 8 additions & 0 deletions src/linalg/op_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ mod test {
gpu_op_assign_with_backend(&cuda).await;
}

#[cfg(feature = "metal")]
#[futures_test::test]
#[serial_test::serial]
async fn gpu_op_assign_metal() {
let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap());
gpu_op_assign_with_backend(&metal).await;
}

async fn gpu_op_assign_with_backend(backend: &GpuBackend) {
let ops = [
OpAssignVariant::Add,
Expand Down
8 changes: 8 additions & 0 deletions src/linalg/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ mod test {
gpu_reduce_generic(&cuda).await;
}

#[cfg(feature = "metal")]
#[futures_test::test]
#[serial_test::serial]
async fn gpu_reduce_metal() {
let metal = GpuBackend::Metal(khal::backend::metal::Metal::new().unwrap());
gpu_reduce_generic(&metal).await;
}

async fn gpu_reduce_generic(backend: &GpuBackend) {
let ops = [
ReduceVariant::Min,
Expand Down
7 changes: 5 additions & 2 deletions vortx-shaders/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ cuda = ["khal-std/cuda", "khal/cuda"]
[dependencies]
khal-std = { workspace = true }
# glamx provides UVec3 and other glam types (no_std compatible, used on all targets).
glamx = { version = "0.2", default-features = false, features = ["nostd-libm", "bytemuck"] }
glamx = { version = "0.3", default-features = false, features = ["nostd-libm", "bytemuck"] }

# Host-only dependencies (excluded on GPU targets: spirv and nvptx64).
[build-dependencies]
khal-std = { workspace = true }

# Host-only dependencies (excluded on GPU targets: spirv, nvptx64).
[target.'cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))'.dependencies]
khal = { workspace = true }
bytemuck = { version = "1", features = ["derive"] }
7 changes: 1 addition & 6 deletions vortx-shaders/build.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
// Re-exports this crate's source location to host crates that build the
// shaders.
fn main() {
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set by cargo");
println!("cargo::metadata=manifest_dir={manifest_dir}");
println!("cargo:rerun-if-changed=build.rs");
khal_std::setup_shader_crate_build();
}
2 changes: 1 addition & 1 deletion vortx-shaders/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#![allow(clippy::too_many_arguments)]

// Enable std on host for generated ShaderArgs structs (not on GPU targets).
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[cfg(not(target_arch_is_gpu))]
extern crate std;

pub mod linalg;
Expand Down
10 changes: 5 additions & 5 deletions vortx-shaders/src/linalg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ pub use shape::Shape;
pub use shape::{Shapes1, Shapes2, Shapes3};

// Re-export generated ShaderArgs structs (only available on host)
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[cfg(not(target_arch_is_gpu))]
pub use contiguous::{Contiguous, ContiguousWithOffset};
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[cfg(not(target_arch_is_gpu))]
pub use gemm::{GemmNaive, GemmTiled};
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[cfg(not(target_arch_is_gpu))]
pub use op_assign::{GpuAdd, GpuCopy, GpuCopyWithOffsets, GpuDiv, GpuMul, GpuSub};
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[cfg(not(target_arch_is_gpu))]
pub use reduce::{ReduceAdd, ReduceMax, ReduceMin, ReduceMul, ReduceSqNorm};
#[cfg(not(any(target_arch = "spirv", target_arch = "nvptx64")))]
#[cfg(not(target_arch_is_gpu))]
pub use repeat::Repeat;
5 changes: 1 addition & 4 deletions vortx-shaders/src/linalg/op_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ const MAX_NUM_THREADS: u32 = MAX_NUM_WORKGROUPS * WORKGROUP_SIZE;
/// Binary operation offsets.
#[repr(C)]
#[derive(Clone, Copy)]
#[cfg_attr(
not(any(target_arch = "spirv", target_arch = "nvptx64")),
derive(bytemuck::Pod, bytemuck::Zeroable)
)]
#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))]
pub struct BinOpOffsets {
pub a: u32,
pub b: u32,
Expand Down
20 changes: 4 additions & 16 deletions vortx-shaders/src/linalg/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use glamx::UVec4;
/// (Samples, Channels, Height, Width), where height is the row count, and width the column count.
#[repr(C)]
#[derive(Clone, Copy)]
#[cfg_attr(
not(any(target_arch = "spirv", target_arch = "nvptx64")),
derive(bytemuck::Pod, bytemuck::Zeroable)
)]
#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))]
pub struct Shape {
/// Number of rows in each matrix of the tensor.
pub n: u32,
Expand Down Expand Up @@ -103,10 +100,7 @@ pub fn div_ceil4(a: u32) -> u32 {
#[cfg(feature = "push_constants")]
#[repr(C)]
#[derive(Clone, Copy)]
#[cfg_attr(
not(any(target_arch = "spirv", target_arch = "nvptx64")),
derive(bytemuck::Pod, bytemuck::Zeroable)
)]
#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))]
pub struct Shapes2 {
/// First shape (typically output or left operand).
pub shape_a: Shape,
Expand All @@ -118,10 +112,7 @@ pub struct Shapes2 {
#[cfg(feature = "push_constants")]
#[repr(C)]
#[derive(Clone, Copy)]
#[cfg_attr(
not(any(target_arch = "spirv", target_arch = "nvptx64")),
derive(bytemuck::Pod, bytemuck::Zeroable)
)]
#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))]
pub struct Shapes3 {
/// Output shape.
pub shape_out: Shape,
Expand All @@ -135,10 +126,7 @@ pub struct Shapes3 {
#[cfg(feature = "push_constants")]
#[repr(C)]
#[derive(Clone, Copy)]
#[cfg_attr(
not(any(target_arch = "spirv", target_arch = "nvptx64")),
derive(bytemuck::Pod, bytemuck::Zeroable)
)]
#[cfg_attr(not(target_arch_is_gpu), derive(bytemuck::Pod, bytemuck::Zeroable))]
pub struct Shapes1 {
/// The shape.
pub shape: Shape,
Expand Down
2 changes: 1 addition & 1 deletion vortx-shaders/src/utils/trig.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Trigonometric utility functions.

#[cfg(any(target_arch = "spirv", target_arch = "nvptx64"))]
#[cfg(target_arch_is_gpu)]
use khal_std::num_traits::Float;

/// The value of pi.
Expand Down
Loading