@@ -11642,79 +11642,191 @@ static void ggml_compute_forward_add_rel_pos(
1164211642 }
1164311643}
1164411644
11645- // ggml_compute_forward_rwkv_wkv
11645+ // ggml_compute_forward_rwkv_wkv6
1164611646
11647- static void ggml_compute_forward_rwkv_wkv_f32 (
11647+ static void ggml_compute_forward_rwkv_wkv6_f32 (
1164811648 const struct ggml_compute_params * params ,
1164911649 struct ggml_tensor * dst ) {
11650- const size_t T = dst -> src [1 ]-> ne [3 ];
11651- const size_t C = dst -> ne [0 ];
11652- const size_t H = dst -> src [1 ]-> ne [2 ];
11653- const size_t n_seqs = dst -> src [5 ]-> ne [1 ];
11650+ const int64_t T = dst -> src [1 ]-> ne [3 ];
11651+ const int64_t C = dst -> ne [0 ];
11652+ const int64_t HEADS = dst -> src [1 ]-> ne [2 ];
11653+ const int64_t n_seqs = dst -> src [5 ]-> ne [1 ];
11654+ const int64_t head_size = C / HEADS ;
1165411655
1165511656 float * dst_data = (float * ) dst -> data ;
1165611657 float * state = ((float * ) dst -> data ) + C * T ;
1165711658
11658- if (params -> ith != 0 ) {
11659+ const int ith = params -> ith ;
11660+ const int nth = params -> nth ;
11661+
11662+ if (ith >= HEADS ) {
1165911663 return ;
1166011664 }
1166111665
11662- memset (dst_data , 0 , T * C * sizeof (float ));
11666+ const int h_start = (HEADS * ith ) / nth ;
11667+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS ) ?
11668+ (HEADS * (ith + 1 )) / nth : HEADS ;
1166311669
1166411670 float * k = (float * ) dst -> src [0 ]-> data ;
1166511671 float * v = (float * ) dst -> src [1 ]-> data ;
1166611672 float * r = (float * ) dst -> src [2 ]-> data ;
1166711673 float * time_faaaa = (float * ) dst -> src [3 ]-> data ;
1166811674 float * time_decay = (float * ) dst -> src [4 ]-> data ;
1166911675
11670- size_t t_stride = H * ( C / H );
11676+ size_t t_stride = HEADS * head_size ; // Same to C
1167111677
11672- size_t h_stride = C / H ;
11673- size_t h_stride_2d = (C / H ) * (C / H );
11678+ size_t h_stride = C / HEADS ;
11679+ GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
11680+ size_t h_stride_2d = head_size * head_size ;
1167411681
11675- // basically fused operations:
11676- // dst = r @ (time_faaaa * (k @ v) + state),
11677- // state = time_decay * state + (k @ v),
11678- // recursive through each token
11679- for (size_t t = 0 ; t < T ; t ++ ) {
11680- size_t t_offset = t * t_stride ;
11681- size_t state_offset = (C / H ) * C * (t / (T / n_seqs ));
11682- float * state_cur = state + state_offset ;
11683- float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
11682+ if (ith == 0 ) {
11683+ memset (dst_data , 0 , T * C * sizeof (float ));
11684+ }
11685+ ggml_barrier (params -> threadpool );
1168411686
11685- for (size_t h = 0 ; h < H ; h ++ ) {
11686- size_t h_offset = h * h_stride ;
11687- size_t t_h_offset = t_offset + h_offset ;
11688- size_t h_2d_offset = h * h_stride_2d ;
1168911687
11690- for (size_t i = 0 ; i < C / H ; i ++ ) {
11691- size_t t_h_i_offset = t_h_offset + i ;
11692- size_t h_i_offset = h_offset + i ;
11693- size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
11688+ #if defined(__AVX__ ) && !defined(__AVX512F__ )
11689+ #define GGML_F32X GGML_F32x8
11690+ #define GGML_F32X_SET1 GGML_F32x8_SET1
11691+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
11692+ #define GGML_F32X_STORE GGML_F32x8_STORE
11693+ #define GGML_F32X_MUL GGML_F32x8_MUL
11694+ #define GGML_F32X_FMA GGML_F32x8_FMA
11695+ #define WKV_VECTOR_SIZE 8
11696+ #elif defined(__AVX512F__ )
11697+ #define GGML_F32X GGML_F32x16
11698+ #define GGML_F32X_SET1 GGML_F32x16_SET1
11699+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
11700+ #define GGML_F32X_STORE GGML_F32x16_STORE
11701+ #define GGML_F32X_MUL GGML_F32x16_MUL
11702+ #define GGML_F32X_FMA GGML_F32x16_FMA
11703+ #define WKV_VECTOR_SIZE 16
11704+ #elif defined(__ARM_NEON ) && defined(__aarch64__ )
11705+ #define GGML_F32X GGML_F32x4
11706+ #define GGML_F32X_SET1 GGML_F32x4_SET1
11707+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
11708+ #define GGML_F32X_STORE GGML_F32x4_STORE
11709+ #define GGML_F32X_MUL GGML_F32x4_MUL
11710+ #define GGML_F32X_FMA GGML_F32x4_FMA
11711+ #define WKV_VECTOR_SIZE 4
11712+ #endif
1169411713
11695- float k_val = k [t_h_i_offset ];
11696- float r_val = r [t_h_i_offset ];
11697- float time_faaaa_val = time_faaaa [h_i_offset ];
11698- // RWKV v6: different time_decay for each token.
11699- float time_decay_val = time_decay [t_h_i_offset ];
11714+ #ifdef WKV_VECTOR_SIZE
11715+ const int64_t vec_count = head_size / WKV_VECTOR_SIZE ;
11716+
11717+ for (int64_t t = 0 ; t < T ; t ++ ) {
11718+ size_t t_offset = t * t_stride ;
11719+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
11720+ float * state_cur = state + state_offset ;
11721+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
11722+
11723+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
11724+ size_t h_offset = h * h_stride ;
11725+ size_t t_h_offset = t_offset + h_offset ;
11726+ size_t h_2d_offset = h * h_stride_2d ;
11727+
11728+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
11729+ size_t t_h_i_offset = t_h_offset + i ;
11730+ size_t h_i_offset = h_offset + i ;
11731+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
11732+
11733+ float k_val = k [t_h_i_offset ];
11734+ float r_val = r [t_h_i_offset ];
11735+ float time_faaaa_val = time_faaaa [h_i_offset ];
11736+ float time_decay_val = time_decay [t_h_i_offset ];
11737+
11738+ // Broadcast scalar values to vectors
11739+ GGML_F32X k_vec = GGML_F32X_SET1 (k_val );
11740+ GGML_F32X r_vec = GGML_F32X_SET1 (r_val );
11741+ GGML_F32X time_faaaa_vec = GGML_F32X_SET1 (time_faaaa_val );
11742+ GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val );
11743+
11744+ for (int64_t j = 0 ; j < vec_count ; j ++ ) {
11745+ size_t base_j = j * WKV_VECTOR_SIZE ;
11746+ size_t t_h_j_offset = t_h_offset + base_j ;
11747+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j ;
11748+
11749+ // Load x elements at once
11750+ GGML_F32X v_vec = GGML_F32X_LOAD (& v [t_h_j_offset ]);
11751+ GGML_F32X prev_state_vec = GGML_F32X_LOAD (& state_prev [h_2d_i_j_offset ]);
11752+ GGML_F32X dst_vec = GGML_F32X_LOAD (& dst_data [t_h_j_offset ]);
11753+
11754+ // Compute kv = v * k
11755+ GGML_F32X kv_vec = GGML_F32X_MUL (v_vec , k_vec );
11756+
11757+ // Compute temp = kv * time_faaaa + prev_state
11758+ GGML_F32X temp_vec = GGML_F32X_FMA (prev_state_vec , kv_vec , time_faaaa_vec );
11759+
11760+ // Update dst: dst += temp * r
11761+ dst_vec = GGML_F32X_FMA (dst_vec , temp_vec , r_vec );
11762+ GGML_F32X_STORE (& dst_data [t_h_j_offset ], dst_vec );
11763+
11764+ // Update state: state = prev_state * time_decay + kv
11765+ GGML_F32X new_state_vec = GGML_F32X_FMA (kv_vec , prev_state_vec , time_decay_vec );
11766+ GGML_F32X_STORE (& state_cur [h_2d_i_j_offset ], new_state_vec );
11767+ }
1170011768
11701- for (size_t j = 0 ; j < C / H ; j ++ ) {
11702- size_t t_h_j_offset = t_h_offset + j ;
11703- size_t h_2d_i_j_offset = h_2d_i_offset + j ;
11769+ // Handle remaining elements, this will not be used.
11770+ for (int64_t j = vec_count * WKV_VECTOR_SIZE ; j < head_size ; j ++ ) {
11771+ size_t t_h_j_offset = t_h_offset + j ;
11772+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
11773+ float v_val = v [t_h_j_offset ];
11774+ float kv_val = v_val * k_val ;
11775+ float prev_state_val = state_prev [h_2d_i_j_offset ];
11776+ float temp_val = kv_val * time_faaaa_val + prev_state_val ;
11777+ dst_data [t_h_j_offset ] += temp_val * r_val ;
11778+ state_cur [h_2d_i_j_offset ] = prev_state_val * time_decay_val + kv_val ;
11779+ }
11780+ }
11781+ }
11782+ }
1170411783
11705- float v_val = v [t_h_j_offset ];
11706- float kv_val = v_val * k_val ;
11707- float prev_state_val = state_prev [h_2d_i_j_offset ];
11708- float temp_val = kv_val * time_faaaa_val + prev_state_val ;
11709- dst_data [t_h_j_offset ] += temp_val * r_val ;
11710- state_cur [h_2d_i_j_offset ] = prev_state_val * time_decay_val + kv_val ;
11784+ #else
11785+ // basically fused operations:
11786+ // dst = r @ (time_faaaa * (k @ v) + state),
11787+ // state = time_decay * state + (k @ v),
11788+ // recursive through each token
11789+ for (int64_t t = 0 ; t < T ; t ++ ) {
11790+ size_t t_offset = t * t_stride ;
11791+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
11792+ float * state_cur = state + state_offset ;
11793+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [5 ]-> data + state_offset ;
11794+
11795+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
11796+ size_t h_offset = h * h_stride ;
11797+ size_t t_h_offset = t_offset + h_offset ;
11798+ size_t h_2d_offset = h * h_stride_2d ;
11799+
11800+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
11801+ size_t t_h_i_offset = t_h_offset + i ;
11802+ size_t h_i_offset = h_offset + i ;
11803+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
11804+
11805+ float k_val = k [t_h_i_offset ];
11806+ float r_val = r [t_h_i_offset ];
11807+ float time_faaaa_val = time_faaaa [h_i_offset ];
11808+ // RWKV v6: different time_decay for each token.
11809+ float time_decay_val = time_decay [t_h_i_offset ];
11810+
11811+ for (int64_t j = 0 ; j < head_size ; j ++ ) {
11812+ size_t t_h_j_offset = t_h_offset + j ;
11813+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
11814+
11815+ float v_val = v [t_h_j_offset ];
11816+ float kv_val = v_val * k_val ;
11817+ float prev_state_val = state_prev [h_2d_i_j_offset ];
11818+ float temp_val = kv_val * time_faaaa_val + prev_state_val ;
11819+ dst_data [t_h_j_offset ] += temp_val * r_val ;
11820+ state_cur [h_2d_i_j_offset ] = prev_state_val * time_decay_val + kv_val ;
11821+ }
1171111822 }
1171211823 }
1171311824 }
11714- }
11825+ #endif
1171511826}
1171611827
11717- static void ggml_compute_forward_rwkv_wkv (
11828+
11829+ static void ggml_compute_forward_rwkv_wkv6 (
1171811830 const struct ggml_compute_params * params ,
1171911831 struct ggml_tensor * dst ) {
1172011832
@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
1172311835 switch (src0 -> type ) {
1172411836 case GGML_TYPE_F32 :
1172511837 {
11726- ggml_compute_forward_rwkv_wkv_f32 (params , dst );
11838+ ggml_compute_forward_rwkv_wkv6_f32 (params , dst );
1172711839 } break ;
1172811840 default :
1172911841 {
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1247512587 {
1247612588 ggml_compute_forward_add_rel_pos (params , tensor );
1247712589 } break ;
12478- case GGML_OP_RWKV_WKV :
12590+ case GGML_OP_RWKV_WKV6 :
1247912591 {
12480- ggml_compute_forward_rwkv_wkv (params , tensor );
12592+ ggml_compute_forward_rwkv_wkv6 (params , tensor );
1248112593 } break ;
1248212594 case GGML_OP_MAP_UNARY :
1248312595 {
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1277512887 case GGML_OP_WIN_PART :
1277612888 case GGML_OP_WIN_UNPART :
1277712889 case GGML_OP_GET_REL_POS :
12778- case GGML_OP_RWKV_WKV :
12890+ case GGML_OP_RWKV_WKV6 :
1277912891 case GGML_OP_MAP_UNARY :
1278012892 case GGML_OP_MAP_BINARY :
1278112893 case GGML_OP_MAP_CUSTOM1_F32 :
0 commit comments