Skip to content

Commit b60214e

Browse files
refactor: unify image/text functions into generate_image, update submodules
1 parent 2de4af3 commit b60214e

21 files changed

+827
-894
lines changed

README.md

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ stable_diffusion = StableDiffusion(
265265
model_path="../models/v1-5-pruned-emaonly.safetensors",
266266
# wtype="default", # Weight type (e.g. "q8_0", "f16", etc) (The "default" setting is automatically applied and determines the weight type of a model file)
267267
)
268-
output = stable_diffusion.txt_to_img(
268+
output = stable_diffusion.generate_image(
269269
prompt="a lovely cat",
270270
width=512, # Must be a multiple of 64
271271
height=512, # Must be a multiple of 64
@@ -291,7 +291,7 @@ stable_diffusion = StableDiffusion(
291291
model_path="../models/v1-5-pruned-emaonly.safetensors",
292292
lora_model_dir="../models/", # This should point to folder where LoRA weights are stored (not an individual file)
293293
)
294-
output = stable_diffusion.txt_to_img(
294+
output = stable_diffusion.generate_image(
295295
prompt="a lovely cat<lora:marblesh:1>",
296296
)
297297
```
@@ -319,9 +319,9 @@ stable_diffusion = StableDiffusion(
319319
clip_l_path="../models/clip_l.safetensors",
320320
t5xxl_path="../models/t5xxl_fp16.safetensors",
321321
vae_path="../models/ae.safetensors",
322-
vae_decode_only=True, # Can be True if we dont use img_to_img
322+
vae_decode_only=True, # Can be True if not generating image to image
323323
)
324-
output = stable_diffusion.txt_to_img(
324+
output = stable_diffusion.generate_image(
325325
prompt="a lovely cat holding a sign says 'flux.cpp'",
326326
sample_steps=4,
327327
cfg_scale=1.0, # a cfg_scale of 1 is recommended for FLUX
@@ -357,7 +357,7 @@ stable_diffusion = StableDiffusion(
357357
vae_path="../models/ae.safetensors",
358358
vae_decode_only=False, # Must be False for FLUX Kontext
359359
)
360-
output = stable_diffusion.edit(
360+
output = stable_diffusion.generate_image(
361361
prompt="make the cat blue",
362362
images=["input.png"],
363363
cfg_scale=1.0, # a cfg_scale of 1 is recommended for FLUX
@@ -380,9 +380,9 @@ stable_diffusion = StableDiffusion(
380380
diffusion_model_path="../models/chroma-unlocked-v40-Q4_0.gguf", # In place of model_path
381381
t5xxl_path="../models/t5xxl_fp16.safetensors",
382382
vae_path="../models/ae.safetensors",
383-
vae_decode_only=True, # Can be True if we dont use img_to_img
383+
vae_decode_only=True, # Can be True if we are not generating image to image
384384
)
385-
output = stable_diffusion.txt_to_img(
385+
output = stable_diffusion.generate_image(
386386
prompt="a lovely cat holding a sign says 'chroma.cpp'",
387387
sample_steps=4,
388388
cfg_scale=4.0, # a cfg_scale of 4 is recommended for Chroma
@@ -410,7 +410,7 @@ stable_diffusion = StableDiffusion(
410410
clip_g_path="../models/clip_g.safetensors",
411411
t5xxl_path="../models/t5xxl_fp16.safetensors",
412412
)
413-
output = stable_diffusion.txt_to_img(
413+
output = stable_diffusion.generate_image(
414414
prompt="a lovely cat holding a sign says 'Stable diffusion 3.5 Large'",
415415
height=1024,
416416
width=1024,
@@ -432,9 +432,9 @@ INPUT_IMAGE = "../input.png"
432432
433433
stable_diffusion = StableDiffusion(model_path="../models/v1-5-pruned-emaonly.safetensors")
434434
435-
output = stable_diffusion.img_to_img(
435+
output = stable_diffusion.generate_image(
436436
prompt="blue eyes",
437-
image=INPUT_IMAGE, # Note: The input image will be automatically resized to the match the width and height arguments (default: 512x512)
437+
init_image=INPUT_IMAGE, # Note: The input image will be automatically resized to the match the width and height arguments (default: 512x512)
438438
strength=0.4,
439439
)
440440
```
@@ -447,9 +447,9 @@ from stable_diffusion_cpp import StableDiffusion
447447
# Note: Inpainting with a base model gives poor results. A model fine-tuned for inpainting is recommended.
448448
stable_diffusion = StableDiffusion(model_path="../models/v1-5-pruned-emaonly.safetensors")
449449
450-
output = stable_diffusion.img_to_img(
450+
output = stable_diffusion.generate_image(
451451
prompt="blue eyes",
452-
image="../input.png",
452+
init_image="../input.png",
453453
mask_image="../mask.png", # A grayscale image where 0 is masked and 255 is unmasked
454454
strength=0.4,
455455
)
@@ -478,7 +478,7 @@ stable_diffusion = StableDiffusion(
478478
# keep_vae_on_cpu=True, # If on low memory GPUs (<= 8GB), setting this to True is recommended to get artifact free images
479479
)
480480
481-
output = stable_diffusion.txt_to_img(
481+
output = stable_diffusion.generate_image(
482482
cfg_scale=5.0, # a cfg_scale of 5.0 is recommended for PhotoMaker
483483
height=1024,
484484
width=1024,
@@ -553,6 +553,15 @@ c_image = sd_cpp.sd_image_t(
553553
ctypes.POINTER(ctypes.c_uint8),
554554
),
555555
) # Create a new C sd_image_t
556+
557+
# Convert a model from safetensors to gguf format
558+
sd_cpp.convert(
559+
"../models/v1-5-pruned-emaonly.safetensors".encode("utf-8"), # input_path
560+
"".encode("utf-8"), # vae_path
561+
"../models/v1-5-pruned-emaonly.gguf".encode("utf-8"), # output_path
562+
sd_cpp.GGMLType.SD_TYPE_Q8_0, # output_type
563+
"".encode("utf-8"), # tensor_type_rules
564+
)
556565
```
557566
558567
## Development

stable_diffusion_cpp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44

55
# isort: on
66

7-
__version__ = "0.3.0"
7+
__version__ = "0.3.1"

stable_diffusion_cpp/_internals.py

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
2+
import ctypes
23
from contextlib import ExitStack
34

45
import stable_diffusion_cpp.stable_diffusion_cpp as sd_cpp
5-
66
from ._utils import suppress_stdout_stderr
77

88
# ============================================
@@ -27,7 +27,7 @@ def __init__(
2727
taesd_path: str,
2828
control_net_path: str,
2929
lora_model_dir: str,
30-
embed_dir: str,
30+
embedding_dir: str,
3131
stacked_id_embed_dir: str,
3232
vae_decode_only: bool,
3333
vae_tiling: bool,
@@ -36,43 +36,43 @@ def __init__(
3636
rng_type: int,
3737
schedule: int,
3838
keep_clip_on_cpu: bool,
39-
keep_control_net_cpu: bool,
39+
keep_control_net_on_cpu: bool,
4040
keep_vae_on_cpu: bool,
4141
diffusion_flash_attn: bool,
4242
chroma_use_dit_mask: bool,
4343
chroma_use_t5_mask: bool,
4444
chroma_t5_mask_pad: int,
4545
verbose: bool,
4646
):
47-
self.model_path = model_path
48-
self.clip_l_path = clip_l_path
49-
self.clip_g_path = clip_g_path
50-
self.t5xxl_path = t5xxl_path
51-
self.diffusion_model_path = diffusion_model_path
52-
self.vae_path = vae_path
53-
self.taesd_path = taesd_path
54-
self.control_net_path = control_net_path
55-
self.lora_model_dir = lora_model_dir
56-
self.embed_dir = embed_dir
57-
self.stacked_id_embed_dir = stacked_id_embed_dir
58-
self.vae_decode_only = vae_decode_only
59-
self.vae_tiling = vae_tiling
60-
self.n_threads = n_threads
61-
self.wtype = wtype
62-
self.rng_type = rng_type
63-
self.schedule = schedule
64-
self.keep_clip_on_cpu = keep_clip_on_cpu
65-
self.keep_control_net_cpu = keep_control_net_cpu
66-
self.keep_vae_on_cpu = keep_vae_on_cpu
67-
self.diffusion_flash_attn = diffusion_flash_attn
68-
self.chroma_use_dit_mask = chroma_use_dit_mask
69-
self.chroma_use_t5_mask = chroma_use_t5_mask
70-
self.chroma_t5_mask_pad = chroma_t5_mask_pad
71-
self.verbose = verbose
72-
7347
self._exit_stack = ExitStack()
74-
7548
self.model = None
49+
self.params = sd_cpp.sd_ctx_params_t(
50+
model_path=model_path.encode("utf-8"),
51+
clip_l_path=clip_l_path.encode("utf-8"),
52+
clip_g_path=clip_g_path.encode("utf-8"),
53+
t5xxl_path=t5xxl_path.encode("utf-8"),
54+
diffusion_model_path=diffusion_model_path.encode("utf-8"),
55+
vae_path=vae_path.encode("utf-8"),
56+
taesd_path=taesd_path.encode("utf-8"),
57+
control_net_path=control_net_path.encode("utf-8"),
58+
lora_model_dir=lora_model_dir.encode("utf-8"),
59+
embedding_dir=embedding_dir.encode("utf-8"),
60+
stacked_id_embed_dir=stacked_id_embed_dir.encode("utf-8"),
61+
vae_decode_only=vae_decode_only,
62+
vae_tiling=vae_tiling,
63+
free_params_immediately=False, # Don't unload model
64+
n_threads=n_threads,
65+
wtype=wtype,
66+
rng_type=rng_type,
67+
schedule=schedule,
68+
keep_clip_on_cpu=keep_clip_on_cpu,
69+
keep_control_net_on_cpu=keep_control_net_on_cpu,
70+
keep_vae_on_cpu=keep_vae_on_cpu,
71+
diffusion_flash_attn=diffusion_flash_attn,
72+
chroma_use_dit_mask=chroma_use_dit_mask,
73+
chroma_use_t5_mask=chroma_use_t5_mask,
74+
chroma_t5_mask_pad=chroma_t5_mask_pad,
75+
)
7676

7777
# Load the free_sd_ctx function
7878
self._free_sd_ctx = sd_cpp._lib.free_sd_ctx
@@ -88,34 +88,8 @@ def __init__(
8888

8989
if model_path or diffusion_model_path:
9090
with suppress_stdout_stderr(disable=verbose):
91-
# Load the Stable Diffusion model ctx
92-
self.model = sd_cpp.new_sd_ctx(
93-
self.model_path.encode("utf-8"),
94-
self.clip_l_path.encode("utf-8"),
95-
self.clip_g_path.encode("utf-8"),
96-
self.t5xxl_path.encode("utf-8"),
97-
self.diffusion_model_path.encode("utf-8"),
98-
self.vae_path.encode("utf-8"),
99-
self.taesd_path.encode("utf-8"),
100-
self.control_net_path.encode("utf-8"),
101-
self.lora_model_dir.encode("utf-8"),
102-
self.embed_dir.encode("utf-8"),
103-
self.stacked_id_embed_dir.encode("utf-8"),
104-
self.vae_decode_only,
105-
self.vae_tiling,
106-
False, # Free params immediately (unload model)
107-
self.n_threads,
108-
self.wtype,
109-
self.rng_type,
110-
self.schedule,
111-
self.keep_clip_on_cpu,
112-
self.keep_control_net_cpu,
113-
self.keep_vae_on_cpu,
114-
self.diffusion_flash_attn,
115-
self.chroma_use_dit_mask,
116-
self.chroma_use_t5_mask,
117-
self.chroma_t5_mask_pad,
118-
)
91+
# Call function with a pointer to params
92+
self.model = sd_cpp.new_sd_ctx(ctypes.byref(self.params))
11993

12094
# Check if the model was loaded successfully
12195
if self.model is None:

0 commit comments

Comments
 (0)