@@ -245,6 +245,7 @@ struct vk_device_struct {
245245    vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
246246    vk_pipeline pipeline_timestep_embedding_f32;
247247    vk_pipeline pipeline_pool2d_f32;
248+     vk_pipeline pipeline_rwkv_wkv6_f32;
248249
249250    //  [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
250251    vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2 ][2 ][2 ];
@@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants {
528529    int32_t  p0; int32_t  p1;
529530};
530531
532+ struct  vk_op_rwkv_wkv6_push_constants  {
533+     uint32_t  B;
534+     uint32_t  T;
535+     uint32_t  C;
536+     uint32_t  H;
537+ };
538+ 
531539//  Allow pre-recording command buffers
532540struct  vk_staging_memcpy  {
533541    vk_staging_memcpy (void  * _dst, const  void  * _src, size_t  _n) : dst(_dst), src(_src), n(_n) {}
@@ -2014,6 +2022,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20142022
20152023    ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32"  , pool2d_f32_len, pool2d_f32_data, " main"  , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
20162024
2025+     ggml_vk_create_pipeline (device, device->pipeline_rwkv_wkv6_f32 , " rwkv_wkv6_f32"  , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, " main"  , 7 , sizeof (vk_op_rwkv_wkv6_push_constants), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
2026+ 
20172027    for  (auto  &c : compiles) {
20182028        c.wait ();
20192029    }
@@ -5022,6 +5032,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
50225032            return  ctx->device ->pipeline_pool2d_f32 ;
50235033        }
50245034        return  nullptr ;
5035+     case  GGML_OP_RWKV_WKV6:
5036+         if  (src0->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32) {
5037+             return  ctx->device ->pipeline_rwkv_wkv6_f32 ;
5038+         }
5039+         return  nullptr ;
50255040    case  GGML_OP_LEAKY_RELU:
50265041        if  (src0->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32) {
50275042            return  ctx->device ->pipeline_leaky_relu_f32 ;
@@ -5424,6 +5439,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
54245439    }, dryrun);
54255440}
54265441
5442+ static  void  ggml_vk_op_f32_rwkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const  vk_op_rwkv_wkv6_push_constants&& pc, bool  dryrun = false ) {
5443+     const  ggml_tensor * k = dst->src [0 ];
5444+     const  ggml_tensor * v = dst->src [1 ];
5445+     const  ggml_tensor * r = dst->src [2 ];
5446+     const  ggml_tensor * tf = dst->src [3 ];
5447+     const  ggml_tensor * td = dst->src [4 ];
5448+     const  ggml_tensor * state = dst->src [5 ];
5449+ 
5450+     GGML_ASSERT (!ggml_is_quantized (k->type ));
5451+     GGML_ASSERT (!ggml_is_quantized (v->type ));
5452+     GGML_ASSERT (!ggml_is_quantized (r->type ));
5453+     GGML_ASSERT (!ggml_is_quantized (tf->type ));
5454+     GGML_ASSERT (!ggml_is_quantized (td->type ));
5455+     GGML_ASSERT (!ggml_is_quantized (state->type ));
5456+     GGML_ASSERT (dst->buffer  != nullptr );
5457+ 
5458+     vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5459+     GGML_ASSERT (pipeline != nullptr );
5460+ 
5461+     if  (dryrun) {
5462+         ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
5463+         return ;
5464+     }
5465+ 
5466+     ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer ->context ;
5467+     ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer ->context ;
5468+     ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer ->context ;
5469+     ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer ->context ;
5470+     ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer ->context ;
5471+     ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer ->context ;
5472+     ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer ->context ;
5473+ 
5474+     ggml_vk_sync_buffers (subctx);
5475+ 
5476+     vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5477+     uint64_t  k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5478+     bool  K_uma = false , V_uma = false , R_uma = false , TF_uma = false , TD_uma = false , STATE_uma = false , DST_uma = false ;
5479+ 
5480+     if  (ctx->device ->uma ) {
5481+         ggml_vk_host_get (ctx->device , k->data , d_K, k_offset);
5482+         ggml_vk_host_get (ctx->device , v->data , d_V, v_offset);
5483+         ggml_vk_host_get (ctx->device , r->data , d_R, r_offset);
5484+         ggml_vk_host_get (ctx->device , tf->data , d_TF, tf_offset);
5485+         ggml_vk_host_get (ctx->device , td->data , d_TD, td_offset);
5486+         ggml_vk_host_get (ctx->device , state->data , d_State, state_offset);
5487+         ggml_vk_host_get (ctx->device , dst->data , d_D, dst_offset);
5488+ 
5489+         K_uma = d_K != nullptr ;
5490+         V_uma = d_V != nullptr ;
5491+         R_uma = d_R != nullptr ;
5492+         TF_uma = d_TF != nullptr ;
5493+         TD_uma = d_TD != nullptr ;
5494+         STATE_uma = d_State != nullptr ;
5495+         DST_uma = d_D != nullptr ;
5496+     }
5497+ 
5498+     if  (!K_uma) {
5499+         d_K = k_buf_ctx->dev_buffer ;
5500+         k_offset = vk_tensor_offset (k) + k->view_offs ;
5501+     }
5502+     if  (!V_uma) {
5503+         d_V = v_buf_ctx->dev_buffer ;
5504+         v_offset = vk_tensor_offset (v) + v->view_offs ;
5505+     }
5506+     if  (!R_uma) {
5507+         d_R = r_buf_ctx->dev_buffer ;
5508+         r_offset = vk_tensor_offset (r) + r->view_offs ;
5509+     }
5510+     if  (!TF_uma) {
5511+         d_TF = tf_buf_ctx->dev_buffer ;
5512+         tf_offset = vk_tensor_offset (tf) + tf->view_offs ;
5513+     }
5514+     if  (!TD_uma) {
5515+         d_TD = td_buf_ctx->dev_buffer ;
5516+         td_offset = vk_tensor_offset (td) + td->view_offs ;
5517+     }
5518+     if  (!STATE_uma) {
5519+         d_State = state_buf_ctx->dev_buffer ;
5520+         state_offset = vk_tensor_offset (state) + state->view_offs ;
5521+     }
5522+     if  (!DST_uma) {
5523+         d_D = dst_buf_ctx->dev_buffer ;
5524+         dst_offset = vk_tensor_offset (dst) + dst->view_offs ;
5525+     }
5526+ 
5527+     const  uint64_t  k_size = ggml_nbytes (k);
5528+     const  uint64_t  v_size = ggml_nbytes (v);
5529+     const  uint64_t  r_size = ggml_nbytes (r);
5530+     const  uint64_t  tf_size = ggml_nbytes (tf);
5531+     const  uint64_t  td_size = ggml_nbytes (td);
5532+     const  uint64_t  state_size = ggml_nbytes (state);
5533+     const  uint64_t  dst_size = ggml_nbytes (dst);
5534+ 
5535+     std::array<uint32_t , 3 > elements = {
5536+         (uint32_t )(pc.B  * pc.H ),
5537+         1 ,
5538+         1 
5539+     };
5540+ 
5541+     ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
5542+         vk_subbuffer{ d_K, k_offset, k_size },
5543+         vk_subbuffer{ d_V, v_offset, v_size },
5544+         vk_subbuffer{ d_R, r_offset, r_size },
5545+         vk_subbuffer{ d_TF, tf_offset, tf_size },
5546+         vk_subbuffer{ d_TD, td_offset, td_size },
5547+         vk_subbuffer{ d_State, state_offset, state_size },
5548+         vk_subbuffer{ d_D, dst_offset, dst_size }
5549+     }, sizeof (vk_op_rwkv_wkv6_push_constants), &pc, elements);
5550+ }
5551+ 
5552+ static  void  ggml_vk_rwkv_wkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool  dryrun = false ) {
5553+     const  size_t  seq_length = dst->src [0 ]->ne [3 ];
5554+     const  size_t  n_embed = dst->ne [0 ];
5555+     const  size_t  n_heads = dst->src [0 ]->ne [2 ];
5556+     const  size_t  n_seqs = dst->src [5 ]->ne [1 ];
5557+ 
5558+     ggml_vk_op_f32_rwkv6 (
5559+         ctx, subctx, dst,
5560+         {
5561+             (uint32_t )n_seqs,
5562+             (uint32_t )seq_length,
5563+             (uint32_t )n_embed,
5564+             (uint32_t )n_heads,
5565+         },
5566+         dryrun
5567+     );
5568+ }
5569+ 
54275570static  void  ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst, bool  dryrun = false ) {
54285571    int  * op_params = (int  *)dst->op_params ;
54295572
@@ -6569,6 +6712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
65696712    case  GGML_OP_IM2COL:
65706713    case  GGML_OP_TIMESTEP_EMBEDDING:
65716714    case  GGML_OP_POOL_2D:
6715+     case  GGML_OP_RWKV_WKV6:
65726716    case  GGML_OP_LEAKY_RELU:
65736717    case  GGML_OP_FLASH_ATTN_EXT:
65746718        break ;
@@ -6768,6 +6912,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
67686912    case  GGML_OP_FLASH_ATTN_EXT:
67696913        ggml_vk_flash_attn (ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
67706914
6915+         break ;
6916+ 
6917+     case  GGML_OP_RWKV_WKV6:
6918+         ggml_vk_rwkv_wkv6 (ctx, compute_ctx, node, dryrun);
6919+ 
67716920        break ;
67726921    default :
67736922        return  false ;
@@ -6848,6 +6997,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
68486997    case  GGML_OP_IM2COL:
68496998    case  GGML_OP_TIMESTEP_EMBEDDING:
68506999    case  GGML_OP_POOL_2D:
7000+     case  GGML_OP_RWKV_WKV6:
68517001    case  GGML_OP_LEAKY_RELU:
68527002    case  GGML_OP_REPEAT:
68537003        buf = tensor->buffer ;
@@ -7724,6 +7874,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
77247874        case  GGML_OP_IM2COL:
77257875        case  GGML_OP_TIMESTEP_EMBEDDING:
77267876        case  GGML_OP_POOL_2D:
7877+         case  GGML_OP_RWKV_WKV6:
77277878        case  GGML_OP_LEAKY_RELU:
77287879            return  true ;
77297880        default :
@@ -8300,7 +8451,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
83008451    } else  if  (tensor->op  == GGML_OP_LEAKY_RELU) {
83018452        const  float  * op_params = (const  float  *)tensor->op_params ;
83028453        tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
8303-     } else  {
8454+     } else  if  (tensor->op  == GGML_OP_RWKV_WKV6) {
8455+         tensor_clone = ggml_rwkv_wkv6 (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ], tensor->src [3 ],
8456+         tensor->src [4 ], tensor->src [5 ]);
8457+     }
8458+     else  {
83048459        std::cerr << " Missing vk_check_results OP: "   << ggml_op_name (tensor->op ) << std::endl;
83058460        GGML_ABORT (" fatal error"  );
83068461    }
0 commit comments