Skip to content

Commit 992816b

Browse files
committed
Graph framework (unused, WIP)
1 parent 2ce86d9 commit 992816b

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed

exllamav3/exllamav3_ext/graph.cu

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#include <Python.h>
2+
#include "graph.cuh"
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
//#include <torch/extension.h>
6+
#include "util.h"
7+
#include "util.cuh"
8+
9+
Graph::Graph()
10+
{
11+
ready = false;
12+
graph = NULL;
13+
graph_exec = NULL;
14+
}
15+
16+
Graph::~Graph()
17+
{
18+
if (graph) cudaGraphDestroy(graph);
19+
if (graph_exec) cudaGraphExecDestroy(graph_exec);
20+
}
21+
22+
cudaStream_t Graph::capture_begin()
23+
{
24+
// Make sure nothing is pending
25+
cudaDeviceSynchronize();
26+
27+
// Create capture stream
28+
cuda_check(cudaStreamCreateWithFlags(&capture_stream, cudaStreamNonBlocking));
29+
30+
// Begin capture
31+
cuda_check(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeThreadLocal));
32+
return capture_stream;
33+
}
34+
35+
void Graph::capture_end()
36+
{
37+
// End capture
38+
cuda_check(cudaStreamEndCapture(capture_stream, &graph));
39+
cuda_check(cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0));
40+
//inspect_graph();
41+
42+
// Get graph nodes
43+
size_t num_nodes;
44+
cudaGraphGetNodes(graph, nullptr, &num_nodes);
45+
nodes.resize(num_nodes);
46+
cudaGraphGetNodes(graph, nodes.data(), &num_nodes);
47+
48+
// Store copies of all node param structures
49+
node_params.resize(num_nodes);
50+
node_needs_update.resize(num_nodes);
51+
for (int i = 0; i < num_nodes; ++i)
52+
node_needs_update[i] = false;
53+
54+
int n = 0;
55+
int c = 0;
56+
while (true)
57+
{
58+
cudaGraphNodeType t{};
59+
cudaGraphNodeGetType(nodes[n], &t);
60+
61+
// Node type: kernel
62+
if (t == cudaGraphNodeTypeKernel)
63+
{
64+
cudaGraphKernelNodeGetParams(nodes[n], &node_params[n]);
65+
// DBGX(node_params[n].func);
66+
67+
for(; c < graph_sites.size(); c++)
68+
{
69+
void* func = std::get<0>(graph_sites[c]);
70+
// DBGX(func);
71+
72+
if (func != node_params[n].func) break;
73+
74+
int param_id = std::get<1>(graph_sites[c]);
75+
int param_offset = std::get<2>(graph_sites[c]);
76+
77+
graph_node_sites.push_back(std::tuple<int, int, int>(n, param_id, param_offset));
78+
if (param_id == GP_end) { c++; break; }
79+
80+
// DBGI2(param_id, param_offset);
81+
}
82+
}
83+
84+
n++;
85+
if (c == graph_sites.size()) break;
86+
if (n == num_nodes) TORCH_CHECK(false, "Graph recording failed");
87+
};
88+
89+
// Destroy capture stream
90+
cuda_check(cudaStreamDestroy(capture_stream));
91+
92+
// Graph is ready
93+
ready = true;
94+
}
95+
96+
void Graph::record_param(void* kernel, int param_id, int param_offset)
97+
{
98+
graph_sites.push_back(std::tuple<void*, int, int>(kernel, param_id, param_offset));
99+
}
100+
101+
void Graph::launch(std::vector<PPTR> params, cudaStream_t stream)
102+
{
103+
int p = 0;
104+
int n = 0;
105+
while (true)
106+
{
107+
if (std::get<1>(graph_node_sites[n]) == std::get<0>(params[p]))
108+
{
109+
if (std::get<0>(params[p]) != GP_end)
110+
{
111+
void* new_value = std::get<1>(params[p]);
112+
int node_idx = std::get<0>(graph_node_sites[n]);
113+
int param_offset = std::get<2>(graph_node_sites[n]);
114+
115+
// DBGI3(p, node_idx, param_offset);
116+
117+
void** p_old_value = (void**) node_params[node_idx].kernelParams[param_offset];
118+
if (*p_old_value != new_value)
119+
{
120+
*p_old_value = new_value;
121+
node_needs_update[node_idx] = true;
122+
}
123+
}
124+
else
125+
{
126+
// DBGI(p);
127+
}
128+
p++;
129+
}
130+
131+
n++;
132+
if (p == params.size()) break;
133+
if (n == graph_node_sites.size()) TORCH_CHECK(false, "Graph update failed");
134+
}
135+
136+
for (int n = 0; n < nodes.size(); ++n)
137+
{
138+
// DBGI(n);
139+
if (!node_needs_update[n]) continue;
140+
// printf("update\n");
141+
cudaGraphExecKernelNodeSetParams(graph_exec, nodes[n], &node_params[n]);
142+
node_needs_update[n] = false;
143+
}
144+
145+
cudaGraphLaunch(graph_exec, stream);
146+
}
147+
148+
void Graph::inspect_graph()
149+
{
150+
// Get the number of nodes in the graph
151+
size_t numNodes;
152+
cudaGraphGetNodes(graph, nullptr, &numNodes);
153+
154+
// Get the nodes in the graph
155+
std::vector<cudaGraphNode_t> nodes(numNodes);
156+
cudaGraphGetNodes(graph, nodes.data(), &numNodes);
157+
DBGI(nodes.size());
158+
159+
// Inspect each node
160+
for (size_t i = 0; i < numNodes; ++i)
161+
{
162+
cudaGraphNodeType nodeType;
163+
cudaGraphNodeGetType(nodes[i], &nodeType);
164+
165+
if (nodeType == cudaGraphNodeTypeKernel)
166+
{
167+
cudaKernelNodeParams nodeParams;
168+
cudaGraphKernelNodeGetParams(nodes[i], &nodeParams);
169+
std::cout << "Kernel node " << i << ":" << std::endl;
170+
std::cout << " Function pointer: " << nodeParams.func << std::endl;
171+
std::cout << " Grid dimensions: (" << nodeParams.gridDim.x << ", " << nodeParams.gridDim.y << ", " << nodeParams.gridDim.z << ")" << std::endl;
172+
std::cout << " Block dimensions: (" << nodeParams.blockDim.x << ", " << nodeParams.blockDim.y << ", " << nodeParams.blockDim.z << ")" << std::endl;
173+
std::cout << " Shared memory: " << nodeParams.sharedMemBytes << " bytes" << std::endl;
174+
175+
} else {
176+
std::cout << "Node " << i << " is not a kernel node." << std::endl;
177+
}
178+
}
179+
}
180+

exllamav3/exllamav3_ext/graph.cuh

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
#include <vector>
5+
#include <pybind11/pybind11.h>
6+
namespace py = pybind11;
7+
#include <cuda_runtime.h>
8+
#include <cuda_fp16.h>
9+
10+
using PPTR = std::tuple<int, void*>;
11+
12+
enum GraphedParams
13+
{
14+
GP_end,
15+
16+
GP_gemm_A,
17+
GP_gemm_C,
18+
19+
GP_mgemm,
20+
GP_mgemm_A,
21+
GP_mgemm_C,
22+
GP_mgemm_indices,
23+
GP_mgemm_weights,
24+
25+
GP_silu_mul_x,
26+
GP_silu_mul_y,
27+
GP_silu_mul_z,
28+
29+
GP_gelu_mul_x,
30+
GP_gelu_mul_y,
31+
GP_gelu_mul_z,
32+
33+
GP_relu2_mul_x,
34+
GP_relu2_mul_y,
35+
GP_relu2_mul_z,
36+
37+
GP_xielu_x,
38+
GP_xielu_y,
39+
40+
GP_add_sigmoid_gate,
41+
GP_add_sigmoid_gate_x,
42+
GP_add_sigmoid_gate_y,
43+
GP_add_sigmoid_gate_z,
44+
45+
GP_add_x,
46+
GP_add_y,
47+
GP_add_z
48+
};
49+
50+
class Graph
51+
{
52+
public:
53+
cudaStream_t capture_stream;
54+
cudaGraph_t graph;
55+
cudaGraphExec_t graph_exec;
56+
57+
std::vector<std::tuple<void*, int, int>> graph_sites;
58+
std::vector<std::tuple<int, int, int>> graph_node_sites;
59+
60+
std::vector<cudaGraphNode_t> nodes;
61+
std::vector<cudaKernelNodeParams> node_params;
62+
std::vector<void*> current_values;
63+
std::vector<bool> node_needs_update;
64+
65+
bool ready;
66+
67+
Graph();
68+
~Graph();
69+
70+
cudaStream_t capture_begin();
71+
void capture_end();
72+
73+
void record_param(void* kernel, int param_id, int param_offset);
74+
void launch(std::vector<PPTR> params, cudaStream_t stream);
75+
76+
void inspect_graph();
77+
};
78+

0 commit comments

Comments
 (0)