2424
2525#include " ggml-alloc.h"
2626#include " ggml-backend.h"
27- #include " ggml-cpu.h"
2827#include " ggml.h"
28+ #include " ggml_extend_backend.hpp"
2929
3030#include " model.h"
3131#include " tensor.hpp"
3232
33- #ifdef SD_USE_CUDA
34- #include " ggml-cuda.h"
35- #endif
36-
37- #ifdef SD_USE_METAL
38- #include " ggml-metal.h"
39- #endif
40-
41- #ifdef SD_USE_VULKAN
42- #include " ggml-vulkan.h"
43- #endif
44-
45- #ifdef SD_USE_OPENCL
46- #include " ggml-opencl.h"
47- #endif
48-
49- #ifdef SD_USE_SYCL
50- #include " ggml-sycl.h"
51- #endif
52-
5333#include " rng.hpp"
5434#include " tensor_ggml.hpp"
5535#include " util.h"
@@ -91,6 +71,45 @@ __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const cha
9171 }
9272}
9373
74+ __STATIC_INLINE__ bool backend_name_exists (std::string name) {
75+ ggml_backend_load_all_once ();
76+ const int device_count = ggml_backend_dev_count ();
77+ for (int i = 0 ; i < device_count; i++) {
78+ if (name == ggml_backend_dev_name (ggml_backend_dev_get (i))) {
79+ return true ;
80+ }
81+ }
82+ return false ;
83+ }
84+
85+ __STATIC_INLINE__ std::string sanitize_backend_name (std::string name) {
86+ if (name == " " || backend_name_exists (name)) {
87+ return name;
88+ } else {
89+ LOG_WARN (" Backend %s not found, using default backend" , name.c_str ());
90+ return " " ;
91+ }
92+ }
93+
94+ __STATIC_INLINE__ std::string get_default_backend_name () {
95+ ggml_backend_load_all_once ();
96+ // should pick the same backend as ggml_backend_init_best
97+ ggml_backend_dev_t dev = ggml_backend_dev_by_type (GGML_BACKEND_DEVICE_TYPE_GPU);
98+ dev = dev ? dev : ggml_backend_dev_by_type (GGML_BACKEND_DEVICE_TYPE_IGPU);
99+ dev = dev ? dev : ggml_backend_dev_by_type (GGML_BACKEND_DEVICE_TYPE_CPU);
100+ return ggml_backend_dev_name (dev);
101+ }
102+
103+ __STATIC_INLINE__ ggml_backend_t init_named_backend (std::string name = " " ) {
104+ ggml_backend_load_all_once ();
105+ LOG_DEBUG (" Initializing backend: %s" , name.c_str ());
106+ if (name.empty ()) {
107+ return ggml_backend_init_best ();
108+ } else {
109+ return ggml_backend_init_by_name (name.c_str (), nullptr );
110+ }
111+ }
112+
94113static_assert (GGML_MAX_NAME >= 128 , " GGML_MAX_NAME must be at least 128" );
95114
96115// n-mode tensor-matrix product
@@ -1286,25 +1305,25 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_ones_like(ggml_context* ctx,
12861305 return ggml_ext_ones (ctx, x->ne [0 ], x->ne [1 ], x->ne [2 ], x->ne [3 ]);
12871306}
12881307
1289- __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32 (ggml_context* ctx, ggml_tensor* a) {
1290- #ifdef SD_USE_VULKAN
1291- auto zero_index = ggml_get_tensor (ctx, " ggml_runner_build_in_tensor:zero_int" );
1292- auto out = ggml_reshape_1d (ctx, a, ggml_nelements (a));
1293- out = ggml_get_rows (ctx, out, zero_index);
1294- out = ggml_reshape (ctx, out, a);
1295- // auto out = ggml_cast(ctx, a, GGML_TYPE_F32);
1296- return out;
1297- #else
1298- auto out = ggml_reshape_2d (ctx, a, 1 , ggml_nelements (a));
1299- ggml_tensor* one = ggml_ext_ones (ctx, 1 , 1 , 1 , 1 ); // [1,]
1300- if (ggml_is_transposed (out)) {
1301- out = ggml_mul_mat (ctx, one, out);
1308+ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32 (ggml_context* ctx, ggml_backend_t backend, ggml_tensor* a) {
1309+ if (sd_backend_is (backend, " Vulkan" )) {
1310+ auto zero_index = ggml_get_tensor (ctx, " ggml_runner_build_in_tensor:zero_int" );
1311+ auto out = ggml_reshape_1d (ctx, a, ggml_nelements (a));
1312+ out = ggml_get_rows (ctx, out, zero_index);
1313+ out = ggml_reshape (ctx, out, a);
1314+ // auto out = ggml_cast(ctx, a, GGML_TYPE_F32);
1315+ return out;
13021316 } else {
1303- out = ggml_mul_mat (ctx, out, one);
1317+ auto out = ggml_reshape_2d (ctx, a, 1 , ggml_nelements (a));
1318+ ggml_tensor* one = ggml_ext_ones (ctx, 1 , 1 , 1 , 1 ); // [1,]
1319+ if (ggml_is_transposed (out)) {
1320+ out = ggml_mul_mat (ctx, one, out);
1321+ } else {
1322+ out = ggml_mul_mat (ctx, out, one);
1323+ }
1324+ out = ggml_reshape (ctx, out, a);
1325+ return out;
13041326 }
1305- out = ggml_reshape (ctx, out, a);
1306- #endif
1307- return out;
13081327}
13091328
13101329// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
@@ -1496,16 +1515,14 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_group_norm(ggml_context* ctx,
14961515}
14971516
14981517__STATIC_INLINE__ void ggml_ext_backend_tensor_get_and_sync (ggml_backend_t backend, const ggml_tensor* tensor, void * data, size_t offset, size_t size) {
1499- # if defined(SD_USE_CUDA ) || defined(SD_USE_SYCL)
1500- if ( !ggml_backend_is_cpu (backend)) {
1518+ if (( sd_backend_is (backend, " ROCm " ) || sd_backend_is (backend, " CUDA " ) || sd_backend_is (backend, " SYCL " )) &&
1519+ !ggml_backend_is_cpu (backend)) {
15011520 ggml_backend_tensor_get_async (backend, tensor, data, offset, size);
15021521 ggml_backend_synchronize (backend);
1503- } else {
1504- ggml_backend_tensor_get (tensor, data, offset, size);
1522+ return ;
15051523 }
1506- # else
1524+
15071525 ggml_backend_tensor_get (tensor, data, offset, size);
1508- #endif
15091526}
15101527
15111528__STATIC_INLINE__ float ggml_ext_backend_tensor_get_f32 (ggml_tensor* tensor) {
@@ -1664,14 +1681,15 @@ struct WeightAdapter {
16641681 float scale = 1 .f;
16651682 } conv2d;
16661683 };
1667- virtual ggml_tensor* patch_weight (ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
1684+ virtual ggml_tensor* patch_weight (ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name) = 0;
16681685 virtual ggml_tensor* forward_with_lora (ggml_context* ctx,
1686+ ggml_backend_t backend,
16691687 ggml_tensor* x,
16701688 ggml_tensor* w,
16711689 ggml_tensor* b,
16721690 const std::string& prefix,
1673- ForwardParams forward_params) = 0;
1674- virtual size_t get_extra_graph_size () = 0;
1691+ ForwardParams forward_params) = 0;
1692+ virtual size_t get_extra_graph_size () = 0;
16751693};
16761694
16771695struct GGMLRunnerContext {
@@ -2192,6 +2210,14 @@ struct GGMLRunner {
21922210 void set_weight_adapter (const std::shared_ptr<WeightAdapter>& adapter) {
21932211 weight_adapter = adapter;
21942212 }
2213+
2214+ ggml_backend_t get_runtime_backend () {
2215+ return runtime_backend;
2216+ }
2217+
2218+ ggml_backend_t get_params_backend () {
2219+ return params_backend;
2220+ }
21952221};
21962222
21972223class GGMLBlock {
@@ -2336,6 +2362,14 @@ class Linear : public UnaryBlock {
23362362 force_prec_f32(force_prec_f32),
23372363 scale(scale) {}
23382364
2365+ void set_scale (float scale_) {
2366+ scale = scale_;
2367+ }
2368+
2369+ void set_force_prec_f32 (bool force_prec_f32_) {
2370+ force_prec_f32 = force_prec_f32_;
2371+ }
2372+
23392373 ggml_tensor* forward (GGMLRunnerContext* ctx, ggml_tensor* x) {
23402374 ggml_tensor* w = params[" weight" ];
23412375 ggml_tensor* b = nullptr ;
@@ -2347,7 +2381,7 @@ class Linear : public UnaryBlock {
23472381 forward_params.op_type = WeightAdapter::ForwardParams::op_type_t ::OP_LINEAR;
23482382 forward_params.linear .force_prec_f32 = force_prec_f32;
23492383 forward_params.linear .scale = scale;
2350- return ctx->weight_adapter ->forward_with_lora (ctx->ggml_ctx , x, w, b, prefix, forward_params);
2384+ return ctx->weight_adapter ->forward_with_lora (ctx->ggml_ctx , ctx-> backend , x, w, b, prefix, forward_params);
23512385 }
23522386 return ggml_ext_linear (ctx->ggml_ctx , x, w, b, force_prec_f32, scale);
23532387 }
@@ -2463,7 +2497,7 @@ class Conv2d : public UnaryBlock {
24632497 forward_params.conv2d .circular_x = ctx->circular_x_enabled ;
24642498 forward_params.conv2d .circular_y = ctx->circular_y_enabled ;
24652499 forward_params.conv2d .scale = scale;
2466- return ctx->weight_adapter ->forward_with_lora (ctx->ggml_ctx , x, w, b, prefix, forward_params);
2500+ return ctx->weight_adapter ->forward_with_lora (ctx->ggml_ctx , ctx-> backend , x, w, b, prefix, forward_params);
24672501 }
24682502 return ggml_ext_conv_2d (ctx->ggml_ctx ,
24692503 x,
@@ -2527,15 +2561,15 @@ class Conv3d : public UnaryBlock {
25272561 ggml_tensor* w = params[" weight" ];
25282562 ggml_tensor* b = nullptr ;
25292563 if (ctx->weight_adapter ) {
2530- w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , w, prefix + " weight" );
2564+ w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , w, prefix + " weight" );
25312565 if (w->type != GGML_TYPE_F16) {
25322566 w = ggml_cast (ctx->ggml_ctx , w, GGML_TYPE_F16);
25332567 }
25342568 }
25352569 if (bias) {
25362570 b = params[" bias" ];
25372571 if (ctx->weight_adapter ) {
2538- b = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , b, prefix + " bias" );
2572+ b = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , b, prefix + " bias" );
25392573 }
25402574 }
25412575 return ggml_ext_conv_3d (ctx->ggml_ctx , x, w, b, in_channels,
@@ -2582,12 +2616,12 @@ class LayerNorm : public UnaryBlock {
25822616 if (elementwise_affine) {
25832617 w = params[" weight" ];
25842618 if (ctx->weight_adapter ) {
2585- w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , w, prefix + " weight" );
2619+ w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , w, prefix + " weight" );
25862620 }
25872621 if (bias) {
25882622 b = params[" bias" ];
25892623 if (ctx->weight_adapter ) {
2590- b = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , b, prefix + " bias" );
2624+ b = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , b, prefix + " bias" );
25912625 }
25922626 }
25932627 }
@@ -2630,8 +2664,8 @@ class GroupNorm : public GGMLBlock {
26302664 w = params[" weight" ];
26312665 b = params[" bias" ];
26322666 if (ctx->weight_adapter ) {
2633- w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , w, prefix + " weight" );
2634- b = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , b, prefix + " bias" );
2667+ w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , w, prefix + " weight" );
2668+ b = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , b, prefix + " bias" );
26352669 }
26362670 }
26372671 return ggml_ext_group_norm (ctx->ggml_ctx , x, w, b, num_groups);
@@ -2665,7 +2699,7 @@ class RMSNorm : public UnaryBlock {
26652699 ggml_tensor* forward (GGMLRunnerContext* ctx, ggml_tensor* x) {
26662700 ggml_tensor* w = params[" weight" ];
26672701 if (ctx->weight_adapter ) {
2668- w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , w, prefix + " weight" );
2702+ w = ctx->weight_adapter ->patch_weight (ctx->ggml_ctx , ctx-> backend , w, prefix + " weight" );
26692703 }
26702704 x = ggml_rms_norm (ctx->ggml_ctx , x, eps);
26712705 x = ggml_mul_inplace (ctx->ggml_ctx , x, w);
@@ -2748,6 +2782,7 @@ class MultiheadAttention : public GGMLBlock {
27482782
27492783__STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward (
27502784 ggml_context* ctx,
2785+ ggml_backend_t backend,
27512786 ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch]
27522787 ggml_tensor* w1, // Outer C (Full rank)
27532788 ggml_tensor* w1a, // Outer A (Low rank part 1)
@@ -2778,29 +2813,29 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward(
27782813 int merge_batch_uq = batch;
27792814 int merge_batch_vp = batch;
27802815
2781- #if SD_USE_VULKAN
2782- if (batch > 1 ) {
2783- // no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
2784- int max_batch = 65535 ;
2785- int max_batch_uq = max_batch / uq;
2786- merge_batch_uq = 1 ;
2787- for (int i = max_batch_uq; i > 0 ; i--) {
2788- if (batch % i == 0 ) {
2789- merge_batch_uq = i;
2790- break ;
2816+ if (sd_backend_is (backend, " Vulkan" )) {
2817+ if (batch > 1 ) {
2818+ // no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
2819+ int max_batch = 65535 ;
2820+ int max_batch_uq = max_batch / uq;
2821+ merge_batch_uq = 1 ;
2822+ for (int i = max_batch_uq; i > 0 ; i--) {
2823+ if (batch % i == 0 ) {
2824+ merge_batch_uq = i;
2825+ break ;
2826+ }
27912827 }
2792- }
27932828
2794- int max_batch_vp = max_batch / vp;
2795- merge_batch_vp = 1 ;
2796- for (int i = max_batch_vp; i > 0 ; i--) {
2797- if (batch % i == 0 ) {
2798- merge_batch_vp = i;
2799- break ;
2829+ int max_batch_vp = max_batch / vp;
2830+ merge_batch_vp = 1 ;
2831+ for (int i = max_batch_vp; i > 0 ; i--) {
2832+ if (batch % i == 0 ) {
2833+ merge_batch_vp = i;
2834+ break ;
2835+ }
28002836 }
28012837 }
28022838 }
2803- #endif
28042839
28052840 ggml_tensor* h_split = ggml_reshape_3d (ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq);
28062841 if (w2 != NULL ) {
0 commit comments