@@ -38,6 +38,23 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
3838 return new_dict
3939
4040
41+ def process_cond_list (d , prefix = "" ):
42+ if hasattr (d , "__iter__" ) and not hasattr (d , "items" ):
43+ for index , item in enumerate (d ):
44+ process_cond_list (item , f"{ prefix } .{ index } " )
45+ return d
46+ elif hasattr (d , "items" ):
47+ for k , v in list (d .items ()):
48+ if isinstance (v , dict ):
49+ process_cond_list (v , f"{ prefix } .{ k } " )
50+ elif isinstance (v , torch .Tensor ):
51+ d [k ] = v .clone ()
52+ elif isinstance (v , (list , tuple )):
53+ for index , item in enumerate (v ):
54+ process_cond_list (item , f"{ prefix } .{ k } .{ index } " )
55+ return d
56+
57+
4158class TrainSampler (comfy .samplers .Sampler ):
4259 def __init__ (self , loss_fn , optimizer , loss_callback = None , batch_size = 1 , grad_acc = 1 , total_steps = 1 , seed = 0 , training_dtype = torch .bfloat16 ):
4360 self .loss_fn = loss_fn
@@ -50,6 +67,7 @@ def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_ac
5067 self .training_dtype = training_dtype
5168
5269 def sample (self , model_wrap , sigmas , extra_args , callback , noise , latent_image = None , denoise_mask = None , disable_pbar = False ):
70+ model_wrap .conds = process_cond_list (model_wrap .conds )
5371 cond = model_wrap .conds ["positive" ]
5472 dataset_size = sigmas .size (0 )
5573 torch .cuda .empty_cache ()
0 commit comments