From 59e6abfae15b311e4531122d8ca2dd04ff51c1af Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 19 Aug 2024 16:06:58 +0300 Subject: [PATCH 01/45] Migrate mamba_ssm and causal_conv1d kernels to vLLM --- CMakeLists.txt | 2 + csrc/mamba/causal_conv1d/causal_conv1d.cu | 757 +++++++++++++++++++++ csrc/mamba/causal_conv1d/causal_conv1d.h | 106 +++ csrc/mamba/causal_conv1d/static_switch.h | 25 + csrc/mamba/mamba_ssm/selective_scan.h | 274 ++++++++ csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 622 +++++++++++++++++ csrc/mamba/mamba_ssm/static_switch.h | 25 + csrc/ops.h | 10 + csrc/torch_bindings.cpp | 10 + 9 files changed, 1831 insertions(+) create mode 100644 csrc/mamba/causal_conv1d/causal_conv1d.cu create mode 100644 csrc/mamba/causal_conv1d/causal_conv1d.h create mode 100644 csrc/mamba/causal_conv1d/static_switch.h create mode 100644 csrc/mamba/mamba_ssm/selective_scan.h create mode 100644 csrc/mamba/mamba_ssm/selective_scan_fwd.cu create mode 100644 csrc/mamba/mamba_ssm/static_switch.h diff --git a/CMakeLists.txt b/CMakeLists.txt index d47f1bb305a9..8ee888625765 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,8 @@ endif() # set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/attention_kernels.cu" "csrc/pos_encoding_kernels.cu" diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 000000000000..81c7cf46fe33 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,757 @@ +#include +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "static_switch.h" + + + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + void* bias_ptr, + bool silu_activation) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); +} + + +at::Tensor +causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + const c10::optional &seq_idx_, + const c10::optional &seq_pos_idx_, + const c10::optional &initial_states_, + c10::optional &final_states_out_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); + TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); + } + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + if (seq_idx_.has_value()) { + TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); + auto seq_idx = seq_idx_.value(); + TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_idx.is_cuda()); + TORCH_CHECK(seq_idx.is_contiguous()); + CHECK_SHAPE(seq_idx, batch_size, seqlen); + } + if (seq_pos_idx_.has_value()) { + auto seq_pos_idx = seq_pos_idx_.value(); + TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_pos_idx.is_cuda()); + TORCH_CHECK(seq_pos_idx.is_contiguous()); + CHECK_SHAPE(seq_pos_idx, batch_size, seqlen); + } + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + + if (seq_idx_.has_value()) { + params.seq_idx_ptr = seq_idx_.value().data_ptr(); + } else { + params.seq_idx_ptr = nullptr; + } + + if (seq_pos_idx_.has_value()) { + params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr(); + } else { + params.seq_pos_idx_ptr = nullptr; + } + if (initial_states_.has_value()) { + TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); + auto initial_states = initial_states_.value(); + TORCH_CHECK(initial_states.scalar_type() == input_type); + TORCH_CHECK(initial_states.is_cuda()); + CHECK_SHAPE(initial_states, batch_size, dim, width - 1); + TORCH_CHECK(initial_states.stride(1) == 1); + params.initial_states_ptr = initial_states.data_ptr(); + params.initial_states_batch_stride = initial_states.stride(0); + params.initial_states_c_stride = initial_states.stride(1); + params.initial_states_l_stride = initial_states.stride(2); + } else { + params.initial_states_ptr = nullptr; + } + + if (final_states_out_.has_value()) { + TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); + auto final_states = final_states_out_.value(); + TORCH_CHECK(final_states.scalar_type() == input_type); + TORCH_CHECK(final_states.is_cuda()); + CHECK_SHAPE(final_states, batch_size, dim, width - 1); + TORCH_CHECK(final_states.stride(1) == 1); + params.final_states_ptr = final_states.data_ptr(); + params.final_states_batch_stride = final_states.stride(0); + params.final_states_c_stride = final_states.stride(1); + params.final_states_l_stride = final_states.stride(2); + } else { + params.final_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, stream); + } + }); + }); + return out; +} + + +at::Tensor +causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + }); + return out; +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + static constexpr int kNLoadsIndex = kNElts / 4; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + + static constexpr int kSmemIOSize = (kIsVecLoad && kNLoadsIndex == 1) + ? 0 + : std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_index_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int *seq_pos_idx = !kHasSeqPosIdx ? nullptr : reinterpret_cast(params.seq_pos_idx_ptr) + batch_id * params.seqlen; + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + int seq_pos_idx_load[kNElts]; + if constexpr(kIsVecLoad) { + Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + if (kHasSeqPosIdx) + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(reinterpret_cast(seq_pos_idx), *reinterpret_cast(seq_pos_idx_load), (params.seqlen - chunk * kChunkSize) / kNElts * Ktraits::kNLoadsIndex); + } else { + __syncthreads(); + Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + if (kHasSeqPosIdx) + Ktraits::BlockLoadIndexT(smem_load_index).Load(seq_pos_idx, seq_pos_idx_load, (params.seqlen - chunk * kChunkSize), 0); + } + x += kChunkSize; + if (kHasSeqPosIdx) seq_pos_idx += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + int w = 0; + if (kHasSeqPosIdx){ + if(seq_pos_idx_load[i] < kWidth){ + w = kWidth - seq_pos_idx_load[i] - 1; + } + } + for (; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + } + out += kChunkSize; + } +} + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] { + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template +struct Causal_conv1d_channellast_fwd_kernel_traits { + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNWarp = Ktraits::kNWarps; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) + + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; + input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr + : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + // The last L-chunk will also have enough info to write to final states, since it also contain a few x values + // from the previous L-chunk. + input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr + : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } else if (initial_states != nullptr + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); + } + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + + __syncthreads(); + + if (final_states != nullptr + && l_idx < kWidth - 1 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) + // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] + *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + } + + constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } + } + float x_vals[kWidth - 1 + kLPerThread]; + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + int seq_idx_thread[kWidth - 1 + kLPerThread]; + if constexpr (kHasSeqIdx) { + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; + } + } + + float out_vals[kLPerThread]; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + if constexpr (!kHasSeqIdx) { + out_vals[i] += weight_vals[w] * x_vals[i + w]; + } else { + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; + } + } + if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; + } + } + +} + +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { + using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +/////// + + + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + } + + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } + x_vals[kWidth - 1] = float(x[0]); + #pragma unroll + for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + } + + float out_val = bias_val; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + if (channel_id < params.dim) { out[0] = input_t(out_val); } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 000000000000..4e05744a8bbd --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,106 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + + void *__restrict__ seq_idx_ptr; + void *__restrict__ seq_pos_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; +}; + + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h new file mode 100644 index 000000000000..0f4ad3eb6223 --- /dev/null +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h new file mode 100644 index 000000000000..69d72bf255e9 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -0,0 +1,274 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#ifndef USE_ROCM + #include +#else + #include +#endif +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + void *__restrict__ index_ptr; +}; + + + + +#ifndef USE_ROCM + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + + +#define MAX_DSTATE 256 + + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_index(int *u, + int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = reinterpret_cast(smem_load_index); + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu new file mode 100644 index 000000000000..b15a1b10f4c9 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -0,0 +1,622 @@ +#include +#include +#include +#include "selective_scan.h" + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + namespace cub = hipcub; +#endif + +#include "selective_scan.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); + } + } + if constexpr (kUseIndex) { + index += kChunkSize; + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + } + } + + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif +} + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus, + void* index_ptr) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +std::vector +selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &index_, + const c10::optional &x) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.value().data_ptr(), + has_z, + delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); + }); + std::vector result = {out, x.value()}; + if (has_z) { result.push_back(out_z); } + return result; +} + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h new file mode 100644 index 000000000000..7920ac045d0a --- /dev/null +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/ops.h b/csrc/ops.h index 609459990102..d35324d39864 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -176,6 +176,16 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); +std::vector +selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &index_, + const c10::optional &x); + #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e26c2e28f2ec..b032c8965b01 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -243,6 +243,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + + // Mamba selective scan kerenl + ops.def("selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor C," + "Tensor! D_, Tensor! z_, Tensor! delta_bias_," + "bool delta_softplus," + "Tensor! index_, Tensor! &x) -> ()"); + ops.impl("selective_scan_fwd", torch::kCUDA, + &selective_scan_fwd); + } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { From d2348ec0a6ee2740f3db7de863552cceb1b1e1f6 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 15:43:03 +0300 Subject: [PATCH 02/45] Casual conv1d compiles --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 81c7cf46fe33..0b31071889f7 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -1,7 +1,6 @@ #include #include #include -#include #include "causal_conv1d.h" #include @@ -379,11 +378,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } float out_vals[kNElts]; - #pragma unroll + #pragma unroll for (int i = 0; i < kNElts; ++i) { out_vals[i] = bias_val; - #pragma unroll int w = 0; if (kHasSeqPosIdx){ if(seq_pos_idx_load[i] < kWidth){ @@ -483,7 +481,6 @@ void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNElts = Ktraits::kNElts; - constexpr int kNWarp = Ktraits::kNWarps; constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; constexpr int kLPerLoad = Ktraits::kNColsPerLoad; constexpr int kChunkSizeL = Ktraits::kChunkSizeL; From 66ee5afdce27c85fbd16fb20002550758551b914 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 15:46:35 +0300 Subject: [PATCH 03/45] Add casual_conv1d to _custom_ops --- vllm/_custom_ops.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1f0a111a53bc..a51bc9e32209 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -451,6 +451,26 @@ def ggml_mul_mat_a8( return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + seq_pos_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + seq_pos_idx_, initial_states_, + final_states_out_, silu_activation) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, From 7a0d2067564a2fd00a68dab92429b687939c0f2c Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 15:47:00 +0300 Subject: [PATCH 04/45] Add mamba ops and triton kernels --- vllm/model_executor/layers/mamba/layer.py | 194 +++++++++ .../layers/mamba/ops/casual_conv1d.py | 156 +++++++ .../layers/mamba/ops/mamba_ssm.py | 380 ++++++++++++++++++ 3 files changed, 730 insertions(+) create mode 100644 vllm/model_executor/layers/mamba/layer.py create mode 100644 vllm/model_executor/layers/mamba/ops/casual_conv1d.py create mode 100644 vllm/model_executor/layers/mamba/ops/mamba_ssm.py diff --git a/vllm/model_executor/layers/mamba/layer.py b/vllm/model_executor/layers/mamba/layer.py new file mode 100644 index 000000000000..8c472dd7928b --- /dev/null +++ b/vllm/model_executor/layers/mamba/layer.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Optional +import torch.nn as nn +import torch +from torch.nn.parameter import Parameter + +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.mamba.ops.casual_conv1d import causal_conv1d_fn, causal_conv1d_update +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn, selective_state_update +from vllm.model_executor.utils import set_weight_attrs + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + + +class Mamba(nn.Module): + + def __init__(self,hidden_size: int, + mamba_d_state: int, + mamba_d_conv: int, + mamba_expand: int, + mamba_dt_rank: int, + mamba_conv_use_bias: bool, + mamba_proj_use_bias: bool, + activation_func:str = "silu", + rms_norm_eps:float = 1e-5): + super().__init__() + + self.hidden_size = hidden_size + self.ssm_state_size = mamba_d_state + self.conv_kernel_size = mamba_d_conv + self.intermediate_size = mamba_expand * hidden_size + self.time_step_rank = mamba_dt_rank + self.use_conv_bias = mamba_conv_use_bias + self.use_bias = mamba_proj_use_bias + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = activation_func + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=rms_norm_eps) + + + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[MambaCacheParams] = None + ): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states,_ = causal_conv1d_fn( + hidden_states, + conv_weights, + bias=self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + diff --git a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py new file mode 100644 index 000000000000..d9db67bba981 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024, Tri Dao. + +from typing import Optional +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: str = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states, + final_states_out, activation + in ["silu", "swish"]) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: str = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, + activation) + + +def causal_conv1d_update_ref(x, + conv_state, + weight, + bias=None, + activation=None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + batch, dim = x.shape + width = weight.shape[1] + assert conv_state.shape == (batch, dim, width) + assert weight.shape == (dim, width) + conv_state.copy_(torch.roll(conv_state, shifts=-1, + dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + out = torch.sum(conv_state * weight, dim=-1) # (B D) + if bias is not None: + out += bias + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 000000000000..5529bf6dae6a --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,380 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this +""" + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from mamba_ssm.ops.triton.softplus import softplus + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + # Matrix dimensions + batch, nheads, dim, dstate, nheads_ngroups_ratio, + # Strides + stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate, + stride_x_batch, stride_x_head, stride_x_dim, + stride_dt_batch, stride_dt_head, stride_dt_dim, + stride_dt_bias_head, stride_dt_bias_dim, + stride_A_head, stride_A_dim, stride_A_dstate, + stride_B_batch, stride_B_group, stride_B_dstate, + stride_C_batch, stride_C_group, stride_C_dstate, + stride_D_head, stride_D_dim, + stride_z_batch, stride_z_head, stride_z_dim, + stride_out_batch, stride_out_head, stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + if not TIE_HDIM: + dB = B[None, :] * dt[:, None] + else: + dB = B * dt # vector of size (dstate,) + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 + else ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else + ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, x, dt, dt_bias, A, B, C, D, z, out, + batch, nheads, dim, dstate, nheads // ngroups, + state.stride(0), state.stride(1), state.stride(2), state.stride(3), + x.stride(0), x.stride(1), x.stride(2), + dt.stride(0), dt.stride(1), dt.stride(2), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), A.stride(1), A.stride(2), + B.stride(0), B.stride(1), B.stride(2), + C.stride(0), C.stride(1), C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], z_strides[1], z_strides[2], + out.stride(0), out.stride(1), out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, position_indices = None, prev_state = None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state and prev_state (if provided) is + not considered in the backward pass. + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros( + (u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),), + device=u.device, + dtype=torch.float32, + requires_grad=u.requires_grad + ) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices, x) + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if z is not None: + return out if not return_last_state else (out, last_state) + else: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, position_indices = None, prev_state=None): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + if position_indices is not None and position_indices[0,i] == 0: + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + From 145b6b7615d24e15cd21c3e9e6158959fae7609c Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 15:49:23 +0300 Subject: [PATCH 05/45] Add casual_conv1d update --- vllm/_custom_ops.py | 4 +--- vllm/model_executor/layers/mamba/ops/casual_conv1d.py | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a51bc9e32209..3e29c6c526da 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -455,12 +455,10 @@ def ggml_mul_mat_a8( def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], seq_idx_: Optional[torch.Tensor], - seq_pos_idx_: Optional[torch.Tensor], initial_states_: Optional[torch.Tensor], final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, - seq_pos_idx_, initial_states_, + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, initial_states_, final_states_out_, silu_activation) diff --git a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py index d9db67bba981..c34afdeb5add 100644 --- a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py @@ -108,7 +108,7 @@ def causal_conv1d_ref( else: final_states_out = final_states out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return out if not return_final_states else (out, final_states_out) + return (out, None) if not return_final_states else (out, final_states_out) def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): @@ -123,8 +123,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") activation = activation in ["silu", "swish"] - return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, - activation) + return ops.causal_conv1d_update(x, conv_state, weight, bias, activation) def causal_conv1d_update_ref(x, From 2bdd7f557cccbe7040f49446cf9496f1884e7719 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 16:03:19 +0300 Subject: [PATCH 06/45] setup selective scan fwd pass --- vllm/_custom_ops.py | 16 ++++++++++++++-- .../model_executor/layers/mamba/ops/mamba_ssm.py | 16 ++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3e29c6c526da..288f4c879203 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -458,8 +458,9 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, initial_states_: Optional[torch.Tensor], final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, initial_states_, - final_states_out_, silu_activation) + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + initial_states_, final_states_out_, + silu_activation) def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, @@ -469,6 +470,17 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, silu_activation) +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, + delta_bias_, delta_softplus, index_, + x) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 5529bf6dae6a..7a9ffa6983b4 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -10,9 +10,21 @@ import triton.language as tl from einops import rearrange, repeat +from vllm import _custom_ops as ops +from packaging import version -from mamba_ssm.ops.triton.softplus import softplus +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +if TRITON3: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @@ -296,7 +308,7 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ x[:, :, 0, 0::2] = 1 if prev_state is not None: x[:, :, 0, 1::2].copy_(prev_state) - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices, x) + out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices, x) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if z is not None: return out if not return_last_state else (out, last_state) From e25dbfe84d4065bd7ba078edc840c4f82c12e4b8 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 18:26:20 +0300 Subject: [PATCH 07/45] Format --- .../layers/mamba/ops/mamba_ssm.py | 261 +++++++++++++----- 1 file changed, 192 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 7a9ffa6983b4..057b4016362a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,8 +1,5 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. -"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this -""" - import torch import torch.nn.functional as F @@ -16,37 +13,74 @@ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") if TRITON3: + @triton.jit def softplus(dt): dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) return dt else: + @triton.jit def softplus(dt): dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) return dt -@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) -@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) @triton.jit def _selective_scan_update_kernel( # Pointers to matrices - state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, # Matrix dimensions - batch, nheads, dim, dstate, nheads_ngroups_ratio, + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, # Strides - stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate, - stride_x_batch, stride_x_head, stride_x_dim, - stride_dt_batch, stride_dt_head, stride_dt_dim, - stride_dt_bias_head, stride_dt_bias_dim, - stride_A_head, stride_A_dim, stride_A_dstate, - stride_B_batch, stride_B_group, stride_B_dstate, - stride_C_batch, stride_C_group, stride_C_dstate, - stride_D_head, stride_D_dim, - stride_z_batch, stride_z_head, stride_z_dim, - stride_out_batch, stride_out_head, stride_out_dim, + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, @@ -65,22 +99,26 @@ def _selective_scan_update_kernel( if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group if HAS_Z: z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) - state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) B_ptrs = B_ptr + offs_n * stride_B_dstate C_ptrs = C_ptr + offs_n * stride_C_dstate if HAS_D: @@ -89,15 +127,20 @@ def _selective_scan_update_kernel( z_ptrs = z_ptr + offs_m * stride_z_dim out_ptrs = out_ptr + offs_m * stride_out_dim - state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) if DT_SOFTPLUS: dt = softplus(dt) - A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) dA = tl.exp(A * dt[:, None]) else: dt = tl.load(dt_ptr).to(tl.float32) @@ -120,7 +163,9 @@ def _selective_scan_update_kernel( else: dB = B * dt # vector of size (dstate,) state = state * dA + dB * x[:, None] - tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -129,7 +174,16 @@ def _selective_scan_update_kernel( tl.store(out_ptrs, out, mask=offs_m < dim) -def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -179,29 +233,61 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, assert dt_bias.shape == (nheads, dim) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) # We don't want autotune since it will overwrite the state # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 - else ((16, 4) if dstate <= 32 else - ((8, 4) if dstate <= 64 else - ((4, 4) if dstate <= 128 else - ((4, 8)))))) - tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( - state, x, dt, dt_bias, A, B, C, D, z, out, - batch, nheads, dim, dstate, nheads // ngroups, - state.stride(0), state.stride(1), state.stride(2), state.stride(3), - x.stride(0), x.stride(1), x.stride(2), - dt.stride(0), dt.stride(1), dt.stride(2), - *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, - A.stride(0), A.stride(1), A.stride(2), - B.stride(0), B.stride(1), B.stride(2), - C.stride(0), C.stride(1), C.stride(2), + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), *(D.stride(0), D.stride(1)) if D is not None else 0, - z_strides[0], z_strides[1], z_strides[2], - out.stride(0), out.stride(1), out.stride(2), + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), dt_softplus, tie_hdim, BLOCK_SIZE_M, @@ -212,7 +298,16 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, return out -def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): +def selective_state_update_ref(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -262,11 +357,16 @@ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=No assert dt_bias.shape == (nheads, dim) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) - dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) - state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * + A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) if D is not None: out += (x * D).to(out.dtype) @@ -276,11 +376,20 @@ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=No return out -def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, position_indices = None, prev_state = None): +def selective_scan_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state and prev_state (if provided) is - not considered in the backward pass. + last_state has shape (batch, dim, dstate). """ if u.stride(-1) != 1: u = u.contiguous() @@ -299,16 +408,20 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) - x = torch.zeros( - (u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),), - device=u.device, - dtype=torch.float32, - requires_grad=u.requires_grad - ) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=u.requires_grad) x[:, :, 0, 0::2] = 1 if prev_state is not None: x[:, :, 0, 1::2].copy_(prev_state) - out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices, x) + out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if z is not None: return out if not return_last_state else (out, last_state) @@ -317,8 +430,18 @@ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_ return out_z if not return_last_state else (out_z, last_state) -def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, position_indices = None, prev_state=None): +def selective_scan_ref(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): """ u: r(B D L) delta: r(B D L) @@ -345,9 +468,11 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: - B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + B = torch.view_as_complex( + rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: - C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + C = torch.view_as_complex( + rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() @@ -366,7 +491,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): - if position_indices is not None and position_indices[0,i] == 0: + if position_indices is not None and position_indices[0, i] == 0: x = deltaB_u[:, :, i] else: x = deltaA[:, :, i] * x + deltaB_u[:, :, i] @@ -382,11 +507,9 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta if y.is_complex(): y = y.real * 2 ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) + y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) return out if not return_last_state else (out, last_state) - - From 64b61609d8d1213523546c1a854e0c6f9a6b316d Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 18:35:18 +0300 Subject: [PATCH 08/45] Do not have a mamba layer for now, push in a future PR --- vllm/model_executor/layers/mamba/layer.py | 194 ---------------------- 1 file changed, 194 deletions(-) delete mode 100644 vllm/model_executor/layers/mamba/layer.py diff --git a/vllm/model_executor/layers/mamba/layer.py b/vllm/model_executor/layers/mamba/layer.py deleted file mode 100644 index 8c472dd7928b..000000000000 --- a/vllm/model_executor/layers/mamba/layer.py +++ /dev/null @@ -1,194 +0,0 @@ -from dataclasses import dataclass -from typing import Optional -import torch.nn as nn -import torch -from torch.nn.parameter import Parameter - -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear -from vllm.model_executor.layers.mamba.ops.casual_conv1d import causal_conv1d_fn, causal_conv1d_update -from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn, selective_state_update -from vllm.model_executor.utils import set_weight_attrs - -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - - - -class Mamba(nn.Module): - - def __init__(self,hidden_size: int, - mamba_d_state: int, - mamba_d_conv: int, - mamba_expand: int, - mamba_dt_rank: int, - mamba_conv_use_bias: bool, - mamba_proj_use_bias: bool, - activation_func:str = "silu", - rms_norm_eps:float = 1e-5): - super().__init__() - - self.hidden_size = hidden_size - self.ssm_state_size = mamba_d_state - self.conv_kernel_size = mamba_d_conv - self.intermediate_size = mamba_expand * hidden_size - self.time_step_rank = mamba_dt_rank - self.use_conv_bias = mamba_conv_use_bias - self.use_bias = mamba_proj_use_bias - - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.intermediate_size, - bias=self.use_conv_bias, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=self.use_bias) - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.intermediate_size, - self.time_step_rank + self.ssm_state_size * 2, - bias=False, - ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True) - - def weight_loader(param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) - - def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): - weight_loader(param, -torch.exp(loaded_weight.float())) - - tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter( - torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - - set_weight_attrs(self.D, {"weight_loader": weight_loader}) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) - - self.out_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=self.use_bias, - input_is_parallel=True, - ) - self.activation = activation_func - - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=rms_norm_eps) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=rms_norm_eps) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=rms_norm_eps) - - - - def forward( - self, - hidden_states: torch.Tensor, - cache_params: Optional[MambaCacheParams] = None - ): - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) - hidden_states, gate = projected_states.chunk(2, dim=1) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - if cache_params is not None and not cache_params.is_prompt: - hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), - cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - ) - hidden_states = hidden_states.unsqueeze(-1) - else: - if cache_params is not None: - conv_states = nn.functional.pad( - hidden_states, - (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_state.copy_(conv_states) - - hidden_states,_ = causal_conv1d_fn( - hidden_states, - conv_weights, - bias=self.conv1d.bias, - activation=self.activation, - ) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - if cache_params is not None and not cache_params.is_prompt: - scan_outputs = selective_state_update( - cache_params.ssm_state, - hidden_states[..., 0], - discrete_time_step[..., 0], - self.A, - B[:, 0], - C[:, 0], - self.D, - gate[..., 0], - time_proj_bias, - dt_softplus=True, - ).unsqueeze(-1) - else: - scan_outputs, ssm_state = selective_scan_fn( - hidden_states, - discrete_time_step, - self.A, - B.transpose(1, 2), - C.transpose(1, 2), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - return_last_state=True, - ) - if ssm_state is not None and cache_params is not None: - cache_params.ssm_state.copy_(ssm_state) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] - return contextualized_states - From 2ff36cb4545868519014f060f0aa22d3de06ae93 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 18:50:07 +0300 Subject: [PATCH 09/45] Format --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 1530 +++++++++-------- csrc/mamba/causal_conv1d/causal_conv1d.h | 145 +- csrc/mamba/causal_conv1d/static_switch.h | 23 +- csrc/mamba/mamba_ssm/selective_scan.h | 416 ++--- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 1235 +++++++------ csrc/mamba/mamba_ssm/static_switch.h | 23 +- csrc/ops.h | 17 +- csrc/torch_bindings.cpp | 17 +- .../layers/mamba/ops/casual_conv1d.py | 76 +- .../layers/mamba/ops/mamba_ssm.py | 174 +- vllm/model_executor/models/jamba.py | 11 +- 11 files changed, 1862 insertions(+), 1805 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 0b31071889f7..75a7bad3fa06 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -12,743 +12,919 @@ #include "static_switch.h" +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \ + "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \ + "'"); \ + } - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream); +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase& params, + cudaStream_t stream); -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template +void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream); -void set_conv_params_fwd(ConvParamsBase ¶ms, +void set_conv_params_fwd(ConvParamsBase& params, // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, + const size_t batch, const size_t dim, + const size_t seqlen, const size_t width, // device pointers - const at::Tensor x, - const at::Tensor weight, - const at::Tensor out, - void* bias_ptr, + const at::Tensor x, const at::Tensor weight, + const at::Tensor out, void* bias_ptr, bool silu_activation) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.width = width; - - params.silu_activation = silu_activation; - - // Set the pointers and strides. - params.x_ptr = x.data_ptr(); - params.weight_ptr = weight.data_ptr(); - params.bias_ptr = bias_ptr; - params.out_ptr = out.data_ptr(); - // All stride are in elements, not bytes. - params.x_batch_stride = x.stride(0); - params.x_c_stride = x.stride(1); - params.x_l_stride = x.stride(-1); - params.weight_c_stride = weight.stride(0); - params.weight_width_stride = weight.stride(1); - params.out_batch_stride = out.stride(0); - params.out_c_stride = out.stride(1); - params.out_l_stride = out.stride(-1); + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); } +at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& seq_idx_, + const c10::optional& seq_pos_idx_, + const c10::optional& initial_states_, + c10::optional& final_states_out_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || + input_type == at::ScalarType::Half || + input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || + weight_type == at::ScalarType::Half || + weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK( + dim % 8 == 0, + "causal_conv1d only supports channel dimension divisible by 8 for now"); + TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, + "causal_conv1d with channel last layout requires strides " + "(x.stride(0) and x.stride(2)) to be multiples of 8"); + } + TORCH_CHECK(width >= 2 && width <= 4, + "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + if (seq_idx_.has_value()) { + TORCH_CHECK(is_channel_last, + "seq_idx is only supported for channel last layout"); + auto seq_idx = seq_idx_.value(); + TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_idx.is_cuda()); + TORCH_CHECK(seq_idx.is_contiguous()); + CHECK_SHAPE(seq_idx, batch_size, seqlen); + } + if (seq_pos_idx_.has_value()) { + auto seq_pos_idx = seq_pos_idx_.value(); + TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_pos_idx.is_cuda()); + TORCH_CHECK(seq_pos_idx.is_contiguous()); + CHECK_SHAPE(seq_pos_idx, batch_size, seqlen); + } + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + + if (seq_idx_.has_value()) { + params.seq_idx_ptr = seq_idx_.value().data_ptr(); + } else { + params.seq_idx_ptr = nullptr; + } + + if (seq_pos_idx_.has_value()) { + params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr(); + } else { + params.seq_pos_idx_ptr = nullptr; + } + if (initial_states_.has_value()) { + TORCH_CHECK(is_channel_last, + "initial_states is only supported for channel last layout"); + auto initial_states = initial_states_.value(); + TORCH_CHECK(initial_states.scalar_type() == input_type); + TORCH_CHECK(initial_states.is_cuda()); + CHECK_SHAPE(initial_states, batch_size, dim, width - 1); + TORCH_CHECK(initial_states.stride(1) == 1); + params.initial_states_ptr = initial_states.data_ptr(); + params.initial_states_batch_stride = initial_states.stride(0); + params.initial_states_c_stride = initial_states.stride(1); + params.initial_states_l_stride = initial_states.stride(2); + } else { + params.initial_states_ptr = nullptr; + } + + if (final_states_out_.has_value()) { + TORCH_CHECK(is_channel_last, + "final_states is only supported for channel last layout"); + auto final_states = final_states_out_.value(); + TORCH_CHECK(final_states.scalar_type() == input_type); + TORCH_CHECK(final_states.is_cuda()); + CHECK_SHAPE(final_states, batch_size, dim, width - 1); + TORCH_CHECK(final_states.stride(1) == 1); + params.final_states_ptr = final_states.data_ptr(); + params.final_states_batch_stride = final_states.stride(0); + params.final_states_c_stride = final_states.stride(1); + params.final_states_l_stride = final_states.stride(2); + } else { + params.final_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16( + x.scalar_type(), "causal_conv1d_fwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16( + weight.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, + stream); + } + }); + }); + return out; +} -at::Tensor -causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, - const c10::optional &bias_, - const c10::optional &seq_idx_, - const c10::optional &seq_pos_idx_, - const c10::optional &initial_states_, - c10::optional &final_states_out_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); - const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; - - if (is_channel_last) { - TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); - TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); - } - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - if (seq_idx_.has_value()) { - TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); - auto seq_idx = seq_idx_.value(); - TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); - TORCH_CHECK(seq_idx.is_cuda()); - TORCH_CHECK(seq_idx.is_contiguous()); - CHECK_SHAPE(seq_idx, batch_size, seqlen); - } - if (seq_pos_idx_.has_value()) { - auto seq_pos_idx = seq_pos_idx_.value(); - TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32); - TORCH_CHECK(seq_pos_idx.is_cuda()); - TORCH_CHECK(seq_pos_idx.is_contiguous()); - CHECK_SHAPE(seq_pos_idx, batch_size, seqlen); - } - at::Tensor out = torch::empty_like(x); +at::Tensor causal_conv1d_update(const at::Tensor& x, + const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || + input_type == at::ScalarType::Half || + input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || + weight_type == at::ScalarType::Half || + weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, + "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd( + params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16( + x.scalar_type(), "causal_conv1d_update", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16( + weight.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + }); + return out; +} - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, - silu_activation); +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + static constexpr int kNLoadsIndex = kNElts / 4; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = + cub::BlockLoad; + using BlockLoadIndexT = + cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = + cub::BlockStore; + + static constexpr int kSmemIOSize = + (kIsVecLoad && kNLoadsIndex == 1) + ? 0 + : std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; - if (seq_idx_.has_value()) { - params.seq_idx_ptr = seq_idx_.value().data_ptr(); +template +__global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel( + ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = + reinterpret_cast(smem_); + auto& smem_load_vec = + reinterpret_cast(smem_); + auto& smem_load_index = + reinterpret_cast(smem_); + auto& smem_load_index_vec = + reinterpret_cast( + smem_); + auto& smem_store = + reinterpret_cast(smem_); + auto& smem_store_vec = + reinterpret_cast(smem_); + vec_t* smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t* x = reinterpret_cast(params.x_ptr) + + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t* weight = reinterpret_cast(params.weight_ptr) + + channel_id * params.weight_c_stride; + input_t* out = reinterpret_cast(params.out_ptr) + + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = + params.bias_ptr == nullptr + ? 0.f + : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int* seq_pos_idx = !kHasSeqPosIdx + ? nullptr + : reinterpret_cast(params.seq_pos_idx_ptr) + + batch_id * params.seqlen; + + // Thread 0 will load the last elements of the previous chunk, so we + // initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + } + + float weight_vals[kWidth]; +#pragma unroll + for (int i = 0; i < kWidth; ++i) { + weight_vals[i] = float(weight[i * params.weight_width_stride]); + } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + int seq_pos_idx_load[kNElts]; + if constexpr (kIsVecLoad) { + Ktraits::BlockLoadVecT(smem_load_vec) + .Load(reinterpret_cast(x), + *reinterpret_cast(&x_vals_load[kNElts]), + (params.seqlen - chunk * kChunkSize) / kNElts); + if (kHasSeqPosIdx) + Ktraits::BlockLoadIndexVecT(smem_load_index_vec) + .Load(reinterpret_cast(seq_pos_idx), + *reinterpret_cast( + seq_pos_idx_load), + (params.seqlen - chunk * kChunkSize) / kNElts * + Ktraits::kNLoadsIndex); } else { - params.seq_idx_ptr = nullptr; + __syncthreads(); + Ktraits::BlockLoadT(smem_load).Load( + x, *reinterpret_cast(&x_vals_load[kNElts]), + params.seqlen - chunk * kChunkSize); + if (kHasSeqPosIdx) + Ktraits::BlockLoadIndexT(smem_load_index) + .Load(seq_pos_idx, seq_pos_idx_load, + (params.seqlen - chunk * kChunkSize), 0); } - - if (seq_pos_idx_.has_value()) { - params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr(); - } else { - params.seq_pos_idx_ptr = nullptr; + x += kChunkSize; + if (kHasSeqPosIdx) seq_pos_idx += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { + smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - if (initial_states_.has_value()) { - TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); - auto initial_states = initial_states_.value(); - TORCH_CHECK(initial_states.scalar_type() == input_type); - TORCH_CHECK(initial_states.is_cuda()); - CHECK_SHAPE(initial_states, batch_size, dim, width - 1); - TORCH_CHECK(initial_states.stride(1) == 1); - params.initial_states_ptr = initial_states.data_ptr(); - params.initial_states_batch_stride = initial_states.stride(0); - params.initial_states_c_stride = initial_states.stride(1); - params.initial_states_l_stride = initial_states.stride(2); - } else { - params.initial_states_ptr = nullptr; + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = + smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current + // chunk. + if (tidx == kNThreads - 1) { + smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - if (final_states_out_.has_value()) { - TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); - auto final_states = final_states_out_.value(); - TORCH_CHECK(final_states.scalar_type() == input_type); - TORCH_CHECK(final_states.is_cuda()); - CHECK_SHAPE(final_states, batch_size, dim, width - 1); - TORCH_CHECK(final_states.stride(1) == 1); - params.final_states_ptr = final_states.data_ptr(); - params.final_states_batch_stride = final_states.stride(0); - params.final_states_c_stride = final_states.stride(1); - params.final_states_l_stride = final_states.stride(2); - } else { - params.final_states_ptr = nullptr; + float x_vals[2 * kNElts]; +#pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { + x_vals[i] = float(x_vals_load[i]); } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { - if (!is_channel_last) { - causal_conv1d_fwd_cuda(params, stream); - } else { - causal_conv1d_channellast_fwd_cuda(params, stream); - } - }); - }); - return out; -} - + float out_vals[kNElts]; -at::Tensor -causal_conv1d_update(const at::Tensor &x, - const at::Tensor &conv_state, - const at::Tensor &weight, - const c10::optional &bias_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - TORCH_CHECK(conv_state.scalar_type() == input_type); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(conv_state.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + int w = 0; + if (kHasSeqPosIdx) { + if (seq_pos_idx_load[i] < kWidth) { + w = kWidth - seq_pos_idx_load[i] - 1; + } + } + for (; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } } - at::Tensor out = torch::empty_like(x); - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, - silu_activation); - params.conv_state_ptr = conv_state.data_ptr(); - // All stride are in elements, not bytes. - params.conv_state_batch_stride = conv_state.stride(0); - params.conv_state_c_stride = conv_state.stride(1); - params.conv_state_l_stride = conv_state.stride(2); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { - causal_conv1d_update_cuda(params, stream); - }); - }); - return out; -} - -template -struct Causal_conv1d_fwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - static constexpr bool kIsVecLoad = kIsVecLoad_; - static constexpr int kNLoadsIndex = kNElts / 4; - using vec_t = typename BytesToType::Type; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadIndexT = cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - - static constexpr int kSmemIOSize = (kIsVecLoad && kNLoadsIndex == 1) - ? 0 - : std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockLoadIndexT::TempStorage), - sizeof(typename BlockLoadIndexVecT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_vec = reinterpret_cast(smem_); - auto& smem_load_index = reinterpret_cast(smem_); - auto& smem_load_index_vec = reinterpret_cast(smem_); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_store_vec = reinterpret_cast(smem_); - vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - int *seq_pos_idx = !kHasSeqPosIdx ? nullptr : reinterpret_cast(params.seq_pos_idx_ptr) + batch_id * params.seqlen; - - // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. - if (tidx == 0) { - input_t zeros[kNElts] = {0}; - smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + if (params.silu_activation) { +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } } - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; - for (int chunk = 0; chunk < n_chunks; ++chunk) { - input_t x_vals_load[2 * kNElts] = {0}; - int seq_pos_idx_load[kNElts]; - if constexpr(kIsVecLoad) { - Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); - if (kHasSeqPosIdx) - Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(reinterpret_cast(seq_pos_idx), *reinterpret_cast(seq_pos_idx_load), (params.seqlen - chunk * kChunkSize) / kNElts * Ktraits::kNLoadsIndex); - } else { - __syncthreads(); - Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); - if (kHasSeqPosIdx) - Ktraits::BlockLoadIndexT(smem_load_index).Load(seq_pos_idx, seq_pos_idx_load, (params.seqlen - chunk * kChunkSize), 0); - } - x += kChunkSize; - if (kHasSeqPosIdx) seq_pos_idx += kChunkSize; - __syncthreads(); - // Thread kNThreads - 1 don't write yet, so that thread 0 can read - // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - __syncthreads(); - reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; - __syncthreads(); - // Now thread kNThreads - 1 can write the last elements of the current chunk. - if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - - float x_vals[2 * kNElts]; - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - - float out_vals[kNElts]; - - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = bias_val; - int w = 0; - if (kHasSeqPosIdx){ - if(seq_pos_idx_load[i] < kWidth){ - w = kWidth - seq_pos_idx_load[i] - 1; - } - } - for (; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - } - - if (params.silu_activation) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); - } - } - - input_t out_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } - if constexpr(kIsVecLoad) { - Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); - } - out += kChunkSize; + input_t out_vals_store[kNElts]; +#pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals_store[i] = out_vals[i]; + } + if constexpr (kIsVecLoad) { + Ktraits::BlockStoreVecT(smem_store_vec) + .Store(reinterpret_cast(out), + reinterpret_cast(out_vals_store), + (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + Ktraits::BlockStoreT(smem_store) + .Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); } + out += kChunkSize; + } } -template -void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] { - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - using Ktraits = Causal_conv1d_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); +template +void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] { + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = + Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); + }); } -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } +template +void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } } -template +template struct Causal_conv1d_channellast_fwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. - // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 - // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - // using BlockLoadT = cub::BlockLoad; - // using BlockStoreT = cub::BlockStore; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel + // dimension. That leaves 4 columns per warp, and so 16 columns per block + // (assuming each block has 128 threads). Each each load is 16 x 32|64 + // elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = + kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = + 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; using BlockStoreT = + // cub::BlockStore; static constexpr int kSmemSize = + // std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename + // BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; }; -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; - - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t *weight = reinterpret_cast(params.weight_ptr) - + chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) - + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; - input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr - : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - // The last L-chunk will also have enough info to write to final states, since it also contain a few x values - // from the previous L-chunk. - input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr - : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); - } - reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - // Load the elements from the previous chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); - } else if (initial_states != nullptr - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); - } - reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; +template +__global__ +__launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel( + ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + input_t* x = reinterpret_cast(params.x_ptr) + + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t* weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t* out = reinterpret_cast(params.out_ptr) + + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + + chunk_c_id * kChunkSizeC + c_idx * kNElts; + int* seq_idx = !kHasSeqIdx + ? nullptr + : reinterpret_cast(params.seq_idx_ptr) + + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; + input_t* initial_states = + params.initial_states_ptr == nullptr || chunk_l_id > 0 + ? nullptr + : reinterpret_cast(params.initial_states_ptr) + + batch_id * params.initial_states_batch_stride + + l_idx * params.initial_states_l_stride + + chunk_c_id * kChunkSizeC + c_idx * kNElts; + // The last L-chunk will also have enough info to write to final states, since + // it also contain a few x values from the previous L-chunk. + input_t* final_states = + params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 + ? nullptr + : reinterpret_cast(params.final_states_ptr) + + batch_id * params.final_states_batch_stride + + l_idx * params.final_states_l_stride + + chunk_c_id * kChunkSizeC + c_idx * kNElts; + +#pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen && + chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = + *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); } - - __syncthreads(); - - if (final_states != nullptr - && l_idx < kWidth - 1 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) - // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] - *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + reinterpret_cast( + x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = + reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 && + chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen && + chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = + *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } else if (initial_states != nullptr && + chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 && + chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = + *reinterpret_cast(initial_states); } - - constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {0}; - if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; - } + reinterpret_cast(x_smem[l_idx])[c_idx] = + reinterpret_cast(x_vals_load)[0]; + } + + __syncthreads(); + + if (final_states != nullptr && l_idx < kWidth - 1 && + chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - + // 1) So last few elements (index params.seqlen - kWidth + 1 + l_idx) are + // stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * + // kChunkSizeL - kWidth + 1)][c_idx] + *reinterpret_cast(final_states) = reinterpret_cast( + x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + } + + constexpr int kLPerThread = + std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for + // simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = + params.bias_ptr == nullptr || + chunk_c_id * kChunkSizeC + row_idx >= params.dim + ? 0.f + : float(reinterpret_cast( + params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { +#pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + + w * params.weight_width_stride]; } - float x_vals[kWidth - 1 + kLPerThread]; - #pragma unroll + } + float x_vals[kWidth - 1 + kLPerThread]; +#pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + int seq_idx_thread[kWidth - 1 + kLPerThread]; + if constexpr (kHasSeqIdx) { +#pragma unroll for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + seq_idx_thread[i] = + chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= + 0 + ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] + : -1; } - int seq_idx_thread[kWidth - 1 + kLPerThread]; - if constexpr (kHasSeqIdx) { - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; - } + } + + float out_vals[kLPerThread]; +#pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; +#pragma unroll + for (int w = 0; w < kWidth; ++w) { + if constexpr (!kHasSeqIdx) { + out_vals[i] += weight_vals[w] * x_vals[i + w]; + } else { + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur + ? weight_vals[w] * x_vals[i + w] + : 0.f; + } } - - float out_vals[kLPerThread]; - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - out_vals[i] = bias_val; - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - out_vals[i] += weight_vals[w] * x_vals[i + w]; - } else { - out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; - } - } - if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + if (params.silu_activation) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } - - __syncthreads(); - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } - __syncthreads(); - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t out_vals_store[kNElts]; - reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; - } + } + + __syncthreads(); +#pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; + } + __syncthreads(); + +#pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = + reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen && + chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = + reinterpret_cast(out_vals_store)[0]; } - + } } -template -void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { - using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_fwd_kernel; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<>>(params); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase& params, + cudaStream_t stream) { + BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { + using Ktraits = + Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase& params, + cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, + stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, + stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, + stream); + } } -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda( + ConvParamsBase& params, cudaStream_t stream); /////// - - - -template +template struct Causal_conv1d_update_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); }; -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_update_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y * kNThreads + tidx; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride - + channel_id * params.conv_state_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - float weight_vals[kWidth] = {0}; - if (channel_id < params.dim) { - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } +template +__global__ +__launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kernel( + ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t* x = reinterpret_cast(params.x_ptr) + + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t* conv_state = reinterpret_cast(params.conv_state_ptr) + + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t* weight = reinterpret_cast(params.weight_ptr) + + channel_id * params.weight_c_stride; + input_t* out = reinterpret_cast(params.out_ptr) + + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = + params.bias_ptr == nullptr || channel_id >= params.dim + ? 0.f + : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { +#pragma unroll + for (int i = 0; i < kWidth; ++i) { + weight_vals[i] = float(weight[i * params.weight_width_stride]); } + } - float x_vals[kWidth] = {0}; - if (channel_id < params.dim) { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } - x_vals[kWidth - 1] = float(x[0]); - #pragma unroll - for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { +#pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { + x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } - - float out_val = bias_val; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } - if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - if (channel_id < params.dim) { out[0] = input_t(out_val); } + x_vals[kWidth - 1] = float(x[0]); +#pragma unroll + for (int i = 0; i < kWidth; ++i) { + conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); + } + } + + float out_val = bias_val; +#pragma unroll + for (int i = 0; i < kWidth; ++i) { + out_val += weight_vals[i] * x_vals[i]; + } + if (params.silu_activation) { + out_val = out_val / (1 + expf(-out_val)); + } + if (channel_id < params.dim) { + out[0] = input_t(out_val); + } } -template -void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - using Ktraits = Causal_conv1d_update_kernel_traits; - dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = &causal_conv1d_update_kernel; - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); +template +void causal_conv1d_update_launch(ConvParamsBase& params, cudaStream_t stream) { + using Ktraits = + Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); - } +template +void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } } -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_update_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase& params, + cudaStream_t stream); +template void causal_conv1d_update_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_update_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_update_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_update_cuda( + ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_update_cuda( + ConvParamsBase& params, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 4e05744a8bbd..76a634ba70ec 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -9,98 +9,103 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// struct ConvParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, width; - bool silu_activation; - - index_t x_batch_stride; - index_t x_c_stride; - index_t x_l_stride; - index_t weight_c_stride; - index_t weight_width_stride; - index_t out_batch_stride; - index_t out_c_stride; - index_t out_l_stride; - - index_t conv_state_batch_stride; - index_t conv_state_c_stride; - index_t conv_state_l_stride; - - // Common data pointers. - void *__restrict__ x_ptr; - void *__restrict__ weight_ptr; - void *__restrict__ bias_ptr; - void *__restrict__ out_ptr; - - void *__restrict__ conv_state_ptr; - - void *__restrict__ seq_idx_ptr; - void *__restrict__ seq_pos_idx_ptr; - - // No __restrict__ since initial_states could be the same as final_states. - void * initial_states_ptr; - index_t initial_states_batch_stride; - index_t initial_states_l_stride; - index_t initial_states_c_stride; - - void * final_states_ptr; - index_t final_states_batch_stride; - index_t final_states_l_stride; - index_t final_states_c_stride; + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void* __restrict__ x_ptr; + void* __restrict__ weight_ptr; + void* __restrict__ bias_ptr; + void* __restrict__ out_ptr; + + void* __restrict__ conv_state_ptr; + + void* __restrict__ seq_idx_ptr; + void* __restrict__ seq_pos_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void* initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void* final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; }; +template +struct BytesToType {}; -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } + __device__ inline T operator()(T const& x, T const& y) { return x + y; } }; -template +template struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } }; -template<> +template <> struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { + template + static __device__ inline T run(T x, Operator& op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; -} + } }; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h index 0f4ad3eb6223..11c876842395 100644 --- a/csrc/mamba/causal_conv1d/static_switch.h +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -1,4 +1,5 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once @@ -13,13 +14,13 @@ /// some_function(...); /// }); /// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - static constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - static constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 69d72bf255e9..dbc4e6dac112 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -5,270 +5,292 @@ #pragma once #ifndef USE_ROCM - #include + #include #else - #include + #include #endif #include //////////////////////////////////////////////////////////////////////////////////////////////////// struct SSMParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, dstate, n_groups, n_chunks; - int dim_ngroups_ratio; - bool is_variable_B; - bool is_variable_C; - - bool delta_softplus; - - index_t A_d_stride; - index_t A_dstate_stride; - index_t B_batch_stride; - index_t B_d_stride; - index_t B_dstate_stride; - index_t B_group_stride; - index_t C_batch_stride; - index_t C_d_stride; - index_t C_dstate_stride; - index_t C_group_stride; - index_t u_batch_stride; - index_t u_d_stride; - index_t delta_batch_stride; - index_t delta_d_stride; - index_t z_batch_stride; - index_t z_d_stride; - index_t out_batch_stride; - index_t out_d_stride; - index_t out_z_batch_stride; - index_t out_z_d_stride; - - // Common data pointers. - void *__restrict__ A_ptr; - void *__restrict__ B_ptr; - void *__restrict__ C_ptr; - void *__restrict__ D_ptr; - void *__restrict__ u_ptr; - void *__restrict__ delta_ptr; - void *__restrict__ delta_bias_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; - void *__restrict__ z_ptr; - void *__restrict__ out_z_ptr; - void *__restrict__ index_ptr; + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void* __restrict__ A_ptr; + void* __restrict__ B_ptr; + void* __restrict__ C_ptr; + void* __restrict__ D_ptr; + void* __restrict__ u_ptr; + void* __restrict__ delta_ptr; + void* __restrict__ delta_bias_ptr; + void* __restrict__ out_ptr; + void* __restrict__ x_ptr; + void* __restrict__ z_ptr; + void* __restrict__ out_z_ptr; + void* __restrict__ index_ptr; }; - - - #ifndef USE_ROCM - constexpr size_t custom_max(std::initializer_list ilist) - { - return std::max(ilist); - } +constexpr size_t custom_max(std::initializer_list ilist) { + return std::max(ilist); +} - template - constexpr T constexpr_min(T a, T b) { - return std::min(a, b); - } +template +constexpr T constexpr_min(T a, T b) { + return std::min(a, b); +} #else - constexpr size_t custom_max(std::initializer_list ilist) - { - return *std::max_element(ilist.begin(), ilist.end()); - } +constexpr size_t custom_max(std::initializer_list ilist) { + return *std::max_element(ilist.begin(), ilist.end()); +} - template - constexpr T constexpr_min(T a, T b) { - return a < b ? a : b; - } +template +constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; +} #endif - #define MAX_DSTATE 256 - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; +inline __device__ float2 operator+(const float2& a, const float2& b) { + return {a.x + b.x, a.y + b.y}; } -inline __device__ float3 operator+(const float3 &a, const float3 &b) { +inline __device__ float3 operator+(const float3& a, const float3& b) { return {a.x + b.x, a.y + b.y, a.z + b.z}; } -inline __device__ float4 operator+(const float4 & a, const float4 & b){ - return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +inline __device__ float4 operator+(const float4& a, const float4& b) { + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template struct BytesToType {}; +template +struct BytesToType {}; -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Converter{ - static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { dst[i] = src[i]; } +template +struct Converter { + static inline __device__ void to_float(const scalar_t (&src)[N], + float (&dst)[N]) { +#pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = src[i]; } + } }; -template -struct Converter{ - static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } +template +struct Converter { + static inline __device__ void to_float(const at::Half (&src)[N], + float (&dst)[N]) { + static_assert(N % 2 == 0); + auto& src2 = reinterpret_cast(src); + auto& dst2 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < N / 2; ++i) { + dst2[i] = __half22float2(src2[i]); } + } }; #if __CUDA_ARCH__ >= 800 -template -struct Converter{ - static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } +template +struct Converter { + static inline __device__ void to_float(const at::BFloat16 (&src)[N], + float (&dst)[N]) { + static_assert(N % 2 == 0); + auto& src2 = reinterpret_cast(src); + auto& dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { + dst2[i] = __bfloat1622float2(src2[i]); } + } }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct SSMScanOp; -template struct SSMScanOp; - -template<> +template <> struct SSMScanOp { - __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { - return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); - } + __device__ __forceinline__ float2 operator()(const float2& ab0, + const float2& ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } }; // A stateful callback functor that maintains a running prefix to be applied // during consecutive scan operations. -template struct SSMScanPrefixCallbackOp { - using scan_t = std::conditional_t, float2, float4>; - scan_t running_prefix; - // Constructor - __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide scan. - __device__ scan_t operator()(scan_t block_aggregate) { - scan_t old_prefix = running_prefix; - running_prefix = SSMScanOp()(running_prefix, block_aggregate); - return old_prefix; - } +template +struct SSMScanPrefixCallbackOp { + using scan_t = + std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) + : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void load_input(typename Ktraits::input_t *u, - typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadT::TempStorage &smem_load, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_vec = reinterpret_cast(smem_load); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadVecT(smem_load_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) - #ifdef USE_ROCM - , Ktraits::kNThreads * Ktraits::kNLoads - #endif - - ); - } else { - typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); - } +template +inline __device__ void load_input( + typename Ktraits::input_t* u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = + reinterpret_cast( + smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec) + .Load(reinterpret_cast(u), + reinterpret_cast(u_vals) +#ifdef USE_ROCM + , + Ktraits::kNThreads * Ktraits::kNLoads +#endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } } -template -inline __device__ void load_index(int *u, - int (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_index_vec = reinterpret_cast(smem_load_index); - Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) - ); - } else { - Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); - } +template +inline __device__ void load_index( + int* u, int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage& smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = + reinterpret_cast( + smem_load_index); + Ktraits::BlockLoadIndexVecT(smem_load_index_vec) + .Load(reinterpret_cast(u), + reinterpret_cast(u_vals)); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } } -template -inline __device__ void load_weight(typename Ktraits::input_t *Bvar, - typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, - int seqlen) { - constexpr int kNItems = Ktraits::kNItems; - typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - // #pragma unroll - // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } - Converter::to_float(B_vals_load, B_vals); +template +inline __device__ void load_weight( + typename Ktraits::input_t* Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = + reinterpret_cast( + smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec) + .Load(reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load)); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight) + .Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); } -template -inline __device__ void store_output(typename Ktraits::input_t *out, - const float (&out_vals)[Ktraits::kNItems], - typename Ktraits::BlockStoreT::TempStorage &smem_store, - int seqlen) { - typename Ktraits::input_t write_vals[Ktraits::kNItems]; - #pragma unroll - for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_store_vec = reinterpret_cast(smem_store); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockStoreVecT(smem_store_vec).Store( - reinterpret_cast(out), - reinterpret_cast(write_vals) - ); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); - } +template +inline __device__ void store_output( + typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; +#pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { + write_vals[i] = out_vals[i]; + } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = + reinterpret_cast( + smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec) + .Store(reinterpret_cast(out), + reinterpret_cast(write_vals)); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } } diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index b15a1b10f4c9..cf5d1311ea2f 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -8,615 +8,710 @@ #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #ifndef USE_ROCM - #include - #include - #include + #include + #include + #include #else - #include - namespace cub = hipcub; + #include +namespace cub = hipcub; #endif #include "selective_scan.h" #include "static_switch.h" -template +template struct Selective_Scan_fwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. - static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; - static constexpr int kNItems = kNItems_; - static constexpr int kNRows = kNRows_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kHasZ = kHasZ_; - static constexpr bool kUseIndex = kUseIndex_; - - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - static constexpr int kNLoadsIndex = kNItems / 4; - using vec_t = typename BytesToType::Type; - using scan_t = float2; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadIndexT = cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - sizeof(typename BlockLoadIndexT::TempStorage), - sizeof(typename BlockLoadIndexVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves + // occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = + cub::BlockLoad; + using BlockLoadIndexT = + cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = + cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = + cub::BlockStore; + // using BlockScanT = cub::BlockScan; using BlockScanT = cub::BlockScan; + using BlockScanT = + cub::BlockScan; + static constexpr int kSmemIOSize = + custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * + sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * + sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = + kSmemIOSize + sizeof(typename BlockScanT::TempStorage); }; -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_fwd_kernel(SSMParamsBase params) { - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr bool kUseIndex = Ktraits::kUseIndex; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - constexpr int kNRows = Ktraits::kNRows; - constexpr bool kDirectIO = Ktraits::kDirectIO; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_index = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); - scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * kNRows * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * kNRows * params.delta_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; - - float D_val[kNRows] = {0}; - if (params.D_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; +template +__global__ __launch_bounds__( + Ktraits::kNThreads, + Ktraits::kMinBlocks) void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + // + 2 * MAX_DSTATE * sizeof(weight_t)); auto& smem_load = + // reinterpret_cast(smem_loadstorescan); + auto& smem_load = + reinterpret_cast(smem_); + auto& smem_load_weight = + reinterpret_cast(smem_); + auto& smem_load_index = + reinterpret_cast(smem_); + auto& smem_load_weight1 = + *reinterpret_cast( + smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = + reinterpret_cast(smem_); + auto& smem_scan = + *reinterpret_cast( + smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + + // smem_loadstorescan_size); weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t* smem_running_prefix = + reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t* u = reinterpret_cast(params.u_ptr) + + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t* delta = reinterpret_cast(params.delta_ptr) + + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t* A = reinterpret_cast(params.A_ptr) + + dim_id * kNRows * params.A_d_stride; + weight_t* B = reinterpret_cast(params.B_ptr) + + dim_id * kNRows * params.B_d_stride; + input_t* Bvar = reinterpret_cast(params.B_ptr) + + batch_id * params.B_batch_stride + + group_id * params.B_group_stride; + weight_t* C = reinterpret_cast(params.C_ptr) + + dim_id * kNRows * params.C_d_stride; + input_t* Cvar = reinterpret_cast(params.C_ptr) + + batch_id * params.C_batch_stride + + group_id * params.C_group_stride; + scan_t* x = reinterpret_cast(params.x_ptr) + + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * + params.dstate; + int* index = !kUseIndex ? nullptr + : reinterpret_cast(params.index_ptr) + + batch_id * params.seqlen; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = + reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += + // blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * + // C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + + __syncthreads(); +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { + __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, + params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { + __syncthreads(); + } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], + smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, + index_vals_load[r], smem_load_index, + params.seqlen - chunk * kChunkSize); + } } - float delta_bias[kNRows] = {0}; - if (params.delta_bias_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + if constexpr (kUseIndex) { + index += kChunkSize; + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], + out_vals[kNRows][kNItems]; +#pragma unroll + for (int r = 0; r < kNRows; ++r) { +#pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f + ? log1pf(expf(delta_vals[r][i])) + : delta_vals[r][i]; } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { - input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - int index_vals_load[kNRows][kNItems]; - - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (kUseIndex) { - load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); - } + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = + A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of + // expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. + // If only B varies across seqlen, this holds C. If only C varies across + // seqlen, this holds B. If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, + (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = + C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } } - if constexpr (kUseIndex) { - index += kChunkSize; + } + if constexpr (kIsVariableC) { + auto& smem_load_weight_C = + !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, + (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableB) { +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = + B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } } - u += kChunkSize; - delta += kChunkSize; - - float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; - #pragma unroll + } + if constexpr (!kIsVariableB && !kIsVariableC) { +#pragma unroll for (int r = 0; r < kNRows; ++r) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float u_val = float(u_vals[r][i]); - delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; - if (params.delta_softplus) { - delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; - } - delta_u_vals[r][i] = delta_vals[r][i] * u_val; - out_vals[r][i] = D_val[r] * u_val; - } + BC_val[r] = + B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * + C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; } - - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - weight_t A_val[kNRows]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - constexpr float kLog2e = M_LOG2E; - A_val[r] *= kLog2e; - } - // This variable holds B * C if both B and C are constant across seqlen. If only B varies - // across seqlen, this holds C. If only C varies across seqlen, this holds B. - // If both B and C vary, this is unused. - weight_t BC_val[kNRows]; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); - if constexpr (!kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - } - if constexpr (kIsVariableC) { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); - if constexpr (!kIsVariableB) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; - } - } - } - if constexpr (!kIsVariableB && !kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } + } + +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { + __syncthreads(); + } // Scan could be using the same smem + scan_t thread_data[kNItems]; +#pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = + make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] + : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; } + } - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if (r > 0) { __syncthreads(); } // Scan could be using the same smem - scan_t thread_data[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - - // Reset A bar for cumulative sequences (Real) - if constexpr (kUseIndex) { - if (index_vals_load[r][i] == 0) { - thread_data[i].x = 0.f; - } - } - - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } - } - } - // Initialize running total - scan_t running_prefix; - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - typename Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. - if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const weight_t C_val = !kIsVariableC - ? BC_val[r] - : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); - out_vals[r][i] += thread_data[i].y * C_val; - } + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is + // correct + if (threadIdx.x * kNItems + i >= + params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); } + } } - - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) + // needs to read + running_prefix = + chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] + : (threadIdx.x % 32 == 0 + ? smem_running_prefix[state_idx + r * MAX_DSTATE] + : make_float2(1.f, 0.f)); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? + // smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading + // and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = + prefix_op.running_prefix; + } +#pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = + !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; } + } + } - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - input_t z_vals[kNItems]; - __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - out_vals[r][i] *= z_val / (1 + expf(-z_val)); - } - __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } + input_t* out = reinterpret_cast(params.out_ptr) + + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { + __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], + smem_store, params.seqlen - chunk * kChunkSize); + } - Bvar += kChunkSize * 1; - Cvar += kChunkSize * 1; + if constexpr (kHasZ) { + input_t* z = reinterpret_cast(params.z_ptr) + + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t* out_z = reinterpret_cast(params.out_z_ptr) + + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + + chunk * kChunkSize; +#pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, + params.seqlen - chunk * kChunkSize); +#pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], + smem_store, params.seqlen - chunk * kChunkSize); + } } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } } -template -void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block - // processing 1 row. - constexpr int kNRows = 1; - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); +template +void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from + // previously when we had each block processing 1 row. + constexpr int kNRows = 1; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr, kHasZ, [&] { + BOOL_SWITCH(params.index_ptr != nullptr, kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits< + kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, + kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = + Ktraits::kSmemSize + + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); + }); }); + }); } -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { - - #ifndef USE_ROCM - if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } - #else - if (params.seqlen <= 256) { - selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } - #endif +template +void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream) { +#ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } +#else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } +#endif } -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); - -void set_ssm_params_fwd(SSMParamsBase ¶ms, +template void selective_scan_fwd_cuda( + SSMParamsBase& params, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase& params, + cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase& params, + cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \ + "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \ + "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \ + "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase& params, // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, + const size_t batch, const size_t dim, + const size_t seqlen, const size_t dstate, + const size_t n_groups, const size_t n_chunks, + const bool is_variable_B, const bool is_variable_C, // device pointers - const torch::Tensor u, - const torch::Tensor delta, - const torch::Tensor A, - const torch::Tensor B, - const torch::Tensor C, - const torch::Tensor out, - const torch::Tensor z, - const torch::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - bool has_z, - bool delta_softplus, - void* index_ptr) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.dstate = dstate; - params.n_groups = n_groups; - params.n_chunks = n_chunks; - params.dim_ngroups_ratio = dim / n_groups; - - params.delta_softplus = delta_softplus; - - params.is_variable_B = is_variable_B; - params.is_variable_C = is_variable_C; - - // Set the pointers and strides. - params.u_ptr = u.data_ptr(); - params.delta_ptr = delta.data_ptr(); - params.A_ptr = A.data_ptr(); - params.B_ptr = B.data_ptr(); - params.C_ptr = C.data_ptr(); - params.D_ptr = D_ptr; - params.delta_bias_ptr = delta_bias_ptr; - params.out_ptr = out.data_ptr(); - params.x_ptr = x_ptr; - params.z_ptr = has_z ? z.data_ptr() : nullptr; - params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; - - params.index_ptr = index_ptr; - - // All stride are in elements, not bytes. - params.A_d_stride = A.stride(0); - params.A_dstate_stride = A.stride(1); - if (!is_variable_B) { - params.B_d_stride = B.stride(0); - } else { - params.B_batch_stride = B.stride(0); - params.B_group_stride = B.stride(1); - } - params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); - if (!is_variable_C) { - params.C_d_stride = C.stride(0); - } else { - params.C_batch_stride = C.stride(0); - params.C_group_stride = C.stride(1); - } - params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); - params.u_batch_stride = u.stride(0); - params.u_d_stride = u.stride(1); - params.delta_batch_stride = delta.stride(0); - params.delta_d_stride = delta.stride(1); - if (has_z) { - params.z_batch_stride = z.stride(0); - params.z_d_stride = z.stride(1); - params.out_z_batch_stride = out_z.stride(0); - params.out_z_d_stride = out_z.stride(1); - } - params.out_batch_stride = out.stride(0); - params.out_d_stride = out.stride(1); + const torch::Tensor u, const torch::Tensor delta, + const torch::Tensor A, const torch::Tensor B, + const torch::Tensor C, const torch::Tensor out, + const torch::Tensor z, const torch::Tensor out_z, + void* D_ptr, void* delta_bias_ptr, void* x_ptr, + bool has_z, bool delta_softplus, void* index_ptr) { + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); } -std::vector -selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, - const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, - bool delta_softplus, - const c10::optional &index_, - const c10::optional &x) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - if (index_.has_value()) { - auto index = index_.value(); - TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(index.is_cuda()); - CHECK_SHAPE(index, batch_size, seqlen); - } - - at::Tensor z, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout - at::Tensor out = torch::empty_like(delta); - if (x.has_value()){ - auto _x = x.value(); - TORCH_CHECK(_x.scalar_type() == weight_type); - TORCH_CHECK(_x.is_cuda()); - TORCH_CHECK(_x.stride(-1) == 1); - CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); - } - - SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, out, z, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.value().data_ptr(), - has_z, - delta_softplus, - index_.has_value() ? index_.value().data_ptr() : nullptr); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { +std::vector selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, bool delta_softplus, + const c10::optional& index_, + const c10::optional& x) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || + input_type == at::ScalarType::Half || + input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || + weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, + "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, + !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, + !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to + // have HBL layout + at::Tensor out = torch::empty_like(delta); + if (x.has_value()) { + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } + + SSMParamsBase params; + set_ssm_params_fwd( + params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, + is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.value().data_ptr(), has_z, delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16( + u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); + selective_scan_fwd_cuda(params, stream); }); - }); - std::vector result = {out, x.value()}; - if (has_z) { result.push_back(out_z); } - return result; + }); + std::vector result = {out, x.value()}; + if (has_z) { + result.push_back(out_z); + } + return result; } - diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h index 7920ac045d0a..d95531cf59ca 100644 --- a/csrc/mamba/mamba_ssm/static_switch.h +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -1,4 +1,5 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once @@ -13,13 +14,13 @@ /// some_function(...); /// }); /// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/ops.h b/csrc/ops.h index d35324d39864..4d86358b41d3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -176,15 +176,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); -std::vector -selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, - const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, - bool delta_softplus, - const c10::optional &index_, - const c10::optional &x); +std::vector selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, bool delta_softplus, + const c10::optional& index_, + const c10::optional& x); #ifndef USE_ROCM using fptr_t = int64_t; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b032c8965b01..806436d0d10e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -243,16 +243,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); - + // Mamba selective scan kerenl - ops.def("selective_scan_fwd(Tensor! u, Tensor! delta," - "Tensor! A, Tensor! B, Tensor C," - "Tensor! D_, Tensor! z_, Tensor! delta_bias_," - "bool delta_softplus," - "Tensor! index_, Tensor! &x) -> ()"); - ops.impl("selective_scan_fwd", torch::kCUDA, - &selective_scan_fwd); - + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor C," + "Tensor! D_, Tensor! z_, Tensor! delta_bias_," + "bool delta_softplus," + "Tensor! index_, Tensor! &x) -> ()"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py index c34afdeb5add..f83817fa1704 100644 --- a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py @@ -1,8 +1,8 @@ # Copyright (c) 2024, Tri Dao. from typing import Optional + import torch -import torch.nn.functional as F from vllm import _custom_ops as ops @@ -66,51 +66,6 @@ def causal_conv1d_fn( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_ref( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out=None, - activation: str = "silu", -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return (out, None) if not return_final_states else (out, final_states_out) - - def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): """ x: (batch, dim) @@ -124,32 +79,3 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): raise NotImplementedError("activation must be None, silu, or swish") activation = activation in ["silu", "swish"] return ops.causal_conv1d_update(x, conv_state, weight, bias, activation) - - -def causal_conv1d_update_ref(x, - conv_state, - weight, - bias=None, - activation=None): - """ - x: (batch, dim) - conv_state: (batch, dim, width) - weight: (dim, width) - bias: (dim,) - - out: (batch, dim) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - batch, dim = x.shape - width = weight.shape[1] - assert conv_state.shape == (batch, dim, width) - assert weight.shape == (dim, width) - conv_state.copy_(torch.roll(conv_state, shifts=-1, - dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - out = torch.sum(conv_state * weight, dim=-1) # (B D) - if bias is not None: - out += bias - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 057b4016362a..bad7ff6c1853 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,14 +1,12 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. import torch -import torch.nn.functional as F - import triton import triton.language as tl +from einops import rearrange +from packaging import version -from einops import rearrange, repeat from vllm import _custom_ops as ops -from packaging import version TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") @@ -158,10 +156,7 @@ def _selective_scan_update_kernel( if HAS_Z: z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dB = B[None, :] * dt[:, None] - else: - dB = B * dt # vector of size (dstate,) + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt state = state * dA + dB * x[:, None] tl.store(state_ptrs, state, @@ -298,84 +293,6 @@ def selective_state_update(state, return out -def selective_state_update_ref(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False): - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) - dt_bias: (dim,) or (nheads, dim) - Return: - out: (batch, dim) or (batch, nheads, dim) - """ - has_heads = state.dim() > 3 - if state.dim() == 3: - state = state.unsqueeze(1) - if x.dim() == 2: - x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if A.dim() == 2: - A = A.unsqueeze(0) - if B.dim() == 2: - B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if D is not None and D.dim() == 1: - D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: - dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape - assert x.shape == (batch, nheads, dim) - assert dt.shape == x.shape - assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[1] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) - dt = dt + dt_bias - dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * - A) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) - dB = rearrange(dt, "b h d -> b h d 1") * rearrange( - B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) - state.copy_(state * dA + - dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate - out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) - if D is not None: - out += (x * D).to(out.dtype) - out = (out if z is None else out * F.silu(z)).to(x.dtype) - if not has_heads: - out = out.squeeze(1) - return out - - def selective_scan_fn(u, delta, A, @@ -428,88 +345,3 @@ def selective_scan_fn(u, else: out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) - - -def selective_scan_ref(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - return_last_state=False, - position_indices=None, - prev_state=None): - """ - u: r(B D L) - delta: r(B D L) - A: c(D N) or r(D N) - B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) - D: r(D) - z: r(B D L) - delta_bias: r(D), fp32 - prev_state: r(B D N), fp32 - - out: r(B D L) - last_state (optional): r(B D dstate) or c(B D dstate) - """ - dtype_in = u.dtype - u = u.float() - delta = delta.float() - if delta_bias is not None: - delta = delta + delta_bias[..., None].float() - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - is_variable_B = B.dim() >= 3 - is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = torch.view_as_complex( - rearrange(B.float(), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = torch.view_as_complex( - rearrange(C.float(), "... (L two) -> ... L two", two=2)) - else: - B = B.float() - C = C.float() - x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state - ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) - if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) - else: - if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) - else: - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) - if is_variable_C and C.dim() == 4: - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - for i in range(u.shape[2]): - if position_indices is not None and position_indices[0, i] == 0: - x = deltaB_u[:, :, i] - else: - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) - else: - if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) - else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - if y.is_complex(): - y = y.real * 2 - ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) - out = y if D is None else y + u * rearrange(D, "d -> d 1") - if z is not None: - out = out * F.silu(z) - out = out.to(dtype=dtype_in) - return out if not return_last_state else (out, last_state) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b82eb14fb5f2..93d4cfa10d12 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,9 +4,6 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn from torch.nn.parameter import Parameter from transformers import JambaConfig @@ -24,6 +21,10 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.casual_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler @@ -161,7 +162,7 @@ def mamba_forward(self, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) - hidden_states = causal_conv1d_fn( + hidden_states, _ = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, @@ -920,7 +921,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break From 5f9c383e898699379473125d507b25b0ce42ad3a Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 18:55:09 +0300 Subject: [PATCH 10/45] Take off mamba from image and requirements --- Dockerfile | 23 ----------------------- requirements-mamba.txt | 3 --- 2 files changed, 26 deletions(-) delete mode 100644 requirements-mamba.txt diff --git a/Dockerfile b/Dockerfile index c13cb5c7e7a9..92f35855c619 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,9 +47,6 @@ COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt -COPY requirements-mamba.txt requirements-mamba.txt -RUN python3 -m pip install packaging -RUN python3 -m pip install -r requirements-mamba.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -138,22 +135,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt #################### DEV IMAGE #################### -#################### MAMBA Build IMAGE #################### -FROM dev as mamba-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} - -WORKDIR /usr/src/mamba - -COPY requirements-mamba.txt requirements-mamba.txt - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel -r requirements-mamba.txt \ - --no-build-isolation --no-deps --no-cache-dir - -#################### MAMBA Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base @@ -189,10 +170,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose -RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir - RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl #################### vLLM installation IMAGE #################### diff --git a/requirements-mamba.txt b/requirements-mamba.txt deleted file mode 100644 index 1838e87d063d..000000000000 --- a/requirements-mamba.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Mamba dependencies -mamba-ssm>=1.2.2 -causal-conv1d>=1.2.0 From ac8354e67ccc7b095d58ef559378ea9cf75b558d Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 20 Aug 2024 18:59:49 +0300 Subject: [PATCH 11/45] Add tests --- tests/kernels/test_causal_conv1d.py | 80 +++++++++ tests/kernels/test_mamba_ssm.py | 168 ++++++++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 2 +- 3 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 tests/kernels/test_causal_conv1d.py create mode 100644 tests/kernels/test_mamba_ssm.py diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py new file mode 100644 index 000000000000..2bf883d1373f --- /dev/null +++ b/tests/kernels/test_causal_conv1d.py @@ -0,0 +1,80 @@ +from typing import Optional +from einops import rearrange, repeat +import torch +import torch.nn.functional as F + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: str = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref(x, + conv_state, + weight, + bias=None, + activation=None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + batch, dim = x.shape + width = weight.shape[1] + assert conv_state.shape == (batch, dim, width) + assert weight.shape == (dim, width) + conv_state.copy_(torch.roll(conv_state, shifts=-1, + dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + out = torch.sum(conv_state * weight, dim=-1) # (B D) + if bias is not None: + out += bias + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + + diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py new file mode 100644 index 000000000000..cff7221a4af1 --- /dev/null +++ b/tests/kernels/test_mamba_ssm.py @@ -0,0 +1,168 @@ +from typing import Optional +from einops import rearrange, repeat +import torch +import torch.nn.functional as F + + +def selective_state_update_ref(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * + A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + + +def selective_scan_ref(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex( + rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex( + rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + if position_indices is not None and position_indices[0, i] == 0: + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index bad7ff6c1853..561386a0037c 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -333,7 +333,7 @@ def selective_scan_fn(u, ), device=u.device, dtype=torch.float32, - requires_grad=u.requires_grad) + requires_grad=False) x[:, :, 0, 0::2] = 1 if prev_state is not None: x[:, :, 0, 1::2].copy_(prev_state) From ea80282108e3c72fb8695646945056e201d7eec8 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 09:16:30 +0300 Subject: [PATCH 12/45] Some small fixes, tests still do not pass --- csrc/torch_bindings.cpp | 6 +- tests/kernels/test_mamba_ssm.py | 115 +++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 806436d0d10e..6397c500fc3f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -247,10 +247,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Mamba selective scan kerenl ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," - "Tensor! A, Tensor! B, Tensor C," - "Tensor! D_, Tensor! z_, Tensor! delta_bias_," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor? z_, Tensor? delta_bias_," "bool delta_softplus," - "Tensor! index_, Tensor! &x) -> ()"); + "Tensor? index_, Tensor? x) -> Tensor[]"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); } diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index cff7221a4af1..bb694241caff 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -1,7 +1,9 @@ -from typing import Optional from einops import rearrange, repeat import torch import torch.nn.functional as F +import pytest + +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn def selective_state_update_ref(state, @@ -166,3 +168,114 @@ def selective_scan_ref(u, out = out * F.silu(z) out = out.to(dtype=dtype_in) return out if not return_last_state else (out, last_state) + + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("scan_chunks", [1,2,3]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, itype, wtype, scan_chunks): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 4 + dstate = 8 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + else: + C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32) + else: + D = None + if has_z: + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + else: + delta_bias = None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + A_ref = A.detach().clone() + B_ref = B.detach().clone() + C_ref = C.detach().clone() + D_ref = D.detach().clone() if D is not None else None + z_ref = z.detach().clone() if z is not None else None + u_ref = u.detach().clone() + delta_ref = delta.detach().clone() + delta_bias_ref = delta_bias.detach().clone() if delta_bias is not None else None + state = None + state_ref = None + outs = [] + for c in range(scan_chunks): + chunked_prompt_len = seqlen // scan_chunks + chunk_start = chunked_prompt_len * c + chunk_end = chunked_prompt_len * (c + 1) + if c == scan_chunks - 1: + chunk_end = seqlen + _B = B + if is_variable_B: + _B = B[...,chunk_start:chunk_end] + _C = C + if is_variable_B: + _C = C[...,chunk_start:chunk_end] + _z = z + if has_z: + _z = z[...,chunk_start:chunk_end] + out, *rest = selective_scan_fn( + u[...,chunk_start:chunk_end], delta[...,chunk_start:chunk_end], A, _B, _C, D, z=_z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state,prev_state=state if c > 0 else None + ) + outs.append(out) + if return_last_state: + state = rest[0] + if len(outs) > 1: + out = torch.cat(outs,dim=-1) + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state + ) + if return_last_state: + state_ref = rest[0] + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) From 2f1549521182ec5a91f7fe812e7bb212c28b6849 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 11:25:27 +0300 Subject: [PATCH 13/45] Fix tests --- tests/kernels/test_mamba_ssm.py | 57 +++++++++++++++---- .../layers/mamba/ops/mamba_ssm.py | 2 +- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index bb694241caff..f76d2d82fda4 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import pytest -from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn, selective_state_update def selective_state_update_ref(state, @@ -229,16 +229,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z delta_bias = None u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) - A_ref = A.detach().clone() - B_ref = B.detach().clone() - C_ref = C.detach().clone() - D_ref = D.detach().clone() if D is not None else None - z_ref = z.detach().clone() if z is not None else None - u_ref = u.detach().clone() - delta_ref = delta.detach().clone() - delta_bias_ref = delta_bias.detach().clone() if delta_bias is not None else None state = None state_ref = None + out = None + out_ref = None outs = [] for c in range(scan_chunks): chunked_prompt_len = seqlen // scan_chunks @@ -254,6 +248,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z _C = C[...,chunk_start:chunk_end] _z = z if has_z: + assert z is not None _z = z[...,chunk_start:chunk_end] out, *rest = selective_scan_fn( u[...,chunk_start:chunk_end], delta[...,chunk_start:chunk_end], A, _B, _C, D, z=_z, @@ -266,16 +261,56 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z if len(outs) > 1: out = torch.cat(outs,dim=-1) out_ref, *rest = selective_scan_ref( - u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, - delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state ) if return_last_state: state_ref = rest[0] + assert out is not None and out_ref is not None print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if return_last_state: + assert state is not None and state_ref is not None print(f'State max diff: {(state - state_ref).abs().max().item()}') assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 2 + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + if has_z: + z = torch.randn_like(x) + else: + z = None + state_ref = state.detach().clone() + out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 561386a0037c..ce9ed5e54935 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -340,7 +340,7 @@ def selective_scan_fn(u, out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_indices, x) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) - if z is not None: + if z is None: return out if not return_last_state else (out, last_state) else: out_z = rest[0] From b51fd281f42d007d6262e3e044ea6d26d0b5ab45 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 16:21:35 +0300 Subject: [PATCH 14/45] Causal conv1d tests are passing --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 2 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 1 - csrc/ops.h | 14 +++ csrc/torch_bindings.cpp | 21 ++++- tests/kernels/test_causal_conv1d.py | 99 ++++++++++++++++++++++ vllm/_custom_ops.py | 2 +- 6 files changed, 135 insertions(+), 4 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 75a7bad3fa06..6c19741a0c50 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -94,7 +94,7 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& seq_idx_, const c10::optional& seq_pos_idx_, const c10::optional& initial_states_, - c10::optional& final_states_out_, + const c10::optional& final_states_out_, bool silu_activation) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index cf5d1311ea2f..1080db7b3019 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -410,7 +410,6 @@ void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) { constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { diff --git a/csrc/ops.h b/csrc/ops.h index 4d86358b41d3..c57b774a1432 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -185,6 +185,20 @@ std::vector selective_scan_fwd( const c10::optional& index_, const c10::optional& x); +at::Tensor causal_conv1d_update(const at::Tensor& x, + const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation); + +at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& seq_idx_, + const c10::optional& seq_pos_idx_, + const c10::optional& initial_states_, + const c10::optional& final_states_out_, + bool silu_activation); + #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6397c500fc3f..48f9368f5c71 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -244,7 +244,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); - // Mamba selective scan kerenl + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," "Tensor! A, Tensor! B, Tensor! C," @@ -252,6 +252,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool delta_softplus," "Tensor? index_, Tensor? x) -> Tensor[]"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor? seq_idx_," + "Tensor? seq_pos_idx_," + "Tensor? initial_states_," + "Tensor? final_states_out_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 2bf883d1373f..6270c46581b3 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -2,6 +2,10 @@ from einops import rearrange, repeat import torch import torch.nn.functional as F +import pytest + +from vllm.model_executor.layers.mamba.ops.casual_conv1d import causal_conv1d_fn + def causal_conv1d_ref( x: torch.Tensor, @@ -78,3 +82,98 @@ def causal_conv1d_update_ref(x, +@pytest.mark.parametrize("return_final_states", [False, True]) +@pytest.mark.parametrize("has_initial_states", [False, True]) +@pytest.mark.parametrize("channel_last", [False, True]) +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize( + "seqlen", [128, 512, 4096] +) +@pytest.mark.parametrize('dim', [64, 4096 + 32]) +def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states): + if not channel_last and (has_initial_states or return_final_states): + pytest.skip("Only channel_last support initial_states or return_final_states") + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch = 2 + # batch = 1 + if not channel_last: + x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :] + else: + x = rearrange( + torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" + ) + weight = torch.randn(dim, width, device=device, dtype=torch.float32) + if has_bias: + bias = torch.randn(dim, device=device, dtype=torch.float32) + else: + bias = None + if has_initial_states: + initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2) + else: + initial_states = None + x_ref = x.detach().clone() + weight_ref = weight.detach().clone() + bias_ref = bias.detach().clone() if bias is not None else None + initial_states_ref = initial_states.detach().clone() if initial_states is not None else None + activation = None if not silu_activation else "silu" + out, _ = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, + activation=activation) + out_ref, _ = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) + if return_final_states: + out, final_states = out + out_ref, final_states_ref = out_ref + print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") + print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") + assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + if return_final_states: + out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) + out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch = 2 + # batch = 1 + # dim = 64 + x = torch.randn(batch, dim, device=device, dtype=itype) + conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) + weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) + out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 288f4c879203..32dd29b649af 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -458,7 +458,7 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, initial_states_: Optional[torch.Tensor], final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,None, initial_states_, final_states_out_, silu_activation) From 0cc22529d2f8d6de3b351f74c2ce8a229c1d3596 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 16:22:12 +0300 Subject: [PATCH 15/45] Import --- tests/kernels/test_causal_conv1d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 6270c46581b3..724868c52620 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import pytest -from vllm.model_executor.layers.mamba.ops.casual_conv1d import causal_conv1d_fn +from vllm.model_executor.layers.mamba.ops.casual_conv1d import causal_conv1d_fn, causal_conv1d_update def causal_conv1d_ref( From d65dfb655d2f6e52cbf87a6b36e67c88aaaefa05 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 16:23:20 +0300 Subject: [PATCH 16/45] Tests --- tests/kernels/test_causal_conv1d.py | 85 +++++++++++++++++------ tests/kernels/test_mamba_ssm.py | 100 +++++++++++++++++++--------- vllm/_custom_ops.py | 2 +- 3 files changed, 133 insertions(+), 54 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 724868c52620..79625c2b4e03 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -81,7 +81,6 @@ def causal_conv1d_update_ref(x, return (out if activation is None else F.silu(out)).to(dtype=dtype_in) - @pytest.mark.parametrize("return_final_states", [False, True]) @pytest.mark.parametrize("has_initial_states", [False, True]) @pytest.mark.parametrize("channel_last", [False, True]) @@ -89,13 +88,13 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize( - "seqlen", [128, 512, 4096] -) +@pytest.mark.parametrize("seqlen", [128, 512, 4096]) @pytest.mark.parametrize('dim', [64, 4096 + 32]) -def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states): +def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, + channel_last, has_initial_states, return_final_states): if not channel_last and (has_initial_states or return_final_states): - pytest.skip("Only channel_last support initial_states or return_final_states") + pytest.skip( + "Only channel_last support initial_states or return_final_states") device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -106,34 +105,62 @@ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, cha batch = 2 # batch = 1 if not channel_last: - x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :] + x = torch.randn(batch, + 4096 + dim + 64, + seqlen, + device=device, + dtype=itype)[:, 4096:4096 + dim, :] else: x = rearrange( - torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" - ) + torch.randn(batch, + seqlen, + 4096 + dim + 64, + device=device, + dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") weight = torch.randn(dim, width, device=device, dtype=torch.float32) if has_bias: bias = torch.randn(dim, device=device, dtype=torch.float32) else: bias = None if has_initial_states: - initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2) + initial_states = torch.randn(batch, + width - 1, + dim, + device=device, + dtype=itype).transpose(1, 2) else: initial_states = None x_ref = x.detach().clone() weight_ref = weight.detach().clone() bias_ref = bias.detach().clone() if bias is not None else None - initial_states_ref = initial_states.detach().clone() if initial_states is not None else None + initial_states_ref = initial_states.detach().clone( + ) if initial_states is not None else None activation = None if not silu_activation else "silu" - out, _ = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, - activation=activation) - out_ref, _ = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) + out, _ = causal_conv1d_fn(x, + weight, + bias, + initial_states=initial_states, + return_final_states=return_final_states, + activation=activation) + out_ref, _ = causal_conv1d_ref(x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) if return_final_states: out, final_states = out out_ref, final_states_ref = out_ref - print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") - print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") - assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + print( + f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}" + ) + print( + f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}" + ) + assert torch.allclose(final_states, + final_states_ref, + rtol=rtol, + atol=atol) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -162,18 +189,32 @@ def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): # dim = 64 x = torch.randn(batch, dim, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) - weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) + weight = torch.randn(dim, + width, + device=device, + dtype=torch.float32, + requires_grad=True) if has_bias: - bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + bias = torch.randn(dim, + device=device, + dtype=torch.float32, + requires_grad=True) else: bias = None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) - out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index f76d2d82fda4..fbb670487dd5 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -84,7 +84,6 @@ def selective_state_update_ref(state, return out - def selective_scan_ref(u, delta, A, @@ -181,9 +180,10 @@ def selective_scan_ref(u, @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -@pytest.mark.parametrize("scan_chunks", [1,2,3]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype, scan_chunks): +@pytest.mark.parametrize("scan_chunks", [1, 2, 3]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, + has_z, has_delta_bias, delta_softplus, + return_last_state, seqlen, itype, wtype, scan_chunks): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -204,17 +204,25 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z if not is_variable_B: B_shape = (dim, dstate) elif varBC_groups == 1: - B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + B_shape = (batch_size, dstate, + seqlen if not is_complex else seqlen * 2) else: - B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype) + B_shape = (batch_size, varBC_groups, dstate, + seqlen if not is_complex else seqlen * 2) + B = torch.randn(*B_shape, + device=device, + dtype=wtype if not is_variable_B else itype) if not is_variable_C: C_shape = (dim, dstate) elif varBC_groups == 1: - C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) + C_shape = (batch_size, dstate, + seqlen if not is_complex else seqlen * 2) else: - C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) - C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype) + C_shape = (batch_size, varBC_groups, dstate, + seqlen if not is_complex else seqlen * 2) + C = torch.randn(*C_shape, + device=device, + dtype=wtype if not is_variable_C else itype) if has_D: D = torch.randn(dim, device=device, dtype=torch.float32) else: @@ -224,11 +232,13 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z else: z = None if has_delta_bias: - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + delta_bias = (0.5 * + torch.rand(dim, device=device, dtype=torch.float32)) else: delta_bias = None u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) - delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + delta = (0.5 * + torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) state = None state_ref = None out = None @@ -242,29 +252,40 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z chunk_end = seqlen _B = B if is_variable_B: - _B = B[...,chunk_start:chunk_end] + _B = B[..., chunk_start:chunk_end] _C = C if is_variable_B: - _C = C[...,chunk_start:chunk_end] + _C = C[..., chunk_start:chunk_end] _z = z if has_z: assert z is not None - _z = z[...,chunk_start:chunk_end] - out, *rest = selective_scan_fn( - u[...,chunk_start:chunk_end], delta[...,chunk_start:chunk_end], A, _B, _C, D, z=_z, - delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state,prev_state=state if c > 0 else None - ) + _z = z[..., chunk_start:chunk_end] + out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end], + delta[..., chunk_start:chunk_end], + A, + _B, + _C, + D, + z=_z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=state if c > 0 else None) outs.append(out) if return_last_state: state = rest[0] if len(outs) > 1: - out = torch.cat(outs,dim=-1) - out_ref, *rest = selective_scan_ref( - u, delta, A, B, C, D, z=z, - delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) + out = torch.cat(outs, dim=-1) + out_ref, *rest = selective_scan_ref(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) if return_last_state: state_ref = rest[0] @@ -278,8 +299,8 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) - -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @@ -306,11 +327,28 @@ def test_selective_state_update(dim, dstate, has_z, itype): else: z = None state_ref = state.detach().clone() - out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) - out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 32dd29b649af..12230ba50f54 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -458,7 +458,7 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, initial_states_: Optional[torch.Tensor], final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,None, + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, None, initial_states_, final_states_out_, silu_activation) From e7b2b3249af4355c126eaf86989ed8e9e6eb13ed Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 18:02:29 +0300 Subject: [PATCH 17/45] Format --- csrc/torch_bindings.cpp | 21 +++++---- tests/kernels/test_causal_conv1d.py | 36 ++++++---------- tests/kernels/test_mamba_ssm.py | 66 ++++++++--------------------- 3 files changed, 41 insertions(+), 82 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 48f9368f5c71..53f53f8f515b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -255,22 +255,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "causal_conv1d_update(Tensor! x," - "Tensor! conv_state," - "Tensor! weight," - "Tensor? bias_," - "bool silu_activation) -> Tensor"); + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( "causal_conv1d_fwd(Tensor! x, Tensor! weight," - "Tensor? bias_," - "Tensor? seq_idx_," - "Tensor? seq_pos_idx_," - "Tensor? initial_states_," - "Tensor? final_states_out_," - "bool silu_activation) -> Tensor"); + "Tensor? bias_," + "Tensor? seq_idx_," + "Tensor? seq_pos_idx_," + "Tensor? initial_states_," + "Tensor? final_states_out_," + "bool silu_activation) -> Tensor"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); - } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 79625c2b4e03..4373d28f4c2e 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -1,10 +1,12 @@ from typing import Optional -from einops import rearrange, repeat + +import pytest import torch import torch.nn.functional as F -import pytest +from einops import rearrange -from vllm.model_executor.layers.mamba.ops.casual_conv1d import causal_conv1d_fn, causal_conv1d_update +from vllm.model_executor.layers.mamba.ops.casual_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) def causal_conv1d_ref( @@ -14,7 +16,7 @@ def causal_conv1d_ref( initial_states: Optional[torch.Tensor] = None, return_final_states: bool = False, final_states_out=None, - activation: str = "silu", + activation: Optional[str] = "silu", ): """ x: (batch, dim, seqlen) @@ -90,8 +92,10 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("seqlen", [128, 512, 4096]) @pytest.mark.parametrize('dim', [64, 4096 + 32]) -def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, - channel_last, has_initial_states, return_final_states): +@pytest.mark.parametrize('batch', [1, 2]) +def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, + itype, channel_last, has_initial_states, + return_final_states): if not channel_last and (has_initial_states or return_final_states): pytest.skip( "Only channel_last support initial_states or return_final_states") @@ -99,11 +103,8 @@ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) # set seed torch.random.manual_seed(0) - batch = 2 - # batch = 1 if not channel_last: x = torch.randn(batch, 4096 + dim + 64, @@ -151,19 +152,11 @@ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, if return_final_states: out, final_states = out out_ref, final_states_ref = out_ref - print( - f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}" - ) - print( - f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}" - ) assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if return_final_states: @@ -176,17 +169,16 @@ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): +@pytest.mark.parametrize("batch", [1, 2]) +def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, + itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) # set seed torch.random.manual_seed(0) batch = 2 - # batch = 1 - # dim = 64 x = torch.randn(batch, dim, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) weight = torch.randn(dim, @@ -214,7 +206,5 @@ def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): bias, activation=activation) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index fbb670487dd5..796de355ffc0 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -1,9 +1,10 @@ -from einops import rearrange, repeat +import pytest import torch import torch.nn.functional as F -import pytest +from einops import rearrange, repeat -from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn, selective_state_update +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) def selective_state_update_ref(state, @@ -120,16 +121,8 @@ def selective_scan_ref(u, batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 - if A.is_complex(): - if is_variable_B: - B = torch.view_as_complex( - rearrange(B.float(), "... (L two) -> ... L two", two=2)) - if is_variable_C: - C = torch.view_as_complex( - rearrange(C.float(), "... (L two) -> ... L two", two=2)) - else: - B = B.float() - C = C.float() + B = B.float() + C = C.float() x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) @@ -158,8 +151,6 @@ def selective_scan_ref(u, y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x - if y.is_complex(): - y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") @@ -199,43 +190,30 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, batch_size = 2 dim = 4 dstate = 8 - is_complex = wtype == torch.complex64 A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) if not is_variable_B: - B_shape = (dim, dstate) + B_shape = [dim, dstate] elif varBC_groups == 1: - B_shape = (batch_size, dstate, - seqlen if not is_complex else seqlen * 2) + B_shape = [batch_size, dstate, seqlen] else: - B_shape = (batch_size, varBC_groups, dstate, - seqlen if not is_complex else seqlen * 2) + B_shape = [batch_size, varBC_groups, dstate, seqlen] B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype) if not is_variable_C: - C_shape = (dim, dstate) + C_shape = [dim, dstate] elif varBC_groups == 1: - C_shape = (batch_size, dstate, - seqlen if not is_complex else seqlen * 2) + C_shape = [batch_size, dstate, seqlen] else: - C_shape = (batch_size, varBC_groups, dstate, - seqlen if not is_complex else seqlen * 2) + C_shape = [batch_size, varBC_groups, dstate, seqlen] C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype) - if has_D: - D = torch.randn(dim, device=device, dtype=torch.float32) - else: - D = None - if has_z: - z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) - else: - z = None - if has_delta_bias: - delta_bias = (0.5 * - torch.rand(dim, device=device, dtype=torch.float32)) - else: - delta_bias = None + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + z = torch.randn(batch_size, dim, seqlen, device=device, + dtype=itype) if has_z else None + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + ) if has_delta_bias else None u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) @@ -290,12 +268,9 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, state_ref = rest[0] assert out is not None and out_ref is not None - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if return_last_state: assert state is not None and state_ref is not None - print(f'State max diff: {(state - state_ref).abs().max().item()}') assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) @@ -322,10 +297,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): B = torch.randn(batch_size, dstate, device=device) C = torch.randn(batch_size, dstate, device=device) D = torch.randn(dim, device=device) - if has_z: - z = torch.randn_like(x) - else: - z = None + z = torch.randn_like(x) if has_z else None state_ref = state.detach().clone() out = selective_state_update(state, x, @@ -348,7 +320,5 @@ def test_selective_state_update(dim, dstate, has_z, itype): dt_bias=dt_bias, dt_softplus=True) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) From 2c9fe00869b63f9ae995fe91557017f4d76b7edd Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 18:06:20 +0300 Subject: [PATCH 18/45] Cleanup --- tests/kernels/test_causal_conv1d.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 4373d28f4c2e..eadf0964ac82 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -15,7 +15,7 @@ def causal_conv1d_ref( bias: Optional[torch.Tensor] = None, initial_states: Optional[torch.Tensor] = None, return_final_states: bool = False, - final_states_out=None, + final_states_out:Optional[torch.Tensor]=None, activation: Optional[str] = "silu", ): """ @@ -54,11 +54,11 @@ def causal_conv1d_ref( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update_ref(x, - conv_state, - weight, - bias=None, - activation=None): +def causal_conv1d_update_ref(x:torch.Tensor, + conv_state:torch.Tensor, + weight:torch.Tensor, + bias:Optional[torch.Tensor]=None, + activation:Optional[str]=None): """ x: (batch, dim) conv_state: (batch, dim, width) From c82cc309a34e6cdfec4d4c422e02d6d23229f744 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 18:10:39 +0300 Subject: [PATCH 19/45] Align with main --- vllm/model_executor/models/jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 93d4cfa10d12..70c0a49bd499 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -921,7 +921,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - name, + weight_name, shard_id=shard_id, expert_id=expert_id) break From 6c83e5fbc83914d0c84a6c8b03cb7e10a2e53550 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 18:21:57 +0300 Subject: [PATCH 20/45] Format --- tests/kernels/test_causal_conv1d.py | 12 ++++++------ .../model_executor/layers/mamba/ops/casual_conv1d.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index eadf0964ac82..a2ddd4ef61bc 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -15,7 +15,7 @@ def causal_conv1d_ref( bias: Optional[torch.Tensor] = None, initial_states: Optional[torch.Tensor] = None, return_final_states: bool = False, - final_states_out:Optional[torch.Tensor]=None, + final_states_out: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", ): """ @@ -54,11 +54,11 @@ def causal_conv1d_ref( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update_ref(x:torch.Tensor, - conv_state:torch.Tensor, - weight:torch.Tensor, - bias:Optional[torch.Tensor]=None, - activation:Optional[str]=None): +def causal_conv1d_update_ref(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): """ x: (batch, dim) conv_state: (batch, dim, width) diff --git a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py index f83817fa1704..413c8bc227ae 100644 --- a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/casual_conv1d.py @@ -66,7 +66,11 @@ def causal_conv1d_fn( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): """ x: (batch, dim) conv_state: (batch, dim, width) @@ -77,5 +81,6 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") - activation = activation in ["silu", "swish"] - return ops.causal_conv1d_update(x, conv_state, weight, bias, activation) + activation_bool = activation in ["silu", "swish"] + return ops.causal_conv1d_update(x, conv_state, weight, bias, + activation_bool) From b6a00cbf759f4ee01cc3f5e3174d888c8e2c56fe Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 20:18:38 +0300 Subject: [PATCH 21/45] Add init py files --- vllm/model_executor/layers/mamba/__init__.py | 0 vllm/model_executor/layers/mamba/ops/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/model_executor/layers/mamba/__init__.py create mode 100644 vllm/model_executor/layers/mamba/ops/__init__.py diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From ef69b6c65a5d57f4e843f2c780b24c59ace9be85 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 20:18:53 +0300 Subject: [PATCH 22/45] Move kernels to cuda only --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ece756e206e3..839e741dfb1b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,8 +176,6 @@ endif() # set(VLLM_EXT_SRC - "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/attention_kernels.cu" "csrc/pos_encoding_kernels.cu" @@ -205,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" From 152f3317bdb21424deac8673ab2237ed04cdd2dc Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 20:19:55 +0300 Subject: [PATCH 23/45] Revert "Move kernels to cuda only" This reverts commit ef69b6c65a5d57f4e843f2c780b24c59ace9be85. --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 839e741dfb1b..ece756e206e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,6 +176,8 @@ endif() # set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/attention_kernels.cu" "csrc/pos_encoding_kernels.cu" @@ -203,8 +205,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC - "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" From 39f0fa010163ce94e0f3e98fbe32e01433421947 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 20:40:07 +0300 Subject: [PATCH 24/45] move kernels to if cuda --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ece756e206e3..839e741dfb1b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,8 +176,6 @@ endif() # set(VLLM_EXT_SRC - "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/attention_kernels.cu" "csrc/pos_encoding_kernels.cu" @@ -205,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" From 42f94b7014479331c9f2f70796b14fc40c269c63 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 22 Aug 2024 21:04:10 +0300 Subject: [PATCH 25/45] Fix tests --- tests/kernels/test_causal_conv1d.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index a2ddd4ef61bc..94639671c549 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -137,21 +137,22 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states_ref = initial_states.detach().clone( ) if initial_states is not None else None activation = None if not silu_activation else "silu" - out, _ = causal_conv1d_fn(x, - weight, - bias, - initial_states=initial_states, - return_final_states=return_final_states, - activation=activation) - out_ref, _ = causal_conv1d_ref(x_ref, - weight_ref, - bias_ref, - initial_states=initial_states_ref, - return_final_states=return_final_states, - activation=activation) + out, final_states = causal_conv1d_fn( + x, + weight, + bias, + initial_states=initial_states, + return_final_states=return_final_states, + activation=activation) + out_ref, final_states_ref = causal_conv1d_ref( + x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) if return_final_states: - out, final_states = out - out_ref, final_states_ref = out_ref + assert final_states is not None and final_states_ref is not None assert torch.allclose(final_states, final_states_ref, rtol=rtol, From f0507813d1db788b224f8662fc01ae198eb69a42 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 11:32:57 +0300 Subject: [PATCH 26/45] Revert formating --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 1531 +++++++++----------- csrc/mamba/causal_conv1d/causal_conv1d.h | 150 +- csrc/mamba/causal_conv1d/static_switch.h | 24 +- csrc/mamba/mamba_ssm/selective_scan.h | 417 +++--- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 1235 ++++++++-------- csrc/mamba/mamba_ssm/static_switch.h | 1 + 6 files changed, 1534 insertions(+), 1824 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 6c19741a0c50..98ce5f9563f0 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -1,3 +1,4 @@ +// clang-format off #include #include #include @@ -12,919 +13,743 @@ #include "static_switch.h" -#define CHECK_SHAPE(x, ...) \ - TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ - #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \ - "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \ - "'"); \ - } -template -void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream); -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase& params, - cudaStream_t stream); +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template -void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream); +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -void set_conv_params_fwd(ConvParamsBase& params, +void set_conv_params_fwd(ConvParamsBase ¶ms, // sizes - const size_t batch, const size_t dim, - const size_t seqlen, const size_t width, + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, // device pointers - const at::Tensor x, const at::Tensor weight, - const at::Tensor out, void* bias_ptr, + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + void* bias_ptr, bool silu_activation) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.width = width; - - params.silu_activation = silu_activation; - - // Set the pointers and strides. - params.x_ptr = x.data_ptr(); - params.weight_ptr = weight.data_ptr(); - params.bias_ptr = bias_ptr; - params.out_ptr = out.data_ptr(); - // All stride are in elements, not bytes. - params.x_batch_stride = x.stride(0); - params.x_c_stride = x.stride(1); - params.x_l_stride = x.stride(-1); - params.weight_c_stride = weight.stride(0); - params.weight_width_stride = weight.stride(1); - params.out_batch_stride = out.stride(0); - params.out_c_stride = out.stride(1); - params.out_l_stride = out.stride(-1); -} -at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& seq_idx_, - const c10::optional& seq_pos_idx_, - const c10::optional& initial_states_, - const c10::optional& final_states_out_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || - input_type == at::ScalarType::Half || - input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || - weight_type == at::ScalarType::Half || - weight_type == at::ScalarType::BFloat16); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); - const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; - - if (is_channel_last) { - TORCH_CHECK( - dim % 8 == 0, - "causal_conv1d only supports channel dimension divisible by 8 for now"); - TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, - "causal_conv1d with channel last layout requires strides " - "(x.stride(0) and x.stride(2)) to be multiples of 8"); - } - TORCH_CHECK(width >= 2 && width <= 4, - "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - if (seq_idx_.has_value()) { - TORCH_CHECK(is_channel_last, - "seq_idx is only supported for channel last layout"); - auto seq_idx = seq_idx_.value(); - TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); - TORCH_CHECK(seq_idx.is_cuda()); - TORCH_CHECK(seq_idx.is_contiguous()); - CHECK_SHAPE(seq_idx, batch_size, seqlen); - } - if (seq_pos_idx_.has_value()) { - auto seq_pos_idx = seq_pos_idx_.value(); - TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32); - TORCH_CHECK(seq_pos_idx.is_cuda()); - TORCH_CHECK(seq_pos_idx.is_contiguous()); - CHECK_SHAPE(seq_pos_idx, batch_size, seqlen); - } - at::Tensor out = torch::empty_like(x); - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, - silu_activation); - - if (seq_idx_.has_value()) { - params.seq_idx_ptr = seq_idx_.value().data_ptr(); - } else { - params.seq_idx_ptr = nullptr; - } - - if (seq_pos_idx_.has_value()) { - params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr(); - } else { - params.seq_pos_idx_ptr = nullptr; - } - if (initial_states_.has_value()) { - TORCH_CHECK(is_channel_last, - "initial_states is only supported for channel last layout"); - auto initial_states = initial_states_.value(); - TORCH_CHECK(initial_states.scalar_type() == input_type); - TORCH_CHECK(initial_states.is_cuda()); - CHECK_SHAPE(initial_states, batch_size, dim, width - 1); - TORCH_CHECK(initial_states.stride(1) == 1); - params.initial_states_ptr = initial_states.data_ptr(); - params.initial_states_batch_stride = initial_states.stride(0); - params.initial_states_c_stride = initial_states.stride(1); - params.initial_states_l_stride = initial_states.stride(2); - } else { - params.initial_states_ptr = nullptr; - } - - if (final_states_out_.has_value()) { - TORCH_CHECK(is_channel_last, - "final_states is only supported for channel last layout"); - auto final_states = final_states_out_.value(); - TORCH_CHECK(final_states.scalar_type() == input_type); - TORCH_CHECK(final_states.is_cuda()); - CHECK_SHAPE(final_states, batch_size, dim, width - 1); - TORCH_CHECK(final_states.stride(1) == 1); - params.final_states_ptr = final_states.data_ptr(); - params.final_states_batch_stride = final_states.stride(0); - params.final_states_c_stride = final_states.stride(1); - params.final_states_l_stride = final_states.stride(2); - } else { - params.final_states_ptr = nullptr; - } - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16( - x.scalar_type(), "causal_conv1d_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16( - weight.scalar_type(), "causal_conv1d_fwd", [&] { - if (!is_channel_last) { - causal_conv1d_fwd_cuda(params, stream); - } else { - causal_conv1d_channellast_fwd_cuda(params, - stream); - } - }); - }); - return out; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); } -at::Tensor causal_conv1d_update(const at::Tensor& x, - const at::Tensor& conv_state, - const at::Tensor& weight, - const c10::optional& bias_, - bool silu_activation) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || - input_type == at::ScalarType::Half || - input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || - weight_type == at::ScalarType::Half || - weight_type == at::ScalarType::BFloat16); - TORCH_CHECK(conv_state.scalar_type() == input_type); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(conv_state.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(width >= 2 && width <= 4, - "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - at::Tensor out = torch::empty_like(x); - - ConvParamsBase params; - set_conv_params_fwd( - params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, silu_activation); - params.conv_state_ptr = conv_state.data_ptr(); - // All stride are in elements, not bytes. - params.conv_state_batch_stride = conv_state.stride(0); - params.conv_state_c_stride = conv_state.stride(1); - params.conv_state_l_stride = conv_state.stride(2); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16( - x.scalar_type(), "causal_conv1d_update", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16( - weight.scalar_type(), "causal_conv1d_update", [&] { - causal_conv1d_update_cuda(params, stream); - }); - }); - return out; -} -template -struct Causal_conv1d_fwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - static constexpr bool kIsVecLoad = kIsVecLoad_; - static constexpr int kNLoadsIndex = kNElts / 4; - using vec_t = typename BytesToType::Type; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = - cub::BlockLoad; - using BlockLoadIndexT = - cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = - cub::BlockStore; - - static constexpr int kSmemIOSize = - (kIsVecLoad && kNLoadsIndex == 1) - ? 0 - : std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockLoadIndexT::TempStorage), - sizeof(typename BlockLoadIndexVecT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; -}; +at::Tensor +causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + const c10::optional &seq_idx_, + const c10::optional &seq_pos_idx_, + const c10::optional &initial_states_, + const c10::optional &final_states_out_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); + TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); + } + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + if (seq_idx_.has_value()) { + TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); + auto seq_idx = seq_idx_.value(); + TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_idx.is_cuda()); + TORCH_CHECK(seq_idx.is_contiguous()); + CHECK_SHAPE(seq_idx, batch_size, seqlen); + } + if (seq_pos_idx_.has_value()) { + auto seq_pos_idx = seq_pos_idx_.value(); + TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_pos_idx.is_cuda()); + TORCH_CHECK(seq_pos_idx.is_contiguous()); + CHECK_SHAPE(seq_pos_idx, batch_size, seqlen); + } + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); -template -__global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel( - ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = - reinterpret_cast(smem_); - auto& smem_load_vec = - reinterpret_cast(smem_); - auto& smem_load_index = - reinterpret_cast(smem_); - auto& smem_load_index_vec = - reinterpret_cast( - smem_); - auto& smem_store = - reinterpret_cast(smem_); - auto& smem_store_vec = - reinterpret_cast(smem_); - vec_t* smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y; - input_t* x = reinterpret_cast(params.x_ptr) + - batch_id * params.x_batch_stride + - channel_id * params.x_c_stride; - weight_t* weight = reinterpret_cast(params.weight_ptr) + - channel_id * params.weight_c_stride; - input_t* out = reinterpret_cast(params.out_ptr) + - batch_id * params.out_batch_stride + - channel_id * params.out_c_stride; - float bias_val = - params.bias_ptr == nullptr - ? 0.f - : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - int* seq_pos_idx = !kHasSeqPosIdx - ? nullptr - : reinterpret_cast(params.seq_pos_idx_ptr) + - batch_id * params.seqlen; - - // Thread 0 will load the last elements of the previous chunk, so we - // initialize those to 0. - if (tidx == 0) { - input_t zeros[kNElts] = {0}; - smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; - } - - float weight_vals[kWidth]; -#pragma unroll - for (int i = 0; i < kWidth; ++i) { - weight_vals[i] = float(weight[i * params.weight_width_stride]); - } - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; - for (int chunk = 0; chunk < n_chunks; ++chunk) { - input_t x_vals_load[2 * kNElts] = {0}; - int seq_pos_idx_load[kNElts]; - if constexpr (kIsVecLoad) { - Ktraits::BlockLoadVecT(smem_load_vec) - .Load(reinterpret_cast(x), - *reinterpret_cast(&x_vals_load[kNElts]), - (params.seqlen - chunk * kChunkSize) / kNElts); - if (kHasSeqPosIdx) - Ktraits::BlockLoadIndexVecT(smem_load_index_vec) - .Load(reinterpret_cast(seq_pos_idx), - *reinterpret_cast( - seq_pos_idx_load), - (params.seqlen - chunk * kChunkSize) / kNElts * - Ktraits::kNLoadsIndex); + if (seq_idx_.has_value()) { + params.seq_idx_ptr = seq_idx_.value().data_ptr(); } else { - __syncthreads(); - Ktraits::BlockLoadT(smem_load).Load( - x, *reinterpret_cast(&x_vals_load[kNElts]), - params.seqlen - chunk * kChunkSize); - if (kHasSeqPosIdx) - Ktraits::BlockLoadIndexT(smem_load_index) - .Load(seq_pos_idx, seq_pos_idx_load, - (params.seqlen - chunk * kChunkSize), 0); + params.seq_idx_ptr = nullptr; } - x += kChunkSize; - if (kHasSeqPosIdx) seq_pos_idx += kChunkSize; - __syncthreads(); - // Thread kNThreads - 1 don't write yet, so that thread 0 can read - // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { - smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; + + if (seq_pos_idx_.has_value()) { + params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr(); + } else { + params.seq_pos_idx_ptr = nullptr; } - __syncthreads(); - reinterpret_cast(x_vals_load)[0] = - smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; - __syncthreads(); - // Now thread kNThreads - 1 can write the last elements of the current - // chunk. - if (tidx == kNThreads - 1) { - smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; + if (initial_states_.has_value()) { + TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); + auto initial_states = initial_states_.value(); + TORCH_CHECK(initial_states.scalar_type() == input_type); + TORCH_CHECK(initial_states.is_cuda()); + CHECK_SHAPE(initial_states, batch_size, dim, width - 1); + TORCH_CHECK(initial_states.stride(1) == 1); + params.initial_states_ptr = initial_states.data_ptr(); + params.initial_states_batch_stride = initial_states.stride(0); + params.initial_states_c_stride = initial_states.stride(1); + params.initial_states_l_stride = initial_states.stride(2); + } else { + params.initial_states_ptr = nullptr; } - float x_vals[2 * kNElts]; -#pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { - x_vals[i] = float(x_vals_load[i]); + if (final_states_out_.has_value()) { + TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); + auto final_states = final_states_out_.value(); + TORCH_CHECK(final_states.scalar_type() == input_type); + TORCH_CHECK(final_states.is_cuda()); + CHECK_SHAPE(final_states, batch_size, dim, width - 1); + TORCH_CHECK(final_states.stride(1) == 1); + params.final_states_ptr = final_states.data_ptr(); + params.final_states_batch_stride = final_states.stride(0); + params.final_states_c_stride = final_states.stride(1); + params.final_states_l_stride = final_states.stride(2); + } else { + params.final_states_ptr = nullptr; } - float out_vals[kNElts]; + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, stream); + } + }); + }); + return out; +} -#pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = bias_val; - int w = 0; - if (kHasSeqPosIdx) { - if (seq_pos_idx_load[i] < kWidth) { - w = kWidth - seq_pos_idx_load[i] - 1; - } - } - for (; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - } - if (params.silu_activation) { -#pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); - } +at::Tensor +causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); } - input_t out_vals_store[kNElts]; -#pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals_store[i] = out_vals[i]; + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + }); + return out; +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + static constexpr int kNLoadsIndex = kNElts / 4; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + + static constexpr int kSmemIOSize = (kIsVecLoad && kNLoadsIndex == 1) + ? 0 + : std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_index_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int *seq_pos_idx = !kHasSeqPosIdx ? nullptr : reinterpret_cast(params.seq_pos_idx_ptr) + batch_id * params.seqlen; + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; } - if constexpr (kIsVecLoad) { - Ktraits::BlockStoreVecT(smem_store_vec) - .Store(reinterpret_cast(out), - reinterpret_cast(out_vals_store), - (params.seqlen - chunk * kChunkSize) / kNElts); - } else { - Ktraits::BlockStoreT(smem_store) - .Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + int seq_pos_idx_load[kNElts]; + if constexpr(kIsVecLoad) { + Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + if (kHasSeqPosIdx) + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(reinterpret_cast(seq_pos_idx), *reinterpret_cast(seq_pos_idx_load), (params.seqlen - chunk * kChunkSize) / kNElts * Ktraits::kNLoadsIndex); + } else { + __syncthreads(); + Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + if (kHasSeqPosIdx) + Ktraits::BlockLoadIndexT(smem_load_index).Load(seq_pos_idx, seq_pos_idx_load, (params.seqlen - chunk * kChunkSize), 0); + } + x += kChunkSize; + if (kHasSeqPosIdx) seq_pos_idx += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + int w = 0; + if (kHasSeqPosIdx){ + if(seq_pos_idx_load[i] < kWidth){ + w = kWidth - seq_pos_idx_load[i] - 1; + } + } + for (; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + } + out += kChunkSize; } - out += kChunkSize; - } } -template -void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] { - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - using Ktraits = - Causal_conv1d_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] { + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); - }); } -template -void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } } -template +template struct Causal_conv1d_channellast_fwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel - // dimension. That leaves 4 columns per warp, and so 16 columns per block - // (assuming each block has 128 threads). Each each load is 16 x 32|64 - // elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = - kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = - 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - // using BlockLoadT = cub::BlockLoad; using BlockStoreT = - // cub::BlockStore; static constexpr int kSmemSize = - // std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename - // BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; }; -template -__global__ -__launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel( - ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; - - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - input_t* x = reinterpret_cast(params.x_ptr) + - batch_id * params.x_batch_stride + - (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + - chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t* weight = reinterpret_cast(params.weight_ptr) + - chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t* out = reinterpret_cast(params.out_ptr) + - batch_id * params.out_batch_stride + - (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + - chunk_c_id * kChunkSizeC + c_idx * kNElts; - int* seq_idx = !kHasSeqIdx - ? nullptr - : reinterpret_cast(params.seq_idx_ptr) + - batch_id * params.seqlen + chunk_l_id * kChunkSizeL; - input_t* initial_states = - params.initial_states_ptr == nullptr || chunk_l_id > 0 - ? nullptr - : reinterpret_cast(params.initial_states_ptr) + - batch_id * params.initial_states_batch_stride + - l_idx * params.initial_states_l_stride + - chunk_c_id * kChunkSizeC + c_idx * kNElts; - // The last L-chunk will also have enough info to write to final states, since - // it also contain a few x values from the previous L-chunk. - input_t* final_states = - params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 - ? nullptr - : reinterpret_cast(params.final_states_ptr) + - batch_id * params.final_states_batch_stride + - l_idx * params.final_states_l_stride + - chunk_c_id * kChunkSizeC + c_idx * kNElts; - -#pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen && - chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = - *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) + + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; + input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr + : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + // The last L-chunk will also have enough info to write to final states, since it also contain a few x values + // from the previous L-chunk. + input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr + : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; } - reinterpret_cast( - x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = - reinterpret_cast(x_vals_load)[0]; - } - // Load the elements from the previous chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 && - chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen && - chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = - *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); - } else if (initial_states != nullptr && - chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 && - chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = - *reinterpret_cast(initial_states); + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } else if (initial_states != nullptr + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); + } + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; } - reinterpret_cast(x_smem[l_idx])[c_idx] = - reinterpret_cast(x_vals_load)[0]; - } - - __syncthreads(); - - if (final_states != nullptr && l_idx < kWidth - 1 && - chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - - // 1) So last few elements (index params.seqlen - kWidth + 1 + l_idx) are - // stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * - // kChunkSizeL - kWidth + 1)][c_idx] - *reinterpret_cast(final_states) = reinterpret_cast( - x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; - } - - constexpr int kLPerThread = - std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for - // simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = - params.bias_ptr == nullptr || - chunk_c_id * kChunkSizeC + row_idx >= params.dim - ? 0.f - : float(reinterpret_cast( - params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {0}; - if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { -#pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + - w * params.weight_width_stride]; + + __syncthreads(); + + if (final_states != nullptr + && l_idx < kWidth - 1 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) + // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] + *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + } + + constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } } - } - float x_vals[kWidth - 1 + kLPerThread]; -#pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - int seq_idx_thread[kWidth - 1 + kLPerThread]; - if constexpr (kHasSeqIdx) { -#pragma unroll + float x_vals[kWidth - 1 + kLPerThread]; + #pragma unroll for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - seq_idx_thread[i] = - chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= - 0 - ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] - : -1; + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); } - } - - float out_vals[kLPerThread]; -#pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - out_vals[i] = bias_val; - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; -#pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - out_vals[i] += weight_vals[w] * x_vals[i + w]; - } else { - out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur - ? weight_vals[w] * x_vals[i + w] - : 0.f; - } + int seq_idx_thread[kWidth - 1 + kLPerThread]; + if constexpr (kHasSeqIdx) { + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; + } } - if (params.silu_activation) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + + float out_vals[kLPerThread]; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + if constexpr (!kHasSeqIdx) { + out_vals[i] += weight_vals[w] * x_vals[i + w]; + } else { + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; + } + } + if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } } - } - - __syncthreads(); -#pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; - } - __syncthreads(); - -#pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t out_vals_store[kNElts]; - reinterpret_cast(out_vals_store)[0] = - reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen && - chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = - reinterpret_cast(out_vals_store)[0]; + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; + } } - } + } -template -void causal_conv1d_channellast_fwd_launch(ConvParamsBase& params, - cudaStream_t stream) { - BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { - using Ktraits = - Causal_conv1d_channellast_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_fwd_kernel; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<>>(params); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { + using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase& params, - cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, - stream); - } else if (params.width == 3) { - causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, - stream); - } else if (params.width == 4) { - causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, - stream); - } +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } } -template void causal_conv1d_fwd_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); - -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda( - ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); /////// -template + + + +template struct Causal_conv1d_update_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); }; -template -__global__ -__launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kernel( - ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y * kNThreads + tidx; - input_t* x = reinterpret_cast(params.x_ptr) + - batch_id * params.x_batch_stride + - channel_id * params.x_c_stride; - input_t* conv_state = reinterpret_cast(params.conv_state_ptr) + - batch_id * params.conv_state_batch_stride + - channel_id * params.conv_state_c_stride; - weight_t* weight = reinterpret_cast(params.weight_ptr) + - channel_id * params.weight_c_stride; - input_t* out = reinterpret_cast(params.out_ptr) + - batch_id * params.out_batch_stride + - channel_id * params.out_c_stride; - float bias_val = - params.bias_ptr == nullptr || channel_id >= params.dim - ? 0.f - : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - float weight_vals[kWidth] = {0}; - if (channel_id < params.dim) { -#pragma unroll - for (int i = 0; i < kWidth; ++i) { - weight_vals[i] = float(weight[i * params.weight_width_stride]); +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } } - } - float x_vals[kWidth] = {0}; - if (channel_id < params.dim) { -#pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { - x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); - } - x_vals[kWidth - 1] = float(x[0]); -#pragma unroll - for (int i = 0; i < kWidth; ++i) { - conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } + x_vals[kWidth - 1] = float(x[0]); + #pragma unroll + for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } } - } - - float out_val = bias_val; -#pragma unroll - for (int i = 0; i < kWidth; ++i) { - out_val += weight_vals[i] * x_vals[i]; - } - if (params.silu_activation) { - out_val = out_val / (1 + expf(-out_val)); - } - if (channel_id < params.dim) { - out[0] = input_t(out_val); - } + + float out_val = bias_val; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + if (channel_id < params.dim) { out[0] = input_t(out_val); } } -template -void causal_conv1d_update_launch(ConvParamsBase& params, cudaStream_t stream) { - using Ktraits = - Causal_conv1d_update_kernel_traits; - dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = &causal_conv1d_update_kernel; - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template -void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); - } +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } } -template void causal_conv1d_update_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_update_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase& params, - cudaStream_t stream); -template void causal_conv1d_update_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_update_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_update_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_update_cuda( - ConvParamsBase& params, cudaStream_t stream); -template void causal_conv1d_update_cuda( - ConvParamsBase& params, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 76a634ba70ec..7ff9ba8594a1 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -1,7 +1,8 @@ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ - +// clang-format off +// adapted from #pragma once #include @@ -9,103 +10,100 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// struct ConvParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, width; - bool silu_activation; - - index_t x_batch_stride; - index_t x_c_stride; - index_t x_l_stride; - index_t weight_c_stride; - index_t weight_width_stride; - index_t out_batch_stride; - index_t out_c_stride; - index_t out_l_stride; - - index_t conv_state_batch_stride; - index_t conv_state_c_stride; - index_t conv_state_l_stride; - - // Common data pointers. - void* __restrict__ x_ptr; - void* __restrict__ weight_ptr; - void* __restrict__ bias_ptr; - void* __restrict__ out_ptr; - - void* __restrict__ conv_state_ptr; - - void* __restrict__ seq_idx_ptr; - void* __restrict__ seq_pos_idx_ptr; - - // No __restrict__ since initial_states could be the same as final_states. - void* initial_states_ptr; - index_t initial_states_batch_stride; - index_t initial_states_l_stride; - index_t initial_states_c_stride; - - void* final_states_ptr; - index_t final_states_batch_stride; - index_t final_states_l_stride; - index_t final_states_c_stride; + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + + void *__restrict__ seq_idx_ptr; + void *__restrict__ seq_pos_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; }; -template -struct BytesToType {}; -template <> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template <> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template <> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template <> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template <> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct SumOp { - __device__ inline T operator()(T const& x, T const& y) { return x + y; } +__device__ inline T operator()(T const & x, T const & y) { return x + y; } }; -template +template struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator& op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } }; -template <> +template<> struct Allreduce<2> { - template - static __device__ inline T run(T x, Operator& op) { +template +static __device__ inline T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; - } +} }; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h index 11c876842395..ce002d11b55f 100644 --- a/csrc/mamba/causal_conv1d/static_switch.h +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -1,6 +1,6 @@ -// Inspired by -// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off #pragma once @@ -14,13 +14,13 @@ /// some_function(...); /// }); /// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - static constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - static constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index dbc4e6dac112..cf75b86b9630 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -1,296 +1,275 @@ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ +// clang-format off #pragma once #ifndef USE_ROCM - #include + #include #else - #include + #include #endif #include //////////////////////////////////////////////////////////////////////////////////////////////////// struct SSMParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, dstate, n_groups, n_chunks; - int dim_ngroups_ratio; - bool is_variable_B; - bool is_variable_C; - - bool delta_softplus; - - index_t A_d_stride; - index_t A_dstate_stride; - index_t B_batch_stride; - index_t B_d_stride; - index_t B_dstate_stride; - index_t B_group_stride; - index_t C_batch_stride; - index_t C_d_stride; - index_t C_dstate_stride; - index_t C_group_stride; - index_t u_batch_stride; - index_t u_d_stride; - index_t delta_batch_stride; - index_t delta_d_stride; - index_t z_batch_stride; - index_t z_d_stride; - index_t out_batch_stride; - index_t out_d_stride; - index_t out_z_batch_stride; - index_t out_z_d_stride; - - // Common data pointers. - void* __restrict__ A_ptr; - void* __restrict__ B_ptr; - void* __restrict__ C_ptr; - void* __restrict__ D_ptr; - void* __restrict__ u_ptr; - void* __restrict__ delta_ptr; - void* __restrict__ delta_bias_ptr; - void* __restrict__ out_ptr; - void* __restrict__ x_ptr; - void* __restrict__ z_ptr; - void* __restrict__ out_z_ptr; - void* __restrict__ index_ptr; + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + void *__restrict__ index_ptr; }; + + + #ifndef USE_ROCM -constexpr size_t custom_max(std::initializer_list ilist) { - return std::max(ilist); -} + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } -template -constexpr T constexpr_min(T a, T b) { - return std::min(a, b); -} + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } #else -constexpr size_t custom_max(std::initializer_list ilist) { - return *std::max_element(ilist.begin(), ilist.end()); -} + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } -template -constexpr T constexpr_min(T a, T b) { - return a < b ? a : b; -} + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } #endif + #define MAX_DSTATE 256 -inline __device__ float2 operator+(const float2& a, const float2& b) { - return {a.x + b.x, a.y + b.y}; + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; } -inline __device__ float3 operator+(const float3& a, const float3& b) { +inline __device__ float3 operator+(const float3 &a, const float3 &b) { return {a.x + b.x, a.y + b.y, a.z + b.z}; } -inline __device__ float4 operator+(const float4& a, const float4& b) { - return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct BytesToType {}; +template struct BytesToType {}; -template <> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template <> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template <> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template <> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template <> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Converter { - static inline __device__ void to_float(const scalar_t (&src)[N], - float (&dst)[N]) { -#pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = src[i]; +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } - } }; -template -struct Converter { - static inline __device__ void to_float(const at::Half (&src)[N], - float (&dst)[N]) { - static_assert(N % 2 == 0); - auto& src2 = reinterpret_cast(src); - auto& dst2 = reinterpret_cast(dst); -#pragma unroll - for (int i = 0; i < N / 2; ++i) { - dst2[i] = __half22float2(src2[i]); +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } } - } }; #if __CUDA_ARCH__ >= 800 -template -struct Converter { - static inline __device__ void to_float(const at::BFloat16 (&src)[N], - float (&dst)[N]) { - static_assert(N % 2 == 0); - auto& src2 = reinterpret_cast(src); - auto& dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { - dst2[i] = __bfloat1622float2(src2[i]); +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } } - } }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct SSMScanOp; -template <> +template struct SSMScanOp; + +template<> struct SSMScanOp { - __device__ __forceinline__ float2 operator()(const float2& ab0, - const float2& ab1) const { - return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); - } + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } }; // A stateful callback functor that maintains a running prefix to be applied // during consecutive scan operations. -template -struct SSMScanPrefixCallbackOp { - using scan_t = - std::conditional_t, float2, float4>; - scan_t running_prefix; - // Constructor - __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) - : running_prefix(running_prefix_) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide - // scan. - __device__ scan_t operator()(scan_t block_aggregate) { - scan_t old_prefix = running_prefix; - running_prefix = SSMScanOp()(running_prefix, block_aggregate); - return old_prefix; - } +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void load_input( - typename Ktraits::input_t* u, - typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_vec = - reinterpret_cast( - smem_load); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadVecT(smem_load_vec) - .Load(reinterpret_cast(u), - reinterpret_cast(u_vals) -#ifdef USE_ROCM - , - Ktraits::kNThreads * Ktraits::kNLoads -#endif - - ); - } else { - typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); - } +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } } -template -inline __device__ void load_index( - int* u, int (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadIndexT::TempStorage& smem_load_index, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_index_vec = - reinterpret_cast( - smem_load_index); - Ktraits::BlockLoadIndexVecT(smem_load_index_vec) - .Load(reinterpret_cast(u), - reinterpret_cast(u_vals)); - } else { - Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); - } +template +inline __device__ void load_index(int *u, + int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = reinterpret_cast(smem_load_index); + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } } -template -inline __device__ void load_weight( - typename Ktraits::input_t* Bvar, - typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight, - int seqlen) { - constexpr int kNItems = Ktraits::kNItems; - typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = - reinterpret_cast( - smem_load_weight); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec) - .Load(reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load)); - } else { - typename Ktraits::BlockLoadWeightT(smem_load_weight) - .Load(Bvar, B_vals_load, seqlen, 0.f); - } - // #pragma unroll - // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } - Converter::to_float(B_vals_load, B_vals); +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); } -template -inline __device__ void store_output( - typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems], - typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen) { - typename Ktraits::input_t write_vals[Ktraits::kNItems]; -#pragma unroll - for (int i = 0; i < Ktraits::kNItems; ++i) { - write_vals[i] = out_vals[i]; - } - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_store_vec = - reinterpret_cast( - smem_store); - using vec_t = typename Ktraits::vec_t; - typename Ktraits::BlockStoreVecT(smem_store_vec) - .Store(reinterpret_cast(out), - reinterpret_cast(write_vals)); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); - } +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } } diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 1080db7b3019..40c8d4d91f51 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -1,3 +1,4 @@ +// clang-format off #include #include #include @@ -8,709 +9,615 @@ #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #ifndef USE_ROCM - #include - #include - #include + #include + #include + #include #else - #include -namespace cub = hipcub; + #include + namespace cub = hipcub; #endif #include "selective_scan.h" #include "static_switch.h" -template +template struct Selective_Scan_fwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves - // occupancy. - static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; - static constexpr int kNItems = kNItems_; - static constexpr int kNRows = kNRows_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kHasZ = kHasZ_; - static constexpr bool kUseIndex = kUseIndex_; - - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - static constexpr int kNLoadsIndex = kNItems / 4; - using vec_t = typename BytesToType::Type; - using scan_t = float2; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = - cub::BlockLoad; - using BlockLoadIndexT = - cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = - cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = - cub::BlockStore; - // using BlockScanT = cub::BlockScan; using BlockScanT = cub::BlockScan; - using BlockScanT = - cub::BlockScan; - static constexpr int kSmemIOSize = - custom_max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - sizeof(typename BlockLoadIndexT::TempStorage), - sizeof(typename BlockLoadIndexVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * - sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * - sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemSize = - kSmemIOSize + sizeof(typename BlockScanT::TempStorage); + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); }; -template -__global__ __launch_bounds__( - Ktraits::kNThreads, - Ktraits::kMinBlocks) void selective_scan_fwd_kernel(SSMParamsBase params) { - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr bool kUseIndex = Ktraits::kUseIndex; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - constexpr int kNRows = Ktraits::kNRows; - constexpr bool kDirectIO = Ktraits::kDirectIO; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ - // + 2 * MAX_DSTATE * sizeof(weight_t)); auto& smem_load = - // reinterpret_cast(smem_loadstorescan); - auto& smem_load = - reinterpret_cast(smem_); - auto& smem_load_weight = - reinterpret_cast(smem_); - auto& smem_load_index = - reinterpret_cast(smem_); - auto& smem_load_weight1 = - *reinterpret_cast( - smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = - reinterpret_cast(smem_); - auto& smem_scan = - *reinterpret_cast( - smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast(smem_ + - // smem_loadstorescan_size); weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); - scan_t* smem_running_prefix = - reinterpret_cast(smem_ + Ktraits::kSmemSize); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t* u = reinterpret_cast(params.u_ptr) + - batch_id * params.u_batch_stride + - dim_id * kNRows * params.u_d_stride; - input_t* delta = reinterpret_cast(params.delta_ptr) + - batch_id * params.delta_batch_stride + - dim_id * kNRows * params.delta_d_stride; - weight_t* A = reinterpret_cast(params.A_ptr) + - dim_id * kNRows * params.A_d_stride; - weight_t* B = reinterpret_cast(params.B_ptr) + - dim_id * kNRows * params.B_d_stride; - input_t* Bvar = reinterpret_cast(params.B_ptr) + - batch_id * params.B_batch_stride + - group_id * params.B_group_stride; - weight_t* C = reinterpret_cast(params.C_ptr) + - dim_id * kNRows * params.C_d_stride; - input_t* Cvar = reinterpret_cast(params.C_ptr) + - batch_id * params.C_batch_stride + - group_id * params.C_group_stride; - scan_t* x = reinterpret_cast(params.x_ptr) + - (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * - params.dstate; - int* index = !kUseIndex ? nullptr - : reinterpret_cast(params.index_ptr) + - batch_id * params.seqlen; - - float D_val[kNRows] = {0}; - if (params.D_ptr != nullptr) { -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; - } - } - float delta_bias[kNRows] = {0}; - if (params.delta_bias_ptr != nullptr) { -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - delta_bias[r] = - reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; - } - } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += - // blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * - // C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { - input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - int index_vals_load[kNRows][kNItems]; - - __syncthreads(); -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { - __syncthreads(); +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; } - } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, - params.seqlen - chunk * kChunkSize); - if constexpr (!kDirectIO) { - __syncthreads(); - } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], - smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (kUseIndex) { - load_index(index + r * params.delta_d_stride, - index_vals_load[r], smem_load_index, - params.seqlen - chunk * kChunkSize); - } } - if constexpr (kUseIndex) { - index += kChunkSize; - } - u += kChunkSize; - delta += kChunkSize; - - float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], - out_vals[kNRows][kNItems]; -#pragma unroll - for (int r = 0; r < kNRows; ++r) { -#pragma unroll - for (int i = 0; i < kNItems; ++i) { - float u_val = float(u_vals[r][i]); - delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; - if (params.delta_softplus) { - delta_vals[r][i] = delta_vals[r][i] <= 20.f - ? log1pf(expf(delta_vals[r][i])) - : delta_vals[r][i]; + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; } - delta_u_vals[r][i] = delta_vals[r][i] * u_val; - out_vals[r][i] = D_val[r] * u_val; - } } - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - weight_t A_val[kNRows]; -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - A_val[r] = - A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of - // expf. - constexpr float kLog2e = M_LOG2E; - A_val[r] *= kLog2e; - } - // This variable holds B * C if both B and C are constant across seqlen. - // If only B varies across seqlen, this holds C. If only C varies across - // seqlen, this holds B. If both B and C vary, this is unused. - weight_t BC_val[kNRows]; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, - (params.seqlen - chunk * kChunkSize) * (1)); - if constexpr (!kIsVariableC) { -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = - C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); + } } - } - if constexpr (kIsVariableC) { - auto& smem_load_weight_C = - !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, - (params.seqlen - chunk * kChunkSize) * (1)); - if constexpr (!kIsVariableB) { -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = - B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; - } + if constexpr (kUseIndex) { + index += kChunkSize; } - } - if constexpr (!kIsVariableB && !kIsVariableC) { -#pragma unroll + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll for (int r = 0; r < kNRows; ++r) { - BC_val[r] = - B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * - C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } } - } - -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - if (r > 0) { - __syncthreads(); - } // Scan could be using the same smem - scan_t thread_data[kNItems]; -#pragma unroll - for (int i = 0; i < kNItems; ++i) { - thread_data[i] = - make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] - : B_vals[i] * delta_u_vals[r][i]); - - // Reset A bar for cumulative sequences (Real) - if constexpr (kUseIndex) { - if (index_vals_load[r][i] == 0) { - thread_data[i].x = 0.f; + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } } - } - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is - // correct - if (threadIdx.x * kNItems + i >= - params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + } + } + + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } } - } } - // Initialize running total - scan_t running_prefix; - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) - // needs to read - running_prefix = - chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] - : (threadIdx.x % 32 == 0 - ? smem_running_prefix[state_idx + r * MAX_DSTATE] - : make_float2(1.f, 0.f)); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? - // smem_running_prefix[state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - typename Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading - // and writing. - if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = - prefix_op.running_prefix; - } -#pragma unroll - for (int i = 0; i < kNItems; ++i) { - const weight_t C_val = - !kIsVariableC - ? BC_val[r] - : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); - out_vals[r][i] += thread_data[i].y * C_val; + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); } - } - } - input_t* out = reinterpret_cast(params.out_ptr) + - batch_id * params.out_batch_stride + - dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; - __syncthreads(); -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { - __syncthreads(); + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } } - } - store_output(out + r * params.out_d_stride, out_vals[r], - smem_store, params.seqlen - chunk * kChunkSize); - } - if constexpr (kHasZ) { - input_t* z = reinterpret_cast(params.z_ptr) + - batch_id * params.z_batch_stride + - dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t* out_z = reinterpret_cast(params.out_z_ptr) + - batch_id * params.out_z_batch_stride + - dim_id * kNRows * params.out_z_d_stride + - chunk * kChunkSize; -#pragma unroll - for (int r = 0; r < kNRows; ++r) { - input_t z_vals[kNItems]; - __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, - params.seqlen - chunk * kChunkSize); -#pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - out_vals[r][i] *= z_val / (1 + expf(-z_val)); - } - __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], - smem_store, params.seqlen - chunk * kChunkSize); - } + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; } - - Bvar += kChunkSize * 1; - Cvar += kChunkSize * 1; - } } -template -void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from - // previously when we had each block processing 1 row. - constexpr int kNRows = 1; - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr, kHasZ, [&] { - BOOL_SWITCH(params.index_ptr != nullptr, kUseIndex, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits< - kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, - kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = - Ktraits::kSmemSize + - kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); }); - }); }); - }); } -template -void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream) { -#ifndef USE_ROCM - if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } -#else - if (params.seqlen <= 256) { - selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } -#endif +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif } -template void selective_scan_fwd_cuda( - SSMParamsBase& params, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase& params, - cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase& params, - cudaStream_t stream); - -#define CHECK_SHAPE(x, ...) \ - TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ - #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \ - "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \ - "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \ - "'"); \ - } - -template -void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream); - -void set_ssm_params_fwd(SSMParamsBase& params, +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, // sizes - const size_t batch, const size_t dim, - const size_t seqlen, const size_t dstate, - const size_t n_groups, const size_t n_chunks, - const bool is_variable_B, const bool is_variable_C, + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, // device pointers - const torch::Tensor u, const torch::Tensor delta, - const torch::Tensor A, const torch::Tensor B, - const torch::Tensor C, const torch::Tensor out, - const torch::Tensor z, const torch::Tensor out_z, - void* D_ptr, void* delta_bias_ptr, void* x_ptr, - bool has_z, bool delta_softplus, void* index_ptr) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.dstate = dstate; - params.n_groups = n_groups; - params.n_chunks = n_chunks; - params.dim_ngroups_ratio = dim / n_groups; - - params.delta_softplus = delta_softplus; - - params.is_variable_B = is_variable_B; - params.is_variable_C = is_variable_C; - - // Set the pointers and strides. - params.u_ptr = u.data_ptr(); - params.delta_ptr = delta.data_ptr(); - params.A_ptr = A.data_ptr(); - params.B_ptr = B.data_ptr(); - params.C_ptr = C.data_ptr(); - params.D_ptr = D_ptr; - params.delta_bias_ptr = delta_bias_ptr; - params.out_ptr = out.data_ptr(); - params.x_ptr = x_ptr; - params.z_ptr = has_z ? z.data_ptr() : nullptr; - params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; - - params.index_ptr = index_ptr; - - // All stride are in elements, not bytes. - params.A_d_stride = A.stride(0); - params.A_dstate_stride = A.stride(1); - if (!is_variable_B) { - params.B_d_stride = B.stride(0); - } else { - params.B_batch_stride = B.stride(0); - params.B_group_stride = B.stride(1); - } - params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); - if (!is_variable_C) { - params.C_d_stride = C.stride(0); - } else { - params.C_batch_stride = C.stride(0); - params.C_group_stride = C.stride(1); - } - params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); - params.u_batch_stride = u.stride(0); - params.u_d_stride = u.stride(1); - params.delta_batch_stride = delta.stride(0); - params.delta_d_stride = delta.stride(1); - if (has_z) { - params.z_batch_stride = z.stride(0); - params.z_d_stride = z.stride(1); - params.out_z_batch_stride = out_z.stride(0); - params.out_z_d_stride = out_z.stride(1); - } - params.out_batch_stride = out.stride(0); - params.out_d_stride = out.stride(1); + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus, + void* index_ptr) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); } -std::vector selective_scan_fwd( - const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, - const torch::Tensor& B, const torch::Tensor& C, - const c10::optional& D_, - const c10::optional& z_, - const c10::optional& delta_bias_, bool delta_softplus, - const c10::optional& index_, - const c10::optional& x) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || - input_type == at::ScalarType::Half || - input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || - weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dstate <= 256, - "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, - !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, - !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - if (index_.has_value()) { - auto index = index_.value(); - TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(index.is_cuda()); - CHECK_SHAPE(index, batch_size, seqlen); - } - - at::Tensor z, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); - // Right now u has BHL layout and delta has HBL layout, and we want out to - // have HBL layout - at::Tensor out = torch::empty_like(delta); - if (x.has_value()) { - auto _x = x.value(); - TORCH_CHECK(_x.scalar_type() == weight_type); - TORCH_CHECK(_x.is_cuda()); - TORCH_CHECK(_x.stride(-1) == 1); - CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); - } - - SSMParamsBase params; - set_ssm_params_fwd( - params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, - is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.value().data_ptr(), has_z, delta_softplus, - index_.has_value() ? index_.value().data_ptr() : nullptr); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16( - u.scalar_type(), "selective_scan_fwd", [&] { +std::vector +selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &index_, + const c10::optional &x) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.value().data_ptr(), + has_z, + delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); + selective_scan_fwd_cuda(params, stream); }); - }); - std::vector result = {out, x.value()}; - if (has_z) { - result.push_back(out_z); - } - return result; + }); + std::vector result = {out, x.value()}; + if (has_z) { result.push_back(out_z); } + return result; } + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h index d95531cf59ca..d2ecfce47220 100644 --- a/csrc/mamba/mamba_ssm/static_switch.h +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -2,6 +2,7 @@ // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off #pragma once /// @param COND - a boolean expression to switch by From c8ffba5bde288ca2c10d538a977fee89e519cba4 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 11:33:36 +0300 Subject: [PATCH 27/45] Format --- csrc/mamba/causal_conv1d/static_switch.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h index ce002d11b55f..6ba8221db065 100644 --- a/csrc/mamba/causal_conv1d/static_switch.h +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -1,4 +1,5 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h // clang-format off From 04f947bf9c29d401a71806a061c401538068b3e2 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 11:38:09 +0300 Subject: [PATCH 28/45] Add comments on adapted from mamba/casual conv1d repos --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 2 ++ csrc/mamba/causal_conv1d/causal_conv1d.h | 2 +- csrc/mamba/causal_conv1d/static_switch.h | 1 + csrc/mamba/mamba_ssm/selective_scan.h | 1 + csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 1 + csrc/mamba/mamba_ssm/static_switch.h | 1 + 6 files changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 98ce5f9563f0..0819543f0a50 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -1,4 +1,6 @@ // clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu #include #include #include diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 7ff9ba8594a1..e909bcd5391e 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -2,7 +2,7 @@ * Copyright (c) 2024, Tri Dao. ******************************************************************************/ // clang-format off -// adapted from +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h #pragma once #include diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h index 6ba8221db065..ef74bf447f84 100644 --- a/csrc/mamba/causal_conv1d/static_switch.h +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -2,6 +2,7 @@ // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h // clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h #pragma once diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index cf75b86b9630..0070c92f6cd0 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -2,6 +2,7 @@ * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h #pragma once diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 40c8d4d91f51..79c4ce338370 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -1,4 +1,5 @@ // clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh #include #include #include diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h index d2ecfce47220..840cb2374a2f 100644 --- a/csrc/mamba/mamba_ssm/static_switch.h +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -3,6 +3,7 @@ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h // clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h #pragma once /// @param COND - a boolean expression to switch by From 732db18f8389fce7a2e995b9e21782a880aa1e07 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 11:52:34 +0300 Subject: [PATCH 29/45] pare down number of w/i dtype combinations --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 44 ++++------------------- 1 file changed, 6 insertions(+), 38 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 0819543f0a50..a98f0dbabbc4 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -19,33 +19,23 @@ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ + using weight_t = at::Half; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ + using weight_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ } -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); @@ -204,14 +194,12 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)x.get_device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { if (!is_channel_last) { causal_conv1d_fwd_cuda(params, stream); } else { causal_conv1d_channellast_fwd_cuda(params, stream); } - }); }); return out; } @@ -268,10 +256,8 @@ causal_conv1d_update(const at::Tensor &x, // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)x.get_device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); - }); }); return out; } @@ -651,23 +637,11 @@ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t str } template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); /////// @@ -747,11 +721,5 @@ void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { } template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); From fdca1ff1a8eb1550eeab8b18e7632553e5d6f6cd Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 11:53:45 +0300 Subject: [PATCH 30/45] Clean up not used --- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 79c4ce338370..d2ae92feae0f 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -383,20 +383,6 @@ template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaS AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ } -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - #define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ if (WTYPE == at::ScalarType::Float) { \ using weight_t = float; \ From fe70a39eb6a06038b4e89d5559031c4c953b5d4e Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 12:31:34 +0300 Subject: [PATCH 31/45] Rename typo --- .../layers/mamba/ops/{casual_conv1d.py => causal_conv1d.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/model_executor/layers/mamba/ops/{casual_conv1d.py => causal_conv1d.py} (100%) diff --git a/vllm/model_executor/layers/mamba/ops/casual_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py similarity index 100% rename from vllm/model_executor/layers/mamba/ops/casual_conv1d.py rename to vllm/model_executor/layers/mamba/ops/causal_conv1d.py From 9a0e538a846567aa11b2ac196a630a5682b96480 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 12:34:23 +0300 Subject: [PATCH 32/45] Add comment on einops --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index cdbc3e50cc9e..46eb05fc3109 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -11,7 +11,7 @@ pytest-shard # testing utils awscli -einops # required for MPT and qwen-vl +einops # required for MPT, qwen-vl and Mamba httpx peft requests From 619a40a813bfcbbd79b2d2f1b2aa9a49db718482 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 12:34:46 +0300 Subject: [PATCH 33/45] Remove requirement for einops --- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index ce9ed5e54935..af4228ce3855 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -3,7 +3,6 @@ import torch import triton import triton.language as tl -from einops import rearrange from packaging import version from vllm import _custom_ops as ops @@ -321,9 +320,9 @@ def selective_scan_fn(u, if z is not None and z.stride(-1) != 1: z = z.contiguous() if B.dim() == 3: - B = rearrange(B, "b dstate l -> b 1 dstate l") + B = B.unsqueeze(1) if C.dim() == 3: - C = rearrange(C, "b dstate l -> b 1 dstate l") + C = B.unsqueeze(1) n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) x = torch.zeros(( u.shape[0], From 5d0d2db78fd06120991234d221294d8df10032b7 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 13:21:49 +0300 Subject: [PATCH 34/45] Fix tests after paring down kernels --- tests/kernels/test_causal_conv1d.py | 8 ++++---- tests/kernels/test_mamba_ssm.py | 6 +++--- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 94639671c549..4944577a21c8 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -118,9 +118,9 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") - weight = torch.randn(dim, width, device=device, dtype=torch.float32) + weight = torch.randn(dim, width, device=device, dtype=itype) if has_bias: - bias = torch.randn(dim, device=device, dtype=torch.float32) + bias = torch.randn(dim, device=device, dtype=itype) else: bias = None if has_initial_states: @@ -185,12 +185,12 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, weight = torch.randn(dim, width, device=device, - dtype=torch.float32, + dtype=itype, requires_grad=True) if has_bias: bias = torch.randn(dim, device=device, - dtype=torch.float32, + dtype=itype, requires_grad=True) else: bias = None diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 796de355ffc0..d3cb0a8656a0 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -197,7 +197,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, B_shape = [batch_size, dstate, seqlen] else: B_shape = [batch_size, varBC_groups, dstate, seqlen] - B = torch.randn(*B_shape, + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) if not is_variable_C: @@ -206,7 +206,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, C_shape = [batch_size, dstate, seqlen] else: C_shape = [batch_size, varBC_groups, dstate, seqlen] - C = torch.randn(*C_shape, + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None @@ -288,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): atol *= 2 # set seed torch.random.manual_seed(0) - batch_size = 2 + batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) dt = torch.randn(batch_size, dim, device=device, dtype=itype) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index af4228ce3855..869c69214caf 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -322,7 +322,7 @@ def selective_scan_fn(u, if B.dim() == 3: B = B.unsqueeze(1) if C.dim() == 3: - C = B.unsqueeze(1) + C = C.unsqueeze(1) n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) x = torch.zeros(( u.shape[0], From c6223753c9a60909655c07cbae9dc3dc79a26952 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 13:43:38 +0300 Subject: [PATCH 35/45] format --- tests/kernels/test_causal_conv1d.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 4944577a21c8..eb9ed2fbfa39 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -119,10 +119,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") weight = torch.randn(dim, width, device=device, dtype=itype) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype) - else: - bias = None + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None if has_initial_states: initial_states = torch.randn(batch, width - 1, @@ -188,10 +185,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, dtype=itype, requires_grad=True) if has_bias: - bias = torch.randn(dim, - device=device, - dtype=itype, - requires_grad=True) + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) else: bias = None conv_state_ref = conv_state.detach().clone() From cdc92058f5ababb8e9a7b15f265c8a7b36418a4a Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 14:49:34 +0300 Subject: [PATCH 36/45] Fix typo --- tests/kernels/test_causal_conv1d.py | 2 +- vllm/model_executor/models/jamba.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index eb9ed2fbfa39..7bf338b36953 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from einops import rearrange -from vllm.model_executor.layers.mamba.ops.casual_conv1d import ( +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 70c0a49bd499..911e8eaf0a60 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -21,7 +21,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.casual_conv1d import ( +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) From 308c922660bbf39d33a5743c301d3651462a1e0e Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 15:17:13 +0300 Subject: [PATCH 37/45] register meta functions to the kernels --- vllm/_custom_ops.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f189c38314a8..1a916267f565 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -489,6 +489,21 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, silu_activation) +try: + torch.ops._C.causal_conv1d_fwd # noqa B018 + + @torch.library.register_fake("_C::causal_conv1d_fwd") + def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.empty_like((x)) +except Exception: + pass + + def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: @@ -496,6 +511,19 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, silu_activation) +try: + torch.ops._C.causal_conv1d_update # noqa B018 + + @torch.library.register_fake("_C::causal_conv1d_update") + def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.empty_like((x)) +except Exception: + pass + + def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], @@ -507,6 +535,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, x) +try: + torch.ops._C.selective_scan_fwd # noqa B018 + + @torch.library.register_fake("_C::selective_scan_fwd") + def selective_scan_fwd_fake( + u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + return [ + torch.empty_like(u), + torch.empty((u.size(0), u.size(1), A.size(1)), + dtype=u.dtype, + device=u.device) + ] +except Exception: + pass + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, From d921a486e5e2f058916161a5b42382c6bfe95a68 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 25 Aug 2024 15:24:48 +0300 Subject: [PATCH 38/45] Revert "register meta functions to the kernels" This reverts commit 308c922660bbf39d33a5743c301d3651462a1e0e. --- vllm/_custom_ops.py | 48 --------------------------------------------- 1 file changed, 48 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1a916267f565..f189c38314a8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -489,21 +489,6 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, silu_activation) -try: - torch.ops._C.causal_conv1d_fwd # noqa B018 - - @torch.library.register_fake("_C::causal_conv1d_fwd") - def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], - seq_idx_: Optional[torch.Tensor], - initial_states_: Optional[torch.Tensor], - final_states_out_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.empty_like((x)) -except Exception: - pass - - def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: @@ -511,19 +496,6 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, silu_activation) -try: - torch.ops._C.causal_conv1d_update # noqa B018 - - @torch.library.register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, - bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.empty_like((x)) -except Exception: - pass - - def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], @@ -535,26 +507,6 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, x) -try: - torch.ops._C.selective_scan_fwd # noqa B018 - - @torch.library.register_fake("_C::selective_scan_fwd") - def selective_scan_fwd_fake( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, - B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, index_: Optional[torch.Tensor], - x: Optional[torch.Tensor]) -> List[torch.Tensor]: - return [ - torch.empty_like(u), - torch.empty((u.size(0), u.size(1), A.size(1)), - dtype=u.dtype, - device=u.device) - ] -except Exception: - pass - - # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, From a8078e7583d7d088a8de31e425ed74977ab0295c Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 26 Aug 2024 10:14:26 +0300 Subject: [PATCH 39/45] move to ifndef ROCm --- csrc/torch_bindings.cpp | 51 +++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 997864e70e56..f3925c7c2372 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -202,6 +202,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, &cutlass_scaled_mm_supports_fp8); + // Mamba selective scan kernel + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor? z_, Tensor? delta_bias_," + "bool delta_softplus," + "Tensor? index_, Tensor? x) -> Tensor[]"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor? seq_idx_," + "Tensor? seq_pos_idx_," + "Tensor? initial_states_," + "Tensor? final_states_out_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif // Quantized GEMM for GPTQ. @@ -259,32 +285,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); - // Mamba selective scan kernel - ops.def( - "selective_scan_fwd(Tensor! u, Tensor! delta," - "Tensor! A, Tensor! B, Tensor! C," - "Tensor? D_, Tensor? z_, Tensor? delta_bias_," - "bool delta_softplus," - "Tensor? index_, Tensor? x) -> Tensor[]"); - ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); - ops.def( - "causal_conv1d_update(Tensor! x," - "Tensor! conv_state," - "Tensor! weight," - "Tensor? bias_," - "bool silu_activation) -> Tensor"); - ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); - - ops.def( - "causal_conv1d_fwd(Tensor! x, Tensor! weight," - "Tensor? bias_," - "Tensor? seq_idx_," - "Tensor? seq_pos_idx_," - "Tensor? initial_states_," - "Tensor? final_states_out_," - "bool silu_activation) -> Tensor"); - ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { From 2ca8db7a3010a132f68548f7dd629add6bb984e3 Mon Sep 17 00:00:00 2001 From: mzusman Date: Mon, 26 Aug 2024 10:29:53 +0300 Subject: [PATCH 40/45] Format --- csrc/torch_bindings.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f3925c7c2372..fb0accbbabac 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -284,8 +284,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); - - } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { From abf02fa0f8a65f53015efca21cece6eb02170500 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 27 Aug 2024 13:18:33 +0300 Subject: [PATCH 41/45] Reduce combinations of bool switch to reduce wheel size --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 25 ++++++----- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 49 +++++++++------------- 2 files changed, 31 insertions(+), 43 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index a98f0dbabbc4..d3d6be4fe881 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -404,19 +404,18 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { template void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] { - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { - using Ktraits = Causal_conv1d_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + constexpr kHasSeqPosIdx = false; + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d2ae92feae0f..03bc3e2110e3 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -311,26 +311,21 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // processing 1 row. constexpr int kNRows = 1; + constexpr bool kIsVariableB = true; + constexpr bool kIsVariableC = true; + constexpr bool kHasZ = true; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); } @@ -369,27 +364,23 @@ template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaS #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ + using weight_t = at::Half; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ + using weight_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ } -#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); @@ -598,10 +589,8 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)u.get_device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { - DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] { + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); - }); }); std::vector result = {out, x.value()}; if (has_z) { result.push_back(out_z); } From 633225ca21f537fe78eb96777221d8c5103a73a0 Mon Sep 17 00:00:00 2001 From: mzusman Date: Tue, 27 Aug 2024 16:59:15 +0300 Subject: [PATCH 42/45] Fix, use float as weight dtype --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 10 +++++----- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index d3d6be4fe881..1a365e9a30a8 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -19,18 +19,18 @@ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ - using weight_t = at::Half; \ + using weight_t = at::Half; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ - using weight_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ - using weight_t = float; \ + using weight_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ @@ -404,7 +404,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { template void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - constexpr kHasSeqPosIdx = false; + constexpr bool kHasSeqPosIdx = false; BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { using Ktraits = Causal_conv1d_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 03bc3e2110e3..347b6666165a 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -364,18 +364,18 @@ template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaS #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ - using weight_t = at::Half; \ + using weight_t = float; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ - using weight_t = at::BFloat16; \ + using weight_t = float; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ - using weight_t = float; \ + using weight_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ From 1f35bbe2ac38d0b8b341c93617d384769c6927d3 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 28 Aug 2024 14:43:45 +0300 Subject: [PATCH 43/45] Take down seq_pos_idx, not used atm, will comeback in a following PR --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 78 ++++++++--------------- csrc/mamba/causal_conv1d/causal_conv1d.h | 37 ++++++++++- 2 files changed, 63 insertions(+), 52 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 1a365e9a30a8..88a64a8ece58 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -89,7 +89,6 @@ at::Tensor causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, const c10::optional &seq_idx_, - const c10::optional &seq_pos_idx_, const c10::optional &initial_states_, const c10::optional &final_states_out_, bool silu_activation) { @@ -135,13 +134,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, TORCH_CHECK(seq_idx.is_contiguous()); CHECK_SHAPE(seq_idx, batch_size, seqlen); } - if (seq_pos_idx_.has_value()) { - auto seq_pos_idx = seq_pos_idx_.value(); - TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32); - TORCH_CHECK(seq_pos_idx.is_cuda()); - TORCH_CHECK(seq_pos_idx.is_contiguous()); - CHECK_SHAPE(seq_pos_idx, batch_size, seqlen); - } + at::Tensor out = torch::empty_like(x); ConvParamsBase params; @@ -155,11 +148,6 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, params.seq_idx_ptr = nullptr; } - if (seq_pos_idx_.has_value()) { - params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr(); - } else { - params.seq_pos_idx_ptr = nullptr; - } if (initial_states_.has_value()) { TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); auto initial_states = initial_states_.value(); @@ -215,6 +203,7 @@ causal_conv1d_update(const at::Tensor &x, auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); TORCH_CHECK(conv_state.scalar_type() == input_type); TORCH_CHECK(x.is_cuda()); @@ -273,26 +262,19 @@ struct Causal_conv1d_fwd_kernel_traits { static constexpr int kNElts = kNBytes == 4 ? 4 : 8; static_assert(kWidth <= kNElts); static constexpr bool kIsVecLoad = kIsVecLoad_; - static constexpr int kNLoadsIndex = kNElts / 4; using vec_t = typename BytesToType::Type; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; - using BlockLoadIndexT = cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; using BlockStoreVecT = cub::BlockStore; - - static constexpr int kSmemIOSize = (kIsVecLoad && kNLoadsIndex == 1) + static constexpr int kSmemIOSize = kIsVecLoad ? 0 - : std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockLoadIndexT::TempStorage), - sizeof(typename BlockLoadIndexVecT::TempStorage)}); + : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; }; -template +template __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; @@ -307,8 +289,6 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { extern __shared__ char smem_[]; auto& smem_load = reinterpret_cast(smem_); auto& smem_load_vec = reinterpret_cast(smem_); - auto& smem_load_index = reinterpret_cast(smem_); - auto& smem_load_index_vec = reinterpret_cast(smem_); auto& smem_store = reinterpret_cast(smem_); auto& smem_store_vec = reinterpret_cast(smem_); vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); @@ -322,8 +302,6 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + channel_id * params.out_c_stride; float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - int *seq_pos_idx = !kHasSeqPosIdx ? nullptr : reinterpret_cast(params.seq_pos_idx_ptr) + batch_id * params.seqlen; // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. if (tidx == 0) { @@ -339,19 +317,13 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; for (int chunk = 0; chunk < n_chunks; ++chunk) { input_t x_vals_load[2 * kNElts] = {0}; - int seq_pos_idx_load[kNElts]; if constexpr(kIsVecLoad) { - Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); - if (kHasSeqPosIdx) - Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(reinterpret_cast(seq_pos_idx), *reinterpret_cast(seq_pos_idx_load), (params.seqlen - chunk * kChunkSize) / kNElts * Ktraits::kNLoadsIndex); - } else { + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { __syncthreads(); - Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); - if (kHasSeqPosIdx) - Ktraits::BlockLoadIndexT(smem_load_index).Load(seq_pos_idx, seq_pos_idx_load, (params.seqlen - chunk * kChunkSize), 0); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); } - x += kChunkSize; - if (kHasSeqPosIdx) seq_pos_idx += kChunkSize; + x += kChunkSize; __syncthreads(); // Thread kNThreads - 1 don't write yet, so that thread 0 can read // the last elements of the previous chunk. @@ -367,17 +339,11 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } float out_vals[kNElts]; - #pragma unroll for (int i = 0; i < kNElts; ++i) { out_vals[i] = bias_val; - int w = 0; - if (kHasSeqPosIdx){ - if(seq_pos_idx_load[i] < kWidth){ - w = kWidth - seq_pos_idx_load[i] - 1; - } - } - for (; w < kWidth; ++w) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; } } @@ -393,28 +359,38 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { #pragma unroll for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } if constexpr(kIsVecLoad) { - Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); } else { - Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); } out += kChunkSize; } } + template void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - constexpr bool kHasSeqPosIdx = false; BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { using Ktraits = Causal_conv1d_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch, params.dim); - auto kernel = &causal_conv1d_fwd_kernel; + + auto kernel = &causal_conv1d_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } + #else + // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -535,7 +511,7 @@ void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; } - constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index e909bcd5391e..bb25314c8bbb 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -37,7 +37,6 @@ struct ConvParamsBase { void *__restrict__ conv_state_ptr; void *__restrict__ seq_idx_ptr; - void *__restrict__ seq_pos_idx_ptr; // No __restrict__ since initial_states could be the same as final_states. void * initial_states_ptr; @@ -52,6 +51,42 @@ struct ConvParamsBase { }; +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType {}; From bed44c4a9ae28ed8b3b0510c0dbaac6e2efd6b2f Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 28 Aug 2024 14:44:23 +0300 Subject: [PATCH 44/45] Add comments and guard checks on disabled "features" --- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 38 +++++++++------------- csrc/torch_bindings.cpp | 1 - vllm/_custom_ops.py | 2 +- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 347b6666165a..df968dda92ad 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -311,6 +311,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // processing 1 row. constexpr int kNRows = 1; + // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size constexpr bool kIsVariableB = true; constexpr bool kIsVariableC = true; constexpr bool kHasZ = true; @@ -485,11 +486,10 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + TORCH_CHECK(weight_type == at::ScalarType::Float); const bool is_variable_B = B.dim() >= 3; const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; TORCH_CHECK(delta.scalar_type() == input_type); TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); @@ -516,18 +516,13 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, CHECK_SHAPE(u, batch_size, dim, seqlen); CHECK_SHAPE(delta, batch_size, dim, seqlen); CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } + TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen ); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + + TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); if (D_.has_value()) { auto D = D_.value(); @@ -553,14 +548,13 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, at::Tensor z, out_z; const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); - } + TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); const int n_chunks = (seqlen + 2048 - 1) / 2048; // const int n_chunks = (seqlen + 1024 - 1) / 1024; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fb0accbbabac..7783acd741f5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -223,7 +223,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "causal_conv1d_fwd(Tensor! x, Tensor! weight," "Tensor? bias_," "Tensor? seq_idx_," - "Tensor? seq_pos_idx_," "Tensor? initial_states_," "Tensor? final_states_out_," "bool silu_activation) -> Tensor"); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 057c71c908fc..89ca4f9ac6c9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -498,7 +498,7 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, initial_states_: Optional[torch.Tensor], final_states_out_: Optional[torch.Tensor], silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, None, + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, initial_states_, final_states_out_, silu_activation) From 950701acefdb8a1e8f0b1262064760fb925263c4 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 28 Aug 2024 18:40:37 +0300 Subject: [PATCH 45/45] Fix header file --- csrc/ops.h | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/ops.h b/csrc/ops.h index 09b0c67f059a..8d24545de898 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -213,7 +213,6 @@ at::Tensor causal_conv1d_update(const at::Tensor& x, at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, const c10::optional& seq_idx_, - const c10::optional& seq_pos_idx_, const c10::optional& initial_states_, const c10::optional& final_states_out_, bool silu_activation);