-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathsd15_config_template.py
More file actions
113 lines (95 loc) · 2.77 KB
/
sd15_config_template.py
File metadata and controls
113 lines (95 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
SD 1.5 LoRA Training Config Template for Kohya sd-scripts
Generates TOML configuration files for train_network.py
"""
import os
def generate_sd15_training_config(
name: str,
image_folder: str,
output_folder: str,
model_path: str,
steps: int = 500,
learning_rate: float = 0.0005,
lora_rank: int = 16,
lora_alpha: int = 16,
resolution: int = 512,
batch_size: int = 1,
optimizer: str = "AdamW8bit",
mixed_precision: str = "fp16",
gradient_checkpointing: bool = True,
cache_latents: bool = True,
) -> str:
"""
Generate a TOML config file for SD 1.5 LoRA training with sd-scripts.
Returns the config as a TOML string.
"""
# Escape backslashes for TOML on Windows
model_path_escaped = model_path.replace('\\', '/')
image_folder_escaped = image_folder.replace('\\', '/')
output_folder_escaped = output_folder.replace('\\', '/')
config = f'''# SD 1.5 LoRA Training Config
# Generated by ComfyUI SD 1.5 LoRA Trainer
[general]
enable_bucket = true
bucket_no_upscale = true
[model]
pretrained_model_name_or_path = "{model_path_escaped}"
[dataset]
train_data_dir = "{image_folder_escaped}"
resolution = "{resolution},{resolution}"
caption_extension = ".txt"
[network]
network_module = "networks.lora"
network_dim = {lora_rank}
network_alpha = {lora_alpha}
network_train_unet_only = true
[optimizer]
optimizer_type = "{optimizer}"
learning_rate = {learning_rate:g}
lr_scheduler = "constant"
[training]
output_dir = "{output_folder_escaped}"
output_name = "{name}"
save_model_as = "safetensors"
save_precision = "fp16"
max_train_steps = {steps}
train_batch_size = {batch_size}
mixed_precision = "{mixed_precision}"
gradient_checkpointing = {str(gradient_checkpointing).lower()}
cache_latents = {str(cache_latents).lower()}
sdpa = true
max_data_loader_n_workers = 0
seed = 42
'''
return config
def save_config(config_content: str, config_path: str):
"""Save config content to a TOML file."""
with open(config_path, 'w', encoding='utf-8') as f:
f.write(config_content)
# VRAM mode presets for SD 1.5 (rank is user-controlled, not preset)
SD15_VRAM_PRESETS = {
"Min (256px)": {
"optimizer": "AdamW8bit",
"mixed_precision": "fp16",
"batch_size": 1,
"gradient_checkpointing": True,
"cache_latents": True,
"resolution": 256,
},
"Low (512px)": {
"optimizer": "AdamW8bit",
"mixed_precision": "fp16",
"batch_size": 1,
"gradient_checkpointing": True,
"cache_latents": True,
"resolution": 512,
},
"Max (768px)": {
"optimizer": "AdamW",
"mixed_precision": "fp16",
"batch_size": 1,
"gradient_checkpointing": False,
"cache_latents": True,
"resolution": 768,
},
}