diff --git a/Cargo.lock b/Cargo.lock index f21cd1d4..bd6ac23d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -600,6 +600,7 @@ dependencies = [ "compiletest_rs", "cuda_builder", "nvvm", + "rustc_codegen_nvvm", "tracing", "tracing-subscriber", ] @@ -1322,6 +1323,7 @@ dependencies = [ "cuda_std", "cust", "cust_raw", + "gemm-kernels", "ndarray", "ndarray-rand", "rand 0.9.2", diff --git a/examples/cuda/gemm/Cargo.toml b/examples/cuda/gemm/Cargo.toml index 8f6645ee..e4964bff 100644 --- a/examples/cuda/gemm/Cargo.toml +++ b/examples/cuda/gemm/Cargo.toml @@ -8,6 +8,7 @@ blastoff = { path = "../../../crates/blastoff" } cuda_std = { path = "../../../crates/cuda_std" } cust = { path = "../../../crates/cust" } cust_raw = { path = "../../../crates/cust_raw", features = ["driver"] } +gemm-kernels = { path = "kernels" } ndarray = { version = "0.16", features = ["approx"] } ndarray-rand = "0.15.0" rand = "0.9" diff --git a/examples/cuda/gemm/kernels/src/gemm_tiled.rs b/examples/cuda/gemm/kernels/src/gemm_tiled.rs index 6c0f00ec..5cc172f8 100644 --- a/examples/cuda/gemm/kernels/src/gemm_tiled.rs +++ b/examples/cuda/gemm/kernels/src/gemm_tiled.rs @@ -1,7 +1,10 @@ +use core::mem::MaybeUninit; use cuda_std::address_space; use cuda_std::kernel; use cuda_std::thread; +pub const TILE_SIZE: usize = 16; + #[kernel] #[allow(improper_ctypes_definitions)] /// Tiled GEMM kernel for C = alpha * A * B + beta * C. @@ -37,12 +40,17 @@ pub unsafe fn gemm_tiled( alpha: f32, beta: f32, ) { - const TILE_SIZE: usize = 16; + const TILE_SIZE_2D: usize = TILE_SIZE * TILE_SIZE; + // Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal + // `static mut`, it is not initialized, and only exists for the duration of the kernel's + // (multi-)execution. Because it is not initialized, it must be marked with `MaybeUninit`, + // written with `write` (in unsafe blocks because writing a `static mut` is unsafe), and + // subsequently read with `assume_init`. #[address_space(shared)] - static mut TILE_A: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE]; + static mut TILE_A: [MaybeUninit; TILE_SIZE_2D] = [MaybeUninit::uninit(); TILE_SIZE_2D]; #[address_space(shared)] - static mut TILE_B: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE]; + static mut TILE_B: [MaybeUninit; TILE_SIZE_2D] = [MaybeUninit::uninit(); TILE_SIZE_2D]; // Thread indices within the block. let tx = thread::thread_idx_x() as usize; @@ -57,20 +65,30 @@ pub unsafe fn gemm_tiled( for kk in (0..k).step_by(TILE_SIZE) { // Collaborative loading of tiles into shared memory. if row < m && (kk + tx) < k { - unsafe { TILE_A[ty * TILE_SIZE + tx] = mat_a[row * k + (kk + tx)] }; + unsafe { + TILE_A[ty * TILE_SIZE + tx].write(mat_a[row * k + (kk + tx)]); + } } else { - unsafe { TILE_A[ty * TILE_SIZE + tx] = 0.0f32 }; + unsafe { + TILE_A[ty * TILE_SIZE + tx].write(0.0f32); + } } if col < n && (kk + ty) < k { - unsafe { TILE_B[ty * TILE_SIZE + tx] = mat_b[(kk + ty) * n + col] }; + unsafe { + TILE_B[ty * TILE_SIZE + tx].write(mat_b[(kk + ty) * n + col]); + } } else { - unsafe { TILE_B[ty * TILE_SIZE + tx] = 0.0f32 }; + unsafe { + TILE_B[ty * TILE_SIZE + tx].write(0.0f32); + } } thread::sync_threads(); // Perform the computation on the tile. for i in 0..TILE_SIZE { - sum += unsafe { TILE_A[ty * TILE_SIZE + i] * TILE_B[i * TILE_SIZE + tx] }; + sum += unsafe { + TILE_A[ty * TILE_SIZE + i].assume_init() * TILE_B[i * TILE_SIZE + tx].assume_init() + }; } thread::sync_threads(); } diff --git a/examples/cuda/gemm/kernels/src/lib.rs b/examples/cuda/gemm/kernels/src/lib.rs index 19fab562..cba9db7c 100644 --- a/examples/cuda/gemm/kernels/src/lib.rs +++ b/examples/cuda/gemm/kernels/src/lib.rs @@ -2,4 +2,4 @@ mod gemm_naive; mod gemm_tiled; pub use crate::gemm_naive::gemm_naive; -pub use crate::gemm_tiled::gemm_tiled; +pub use crate::gemm_tiled::{TILE_SIZE, gemm_tiled}; diff --git a/examples/cuda/gemm/src/main.rs b/examples/cuda/gemm/src/main.rs index 73df03d5..531e0c7c 100644 --- a/examples/cuda/gemm/src/main.rs +++ b/examples/cuda/gemm/src/main.rs @@ -13,6 +13,7 @@ use cust::memory::CopyDestination as _; use cust::module; use cust::stream; use cust::util::SliceExt as _; +use gemm_kernels::TILE_SIZE; use ndarray::Array; use ndarray_rand::RandomExt as _; use ndarray_rand::rand_distr::Uniform; @@ -430,9 +431,6 @@ pub fn gemm_tiled( assert_eq!(mat_b.len(), k * n); assert_eq!(mat_c.len(), m * n); - // These values must be aligned with the kernel code. - const TILE_SIZE: usize = 16; - let kernel_cell = cell::LazyCell::new(|| { module .get_function("gemm_tiled")