Skip to content

Commit cb5c0ce

Browse files
feat: add generation parameters to PIL image.info, update submodule
1 parent b60214e commit cb5c0ce

File tree

6 files changed

+110
-25
lines changed

6 files changed

+110
-25
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ stable_diffusion = StableDiffusion(
267267
)
268268
output = stable_diffusion.generate_image(
269269
prompt="a lovely cat",
270-
width=512, # Must be a multiple of 64
271-
height=512, # Must be a multiple of 64
270+
width=512,
271+
height=512,
272272
progress_callback=callback,
273273
# seed=1337, # Uncomment to set a specific seed (use -1 for a random seed)
274274
)

stable_diffusion_cpp/_internals.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def __init__(
3939
keep_control_net_on_cpu: bool,
4040
keep_vae_on_cpu: bool,
4141
diffusion_flash_attn: bool,
42+
diffusion_conv_direct: bool,
43+
vae_conv_direct: bool,
4244
chroma_use_dit_mask: bool,
4345
chroma_use_t5_mask: bool,
4446
chroma_t5_mask_pad: int,
@@ -69,6 +71,8 @@ def __init__(
6971
keep_control_net_on_cpu=keep_control_net_on_cpu,
7072
keep_vae_on_cpu=keep_vae_on_cpu,
7173
diffusion_flash_attn=diffusion_flash_attn,
74+
diffusion_conv_direct=diffusion_conv_direct,
75+
vae_conv_direct=vae_conv_direct,
7276
chroma_use_dit_mask=chroma_use_dit_mask,
7377
chroma_use_t5_mask=chroma_use_t5_mask,
7478
chroma_t5_mask_pad=chroma_t5_mask_pad,
@@ -127,10 +131,12 @@ def __init__(
127131
self,
128132
upscaler_path: str,
129133
n_threads: int,
134+
diffusion_conv_direct: bool,
130135
verbose: bool,
131136
):
132137
self.upscaler_path = upscaler_path
133138
self.n_threads = n_threads
139+
self.diffusion_conv_direct = diffusion_conv_direct
134140
self.verbose = verbose
135141
self._exit_stack = ExitStack()
136142

@@ -146,7 +152,11 @@ def __init__(
146152
raise ValueError(f"Upscaler model path does not exist: {upscaler_path}")
147153

148154
# Load the image upscaling model ctx
149-
self.upscaler = sd_cpp.new_upscaler_ctx(upscaler_path.encode("utf-8"), self.n_threads)
155+
self.upscaler = sd_cpp.new_upscaler_ctx(
156+
upscaler_path.encode("utf-8"),
157+
self.n_threads,
158+
self.diffusion_conv_direct,
159+
)
150160

151161
# Check if the model was loaded successfully
152162
if self.upscaler is None:

stable_diffusion_cpp/stable_diffusion.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def __init__(
4040
keep_control_net_on_cpu: bool = False,
4141
keep_vae_on_cpu: bool = False,
4242
diffusion_flash_attn: bool = False,
43+
diffusion_conv_direct: bool = False,
44+
vae_conv_direct: bool = False,
4345
chroma_use_dit_mask: bool = True,
4446
chroma_use_t5_mask: bool = False,
4547
chroma_t5_mask_pad: int = 1,
@@ -80,6 +82,8 @@ def __init__(
8082
keep_control_net_on_cpu: Keep controlnet in CPU (for low vram).
8183
keep_vae_on_cpu: Keep vae in CPU (for low vram).
8284
diffusion_flash_attn: Use flash attention in diffusion model (can reduce memory usage significantly). May lower quality or crash if backend not supported.
85+
diffusion_conv_direct: Use Conv2d direct in the diffusion model. May crash if backend not supported.
86+
vae_conv_direct: Use Conv2d direct in the vae model (should improve performance). May crash if backend not supported.
8387
chroma_use_dit_mask: Use DiT mask for chroma.
8488
chroma_use_t5_mask: Use T5 mask for chroma.
8589
chroma_t5_mask_pad: T5 mask padding size of chroma.
@@ -114,6 +118,8 @@ def __init__(
114118
self.keep_control_net_on_cpu = keep_control_net_on_cpu
115119
self.keep_vae_on_cpu = keep_vae_on_cpu
116120
self.diffusion_flash_attn = diffusion_flash_attn
121+
self.diffusion_conv_direct = diffusion_conv_direct
122+
self.vae_conv_direct = vae_conv_direct
117123
self.chroma_use_dit_mask = chroma_use_dit_mask
118124
self.chroma_use_t5_mask = chroma_use_t5_mask
119125
self.chroma_t5_mask_pad = chroma_t5_mask_pad
@@ -160,6 +166,8 @@ def __init__(
160166
keep_control_net_on_cpu=self.keep_control_net_on_cpu,
161167
keep_vae_on_cpu=self.keep_vae_on_cpu,
162168
diffusion_flash_attn=self.diffusion_flash_attn,
169+
diffusion_conv_direct=self.diffusion_conv_direct,
170+
vae_conv_direct=self.vae_conv_direct,
163171
chroma_use_dit_mask=self.chroma_use_dit_mask,
164172
chroma_use_t5_mask=self.chroma_use_t5_mask,
165173
chroma_t5_mask_pad=self.chroma_t5_mask_pad,
@@ -175,6 +183,7 @@ def __init__(
175183
_UpscalerModel(
176184
upscaler_path=upscaler_path,
177185
n_threads=self.n_threads,
186+
diffusion_conv_direct=self.diffusion_conv_direct,
178187
verbose=self.verbose,
179188
)
180189
)
@@ -276,7 +285,7 @@ def generate_image(
276285

277286
sample_method = validate_and_set_input(sample_method, SAMPLE_METHOD_MAP, "sample_method")
278287

279-
# Ensure dimensions are multiples of 64
288+
# Ensure valid dimensions
280289
width = validate_dimensions(width, "width")
281290
height = validate_dimensions(height, "height")
282291

@@ -407,7 +416,42 @@ def _create_blank_mask_image(width: int, height: int):
407416
)
408417

409418
# Convert the C array of images to a Python list of images
410-
return self._sd_image_t_p_to_images(c_images, batch_count, upscale_factor)
419+
images = self._sd_image_t_p_to_images(c_images, batch_count, upscale_factor)
420+
421+
# Attach metadata safely
422+
for i, image in enumerate(images):
423+
image.info.update(
424+
{
425+
# Generation Parameters
426+
"prompt": prompt,
427+
"negative_prompt": negative_prompt,
428+
"seed": seed + i if batch_count > 1 else seed, # Increment seed for each image in batch
429+
"sample_steps": sample_steps,
430+
"sample_method": sample_method,
431+
"cfg_scale": cfg_scale,
432+
"slg_scale": slg_scale,
433+
"skip_layers": skip_layers,
434+
"skip_layer_start": skip_layer_start,
435+
"skip_layer_end": skip_layer_end,
436+
"guidance": guidance,
437+
"eta": eta,
438+
"width": width,
439+
"height": height,
440+
# Model Context Parameters
441+
"model_path": self.model_path,
442+
"diffusion_model_path": self.diffusion_model_path,
443+
"vae_path": self.vae_path,
444+
"clip_l_path": self.clip_l_path,
445+
"clip_g_path": self.clip_g_path,
446+
"t5xxl_path": self.t5xxl_path,
447+
"taesd_path": self.taesd_path,
448+
"control_net_path": self.control_net_path,
449+
"rng_type": self.rng_type,
450+
"clip_skip": clip_skip,
451+
}
452+
)
453+
454+
return images
411455

412456
# ============================================
413457
# Generate Video
@@ -476,7 +520,7 @@ def generate_video(
476520

477521
# sample_method = validate_and_set_input(sample_method, SAMPLE_METHOD_MAP, "sample_method")
478522

479-
# # Ensure dimensions are multiples of 64
523+
# # Ensure valid dimensions
480524
# width = validate_dimensions(width, "width")
481525
# height = validate_dimensions(height, "height")
482526

@@ -865,10 +909,9 @@ def __del__(self) -> None:
865909

866910

867911
def validate_dimensions(dimension: Union[int, float], attribute_name: str) -> int:
868-
"""Dimensions must be a multiple of 64 otherwise a GGML_ASSERT error is encountered."""
869912
dimension = int(dimension)
870-
if dimension <= 0 or dimension % 64 != 0:
871-
raise ValueError(f"The '{attribute_name}' must be a multiple of 64.")
913+
if dimension <= 0:
914+
raise ValueError(f"The '{attribute_name}' must be greater than 0.")
872915
return dimension
873916

874917

stable_diffusion_cpp/stable_diffusion_cpp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class GGMLType(IntEnum):
308308
# ------------ sd_ctx_params_t ------------
309309

310310

311-
# typedef struct { const char* model_path; const char* clip_l_path; const char* clip_g_path; const char* t5xxl_path; const char* diffusion_model_path; const char* vae_path; const char* taesd_path; const char* control_net_path; const char* lora_model_dir; const char* embedding_dir; const char* stacked_id_embed_dir; bool vae_decode_only; bool vae_tiling; bool free_params_immediately; int n_threads; enum sd_type_t wtype; enum rng_type_t rng_type; enum schedule_t schedule; bool keep_clip_on_cpu; bool keep_control_net_on_cpu; bool keep_vae_on_cpu; bool diffusion_flash_attn; bool chroma_use_dit_mask; bool chroma_use_t5_mask; int chroma_t5_mask_pad; } sd_ctx_params_t;
311+
# typedef struct { const char* model_path; const char* clip_l_path; const char* clip_g_path; const char* t5xxl_path; const char* diffusion_model_path; const char* vae_path; const char* taesd_path; const char* control_net_path; const char* lora_model_dir; const char* embedding_dir; const char* stacked_id_embed_dir; bool vae_decode_only; bool vae_tiling; bool free_params_immediately; int n_threads; enum sd_type_t wtype; enum rng_type_t rng_type; enum schedule_t schedule; bool keep_clip_on_cpu; bool keep_control_net_on_cpu; bool keep_vae_on_cpu; bool diffusion_flash_attn; bool diffusion_conv_direct; bool vae_conv_direct; bool chroma_use_dit_mask; bool chroma_use_t5_mask; int chroma_t5_mask_pad; } sd_ctx_params_t;
312312
class sd_ctx_params_t(ctypes.Structure):
313313
_fields_ = [
314314
("model_path", ctypes.c_char_p),
@@ -333,6 +333,8 @@ class sd_ctx_params_t(ctypes.Structure):
333333
("keep_control_net_on_cpu", ctypes.c_bool),
334334
("keep_vae_on_cpu", ctypes.c_bool),
335335
("diffusion_flash_attn", ctypes.c_bool),
336+
("diffusion_conv_direct", ctypes.c_bool),
337+
("vae_conv_direct", ctypes.c_bool),
336338
("chroma_use_dit_mask", ctypes.c_bool),
337339
("chroma_use_t5_mask", ctypes.c_bool),
338340
("chroma_t5_mask_pad", ctypes.c_int),
@@ -532,18 +534,20 @@ class upscaler_ctx_t(ctypes.Structure):
532534
# ------------ new_upscaler_ctx ------------
533535

534536

535-
# SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, int n_threads);
537+
# SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, int n_threads, bool direct);
536538
@ctypes_function(
537539
"new_upscaler_ctx",
538540
[
539541
ctypes.c_char_p, # esrgan_path
540542
ctypes.c_int, # n_threads
543+
ctypes.c_bool, # direct
541544
],
542545
upscaler_ctx_t_p_ctypes,
543546
)
544547
def new_upscaler_ctx(
545548
esrgan_path: bytes,
546549
n_threads: int,
550+
direct: bool,
547551
/,
548552
) -> upscaler_ctx_t_p: ...
549553

tests/test_txt2img.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99

1010
LORA_DIR = "C:\\stable-diffusion\\loras"
1111

12+
PROMPTS = [
13+
# {"add": "_lora", "prompt": "a lovely cat <lora:realism_lora:1>"}, # With LORA
14+
# {"add": "", "prompt": "a lovely cat"}, # Without LORA
15+
# {"add": "_lora", "prompt": "a cute cat glass statue <lora:glass_statue_v1:1>"}, # With LORA
16+
{"add": "", "prompt": "a cute cat glass statue"}, # Without LORA
17+
]
18+
STEPS = 4
19+
20+
OUTPUT_DIR = "tests/outputs"
21+
if not os.path.exists(OUTPUT_DIR):
22+
os.makedirs(OUTPUT_DIR)
23+
24+
1225
stable_diffusion = StableDiffusion(
1326
model_path=MODEL_PATH,
1427
lora_model_dir=LORA_DIR,
@@ -20,22 +33,11 @@ def callback(step: int, steps: int, time: float):
2033

2134

2235
try:
23-
prompts = [
24-
# {"add": "_lora", "prompt": "a lovely cat <lora:realism_lora:1>"}, # With LORA
25-
# {"add": "", "prompt": "a lovely cat"}, # Without LORA
26-
{"add": "_lora", "prompt": "a cute cat glass statue <lora:glass_statue_v1:1>"}, # With LORA
27-
{"add": "", "prompt": "a cute cat glass statue"}, # Without LORA
28-
]
29-
30-
OUTPUT_DIR = "tests/outputs"
31-
if not os.path.exists(OUTPUT_DIR):
32-
os.makedirs(OUTPUT_DIR)
33-
34-
for prompt in prompts:
36+
for prompt in PROMPTS:
3537
# Generate images
3638
images = stable_diffusion.generate_image(
3739
prompt=prompt["prompt"],
38-
sample_steps=4,
40+
sample_steps=STEPS,
3941
progress_callback=callback,
4042
)
4143

@@ -46,3 +48,29 @@ def callback(step: int, steps: int, time: float):
4648
except Exception as e:
4749
traceback.print_exc()
4850
print("Test - txt2img failed: ", e)
51+
52+
# # ======== C++ CLI ========
53+
54+
# import subprocess
55+
56+
# stable_diffusion = None # Clear model
57+
58+
# SD_CPP_CLI = "C:\\Users\\Willi\\Documents\\GitHub\\stable-diffusion.cpp\\build\\bin\\sd"
59+
60+
# for prompt in PROMPTS:
61+
# cli_cmd = [
62+
# SD_CPP_CLI,
63+
# "--model",
64+
# MODEL_PATH,
65+
# "--lora-model-dir",
66+
# LORA_DIR,
67+
# "--prompt",
68+
# prompt["prompt"],
69+
# "--steps",
70+
# str(STEPS),
71+
# "--output",
72+
# f"{OUTPUT_DIR}/txt2img{prompt['add']}_cli.png",
73+
# "-v",
74+
# ]
75+
# print(" ".join(cli_cmd))
76+
# subprocess.run(cli_cmd, check=True)

0 commit comments

Comments
 (0)