Skip to content

Commit 9d8935a

Browse files
Fix LoRA Trainer bugs with FP8 models. (Comfy-Org#9854)
* Fix adapter weight init * Fix fp8 model training * Avoid inference tensor
1 parent d3ba754 commit 9d8935a

File tree

6 files changed

+34
-15
lines changed

6 files changed

+34
-15
lines changed

comfy/ops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,13 @@ def reset_parameters(self):
365365
return None
366366

367367
def forward_comfy_cast_weights(self, input):
368-
try:
369-
out = fp8_linear(self, input)
370-
if out is not None:
371-
return out
372-
except Exception as e:
373-
logging.info("Exception during fp8 op: {}".format(e))
368+
if not self.training:
369+
try:
370+
out = fp8_linear(self, input)
371+
if out is not None:
372+
return out
373+
except Exception as e:
374+
logging.info("Exception during fp8 op: {}".format(e))
374375

375376
weight, bias = cast_bias_weight(self, input)
376377
return torch.nn.functional.linear(input, weight, bias)

comfy/weight_adapter/loha.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ def __init__(self, loaded_keys, weights):
130130
def create_train(cls, weight, rank=1, alpha=1.0):
131131
out_dim = weight.shape[0]
132132
in_dim = weight.shape[1:].numel()
133-
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
134-
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
133+
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
134+
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
135135
torch.nn.init.normal_(mat1, 0.1)
136136
torch.nn.init.constant_(mat2, 0.0)
137-
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
138-
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
137+
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
138+
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
139139
torch.nn.init.normal_(mat3, 0.1)
140140
torch.nn.init.normal_(mat4, 0.01)
141141
return LohaDiff(

comfy/weight_adapter/lokr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def create_train(cls, weight, rank=1, alpha=1.0):
8989
in_dim = weight.shape[1:].numel()
9090
out1, out2 = factorization(out_dim, rank)
9191
in1, in2 = factorization(in_dim, rank)
92-
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
93-
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
92+
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
93+
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
9494
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
9595
torch.nn.init.constant_(mat1, 0.0)
9696
return LokrDiff(

comfy/weight_adapter/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(self, loaded_keys, weights):
6666
def create_train(cls, weight, rank=1, alpha=1.0):
6767
out_dim = weight.shape[0]
6868
in_dim = weight.shape[1:].numel()
69-
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
70-
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
69+
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32)
70+
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
7171
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
7272
torch.nn.init.constant_(mat2, 0.0)
7373
return LoraDiff(

comfy/weight_adapter/oft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, loaded_keys, weights):
6868
def create_train(cls, weight, rank=1, alpha=1.0):
6969
out_dim = weight.shape[0]
7070
block_size, block_num = factorization(out_dim, rank)
71-
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
71+
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
7272
return OFTDiff(
7373
(block, None, alpha, None)
7474
)

comfy_extras/nodes_train.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
3838
return new_dict
3939

4040

41+
def process_cond_list(d, prefix=""):
42+
if hasattr(d, "__iter__") and not hasattr(d, "items"):
43+
for index, item in enumerate(d):
44+
process_cond_list(item, f"{prefix}.{index}")
45+
return d
46+
elif hasattr(d, "items"):
47+
for k, v in list(d.items()):
48+
if isinstance(v, dict):
49+
process_cond_list(v, f"{prefix}.{k}")
50+
elif isinstance(v, torch.Tensor):
51+
d[k] = v.clone()
52+
elif isinstance(v, (list, tuple)):
53+
for index, item in enumerate(v):
54+
process_cond_list(item, f"{prefix}.{k}.{index}")
55+
return d
56+
57+
4158
class TrainSampler(comfy.samplers.Sampler):
4259
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
4360
self.loss_fn = loss_fn
@@ -50,6 +67,7 @@ def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_ac
5067
self.training_dtype = training_dtype
5168

5269
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
70+
model_wrap.conds = process_cond_list(model_wrap.conds)
5371
cond = model_wrap.conds["positive"]
5472
dataset_size = sigmas.size(0)
5573
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)