forked from Eco-Sphere/cache-dit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_qwen_image_edit.py
More file actions
76 lines (59 loc) · 1.8 KB
/
run_qwen_image_edit.py
File metadata and controls
76 lines (59 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import sys
sys.path.append("..")
import time
import torch
from PIL import Image
from diffusers import QwenImageEditPipeline, QwenImageTransformer2DModel
from utils import GiB, get_args, strify, cachify
import cache_dit
args = get_args()
print(args)
pipe = QwenImageEditPipeline.from_pretrained(
os.environ.get(
"QWEN_IMAGE_EDIT_DIR",
"Qwen/Qwen-Image-Edit",
),
torch_dtype=torch.bfloat16,
# https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
device_map=(
"balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None
),
)
if args.cache:
cachify(args, pipe)
if torch.cuda.device_count() <= 1:
# Enable memory savings
pipe.enable_model_cpu_offload()
image = Image.open("../data/bear.png").convert("RGB")
prompt = "Only change the bear's color to purple"
if args.compile:
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
torch._dynamo.config.recompile_limit = 1024
torch._dynamo.config.accumulated_recompile_limit = 8192
pipe.transformer.compile_repeated_blocks(mode="default")
# Warmup
image = pipe(
image=image,
prompt=prompt,
negative_prompt=" ",
generator=torch.Generator(device="cpu").manual_seed(0),
true_cfg_scale=4.0,
num_inference_steps=50,
).images[0]
start = time.time()
image = pipe(
image=image,
prompt=prompt,
negative_prompt=" ",
generator=torch.Generator(device="cpu").manual_seed(0),
true_cfg_scale=4.0,
num_inference_steps=50,
).images[0]
end = time.time()
stats = cache_dit.summary(pipe)
time_cost = end - start
save_path = f"qwen-image-edit.{strify(args, stats)}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)