@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
104104 }
105105}
106106
107- template <int block_size>
107+ template <int block_size, bool do_multiply = false >
108108static __global__ void rms_norm_f32 (
109109 const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110- const int64_t stride_sample, const float eps) {
110+ const int64_t stride_sample, const float eps, const float * mul = nullptr , const int64_t mul_stride_row = 0 ,
111+ const int64_t mul_stride_channel = 0 , const int64_t mul_stride_sample = 0 , const int mul_ncols = 0 ,
112+ const int mul_nrows = 0 , const int mul_nchannels = 0 , const int mul_nsamples = 0 ) {
111113 const int nrows = gridDim .x ;
112114 const int nchannels = gridDim .y ;
113115
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
119121 x += sample*stride_sample + channel*stride_channel + row*stride_row;
120122 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
121123
124+ if constexpr (do_multiply) {
125+ const int mul_row = row % mul_nrows;
126+ const int mul_channel = channel % mul_nchannels;
127+ const int mul_sample = sample % mul_nsamples;
128+ mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
129+ }
130+
122131 float tmp = 0 .0f ; // partial sum for thread in warp
123132
124133 for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
145154 const float scale = rsqrtf (mean + eps);
146155
147156 for (int col = tid; col < ncols; col += block_size) {
148- dst[col] = scale * x[col];
157+ if constexpr (do_multiply) {
158+ const int mul_col = col % mul_ncols;
159+ dst[col] = scale * x[col] * mul[mul_col];
160+ } else {
161+ dst[col] = scale * x[col];
162+ }
149163 }
150164}
151165
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
310324 const dim3 blocks_num (nrows, nchannels, nsamples);
311325 if (ncols < 1024 ) {
312326 const dim3 block_dims (WARP_SIZE, 1 , 1 );
313- rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
327+ rms_norm_f32<WARP_SIZE, false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
328+ } else {
329+ const dim3 block_dims (1024 , 1 , 1 );
330+ rms_norm_f32<1024 , false ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
331+ }
332+ }
333+
334+ static void rms_norm_mul_f32_cuda (
335+ const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
337+ const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
338+ const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
339+ const float eps, cudaStream_t stream) {
340+ const dim3 blocks_num (nrows, nchannels, nsamples);
341+ if (mul == nullptr ) {
342+ rms_norm_f32_cuda (x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
343+ return ;
344+ }
345+ if (ncols < 1024 ) {
346+ const dim3 block_dims (WARP_SIZE, 1 , 1 );
347+ rms_norm_f32<WARP_SIZE, true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
314348 } else {
315349 const dim3 block_dims (1024 , 1 , 1 );
316- rms_norm_f32<1024 ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
350+ rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples );
317351 }
318352}
319353
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
407441 rms_norm_f32_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
408442}
409443
444+ void ggml_cuda_op_rms_norm_fused (ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
445+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src [0 ];
446+ float eps = 0 .0f ;
447+
448+ memcpy (&eps, dst->op_params , sizeof (float ));
449+
450+ const float * src0_d = (const float *) rms_norm_src->data ;
451+ const float * mul_d = nullptr ;
452+ const ggml_tensor * mul_src = nullptr ;
453+
454+ if (mul_tensor->src [0 ] == dst) {
455+ mul_d = (float *) mul_tensor->src [1 ]->data ;
456+ mul_src = mul_tensor->src [1 ];
457+ } else if (mul_tensor->src [1 ] == dst) {
458+ mul_d = (float *) mul_tensor->src [0 ]->data ;
459+ mul_src = mul_tensor->src [0 ];
460+ } else {
461+ GGML_ASSERT (false );
462+ }
463+
464+ float * dst_d = (float *) mul_tensor->data ;
465+ cudaStream_t stream = ctx.stream ();
466+
467+ GGML_ASSERT (rms_norm_src->type == GGML_TYPE_F32);
468+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
469+ GGML_ASSERT (mul_tensor->type == GGML_TYPE_F32);
470+ GGML_ASSERT (eps >= 0 .0f );
471+
472+ const int64_t ne00 = rms_norm_src->ne [0 ];
473+ const int64_t ne01 = rms_norm_src->ne [1 ];
474+ const int64_t ne02 = rms_norm_src->ne [2 ];
475+ const int64_t ne03 = rms_norm_src->ne [3 ];
476+
477+ const size_t ts0 = ggml_type_size (rms_norm_src->type );
478+ GGML_ASSERT (rms_norm_src->nb [0 ] == ts0);
479+ const int64_t s01 = rms_norm_src->nb [1 ] / ts0;
480+ const int64_t s02 = rms_norm_src->nb [2 ] / ts0;
481+ const int64_t s03 = rms_norm_src->nb [3 ] / ts0;
482+
483+ const size_t ts_mul = ggml_type_size (mul_src->type );
484+ GGML_ASSERT (mul_src->nb [0 ] == ts_mul);
485+ const int64_t mul_s01 = mul_src->nb [1 ] / ts_mul;
486+ const int64_t mul_s02 = mul_src->nb [2 ] / ts_mul;
487+ const int64_t mul_s03 = mul_src->nb [3 ] / ts_mul;
488+
489+ const int mul_ncols = mul_src->ne [0 ];
490+ const int mul_nrows = mul_src->ne [1 ];
491+ const int mul_nchannels = mul_src->ne [2 ];
492+ const int mul_nsamples = mul_src->ne [3 ];
493+
494+ rms_norm_mul_f32_cuda (src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
495+ }
496+
410497void ggml_cuda_op_rms_norm_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
411498 const ggml_tensor * grad = dst->src [0 ]; // gradients
412499 const ggml_tensor * src0f = dst->src [1 ]; // src0 from forward pass
0 commit comments