forked from Eco-Sphere/cache-dit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_hidream.py
More file actions
80 lines (68 loc) · 1.97 KB
/
run_hidream.py
File metadata and controls
80 lines (68 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import sys
sys.path.append("..")
import time
import torch
from diffusers import HiDreamImagePipeline
from transformers import AutoTokenizer, LlamaForCausalLM
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from utils import get_args, strify, cachify
import cache_dit
args = get_args()
print(args)
tokenizer_4 = AutoTokenizer.from_pretrained(
os.environ.get(
"LLAMA_DIR",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
),
)
text_encoder_4 = LlamaForCausalLM.from_pretrained(
os.environ.get(
"LLAMA_DIR",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
),
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
quantization_config=TransformersBitsAndBytesConfig(
load_in_4bit=True,
),
)
pipe = HiDreamImagePipeline.from_pretrained(
os.environ.get(
"HIDREAM_DIR",
"HiDream-ai/HiDream-I1-Full",
),
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16,
quantization_config=PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer"],
),
)
pipe.to("cuda")
if args.cache:
cachify(args, pipe)
start = time.time()
image = pipe(
'A cute girl holding a sign that says "Hi-Dreams.ai".',
height=1024 if args.height is None else args.height,
width=1024 if args.width is None else args.width,
guidance_scale=5.0,
num_inference_steps=50,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
end = time.time()
stats = cache_dit.summary(pipe)
time_cost = end - start
save_path = f"hidream.{strify(args, stats)}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving to {save_path}")
image.save(save_path)