Skip to content

Commit 6075bd3

Browse files
feat: add WAN video generation support, improve test suite
1 parent 4e74069 commit 6075bd3

25 files changed

+1228
-821
lines changed

README.md

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Simple Python bindings for **@leejet's** [`stable-diffusion.cpp`](https://github
99
This package provides:
1010

1111
- Low-level access to C API via `ctypes` interface.
12-
- High-level Python API for Stable Diffusion and FLUX image generation.
12+
- High-level Python API for Stable Diffusion, FLUX and Wan image/video generation.
1313

1414
## Installation
1515

@@ -97,7 +97,9 @@ This provides BLAS acceleration using the ROCm cores of your AMD GPU. Make sure
9797
Windows users refer to [docs/hipBLAS_on_Windows.md](docs%2FhipBLAS_on_Windows.md) for a comprehensive guide and troubleshooting tips.
9898

9999
```bash
100-
CMAKE_ARGS="-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=gfx1101 -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON" pip install stable-diffusion-cpp-python
100+
export GFX_NAME=$(rocminfo | grep -m 1 -E "gfx[^0]{1}" | sed -e 's/ *Name: *//' | awk '{$1=$1; print}' || echo "rocminfo missing")
101+
echo $GFX_NAME
102+
CMAKE_ARGS="-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=$GFX_NAME -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON" pip install stable-diffusion-cpp-python
101103
```
102104

103105
</details>
@@ -273,6 +275,9 @@ output = stable_diffusion.generate_image(
273275
# seed=1337, # Uncomment to set a specific seed (use -1 for a random seed)
274276
)
275277
output[0].save("output.png") # Output returned as list of PIL Images
278+
279+
# Model and generation paramaters accessible via .info
280+
print(output[0].info)
276281
```
277282
278283
#### <u>With LoRA (Stable Diffusion)</u>
@@ -323,7 +328,6 @@ stable_diffusion = StableDiffusion(
323328
)
324329
output = stable_diffusion.generate_image(
325330
prompt="a lovely cat holding a sign says 'flux.cpp'",
326-
sample_steps=4,
327331
cfg_scale=1.0, # a cfg_scale of 1 is recommended for FLUX
328332
sample_method="euler", # euler is recommended for FLUX
329333
)
@@ -369,22 +373,22 @@ output = stable_diffusion.generate_image(
369373
370374
Download the weights from the links below:
371375
372-
- Preconverted gguf model from [silveroxides/Chroma-GGUF](https://huggingface.co/silveroxides/Chroma-GGUF)
373-
- Otherwise, download chroma's safetensors from [lodestones/Chroma](https://huggingface.co/lodestones/Chroma)
376+
- Preconverted gguf model from [silveroxides/Chroma1-Flash-GGUF](https://huggingface.co/silveroxides/Chroma1-Flash-GGUF), [silveroxides/Chroma1-Base-GGUF](https://huggingface.co/silveroxides/Chroma1-Base-GGUF) or [silveroxides/Chroma1-HD-GGUF](https://huggingface.co/silveroxides/Chroma1-HD-GGUF) ([silveroxides/Chroma-GGUF](https://huggingface.co/silveroxides/Chroma-GGUF) is DEPRECATED)
377+
- Otherwise, download chroma's safetensors from [lodestones/Chroma1-Flash](https://huggingface.co/lodestones/Chroma1-Flash), [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) or [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD) ([lodestones/Chroma](https://huggingface.co/lodestones/Chroma) is DEPRECATED)
374378
- The `vae` and `t5xxl` models are the same as for FLUX image generation linked above (`clip_l` not required).
375379
376380
```python
377381
from stable_diffusion_cpp import StableDiffusion
378382
379383
stable_diffusion = StableDiffusion(
380-
diffusion_model_path="../models/chroma-unlocked-v40-Q4_0.gguf", # In place of model_path
384+
diffusion_model_path="../models/Chroma1-HD-Flash-Q4_0.gguf", # In place of model_path
381385
t5xxl_path="../models/t5xxl_fp16.safetensors",
382386
vae_path="../models/ae.safetensors",
383387
vae_decode_only=True, # Can be True if we are not generating image to image
388+
chroma_use_dit_mask=False,
384389
)
385390
output = stable_diffusion.generate_image(
386391
prompt="a lovely cat holding a sign says 'chroma.cpp'",
387-
sample_steps=4,
388392
cfg_scale=4.0, # a cfg_scale of 4 is recommended for Chroma
389393
sample_method="euler", # euler is recommended for FLUX
390394
)
@@ -510,16 +514,83 @@ An `id_embeds.safetensors` file will be generated in `input_images_dir`.
510514
511515
---
512516
517+
### <u>WAN Video Generation</u>
518+
519+
See [stable-diffusion.cpp WAN download weights](https://github.com/leejet/stable-diffusion.cpp/blob/master/docs/wan.md#download-weights) for a complete list of WAN models.
520+
521+
```python
522+
from stable_diffusion_cpp import StableDiffusion
523+
524+
stable_diffusion = StableDiffusion(
525+
diffusion_model_path="../models/wan2.1_t2v_1.3B_fp16.safetensors", # In place of model_path
526+
t5xxl_path="../models/umt5-xxl-encoder-Q8_0.gguf",
527+
vae_path="../models/wan_2.1_vae.safetensors",
528+
flow_shift=3.0,
529+
)
530+
531+
output = stable_diffusion.generate_video(
532+
prompt="a cute dog jumping",
533+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
534+
height=832,
535+
width=480,
536+
cfg_scale=6.0,
537+
sample_method="euler",
538+
video_frames=33,
539+
) # Output is a list of PIL Images (video frames)
540+
```
541+
542+
As the output is simply a list of images (video frames), you can convert it into a video using any library you prefer. The example below uses `ffmpeg-python`. Alternatively, libraries such **OpenCV** or **MoviePy** can also be used.
543+
544+
> **Note**
545+
>
546+
> - You'll require **Python bindings for FFmpeg**, `python-ffmpeg` (`pip install ffmpeg-python`) in addition to an **FFmpeg installation on your system**, accessible in your PATH. Check with `ffmpeg -version`.
547+
548+
```python
549+
from typing import List
550+
from PIL import Image
551+
import numpy as np
552+
import ffmpeg
553+
554+
def save_video_ffmpeg(frames: List[Image.Image], fps: int, out_path: str) -> None:
555+
if not frames:
556+
raise ValueError("No frames provided")
557+
558+
width, height = frames[0].size
559+
560+
# Concatenate frames into raw RGB bytes
561+
raw_bytes = b"".join(np.array(frame.convert("RGB"), dtype=np.uint8).tobytes() for frame in frames)
562+
(
563+
ffmpeg.input(
564+
"pipe:",
565+
format="rawvideo",
566+
pix_fmt="rgb24",
567+
s=f"{width}x{height}",
568+
r=fps,
569+
)
570+
.output(
571+
out_path,
572+
vcodec="libx264",
573+
pix_fmt="yuv420p",
574+
r=fps,
575+
movflags="+faststart",
576+
)
577+
.overwrite_output()
578+
.run(input=raw_bytes)
579+
)
580+
581+
save_video_ffmpeg(output, fps=16, out_path="output.mp4")
582+
```
583+
513584
### <u>Listing GGML model and RNG types, schedulers and sample methods</u>
514585
515586
Access the GGML model and RNG types, schedulers, and sample methods via the following maps:
516587
517588
```python
518-
from stable_diffusion_cpp import GGML_TYPE_MAP, RNG_TYPE_MAP, SCHEDULE_MAP, SAMPLE_METHOD_MAP
589+
from stable_diffusion_cpp import GGML_TYPE_MAP, RNG_TYPE_MAP, SCHEDULER_MAP, SAMPLE_METHOD_MAP
519590
520591
print("GGML model types:", list(GGML_TYPE_MAP))
521592
print("RNG types:", list(RNG_TYPE_MAP))
522-
print("Schedulers:", list(SCHEDULE_MAP))
593+
print("Schedulers:", list(SCHEDULER_MAP))
523594
print("Sample methods:", list(SAMPLE_METHOD_MAP))
524595
```
525596

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ classifiers = [
3030
"Programming Language :: Python :: 3.13",
3131
]
3232

33+
[project.optional-dependencies]
34+
dev = [
35+
"black>=24.8.0",
36+
"isort>=5.13.2",
37+
"pytest>=7.4.4",
38+
"ffmpeg-python>=0.2.0",
39+
]
40+
3341
[tool.scikit-build]
3442
wheel.packages = ["stable_diffusion_cpp"]
3543
cmake.verbose = true
@@ -54,3 +62,6 @@ profile = "black"
5462
known_local_folder = ["stable_diffusion_cpp"]
5563
remove_redundant_aliases = true
5664
length_sort = true
65+
66+
[tool.pytest.ini_options]
67+
testpaths = "tests"

stable_diffusion_cpp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44

55
# isort: on
66

7-
__version__ = "0.3.2"
7+
__version__ = "0.3.3"

stable_diffusion_cpp/_internals.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import stable_diffusion_cpp.stable_diffusion_cpp as sd_cpp
66
from ._utils import suppress_stdout_stderr
77

8-
# ============================================
8+
# ===========================================
99
# Stable Diffusion Model
10-
# ============================================
10+
# ===========================================
1111

1212

1313
class _StableDiffusionModel:
@@ -21,8 +21,10 @@ def __init__(
2121
model_path: str,
2222
clip_l_path: str,
2323
clip_g_path: str,
24+
clip_vision_path: str,
2425
t5xxl_path: str,
2526
diffusion_model_path: str,
27+
high_noise_diffusion_model_path: str,
2628
vae_path: str,
2729
taesd_path: str,
2830
control_net_path: str,
@@ -34,7 +36,7 @@ def __init__(
3436
n_threads: int,
3537
wtype: int,
3638
rng_type: int,
37-
schedule: int,
39+
offload_params_to_cpu: bool,
3840
keep_clip_on_cpu: bool,
3941
keep_control_net_on_cpu: bool,
4042
keep_vae_on_cpu: bool,
@@ -44,6 +46,7 @@ def __init__(
4446
chroma_use_dit_mask: bool,
4547
chroma_use_t5_mask: bool,
4648
chroma_t5_mask_pad: int,
49+
flow_shift: int,
4750
verbose: bool,
4851
):
4952
self._exit_stack = ExitStack()
@@ -52,8 +55,10 @@ def __init__(
5255
model_path=model_path.encode("utf-8"),
5356
clip_l_path=clip_l_path.encode("utf-8"),
5457
clip_g_path=clip_g_path.encode("utf-8"),
58+
clip_vision_path=clip_vision_path.encode("utf-8"),
5559
t5xxl_path=t5xxl_path.encode("utf-8"),
5660
diffusion_model_path=diffusion_model_path.encode("utf-8"),
61+
high_noise_diffusion_model_path=high_noise_diffusion_model_path.encode("utf-8"),
5762
vae_path=vae_path.encode("utf-8"),
5863
taesd_path=taesd_path.encode("utf-8"),
5964
control_net_path=control_net_path.encode("utf-8"),
@@ -66,7 +71,7 @@ def __init__(
6671
n_threads=n_threads,
6772
wtype=wtype,
6873
rng_type=rng_type,
69-
schedule=schedule,
74+
offload_params_to_cpu=offload_params_to_cpu,
7075
keep_clip_on_cpu=keep_clip_on_cpu,
7176
keep_control_net_on_cpu=keep_control_net_on_cpu,
7277
keep_vae_on_cpu=keep_vae_on_cpu,
@@ -76,6 +81,7 @@ def __init__(
7681
chroma_use_dit_mask=chroma_use_dit_mask,
7782
chroma_use_t5_mask=chroma_use_t5_mask,
7883
chroma_t5_mask_pad=chroma_t5_mask_pad,
84+
flow_shift=flow_shift,
7985
)
8086

8187
# Load the free_sd_ctx function
@@ -84,11 +90,11 @@ def __init__(
8490
# Load the model from the file if the path is provided
8591
if model_path:
8692
if not os.path.exists(model_path):
87-
raise ValueError(f"Model path does not exist: {model_path}")
93+
raise ValueError(f"Model path does not exist: '{model_path}'")
8894

8995
if diffusion_model_path:
9096
if not os.path.exists(diffusion_model_path):
91-
raise ValueError(f"Diffusion model path does not exist: {diffusion_model_path}")
97+
raise ValueError(f"Diffusion model path does not exist: '{diffusion_model_path}'")
9298

9399
if model_path or diffusion_model_path:
94100
with suppress_stdout_stderr(disable=verbose):
@@ -97,7 +103,7 @@ def __init__(
97103

98104
# Check if the model was loaded successfully
99105
if self.model is None:
100-
raise ValueError(f"Failed to load model from file: {model_path}")
106+
raise ValueError(f"Failed to load model from file: '{model_path}'")
101107

102108
def free_ctx():
103109
"""Free the model from memory."""
@@ -116,9 +122,9 @@ def __del__(self):
116122
self.close()
117123

118124

119-
# ============================================
125+
# ===========================================
120126
# Upscaler Model
121-
# ============================================
127+
# ===========================================
122128

123129

124130
class _UpscalerModel:
@@ -130,13 +136,15 @@ class _UpscalerModel:
130136
def __init__(
131137
self,
132138
upscaler_path: str,
139+
offload_params_to_cpu: bool,
140+
direct: bool,
133141
n_threads: int,
134-
diffusion_conv_direct: bool,
135142
verbose: bool,
136143
):
137144
self.upscaler_path = upscaler_path
145+
self.offload_params_to_cpu = offload_params_to_cpu
146+
self.direct = direct
138147
self.n_threads = n_threads
139-
self.diffusion_conv_direct = diffusion_conv_direct
140148
self.verbose = verbose
141149
self._exit_stack = ExitStack()
142150

@@ -149,18 +157,19 @@ def __init__(
149157
self._free_upscaler_ctx = sd_cpp._lib.free_upscaler_ctx
150158

151159
if not os.path.exists(upscaler_path):
152-
raise ValueError(f"Upscaler model path does not exist: {upscaler_path}")
160+
raise ValueError(f"Upscaler model path does not exist: '{upscaler_path}'")
153161

154162
# Load the image upscaling model ctx
155163
self.upscaler = sd_cpp.new_upscaler_ctx(
156164
upscaler_path.encode("utf-8"),
165+
self.offload_params_to_cpu,
166+
self.direct,
157167
self.n_threads,
158-
self.diffusion_conv_direct,
159168
)
160169

161170
# Check if the model was loaded successfully
162171
if self.upscaler is None:
163-
raise ValueError(f"Failed to load upscaler model from file: {upscaler_path}")
172+
raise ValueError(f"Failed to load upscaler model from file: '{upscaler_path}'")
164173

165174
def free_ctx():
166175
"""Free the model from memory."""

0 commit comments

Comments
 (0)