Skip to content

Commit 3c77dcf

Browse files
amoskvicArseny Moskvichevajassanirootgabeweisz
authored
Adding ROCm support [AMD official] (#359)
* Current progress * Current progress * FWD kernel compiles * Current progress: compiles and imports, not tested * Adapting tests * working benchmark for torch hipify * Removing hard-coded device typos * all dtypes support * fixed parallel compile issue * use optimized causal_conv1d if available * april18 perf bench script * april18 perf bench script * Delete csrc/selective_scan/selective_scan_fwd_kernel_minimal.cuh * reverted benchmark * reverted changes to base iteration 1 * removed files not in base * Ported bwd changes (partial) * Backward working fp32 * all dtypes with bwd * gitignore hipfied files * rocm cond and move max min to common * triton autotune conditional * Unifying setup.py (in progress) * triton conditional autotune configs * some more conditional compiles * Setup.py functional * Functional * Minmax changes * reduce repeatibility * Removed extra comments * fix template error * Update csrc/selective_scan/reverse_scan.cuh Co-authored-by: Jeff Daily <[email protected]> * restore permissions to base * permission for gitignore and readme * warp size based on code review * Adding ifndef + warnings for dynamic memory size adjustment * minor chnages to setup * fall back for warp size conditional * patch method updated * Minor stylistic changes + an extra warning about patching * 4096 knloads patch * Cleanup, conditional kernel launch parameters * Flexible warp size * Fix warp size to 32 for CUDA --------- Co-authored-by: Arseny Moskvichev <[email protected]> Co-authored-by: ajassani <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Adeem Jassani <[email protected]> Co-authored-by: Gabe Weisz <[email protected]> Co-authored-by: ajassani <[email protected]> Co-authored-by: Jeff Daily <[email protected]> Co-authored-by: root <[email protected]>
1 parent c2568f5 commit 3c77dcf

File tree

11 files changed

+479
-142
lines changed

11 files changed

+479
-142
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
*.egg-info/
33
build/
44
**.so
5+
*.hip
6+
*_hip.*

README.md

100644100755
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@ Mamba is a new state space model architecture showing promising performance on i
1717
It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
1818
with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
1919

20+
## Prerequisites
21+
22+
### Patching ROCm
23+
24+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
25+
26+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
27+
28+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
29+
```bash
30+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
31+
```
32+
2033
## Installation
2134

2235
- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.

csrc/selective_scan/reverse_scan.cuh

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44

55
#pragma once
66

7-
#include <cub/config.cuh>
8-
9-
#include <cub/util_ptx.cuh>
10-
#include <cub/util_type.cuh>
11-
#include <cub/block/block_raking_layout.cuh>
12-
// #include <cub/detail/uninitialized_copy.cuh>
7+
#ifndef USE_ROCM
8+
#include <cub/config.cuh>
9+
10+
#include <cub/util_ptx.cuh>
11+
#include <cub/util_type.cuh>
12+
#include <cub/block/block_raking_layout.cuh>
13+
// #include <cub/detail/uninitialized_copy.cuh>
14+
#else
15+
#include <hipcub/hipcub.hpp>
16+
namespace cub = hipcub;
17+
#endif
1318
#include "uninitialized_copy.cuh"
1419

1520
/**
@@ -46,6 +51,7 @@ __device__ __forceinline__ T ThreadReverseScanInclusive(
4651
inclusive = scan_op(inclusive, input[i]);
4752
output[i] = inclusive;
4853
}
54+
return inclusive;
4955
}
5056

5157
/**
@@ -89,7 +95,15 @@ struct WarpReverseScan {
8995
//---------------------------------------------------------------------
9096

9197
/// Whether the logical warp size and the PTX warp size coincide
92-
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));
98+
99+
// In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
100+
// While in cub, it's defined as a macro that takes a redundant unused argument.
101+
#ifndef USE_ROCM
102+
#define WARP_THREADS CUB_WARP_THREADS(0)
103+
#else
104+
#define WARP_THREADS HIPCUB_WARP_THREADS
105+
#endif
106+
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
93107
/// The number of warp scan steps
94108
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
95109
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 75 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@
99
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
1010
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
1111

12-
#include <cub/block/block_load.cuh>
13-
#include <cub/block/block_store.cuh>
14-
#include <cub/block/block_scan.cuh>
15-
#include <cub/block/block_reduce.cuh>
12+
#ifndef USE_ROCM
13+
#include <cub/block/block_load.cuh>
14+
#include <cub/block/block_store.cuh>
15+
#include <cub/block/block_scan.cuh>
16+
#include <cub/block/block_reduce.cuh>
17+
#else
18+
#include <hipcub/hipcub.hpp>
19+
namespace cub = hipcub;
20+
#endif
1621

1722
#include "selective_scan.h"
1823
#include "selective_scan_common.h"
@@ -33,7 +38,7 @@ struct Selective_Scan_bwd_kernel_traits {
3338
static constexpr int kNItems = kNItems_;
3439
static constexpr int kNBytes = sizeof(input_t);
3540
static_assert(kNBytes == 2 || kNBytes == 4);
36-
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
41+
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
3742
static_assert(kNItems % kNElts == 0);
3843
static constexpr int kNLoads = kNItems / kNElts;
3944
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
@@ -61,12 +66,13 @@ struct Selective_Scan_bwd_kernel_traits {
6166
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
6267
using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
6368
using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
64-
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
65-
sizeof(typename BlockLoadVecT::TempStorage),
66-
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
67-
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
68-
sizeof(typename BlockStoreT::TempStorage),
69-
sizeof(typename BlockStoreVecT::TempStorage)});
69+
70+
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
71+
sizeof(typename BlockLoadVecT::TempStorage),
72+
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
73+
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
74+
sizeof(typename BlockStoreT::TempStorage),
75+
sizeof(typename BlockStoreVecT::TempStorage)});
7076
static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
7177
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
7278
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
@@ -263,12 +269,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
263269
// Initialize running total
264270
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
265271
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
266-
Ktraits::BlockScanT(smem_scan).InclusiveScan(
272+
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
267273
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
268274
);
269275
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
270276
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
271-
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
277+
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
272278
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
273279
);
274280
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
@@ -297,11 +303,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
297303
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
298304
if constexpr (kIsVariableB || kIsVariableC) {
299305
if constexpr (kIsVariableB) {
300-
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
306+
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
301307
}
302308
if constexpr (kIsVariableC) {
303309
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
304-
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
310+
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
305311
}
306312
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
307313
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
@@ -316,13 +322,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
316322
}
317323
if constexpr (!kIsVariableB || !kIsVariableC) {
318324
float2 dA_dBC_val = make_float2(dA_val, dBC_val);
319-
dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
325+
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
320326
dA_val = dA_dBC_val.x;
321327
if (threadIdx.x == 0) {
322328
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
323329
}
324330
} else {
325-
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
331+
dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
326332
}
327333
if (threadIdx.x == 0) {
328334
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
@@ -356,12 +362,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
356362
// Initialize running total
357363
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
358364
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
359-
Ktraits::BlockScanT(smem_scan).InclusiveScan(
365+
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
360366
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
361367
);
362368
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
363369
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
364-
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
370+
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
365371
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
366372
);
367373
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
@@ -397,7 +403,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
397403
dB_vals_f[i * 2] = dB_vals[i].real_;
398404
dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
399405
}
400-
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
406+
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
401407
}
402408
if constexpr (kIsVariableC) {
403409
#pragma unroll
@@ -406,7 +412,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
406412
dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
407413
}
408414
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
409-
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
415+
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
410416
}
411417
const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
412418
float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
@@ -421,14 +427,14 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
421427
}
422428
if constexpr (!kIsVariableB || !kIsVariableC) {
423429
float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
424-
dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
430+
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
425431
dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
426432
dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
427433
if (threadIdx.x == 0) {
428434
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
429435
}
430436
} else {
431-
dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
437+
dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
432438
}
433439
if (threadIdx.x == 0) {
434440
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
@@ -465,12 +471,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
465471
Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
466472
}
467473
if (params.dD_ptr != nullptr) {
468-
dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
474+
dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
469475
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
470476
}
471477
if (params.ddelta_bias_ptr != nullptr) {
472478
__syncthreads();
473-
ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
479+
ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
474480
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
475481
}
476482
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
@@ -499,13 +505,24 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
499505
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
500506
// TODO: check this
501507
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
502-
// printf("smem_size = %d\n", kSmemSize);
508+
503509
dim3 grid(params.batch, params.dim);
510+
504511
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
512+
505513
if (kSmemSize >= 48 * 1024) {
514+
515+
#ifndef USE_ROCM
506516
C10_CUDA_CHECK(cudaFuncSetAttribute(
507517
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
518+
#else
519+
C10_CUDA_CHECK(cudaFuncSetAttribute(
520+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
521+
std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
522+
#endif
523+
508524
}
525+
509526
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
510527
C10_CUDA_KERNEL_LAUNCH_CHECK();
511528
});
@@ -517,15 +534,37 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
517534

518535
template<typename input_t, typename weight_t>
519536
void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
520-
if (params.seqlen <= 128) {
521-
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
522-
} else if (params.seqlen <= 256) {
523-
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
524-
} else if (params.seqlen <= 512) {
525-
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
526-
} else if (params.seqlen <= 1024) {
527-
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
528-
} else {
529-
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
537+
538+
#ifndef USE_ROCM
539+
constexpr int warp_size = 32;
540+
#else
541+
constexpr int warp_size = rocprim::warp_size();
542+
#endif
543+
544+
if (warp_size == 32) {
545+
if (params.seqlen <= 128) {
546+
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
547+
} else if (params.seqlen <= 256) {
548+
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
549+
} else if (params.seqlen <= 512) {
550+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
551+
} else if (params.seqlen <= 1024) {
552+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
553+
} else {
554+
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
555+
}
556+
}
557+
#ifdef USE_ROCM
558+
else {
559+
if (params.seqlen <= 256) {
560+
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
561+
} else if (params.seqlen <= 512) {
562+
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
563+
} else if (params.seqlen <= 1024) {
564+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
565+
} else {
566+
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
567+
}
530568
}
569+
#endif
531570
}

0 commit comments

Comments
 (0)