forked from modelscope/DiffSynth-Studio
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathflux_text_to_image_low_vram.py
More file actions
51 lines (43 loc) · 1.82 KB
/
flux_text_to_image_low_vram.py
File metadata and controls
51 lines (43 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from diffsynth import download_models, ModelManager, FluxImagePipeline
download_models(["FLUX.1-dev"])
model_manager = ModelManager(
torch_dtype=torch.bfloat16,
device="cpu" # To reduce VRAM required, we load models to RAM.
)
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
])
model_manager.load_models(
["models/FLUX/FLUX.1-dev/flux1-dev.safetensors"],
torch_dtype=torch.float8_e4m3fn # Load the DiT model in FP8 format.
)
pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_cpu_offload()
pipe.dit.quantize()
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(9)
image = pipe(
prompt=prompt,
num_inference_steps=50, embedded_guidance=3.5
)
image.save("image_1024.jpg")
# Enable classifier-free guidance
torch.manual_seed(9)
image = pipe(
prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
)
image.save("image_1024_cfg.jpg")
# Highres-fix
torch.manual_seed(10)
image = pipe(
prompt=prompt,
num_inference_steps=50, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
)
image.save("image_2048_highres.jpg")