@@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
2222 dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
2323}
2424
25+ static __global__ void upscale_f32_bilinear (const float * x, float * dst,
26+ const int nb00, const int nb01, const int nb02, const int nb03,
27+ const int ne00_src, const int ne01_src,
28+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
29+ const float sf0, const float sf1, const float sf2, const float sf3,
30+ const float pixel_offset) {
31+ const int64_t index = threadIdx .x + blockIdx .x * blockDim .x ;
32+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
33+
34+ if (index >= dst_total_elements) {
35+ return ;
36+ }
37+
38+ const int i10_dst = index % ne10_dst;
39+ const int i11_dst = (index / ne10_dst) % ne11_dst;
40+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
41+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
42+
43+ const int i02_src = (int )(i12_dst / sf2);
44+ const int i03_src = (int )(i13_dst / sf3);
45+
46+ const float y_src_f = ((float )i11_dst + pixel_offset) / sf1 - pixel_offset;
47+ int y0_src = (int )floorf (y_src_f);
48+ int y1_src = y0_src + 1 ;
49+
50+ y0_src = max (0 , min (y0_src, ne01_src - 1 ));
51+ y1_src = max (0 , min (y1_src, ne01_src - 1 ));
52+
53+ float dy = y_src_f - (float )y0_src;
54+ dy = max (0 .0f , min (dy, 1 .0f ));
55+
56+ float x_src_f = ((float )i10_dst + pixel_offset) / sf0 - pixel_offset;
57+ int x0_src = (int )floorf (x_src_f);
58+ int x1_src = x0_src + 1 ;
59+
60+ x0_src = max (0 , min (x0_src, ne00_src - 1 ));
61+ x1_src = max (0 , min (x1_src, ne00_src - 1 ));
62+
63+ float dx = x_src_f - (float )x0_src;
64+ dx = max (0 .0f , min (dx, 1 .0f ));
65+
66+ const float * p_a = (const float *)((const char *)x + (int64_t )x0_src * nb00 + (int64_t )y0_src * nb01 + (int64_t )i02_src * nb02 + (int64_t )i03_src * nb03);
67+ const float * p_b = (const float *)((const char *)x + (int64_t )x1_src * nb00 + (int64_t )y0_src * nb01 + (int64_t )i02_src * nb02 + (int64_t )i03_src * nb03);
68+ const float * p_c = (const float *)((const char *)x + (int64_t )x0_src * nb00 + (int64_t )y1_src * nb01 + (int64_t )i02_src * nb02 + (int64_t )i03_src * nb03);
69+ const float * p_d = (const float *)((const char *)x + (int64_t )x1_src * nb00 + (int64_t )y1_src * nb01 + (int64_t )i02_src * nb02 + (int64_t )i03_src * nb03);
70+
71+ const float val_a = *p_a;
72+ const float val_b = *p_b;
73+ const float val_c = *p_c;
74+ const float val_d = *p_d;
75+
76+ float result = val_a * (1 .0f - dx) * (1 .0f - dy) +
77+ val_b * dx * (1 .0f - dy) +
78+ val_c * (1 .0f - dx) * dy +
79+ val_d * dx * dy;
80+
81+ dst[index] = result;
82+ }
83+
2584static void upscale_f32_cuda (const float * x, float * dst,
2685 const int nb00, const int nb01, const int nb02, const int nb03,
2786 const int ne10, const int ne11, const int ne12, const int ne13,
2887 const float sf0, const float sf1, const float sf2, const float sf3,
2988 cudaStream_t stream) {
30- int dst_size = ne10 * ne11 * ne12 * ne13;
31- int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
89+ const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
90+ const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
3291
3392 upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
3493}
3594
95+ static void upscale_f32_bilinear_cuda (const float * x, float * dst,
96+ const int nb00, const int nb01, const int nb02, const int nb03,
97+ const int ne00_src, const int ne01_src,
98+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
99+ const float sf0, const float sf1, const float sf2, const float sf3,
100+ const float pixel_offset, cudaStream_t stream) {
101+ const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
102+ const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
103+
104+ upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
105+ }
106+
36107void ggml_cuda_op_upscale (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
37108 const ggml_tensor * src0 = dst->src [0 ];
38109 const float * src0_d = (const float *)src0->data ;
@@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
42113 GGML_ASSERT (src0->type == GGML_TYPE_F32);
43114 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
44115
45- const float sf0 = (float )dst->ne [0 ]/src0->ne [0 ];
46- const float sf1 = (float )dst->ne [1 ]/src0->ne [1 ];
47- const float sf2 = (float )dst->ne [2 ]/src0->ne [2 ];
116+ const int mode_flags = dst->op_params [0 ];
117+ const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF );
118+
119+ float sf0 = (float )dst->ne [0 ]/src0->ne [0 ];
120+ float sf1 = (float )dst->ne [1 ]/src0->ne [1 ];
121+ float sf2 = (float )dst->ne [2 ]/src0->ne [2 ];
48122 const float sf3 = (float )dst->ne [3 ]/src0->ne [3 ];
49123
50- upscale_f32_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], sf0, sf1, sf2, sf3, stream);
124+ if (mode == GGML_SCALE_MODE_NEAREST) {
125+ upscale_f32_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], sf0, sf1, sf2, sf3, stream);
126+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
127+ float pixel_offset = 0 .5f ;
128+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
129+ sf0 = (float )(dst->ne [0 ] - 1 ) / (src0->ne [0 ] - 1 );
130+ sf1 = (float )(dst->ne [1 ] - 1 ) / (src0->ne [1 ] - 1 );
131+ pixel_offset = 0 .0f ;
132+ }
133+ upscale_f32_bilinear_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
134+ src0->ne [0 ], src0->ne [1 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
135+ sf0, sf1, sf2, sf3, pixel_offset, stream);
136+ }
51137}
0 commit comments