Skip to content

Commit 41e2846

Browse files
committed
Merge branch 'dev'
2 parents 8098d61 + 19513f3 commit 41e2846

35 files changed

+14577
-142
lines changed

exllamav3/architecture/gemma3.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def __init__(
138138
self.vision_pp.size = read_dict(read_prep_config, dict, ["size"], no_default)
139139

140140

141+
def default_max_position_embeddings(self):
142+
# Fixed for Gemma3, usually not present in config.json
143+
return 131072
144+
145+
141146
class Gemma3TextConfig(Config):
142147
arch_string = "Gemma3ForCausalLM"
143148

@@ -218,6 +223,11 @@ def __init__(
218223
self.final_logit_softcapping = self.read_cfg(float, "final_logit_softcapping", 0.0)
219224

220225

226+
def default_max_position_embeddings(self):
227+
# Fixed for Gemma2, usually not present in config.json
228+
return 8192
229+
230+
221231
class Gemma3Model(Model):
222232
config_class = Gemma3Config
223233

exllamav3/conversion/convert_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@
5151

5252
num_ref_states = 5
5353

54+
def check_system():
55+
print("asdasdasd")
56+
if os.environ.get("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") is not None:
57+
print(
58+
"\n"
59+
f" !! {col_red}IMPORTANT: The environment variable TORCH_ALLOW_TF32_CUBLAS_OVERRIDE is set!{col_default}\n"
60+
f" !! {col_red}This causes Torch to run in reduced precision mode, which is likely to cause this "
61+
f"script to fail or result in broken models.{col_default}\n"
62+
"\n"
63+
)
64+
65+
5466
def save_dict(filename, dict_, args):
5567
path = os.path.join(args["work_dir"], filename)
5668
with open(path, "w", encoding = "utf8") as f:
@@ -104,6 +116,8 @@ def prepare_env(args):
104116

105117

106118
def prepare(args) -> (dict, dict, bool, str):
119+
check_system()
120+
107121
if not args.work_dir:
108122
return None, None, False, "Must specify --work_dir"
109123
if not args.in_dir and not args.resume:

exllamav3/exllamav3_ext/add.cu

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <cuda_fp16.h>
2+
#include "add.cuh"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include "util.h"
6+
#include "util.cuh"
7+
8+
#define NUM_THREADS 1024
9+
10+
#define KERNEL_DEF(xt, yt, zt, kernel, fn) \
11+
__launch_bounds__(NUM_THREADS) \
12+
__global__ void kernel \
13+
( \
14+
const xt* __restrict__ x, \
15+
const yt* __restrict__ y, \
16+
zt* __restrict__ z, \
17+
const uint64_t numel \
18+
) \
19+
{ \
20+
uint64_t idx = ((uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x); \
21+
if (idx >= numel) return; \
22+
xt a = x[idx]; \
23+
yt b = y[idx]; \
24+
z[idx] = fn; \
25+
}
26+
27+
KERNEL_DEF(half, half, half, add_kernel_hhh, __hadd(a, b))
28+
KERNEL_DEF(half, half, float, add_kernel_hhf, __half2float(__hadd(a, b)))
29+
KERNEL_DEF(half, float, half, add_kernel_hfh, __float2half_rn(__half2float(a) + b))
30+
KERNEL_DEF(half, float, float, add_kernel_hff, __half2float(a) + b)
31+
KERNEL_DEF(float, half, half, add_kernel_fhh, __float2half_rn(a + __half2float(b)))
32+
KERNEL_DEF(float, half, float, add_kernel_fhf, a + __half2float(b))
33+
KERNEL_DEF(float, float, half, add_kernel_ffh, __float2half_rn(a + b))
34+
KERNEL_DEF(float, float, float, add_kernel_fff, a + b)
35+
36+
#undef KERNEL_DEF
37+
38+
/*
39+
x + y -> z
40+
Works inplace if x == z or y == z
41+
*/
42+
43+
void add_gr
44+
(
45+
const at::Tensor& x,
46+
const at::Tensor& y,
47+
at::Tensor& z,
48+
Graph* graph
49+
)
50+
{
51+
const at::cuda::OptionalCUDAGuard device_guard(x.device());
52+
cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream();
53+
54+
auto xt = x.dtype();
55+
auto yt = y.dtype();
56+
auto zt = z.dtype();
57+
uint64_t numel = x.numel();
58+
int blocks = (int) CEIL_DIVIDE(numel, (uint64_t) NUM_THREADS);
59+
60+
#define INSTANCE(xt_, yt_, zt_, xt__, yt__, zt__, kernel) \
61+
if (xt == xt_ && yt == yt_ && zt == zt_) \
62+
{ \
63+
kernel<<<blocks, NUM_THREADS, 0, stream>>> \
64+
( \
65+
(const xt__*) x.data_ptr(), \
66+
(const yt__*) y.data_ptr(), \
67+
(zt__*) z.data_ptr(), \
68+
numel \
69+
); \
70+
if (graph) graph->record_param((void*) &kernel, GP_add_x, 0); \
71+
if (graph) graph->record_param((void*) &kernel, GP_add_y, 1); \
72+
if (graph) graph->record_param((void*) &kernel, GP_add_z, 2); \
73+
if (graph) graph->record_param((void*) &kernel, GP_end, 0); \
74+
cuda_check(cudaPeekAtLastError()); \
75+
}
76+
77+
INSTANCE(at::kHalf, at::kHalf, at::kHalf, half, half, half , add_kernel_hhh)
78+
INSTANCE(at::kHalf, at::kHalf, at::kFloat, half, half, float, add_kernel_hhf)
79+
INSTANCE(at::kHalf, at::kFloat, at::kHalf, half, float, half , add_kernel_hfh)
80+
INSTANCE(at::kHalf, at::kFloat, at::kFloat, half, float, float, add_kernel_hff)
81+
INSTANCE(at::kFloat, at::kHalf, at::kHalf, float, half, half , add_kernel_fhh)
82+
INSTANCE(at::kFloat, at::kHalf, at::kFloat, float, half, float, add_kernel_fhf)
83+
INSTANCE(at::kFloat, at::kFloat, at::kHalf, float, float, half , add_kernel_ffh)
84+
INSTANCE(at::kFloat, at::kFloat, at::kFloat, float, float, float, add_kernel_fff)
85+
86+
#undef INSTANCE
87+
88+
cuda_check(cudaPeekAtLastError());
89+
}
90+
91+
void add
92+
(
93+
const at::Tensor& x,
94+
const at::Tensor& y,
95+
at::Tensor& z
96+
)
97+
{
98+
add_gr(x, y, z, nullptr);
99+
}

exllamav3/exllamav3_ext/add.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include "graph.cuh"
5+
6+
void add_gr
7+
(
8+
const at::Tensor& x,
9+
const at::Tensor& y,
10+
at::Tensor& z,
11+
Graph* graph
12+
);
13+
14+
void add
15+
(
16+
const at::Tensor& x,
17+
const at::Tensor& y,
18+
at::Tensor& z
19+
);

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "routing.cuh"
1616
#include "gdn.cuh"
1717
#include "causal_conv1d.cuh"
18+
#include "add.cuh"
1819

1920
#include "quant/quantize.cuh"
2021
#include "quant/pack.cuh"
@@ -23,6 +24,7 @@
2324
#include "quant/exl3_gemm.cuh"
2425
#include "quant/exl3_kernel_map.cuh"
2526
#include "quant/util.cuh"
27+
#include "quant/exl3_devctx.cuh"
2628

2729
#include "generator/strings.h"
2830
#include "generator/sampling_basic.cuh"
@@ -87,6 +89,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
8789
m.def("exl3_gemm", &exl3_gemm, "exl3_gemm");
8890
m.def("exl3_gemm_num_kernel_shapes", &exl3_gemm_num_kernel_shapes, "exl3_gemm_num_kernel_shapes");
8991
m.def("exl3_gemm_shape_compat", &exl3_gemm_shape_compat, "exl3_gemm_shape_compat");
92+
m.def("g_get_cc", &g_get_cc, "g_get_cc");
93+
m.def("g_get_num_sms", &g_get_num_sms, "g_get_num_sms");
9094
m.def("exl3_mgemm", &exl3_mgemm, "exl3_mgemm");
9195
m.def("hgemm", &hgemm, "hgemm");
9296
m.def("rope", &rope, "rope");
@@ -95,6 +99,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
9599
m.def("relu2_mul", &relu2_mul, "relu2_mul");
96100
m.def("xielu", &xielu, "xielu");
97101
m.def("add_sigmoid_gate", &add_sigmoid_gate, "add_sigmoid_gate");
102+
m.def("add", &add, "add");
98103

99104
m.def("gated_delta_net_fused_op", &gated_delta_net_fused_op, "gated_delta_net_fused_op");
100105
m.def("cuda_recurrent_gated_delta_rule", &cuda_recurrent_gated_delta_rule, "cuda_recurrent_gated_delta_rule");

exllamav3/exllamav3_ext/graph.cu

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#include <Python.h>
2+
#include "graph.cuh"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
//#include <torch/extension.h>
6+
#include "util.h"
7+
#include "util.cuh"
8+
9+
Graph::Graph()
10+
{
11+
ready = false;
12+
graph = NULL;
13+
graph_exec = NULL;
14+
}
15+
16+
Graph::~Graph()
17+
{
18+
if (graph) cudaGraphDestroy(graph);
19+
if (graph_exec) cudaGraphExecDestroy(graph_exec);
20+
}
21+
22+
cudaStream_t Graph::capture_begin()
23+
{
24+
// Make sure nothing is pending
25+
cudaDeviceSynchronize();
26+
27+
// Create capture stream
28+
cuda_check(cudaStreamCreateWithFlags(&capture_stream, cudaStreamNonBlocking));
29+
30+
// Begin capture
31+
cuda_check(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeThreadLocal));
32+
return capture_stream;
33+
}
34+
35+
void Graph::capture_end()
36+
{
37+
// End capture
38+
cuda_check(cudaStreamEndCapture(capture_stream, &graph));
39+
cuda_check(cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0));
40+
//inspect_graph();
41+
42+
// Get graph nodes
43+
size_t num_nodes;
44+
cudaGraphGetNodes(graph, nullptr, &num_nodes);
45+
nodes.resize(num_nodes);
46+
cudaGraphGetNodes(graph, nodes.data(), &num_nodes);
47+
48+
// Store copies of all node param structures
49+
node_params.resize(num_nodes);
50+
node_needs_update.resize(num_nodes);
51+
for (int i = 0; i < num_nodes; ++i)
52+
node_needs_update[i] = false;
53+
54+
int n = 0;
55+
int c = 0;
56+
while (true)
57+
{
58+
cudaGraphNodeType t{};
59+
cudaGraphNodeGetType(nodes[n], &t);
60+
61+
// Node type: kernel
62+
if (t == cudaGraphNodeTypeKernel)
63+
{
64+
cudaGraphKernelNodeGetParams(nodes[n], &node_params[n]);
65+
// DBGX(node_params[n].func);
66+
67+
for(; c < graph_sites.size(); c++)
68+
{
69+
void* func = std::get<0>(graph_sites[c]);
70+
// DBGX(func);
71+
72+
if (func != node_params[n].func) break;
73+
74+
int param_id = std::get<1>(graph_sites[c]);
75+
int param_offset = std::get<2>(graph_sites[c]);
76+
77+
graph_node_sites.push_back(std::tuple<int, int, int>(n, param_id, param_offset));
78+
if (param_id == GP_end) { c++; break; }
79+
80+
// DBGI2(param_id, param_offset);
81+
}
82+
}
83+
84+
n++;
85+
if (c == graph_sites.size()) break;
86+
if (n == num_nodes) TORCH_CHECK(false, "Graph recording failed");
87+
};
88+
89+
// Destroy capture stream
90+
cuda_check(cudaStreamDestroy(capture_stream));
91+
92+
// Graph is ready
93+
ready = true;
94+
}
95+
96+
void Graph::record_param(void* kernel, int param_id, int param_offset)
97+
{
98+
graph_sites.push_back(std::tuple<void*, int, int>(kernel, param_id, param_offset));
99+
}
100+
101+
void Graph::launch(std::vector<PPTR> params, cudaStream_t stream)
102+
{
103+
int p = 0;
104+
int n = 0;
105+
while (true)
106+
{
107+
if (std::get<1>(graph_node_sites[n]) == std::get<0>(params[p]))
108+
{
109+
if (std::get<0>(params[p]) != GP_end)
110+
{
111+
void* new_value = std::get<1>(params[p]);
112+
int node_idx = std::get<0>(graph_node_sites[n]);
113+
int param_offset = std::get<2>(graph_node_sites[n]);
114+
115+
// DBGI3(p, node_idx, param_offset);
116+
117+
void** p_old_value = (void**) node_params[node_idx].kernelParams[param_offset];
118+
if (*p_old_value != new_value)
119+
{
120+
*p_old_value = new_value;
121+
node_needs_update[node_idx] = true;
122+
}
123+
}
124+
else
125+
{
126+
// DBGI(p);
127+
}
128+
p++;
129+
}
130+
131+
n++;
132+
if (p == params.size()) break;
133+
if (n == graph_node_sites.size()) TORCH_CHECK(false, "Graph update failed");
134+
}
135+
136+
for (int n = 0; n < nodes.size(); ++n)
137+
{
138+
// DBGI(n);
139+
if (!node_needs_update[n]) continue;
140+
// printf("update\n");
141+
cudaGraphExecKernelNodeSetParams(graph_exec, nodes[n], &node_params[n]);
142+
node_needs_update[n] = false;
143+
}
144+
145+
cudaGraphLaunch(graph_exec, stream);
146+
}
147+
148+
void Graph::inspect_graph()
149+
{
150+
// Get the number of nodes in the graph
151+
size_t numNodes;
152+
cudaGraphGetNodes(graph, nullptr, &numNodes);
153+
154+
// Get the nodes in the graph
155+
std::vector<cudaGraphNode_t> nodes(numNodes);
156+
cudaGraphGetNodes(graph, nodes.data(), &numNodes);
157+
DBGI(nodes.size());
158+
159+
// Inspect each node
160+
for (size_t i = 0; i < numNodes; ++i)
161+
{
162+
cudaGraphNodeType nodeType;
163+
cudaGraphNodeGetType(nodes[i], &nodeType);
164+
165+
if (nodeType == cudaGraphNodeTypeKernel)
166+
{
167+
cudaKernelNodeParams nodeParams;
168+
cudaGraphKernelNodeGetParams(nodes[i], &nodeParams);
169+
std::cout << "Kernel node " << i << ":" << std::endl;
170+
std::cout << " Function pointer: " << nodeParams.func << std::endl;
171+
std::cout << " Grid dimensions: (" << nodeParams.gridDim.x << ", " << nodeParams.gridDim.y << ", " << nodeParams.gridDim.z << ")" << std::endl;
172+
std::cout << " Block dimensions: (" << nodeParams.blockDim.x << ", " << nodeParams.blockDim.y << ", " << nodeParams.blockDim.z << ")" << std::endl;
173+
std::cout << " Shared memory: " << nodeParams.sharedMemBytes << " bytes" << std::endl;
174+
175+
} else {
176+
std::cout << "Node " << i << " is not a kernel node." << std::endl;
177+
}
178+
}
179+
}
180+

0 commit comments

Comments
 (0)