Skip to content

Commit 001e2ab

Browse files
committed
Addition kernel
1 parent 78391e7 commit 001e2ab

File tree

4 files changed

+122
-1
lines changed

4 files changed

+122
-1
lines changed

exllamav3/exllamav3_ext/add.cu

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <cuda_fp16.h>
2+
#include "add.cuh"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include "util.h"
6+
#include "util.cuh"
7+
8+
#define NUM_THREADS 1024
9+
10+
#define KERNEL_DEF(xt, yt, zt, kernel, fn) \
11+
__launch_bounds__(NUM_THREADS) \
12+
__global__ void kernel \
13+
( \
14+
const xt* __restrict__ x, \
15+
const yt* __restrict__ y, \
16+
zt* __restrict__ z, \
17+
const uint64_t numel \
18+
) \
19+
{ \
20+
uint64_t idx = ((uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x); \
21+
if (idx >= numel) return; \
22+
xt a = x[idx]; \
23+
yt b = y[idx]; \
24+
z[idx] = fn; \
25+
}
26+
27+
KERNEL_DEF(half, half, half, add_kernel_hhh, __hadd(a, b))
28+
KERNEL_DEF(half, half, float, add_kernel_hhf, __half2float(__hadd(a, b)))
29+
KERNEL_DEF(half, float, half, add_kernel_hfh, __float2half_rn(__half2float(a) + b))
30+
KERNEL_DEF(half, float, float, add_kernel_hff, __half2float(a) + b)
31+
KERNEL_DEF(float, half, half, add_kernel_fhh, __float2half_rn(a + __half2float(b)))
32+
KERNEL_DEF(float, half, float, add_kernel_fhf, a + __half2float(b))
33+
KERNEL_DEF(float, float, half, add_kernel_ffh, __float2half_rn(a + b))
34+
KERNEL_DEF(float, float, float, add_kernel_fff, a + b)
35+
36+
#undef KERNEL_DEF
37+
38+
/*
39+
x + y -> z
40+
Works inplace if x == z or y == z
41+
*/
42+
43+
void add_gr
44+
(
45+
const at::Tensor& x,
46+
const at::Tensor& y,
47+
at::Tensor& z,
48+
Graph* graph
49+
)
50+
{
51+
const at::cuda::OptionalCUDAGuard device_guard(x.device());
52+
cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream();
53+
54+
auto xt = x.dtype();
55+
auto yt = y.dtype();
56+
auto zt = z.dtype();
57+
uint64_t numel = x.numel();
58+
int blocks = (int) CEIL_DIVIDE(numel, (uint64_t) NUM_THREADS);
59+
60+
#define INSTANCE(xt_, yt_, zt_, xt__, yt__, zt__, kernel) \
61+
if (xt == xt_ && yt == yt_ && zt == zt_) \
62+
{ \
63+
kernel<<<blocks, NUM_THREADS, 0, stream>>> \
64+
( \
65+
(const xt__*) x.data_ptr(), \
66+
(const yt__*) y.data_ptr(), \
67+
(zt__*) z.data_ptr(), \
68+
numel \
69+
); \
70+
if (graph) graph->record_param((void*) &kernel, GP_add_x, 0); \
71+
if (graph) graph->record_param((void*) &kernel, GP_add_y, 1); \
72+
if (graph) graph->record_param((void*) &kernel, GP_add_z, 2); \
73+
if (graph) graph->record_param((void*) &kernel, GP_end, 0); \
74+
cuda_check(cudaPeekAtLastError()); \
75+
}
76+
77+
INSTANCE(at::kHalf, at::kHalf, at::kHalf, half, half, half , add_kernel_hhh)
78+
INSTANCE(at::kHalf, at::kHalf, at::kFloat, half, half, float, add_kernel_hhf)
79+
INSTANCE(at::kHalf, at::kFloat, at::kHalf, half, float, half , add_kernel_hfh)
80+
INSTANCE(at::kHalf, at::kFloat, at::kFloat, half, float, float, add_kernel_hff)
81+
INSTANCE(at::kFloat, at::kHalf, at::kHalf, float, half, half , add_kernel_fhh)
82+
INSTANCE(at::kFloat, at::kHalf, at::kFloat, float, half, float, add_kernel_fhf)
83+
INSTANCE(at::kFloat, at::kFloat, at::kHalf, float, float, half , add_kernel_ffh)
84+
INSTANCE(at::kFloat, at::kFloat, at::kFloat, float, float, float, add_kernel_fff)
85+
86+
#undef INSTANCE
87+
88+
cuda_check(cudaPeekAtLastError());
89+
}
90+
91+
void add
92+
(
93+
const at::Tensor& x,
94+
const at::Tensor& y,
95+
at::Tensor& z
96+
)
97+
{
98+
add_gr(x, y, z, nullptr);
99+
}

exllamav3/exllamav3_ext/add.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include "graph.cuh"
5+
6+
void add_gr
7+
(
8+
const at::Tensor& x,
9+
const at::Tensor& y,
10+
at::Tensor& z,
11+
Graph* graph
12+
);
13+
14+
void add
15+
(
16+
const at::Tensor& x,
17+
const at::Tensor& y,
18+
at::Tensor& z
19+
);

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "routing.cuh"
1616
#include "gdn.cuh"
1717
#include "causal_conv1d.cuh"
18+
#include "add.cuh"
1819

1920
#include "quant/quantize.cuh"
2021
#include "quant/pack.cuh"
@@ -98,6 +99,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
9899
m.def("relu2_mul", &relu2_mul, "relu2_mul");
99100
m.def("xielu", &xielu, "xielu");
100101
m.def("add_sigmoid_gate", &add_sigmoid_gate, "add_sigmoid_gate");
102+
m.def("add", &add, "add");
101103

102104
m.def("gated_delta_net_fused_op", &gated_delta_net_fused_op, "gated_delta_net_fused_op");
103105
m.def("cuda_recurrent_gated_delta_rule", &cuda_recurrent_gated_delta_rule, "cuda_recurrent_gated_delta_rule");

exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "../hgemm.cuh"
88
#include "../quant/exl3_gemm.cuh"
99
#include "../activation.cuh"
10+
#include "../add.cuh"
1011

1112
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
1213
int bsz,
@@ -140,7 +141,7 @@ void BC_BlockSparseMLP::run_bsz1
140141
}
141142
else
142143
{
143-
out_d.add_(out_d_sh.value());
144+
add(out_d, out_d_sh.value(), out_d);
144145
}
145146
}
146147

0 commit comments

Comments
 (0)