Skip to content

Commit 7319b6e

Browse files
committed
Fix graph update for MLP with post layernorm
1 parent 4230dab commit 7319b6e

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

exllamav2/exllamav2_ext/cuda/graph.cu

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void Graph::attach_label(cudaStream_t stream, int label, int sublabel)
133133
}
134134

135135
template <typename T>
136-
void Graph::update_param(int label, int sublabel, int param, T value)
136+
void Graph::update_param(int label, int sublabel, int param, T value, bool debug)
137137
{
138138
for (int i = 0; i < node_labels.size(); ++i)
139139
{
@@ -145,19 +145,22 @@ void Graph::update_param(int label, int sublabel, int param, T value)
145145

146146
node_needs_update[i] = true;
147147

148-
// printf("-----------------------------------------------------\n");
149-
// printf("UPDATED:\n");
150-
// DBGI(i);
151-
// inspect_graph();
148+
if (debug)
149+
{
150+
printf("-----------------------------------------------------\n");
151+
printf("UPDATED: ");
152+
DBGI(i);
153+
inspect_graph();
154+
}
152155
}
153156
}
154157

155-
void Graph::update_param_ptr(int label, int sublabel, int param, void* value)
158+
void Graph::update_param_ptr(int label, int sublabel, int param, void* value, bool debug)
156159
{
157-
update_param<void*>(label, sublabel, param, value);
160+
update_param<void*>(label, sublabel, param, value, debug);
158161
}
159162

160-
void Graph::update_param_int(int label, int sublabel, int param, int value)
163+
void Graph::update_param_int(int label, int sublabel, int param, int value, bool debug)
161164
{
162-
update_param<int>(label, sublabel, param, value);
165+
update_param<int>(label, sublabel, param, value, debug);
163166
}

exllamav2/exllamav2_ext/cuda/graph.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public:
4646
void attach_label(cudaStream_t stream, int label, int sublabel);
4747

4848
template <typename T>
49-
void update_param(int label, int sublabel, int param, T value);
49+
void update_param(int label, int sublabel, int param, T value, bool debug);
5050

51-
void update_param_ptr(int label, int sublabel, int param, void* value);
52-
void update_param_int(int label, int sublabel, int param, int value);
51+
void update_param_ptr(int label, int sublabel, int param, void* value, bool debug = false);
52+
void update_param_int(int label, int sublabel, int param, int value, bool debug = false);
5353
};
5454

5555

exllamav2/exllamav2_ext/cuda/q_mlp.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void QMLP::forward_
109109
if (graph->count())
110110
{
111111
graph->begin_capture(stream);
112-
forward_run_(stream, cublas_handle, (half*) x, rows, columns, loras, lora_temp, graph);
112+
forward_run_(stream, cublas_handle, (void*) x, rows, columns, loras, lora_temp, graph);
113113
graph->end_capture(stream);
114114
// printf("**** record ****\n");
115115
// DBGI2(rows, columns);
@@ -225,7 +225,7 @@ void QMLP::forward_run_
225225

226226
else
227227
{
228-
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, graph, 0);
228+
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, false, NULL, 0, false, graph, 0);
229229
if (layernorm_is_rms)
230230
rms_norm_cuda(stream, temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32, graph, KernelLabels::POST_NORM);
231231
else

0 commit comments

Comments
 (0)