Skip to content

Commit 5ea29e3

Browse files
wbrunastduhpfCyberhan123
committed
feat: transition from compile-time to runtime backend discovery
Applies the following parts from leejet#1184 and leejet#1368 : - Introduce ggml_extend_backend.hpp for dynamic backend management. - Convert backend-specific SD_USE_* preprocessor tests to runtime tests, propagating the backend handler when needed. Additionally, to make this work with minimal changes: - A new function sd_get_default_backend replaces the default backend selection on src/stable-diffusion.cpp and src/upscaler.cpp, preserving the SD_VK_DEVICE env var support. - Clean up SD_USE_* defines from CMakeLists.txt (no other build changes). This is not just a refactor, because it improves device selection a bit: - Previously, Vulkan selected device 0 by default, but this was the wrong choice on my system, which has the iGPU as 0 and the discrete card as 1. The new selection algorithm correctly prioritizes the discrete GPU. - The upscaler now follows SD_VK_DEVICE too. Co-authored-by: Stéphane du Hamel <stephduh@live.fr> Co-authored-by: Cyberhan123 <255542417@qq.com>
1 parent c97702e commit 5ea29e3

12 files changed

Lines changed: 543 additions & 219 deletions

CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,37 +72,31 @@ option(SD_USE_SYSTEM_GGML "sd: use system-installed GGML library" OFF
7272
if(SD_CUDA)
7373
message("-- Use CUDA as backend stable-diffusion")
7474
set(GGML_CUDA ON)
75-
add_definitions(-DSD_USE_CUDA)
7675
endif()
7776

7877
if(SD_METAL)
7978
message("-- Use Metal as backend stable-diffusion")
8079
set(GGML_METAL ON)
81-
add_definitions(-DSD_USE_METAL)
8280
endif()
8381

8482
if (SD_VULKAN)
8583
message("-- Use Vulkan as backend stable-diffusion")
8684
set(GGML_VULKAN ON)
87-
add_definitions(-DSD_USE_VULKAN)
8885
endif ()
8986

9087
if (SD_OPENCL)
9188
message("-- Use OpenCL as backend stable-diffusion")
9289
set(GGML_OPENCL ON)
93-
add_definitions(-DSD_USE_OPENCL)
9490
endif ()
9591

9692
if (SD_HIPBLAS)
9793
message("-- Use HIPBLAS as backend stable-diffusion")
9894
set(GGML_HIP ON)
99-
add_definitions(-DSD_USE_CUDA)
10095
endif ()
10196

10297
if(SD_MUSA)
10398
message("-- Use MUSA as backend stable-diffusion")
10499
set(GGML_MUSA ON)
105-
add_definitions(-DSD_USE_CUDA)
106100
endif()
107101

108102
if(SD_WEBP)
@@ -222,7 +216,6 @@ if(SD_SYCL)
222216
message("-- Use SYCL as backend stable-diffusion")
223217
set(GGML_SYCL ON)
224218
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
225-
add_definitions(-DSD_USE_SYCL)
226219
# disable fast-math on host, see:
227220
# https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html
228221
if (WIN32)

src/common_block.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef __COMMON_BLOCK_HPP__
22
#define __COMMON_BLOCK_HPP__
33

4+
#include "ggml-backend.h"
45
#include "ggml_extend.hpp"
6+
#include "util.h"
57

68
class DownSampleBlock : public GGMLBlock {
79
protected:
@@ -248,9 +250,6 @@ class FeedForward : public GGMLBlock {
248250
float scale = 1.f;
249251
if (precision_fix) {
250252
scale = 1.f / 128.f;
251-
#ifdef SD_USE_VULKAN
252-
force_prec_f32 = true;
253-
#endif
254253
}
255254
// The purpose of the scale here is to prevent NaN issues in certain situations.
256255
// For example, when using Vulkan without enabling force_prec_f32,
@@ -264,6 +263,9 @@ class FeedForward : public GGMLBlock {
264263

265264
auto net_0 = std::dynamic_pointer_cast<UnaryBlock>(blocks["net.0"]);
266265
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
266+
if (sd_backend_is(ctx->backend, "Vulkan")) {
267+
net_2->set_force_prec_f32(true);
268+
}
267269

268270
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
269271
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]

src/ggml_extend.hpp

Lines changed: 109 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,12 @@
2424

2525
#include "ggml-alloc.h"
2626
#include "ggml-backend.h"
27-
#include "ggml-cpu.h"
2827
#include "ggml.h"
28+
#include "ggml_extend_backend.hpp"
2929

3030
#include "model.h"
3131
#include "tensor.hpp"
3232

33-
#ifdef SD_USE_CUDA
34-
#include "ggml-cuda.h"
35-
#endif
36-
37-
#ifdef SD_USE_METAL
38-
#include "ggml-metal.h"
39-
#endif
40-
41-
#ifdef SD_USE_VULKAN
42-
#include "ggml-vulkan.h"
43-
#endif
44-
45-
#ifdef SD_USE_OPENCL
46-
#include "ggml-opencl.h"
47-
#endif
48-
49-
#ifdef SD_USE_SYCL
50-
#include "ggml-sycl.h"
51-
#endif
52-
5333
#include "rng.hpp"
5434
#include "tensor_ggml.hpp"
5535
#include "util.h"
@@ -91,6 +71,45 @@ __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const cha
9171
}
9272
}
9373

74+
__STATIC_INLINE__ bool backend_name_exists(std::string name) {
75+
ggml_backend_load_all_once();
76+
const int device_count = ggml_backend_dev_count();
77+
for (int i = 0; i < device_count; i++) {
78+
if (name == ggml_backend_dev_name(ggml_backend_dev_get(i))) {
79+
return true;
80+
}
81+
}
82+
return false;
83+
}
84+
85+
__STATIC_INLINE__ std::string sanitize_backend_name(std::string name) {
86+
if (name == "" || backend_name_exists(name)) {
87+
return name;
88+
} else {
89+
LOG_WARN("Backend %s not found, using default backend", name.c_str());
90+
return "";
91+
}
92+
}
93+
94+
__STATIC_INLINE__ std::string get_default_backend_name() {
95+
ggml_backend_load_all_once();
96+
// should pick the same backend as ggml_backend_init_best
97+
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
98+
dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU);
99+
dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
100+
return ggml_backend_dev_name(dev);
101+
}
102+
103+
__STATIC_INLINE__ ggml_backend_t init_named_backend(std::string name = "") {
104+
ggml_backend_load_all_once();
105+
LOG_DEBUG("Initializing backend: %s", name.c_str());
106+
if (name.empty()) {
107+
return ggml_backend_init_best();
108+
} else {
109+
return ggml_backend_init_by_name(name.c_str(), nullptr);
110+
}
111+
}
112+
94113
static_assert(GGML_MAX_NAME >= 128, "GGML_MAX_NAME must be at least 128");
95114

96115
// n-mode tensor-matrix product
@@ -1286,25 +1305,25 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_ones_like(ggml_context* ctx,
12861305
return ggml_ext_ones(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
12871306
}
12881307

1289-
__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) {
1290-
#ifdef SD_USE_VULKAN
1291-
auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int");
1292-
auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a));
1293-
out = ggml_get_rows(ctx, out, zero_index);
1294-
out = ggml_reshape(ctx, out, a);
1295-
// auto out = ggml_cast(ctx, a, GGML_TYPE_F32);
1296-
return out;
1297-
#else
1298-
auto out = ggml_reshape_2d(ctx, a, 1, ggml_nelements(a));
1299-
ggml_tensor* one = ggml_ext_ones(ctx, 1, 1, 1, 1); // [1,]
1300-
if (ggml_is_transposed(out)) {
1301-
out = ggml_mul_mat(ctx, one, out);
1308+
__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* a) {
1309+
if (sd_backend_is(backend, "Vulkan")) {
1310+
auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int");
1311+
auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a));
1312+
out = ggml_get_rows(ctx, out, zero_index);
1313+
out = ggml_reshape(ctx, out, a);
1314+
// auto out = ggml_cast(ctx, a, GGML_TYPE_F32);
1315+
return out;
13021316
} else {
1303-
out = ggml_mul_mat(ctx, out, one);
1317+
auto out = ggml_reshape_2d(ctx, a, 1, ggml_nelements(a));
1318+
ggml_tensor* one = ggml_ext_ones(ctx, 1, 1, 1, 1); // [1,]
1319+
if (ggml_is_transposed(out)) {
1320+
out = ggml_mul_mat(ctx, one, out);
1321+
} else {
1322+
out = ggml_mul_mat(ctx, out, one);
1323+
}
1324+
out = ggml_reshape(ctx, out, a);
1325+
return out;
13041326
}
1305-
out = ggml_reshape(ctx, out, a);
1306-
#endif
1307-
return out;
13081327
}
13091328

13101329
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
@@ -1496,16 +1515,14 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_group_norm(ggml_context* ctx,
14961515
}
14971516

14981517
__STATIC_INLINE__ void ggml_ext_backend_tensor_get_and_sync(ggml_backend_t backend, const ggml_tensor* tensor, void* data, size_t offset, size_t size) {
1499-
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL)
1500-
if (!ggml_backend_is_cpu(backend)) {
1518+
if ((sd_backend_is(backend, "ROCm") || sd_backend_is(backend, "CUDA") || sd_backend_is(backend, "SYCL")) &&
1519+
!ggml_backend_is_cpu(backend)) {
15011520
ggml_backend_tensor_get_async(backend, tensor, data, offset, size);
15021521
ggml_backend_synchronize(backend);
1503-
} else {
1504-
ggml_backend_tensor_get(tensor, data, offset, size);
1522+
return;
15051523
}
1506-
#else
1524+
15071525
ggml_backend_tensor_get(tensor, data, offset, size);
1508-
#endif
15091526
}
15101527

15111528
__STATIC_INLINE__ float ggml_ext_backend_tensor_get_f32(ggml_tensor* tensor) {
@@ -1664,14 +1681,15 @@ struct WeightAdapter {
16641681
float scale = 1.f;
16651682
} conv2d;
16661683
};
1667-
virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_tensor* weight, const std::string& weight_name) = 0;
1684+
virtual ggml_tensor* patch_weight(ggml_context* ctx, ggml_backend_t backend, ggml_tensor* weight, const std::string& weight_name) = 0;
16681685
virtual ggml_tensor* forward_with_lora(ggml_context* ctx,
1686+
ggml_backend_t backend,
16691687
ggml_tensor* x,
16701688
ggml_tensor* w,
16711689
ggml_tensor* b,
16721690
const std::string& prefix,
1673-
ForwardParams forward_params) = 0;
1674-
virtual size_t get_extra_graph_size() = 0;
1691+
ForwardParams forward_params) = 0;
1692+
virtual size_t get_extra_graph_size() = 0;
16751693
};
16761694

16771695
struct GGMLRunnerContext {
@@ -2192,6 +2210,14 @@ struct GGMLRunner {
21922210
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
21932211
weight_adapter = adapter;
21942212
}
2213+
2214+
ggml_backend_t get_runtime_backend() {
2215+
return runtime_backend;
2216+
}
2217+
2218+
ggml_backend_t get_params_backend() {
2219+
return params_backend;
2220+
}
21952221
};
21962222

21972223
class GGMLBlock {
@@ -2336,6 +2362,14 @@ class Linear : public UnaryBlock {
23362362
force_prec_f32(force_prec_f32),
23372363
scale(scale) {}
23382364

2365+
void set_scale(float scale_) {
2366+
scale = scale_;
2367+
}
2368+
2369+
void set_force_prec_f32(bool force_prec_f32_) {
2370+
force_prec_f32 = force_prec_f32_;
2371+
}
2372+
23392373
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
23402374
ggml_tensor* w = params["weight"];
23412375
ggml_tensor* b = nullptr;
@@ -2347,7 +2381,7 @@ class Linear : public UnaryBlock {
23472381
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
23482382
forward_params.linear.force_prec_f32 = force_prec_f32;
23492383
forward_params.linear.scale = scale;
2350-
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
2384+
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params);
23512385
}
23522386
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
23532387
}
@@ -2463,7 +2497,7 @@ class Conv2d : public UnaryBlock {
24632497
forward_params.conv2d.circular_x = ctx->circular_x_enabled;
24642498
forward_params.conv2d.circular_y = ctx->circular_y_enabled;
24652499
forward_params.conv2d.scale = scale;
2466-
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
2500+
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, ctx->backend, x, w, b, prefix, forward_params);
24672501
}
24682502
return ggml_ext_conv_2d(ctx->ggml_ctx,
24692503
x,
@@ -2527,15 +2561,15 @@ class Conv3d : public UnaryBlock {
25272561
ggml_tensor* w = params["weight"];
25282562
ggml_tensor* b = nullptr;
25292563
if (ctx->weight_adapter) {
2530-
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
2564+
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight");
25312565
if (w->type != GGML_TYPE_F16) {
25322566
w = ggml_cast(ctx->ggml_ctx, w, GGML_TYPE_F16);
25332567
}
25342568
}
25352569
if (bias) {
25362570
b = params["bias"];
25372571
if (ctx->weight_adapter) {
2538-
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
2572+
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, b, prefix + "bias");
25392573
}
25402574
}
25412575
return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels,
@@ -2582,12 +2616,12 @@ class LayerNorm : public UnaryBlock {
25822616
if (elementwise_affine) {
25832617
w = params["weight"];
25842618
if (ctx->weight_adapter) {
2585-
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
2619+
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight");
25862620
}
25872621
if (bias) {
25882622
b = params["bias"];
25892623
if (ctx->weight_adapter) {
2590-
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
2624+
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, b, prefix + "bias");
25912625
}
25922626
}
25932627
}
@@ -2630,8 +2664,8 @@ class GroupNorm : public GGMLBlock {
26302664
w = params["weight"];
26312665
b = params["bias"];
26322666
if (ctx->weight_adapter) {
2633-
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
2634-
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "bias");
2667+
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight");
2668+
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, b, prefix + "bias");
26352669
}
26362670
}
26372671
return ggml_ext_group_norm(ctx->ggml_ctx, x, w, b, num_groups);
@@ -2665,7 +2699,7 @@ class RMSNorm : public UnaryBlock {
26652699
ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
26662700
ggml_tensor* w = params["weight"];
26672701
if (ctx->weight_adapter) {
2668-
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "weight");
2702+
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, ctx->backend, w, prefix + "weight");
26692703
}
26702704
x = ggml_rms_norm(ctx->ggml_ctx, x, eps);
26712705
x = ggml_mul_inplace(ctx->ggml_ctx, x, w);
@@ -2748,6 +2782,7 @@ class MultiheadAttention : public GGMLBlock {
27482782

27492783
__STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward(
27502784
ggml_context* ctx,
2785+
ggml_backend_t backend,
27512786
ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch]
27522787
ggml_tensor* w1, // Outer C (Full rank)
27532788
ggml_tensor* w1a, // Outer A (Low rank part 1)
@@ -2778,29 +2813,29 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward(
27782813
int merge_batch_uq = batch;
27792814
int merge_batch_vp = batch;
27802815

2781-
#if SD_USE_VULKAN
2782-
if (batch > 1) {
2783-
// no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
2784-
int max_batch = 65535;
2785-
int max_batch_uq = max_batch / uq;
2786-
merge_batch_uq = 1;
2787-
for (int i = max_batch_uq; i > 0; i--) {
2788-
if (batch % i == 0) {
2789-
merge_batch_uq = i;
2790-
break;
2816+
if (sd_backend_is(backend, "Vulkan")) {
2817+
if (batch > 1) {
2818+
// no access to backend here, worst case is slightly worse perfs for other backends when built alongside Vulkan backend
2819+
int max_batch = 65535;
2820+
int max_batch_uq = max_batch / uq;
2821+
merge_batch_uq = 1;
2822+
for (int i = max_batch_uq; i > 0; i--) {
2823+
if (batch % i == 0) {
2824+
merge_batch_uq = i;
2825+
break;
2826+
}
27912827
}
2792-
}
27932828

2794-
int max_batch_vp = max_batch / vp;
2795-
merge_batch_vp = 1;
2796-
for (int i = max_batch_vp; i > 0; i--) {
2797-
if (batch % i == 0) {
2798-
merge_batch_vp = i;
2799-
break;
2829+
int max_batch_vp = max_batch / vp;
2830+
merge_batch_vp = 1;
2831+
for (int i = max_batch_vp; i > 0; i--) {
2832+
if (batch % i == 0) {
2833+
merge_batch_vp = i;
2834+
break;
2835+
}
28002836
}
28012837
}
28022838
}
2803-
#endif
28042839

28052840
ggml_tensor* h_split = ggml_reshape_3d(ctx, h, vq, uq * merge_batch_uq, batch / merge_batch_uq);
28062841
if (w2 != NULL) {

0 commit comments

Comments
 (0)