@@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
39873987template [[host_name(" kernel_rope_vision_f32" )]] kernel kernel_rope_vision_t kernel_rope_vision<float >;
39883988template [[host_name(" kernel_rope_vision_f16" )]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
39893989
3990+ typedef void (im2col_t )(
3991+ constant ggml_metal_kargs_im2col & args,
3992+ device const float * x,
3993+ device char * dst,
3994+ uint3 tgpig[[threadgroup_position_in_grid]],
3995+ uint3 tgpg[[threadgroups_per_grid]],
3996+ uint3 tpitg[[thread_position_in_threadgroup]],
3997+ uint3 ntg[[threads_per_threadgroup]]);
3998+
3999+ template <typename T>
4000+ kernel void kernel_im2col (
4001+ constant ggml_metal_kargs_im2col & args,
4002+ device const float * x,
4003+ device char * dst,
4004+ uint3 tgpig[[threadgroup_position_in_grid]],
4005+ uint3 tgpg[[threadgroups_per_grid]],
4006+ uint3 tpitg[[thread_position_in_threadgroup]],
4007+ uint3 ntg[[threads_per_threadgroup]]) {
4008+ // const int64_t IC = tgpg[0];
4009+ const int64_t OH = tgpg[1 ];
4010+ const int64_t OW = tgpg[2 ];
4011+
4012+ const int64_t KH = ntg[1 ];
4013+ const int64_t KW = ntg[2 ];
4014+
4015+ int64_t in = tpitg[0 ];
4016+ const int64_t ikh = tpitg[1 ];
4017+ const int64_t ikw = tpitg[2 ];
4018+
4019+ const int64_t iic = tgpig[0 ];
4020+ const int64_t ioh = tgpig[1 ];
4021+ const int64_t iow = tgpig[2 ];
4022+
4023+ const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0 ;
4024+ const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1 ;
4025+
4026+ int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4027+
4028+ device T * pdst = (device T *) (dst);
4029+
4030+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW ) {
4031+ while (in < args.N ) {
4032+ pdst[offset_dst] = 0 .0f ;
4033+ offset_dst += ntg[0 ]*args.CHW *OH*OW;
4034+
4035+ in += ntg[0 ];
4036+ }
4037+ } else {
4038+ int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
4039+
4040+ while (in < args.N ) {
4041+ pdst[offset_dst] = x[offset_src];
4042+
4043+ offset_dst += ntg[0 ]*args.CHW *OH*OW;
4044+ offset_src += ntg[0 ]*args.ofs0 ;
4045+
4046+ in += ntg[0 ];
4047+ }
4048+ }
4049+ }
4050+
4051+ template [[host_name(" kernel_im2col_f32" )]] kernel im2col_t kernel_im2col<float >;
4052+ template [[host_name(" kernel_im2col_f16" )]] kernel im2col_t kernel_im2col<half>;
4053+
39904054// TODO: obolete -- remove
3991- // typedef void (im2col_t )(
4055+ // typedef void (im2col_ext_t )(
39924056// constant ggml_metal_kargs_im2col & args,
39934057// device const float * x,
39944058// device char * dst,
@@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
39984062// uint3 ntg[[threads_per_threadgroup]]);
39994063//
40004064// template <typename T>
4001- // kernel void kernel_im2col (
4065+ // kernel void kernel_im2col_ext (
40024066// constant ggml_metal_kargs_im2col & args,
40034067// device const float * x,
40044068// device char * dst,
40054069// uint3 tgpig[[threadgroup_position_in_grid]],
4006- // uint3 tgpg[[threadgroups_per_grid]],
4070+ // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
40074071// uint3 tpitg[[thread_position_in_threadgroup]],
4008- // uint3 ntg[[threads_per_threadgroup]]) {
4009- // // const int64_t IC = tgpg[0];
4010- // const int64_t OH = tgpg[1];
4011- // const int64_t OW = tgpg[2];
4072+ // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4073+ // const int64_t KHW = (int64_t)args.KHW;
40124074//
4013- // // const int64_t N = ntg[0];
4014- // const int64_t KH = ntg[1];
4015- // const int64_t KW = ntg[2];
4075+ // const int64_t d = tgpig[0] / args.CHW;
4076+ // const int64_t chw = tgpig[0] % args.CHW;
4077+ // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4078+ // const int64_t HW = tgpig[0] % KHW;
40164079//
4017- // const int64_t in = tpitg[0];
4018- // const int64_t ikh = tpitg[1];
4019- // const int64_t ikw = tpitg[2];
4080+ // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4081+ // if (tpitg_0 >= args.N) {
4082+ // return;
4083+ // }
40204084//
4021- // const int64_t iic = tgpig[0];
4022- // const int64_t ioh = tgpig[1];
4023- // const int64_t iow = tgpig[2];
4085+ // const int64_t tpitg_1 = HW / args.KW;
4086+ // const int64_t tpitg_2 = HW % args.KW;
40244087//
4025- // const int64_t iiw = iow* args.s0 + ikw* args.d0 - args.p0;
4026- // const int64_t iih = ioh* args.s1 + ikh* args.d1 - args.p1;
4088+ // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4089+ // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
40274090//
4028- // const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4091+ // const int64_t offset_dst =
4092+ // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4093+ // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
40294094//
40304095// device T * pdst = (device T *) (dst);
40314096//
40324097// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
40334098// pdst[offset_dst] = 0.0f;
40344099// } else {
4035- // const int64_t offset_src = in* args.ofs0 + iic*args.ofs1 + iih* args.IW + iiw ;
4036- // pdst[offset_dst] = x[offset_src];
4100+ // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1 ;
4101+ // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw ];
40374102// }
40384103// }
40394104//
4040- // template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4041- // template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4042-
4043- typedef void (im2col_ext_t )(
4044- constant ggml_metal_kargs_im2col & args,
4045- device const float * x,
4046- device char * dst,
4047- uint3 tgpig[[threadgroup_position_in_grid]],
4048- uint3 tgpg[[threadgroups_per_grid]],
4049- uint3 tpitg[[thread_position_in_threadgroup]],
4050- uint3 ntg[[threads_per_threadgroup]]);
4051-
4052- template <typename T>
4053- kernel void kernel_im2col_ext (
4054- constant ggml_metal_kargs_im2col & args,
4055- device const float * x,
4056- device char * dst,
4057- uint3 tgpig[[threadgroup_position_in_grid]],
4058- uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4059- uint3 tpitg[[thread_position_in_threadgroup]],
4060- uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4061- const int64_t KHW = (int64_t )args.KHW ;
4062-
4063- const int64_t d = tgpig[0 ] / args.CHW ;
4064- const int64_t chw = tgpig[0 ] % args.CHW ;
4065- const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4066- const int64_t HW = tgpig[0 ] % KHW;
4067-
4068- const int64_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
4069- if (tpitg_0 >= args.N ) {
4070- return ;
4071- }
4072-
4073- const int64_t tpitg_1 = HW / args.KW ;
4074- const int64_t tpitg_2 = HW % args.KW ;
4075-
4076- const int64_t iiw = tgpig[2 ] * args.s0 + tpitg_2 * args.d0 - args.p0 ;
4077- const int64_t iih = tgpig[1 ] * args.s1 + tpitg_1 * args.d1 - args.p1 ;
4078-
4079- const int64_t offset_dst =
4080- (tpitg_0 * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * args.CHW +
4081- (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4082-
4083- device T * pdst = (device T *) (dst);
4084-
4085- if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW ) {
4086- pdst[offset_dst] = 0 .0f ;
4087- } else {
4088- const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1 ;
4089- pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4090- }
4091- }
4092-
4093- template [[host_name(" kernel_im2col_ext_f32" )]] kernel im2col_ext_t kernel_im2col_ext<float >;
4094- template [[host_name(" kernel_im2col_ext_f16" )]] kernel im2col_ext_t kernel_im2col_ext<half>;
4105+ // template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4106+ // template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
40954107
40964108typedef void (conv_transpose_1d_t )(
40974109 constant ggml_metal_kargs_conv_transpose_1d & args,
0 commit comments