forked from Eco-Sphere/cache-dit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_visual_cloze.py
More file actions
76 lines (63 loc) · 1.88 KB
/
run_visual_cloze.py
File metadata and controls
76 lines (63 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import sys
sys.path.append("..")
import time
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
from utils import get_args, strify, cachify
import cache_dit
args = get_args()
print(args)
# Load the VisualClozePipeline
pipe = VisualClozePipeline.from_pretrained(
os.environ.get(
"VISUAL_CLOZE_DIR",
"VisualCloze/VisualClozePipeline-512",
),
resolution=512,
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
if args.cache:
cachify(args, pipe)
# Load in-context images (make sure the paths are correct and accessible)
# The images are from the VITON-HD dataset at https://github.com/shadow2496/VITON-HD
image_paths = [
# in-context examples
[
load_image("../data/visualcloze/00700_00.jpg"),
load_image("../data/visualcloze/03673_00.jpg"),
load_image("../data/visualcloze/00700_00_tryon_catvton_0.jpg"),
],
# query with the target image
[
load_image("../data/visualcloze/00555_00.jpg"),
load_image("../data/visualcloze/12265_00.jpg"),
None,
],
]
# Task and content prompt
task_prompt = "Each row shows a virtual try-on process that aims to put [IMAGE2] the clothing onto [IMAGE1] the person, producing [IMAGE3] the person wearing the new clothing."
content_prompt = None
# Run the pipeline
start = time.time()
image = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_height=1632,
upsampling_width=1232,
upsampling_strength=0.3,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0),
).images[0][0]
end = time.time()
cache_dit.summary(pipe)
time_cost = end - start
save_path = f"visualcloze-512.{strify(args, pipe)}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)