@@ -77,13 +77,13 @@ def __init__(
7777 embedder : nn .Module ,
7878 optimizer : optim .Optimizer = None ,
7979 lr : float = None ,
80- null_update = True ,
80+ # null_update=True,
8181 params = None ,
8282 writer = None ,
83- OUTPATH = None ,
84- base_name = None ,
8583 fig = None ,
8684 axs = None ,
85+ base_name = None ,
86+ OUTPATH = None , # <<<<<<<<<<<<<<
8787 #####################
8888 video_frames = None , # # only need this to pass to animate_video_source
8989 optical_flows = None ,
@@ -108,7 +108,7 @@ def __init__(
108108 self .optimizer = optimizer
109109 self .dataframe = []
110110
111- self .null_update = null_update
111+ # self.null_update = null_update
112112 self .params = params
113113 self .writer = writer
114114 self .OUTPATH = OUTPATH
@@ -117,13 +117,13 @@ def __init__(
117117 self .axs = axs
118118 self .video_frames = video_frames
119119 self .optical_flows = optical_flows
120- if stabilization_augs is None :
121- stabilization_augs = []
120+ # if stabilization_augs is None:
121+ # stabilization_augs = []
122122 self .stabilization_augs = stabilization_augs
123123 self .last_frame_semantic = last_frame_semantic
124124 self .semantic_init_prompt = semantic_init_prompt
125- if init_augs is None :
126- init_augs = []
125+ # if init_augs is None:
126+ # init_augs = []
127127 self .init_augs = init_augs
128128
129129 def run_steps (
@@ -152,8 +152,26 @@ def run_steps(
152152 # and here we can check if the DirectImageGuide was
153153 # initialized with a renderer or not, and call self.renderer.update()
154154 # if appropriate
155- if not self .null_update :
156- self .update (i + i_offset , i + skipped_steps )
155+ # if not self.null_update:
156+ # self.update(i + i_offset, i + skipped_steps)
157+ self .update (
158+ model = self ,
159+ img = self .image_rep ,
160+ i = i + i_offset ,
161+ stage_i = i + skipped_steps ,
162+ params = self .params ,
163+ writer = self .writer ,
164+ fig = self .fig ,
165+ axs = self .axs ,
166+ base_name = self .base_name ,
167+ optical_flows = self .optical_flows ,
168+ video_frames = self .video_frames ,
169+ stabilization_augs = self .stabilization_augs ,
170+ last_frame_semantic = self .last_frame_semantic ,
171+ embedder = self .embedder ,
172+ init_augs = self .init_augs ,
173+ semantic_init_prompt = self .semantic_init_prompt ,
174+ )
157175 losses = self .train (
158176 i + skipped_steps ,
159177 prompts ,
@@ -343,277 +361,8 @@ def train(
343361
344362 return {"TOTAL" : float (total_loss )}
345363
346- def report_out (
347- self ,
348- i ,
349- stage_i ,
350- # model,
351- writer ,
352- fig , # default to None...
353- axs , # default to None...
354- clear_every ,
355- display_every ,
356- approximate_vram_usage ,
357- display_scale ,
358- show_graphs ,
359- show_palette ,
360- ):
361- model = self
362- img = self .image_rep # pretty sure this is right
363- # DM: I bet this could be abstracted out into a report_out() function or whatever
364- if clear_every > 0 and i > 0 and i % clear_every == 0 :
365- display .clear_output ()
366-
367- if display_every > 0 and i % display_every == 0 :
368- logger .debug (f"Step { i } losses:" )
369- if model .dataframe :
370- rec = model .dataframe [0 ].iloc [- 1 ]
371- logger .debug (rec )
372- if writer is not None :
373- for k , v in rec .iteritems ():
374- writer .add_scalar (
375- tag = f"losses/{ k } " , scalar_value = v , global_step = i
376- )
377-
378- # does this VRAM stuff even do anything?
379- if approximate_vram_usage :
380- logger .debug ("VRAM Usage:" )
381- print_vram_usage () # update this function to use logger
382- # update this stuff to use/rely on tensorboard
383- display_width = int (img .image_shape [0 ] * display_scale )
384- display_height = int (img .image_shape [1 ] * display_scale )
385- if stage_i > 0 and show_graphs :
386- model .plot_losses (axs )
387- im = img .decode_image ()
388- sidebyside = make_hbox (
389- im .resize ((display_width , display_height ), Image .LANCZOS ),
390- fig ,
391- )
392- display .display (sidebyside )
393- else :
394- im = img .decode_image ()
395- display .display (
396- im .resize ((display_width , display_height ), Image .LANCZOS )
397- )
398- logger .debug (PixelImage )
399- logger .debug (type (PixelImage ))
400- if show_palette and isinstance (img , PixelImage ):
401- logger .debug ("Palette:" )
402- display .display (img .render_pallet ())
403-
404- def save_out (
405- self ,
406- i ,
407- # img,
408- writer ,
409- OUTPATH ,
410- base_name ,
411- save_every ,
412- file_namespace ,
413- backups ,
414- ):
415- img = self .image_rep
416- # save
417- # if i > 0 and save_every > 0 and i % save_every == 0:
418- if i > 0 and save_every > 0 and (i + 1 ) % save_every == 0 :
419- im = (
420- img .decode_image ()
421- ) # let's turn this into a property so decoding is cheap
422- # n = i // save_every
423- n = (i + 1 ) // save_every
424- filename = f"{ OUTPATH } /{ file_namespace } /{ base_name } _{ n } .png"
425- logger .debug (filename )
426- im .save (filename )
427-
428- im_np = np .array (im )
429- if writer is not None :
430- writer .add_image (
431- tag = "pytti output" ,
432- # img_tensor=filename, # thought this would work?
433- img_tensor = im_np ,
434- global_step = i ,
435- dataformats = "HWC" , # this was the key
436- )
437-
438- if backups > 0 :
439- filename = f"backup/{ file_namespace } /{ base_name } _{ n } .bak"
440- torch .save (img .state_dict (), filename )
441- if n > backups :
442-
443- # YOOOOOOO let's not start shell processes unnecessarily
444- # and then execute commands using string interpolation.
445- # Replace this with a pythonic folder removal, then see
446- # if we can't deprecate the folder removal entirely. What
447- # is the purpose of "backups" here? Just use the frames that
448- # are being written to disk.
449- subprocess .run (
450- [
451- "rm" ,
452- f"backup/{ file_namespace } /{ base_name } _{ n - backups } .bak" ,
453- ]
454- )
455-
456- def update (
457- self ,
458- # params,
459- # move to class
460- i ,
461- stage_i ,
462- ):
364+ def update (self , model , img , i , stage_i , * args , ** kwargs ):
463365 """
464- Orchestrates animation transformations and reporting
366+ update hook called ever step
465367 """
466- # logger.debug("model.update called")
467-
468- # ... I have regrets.
469- params = self .params
470- writer = self .writer
471- OUTPATH = self .OUTPATH
472- base_name = self .base_name
473- fig = self .fig
474- axs = self .axs
475- video_frames = self .video_frames
476- optical_flows = self .optical_flows
477- stabilization_augs = self .stabilization_augs
478- last_frame_semantic = self .last_frame_semantic
479- semantic_init_prompt = self .semantic_init_prompt
480- init_augs = self .init_augs
481-
482- model = self
483- img = self .image_rep
484- embedder = self .embedder
485-
486- model .report_out (
487- i = i ,
488- stage_i = stage_i ,
489- # model=model,
490- writer = writer ,
491- fig = fig , # default to None...
492- axs = axs , # default to None...
493- clear_every = params .clear_every ,
494- display_every = params .display_every ,
495- approximate_vram_usage = params .approximate_vram_usage ,
496- display_scale = params .display_scale ,
497- show_graphs = params .show_graphs ,
498- show_palette = params .show_palette ,
499- )
500-
501- model .save_out (
502- i = i ,
503- # img=img,
504- writer = writer ,
505- OUTPATH = OUTPATH ,
506- base_name = base_name ,
507- save_every = params .save_every ,
508- file_namespace = params .file_namespace ,
509- backups = params .backups ,
510- )
511-
512- # animate
513- ################
514- ## TO DO: attach T as a class attribute
515- t = (i - params .pre_animation_steps ) / (
516- params .steps_per_frame * params .frames_per_second
517- )
518- set_t (t ) # this won't need to be a thing with `t`` attached to the class
519- if i >= params .pre_animation_steps :
520- # next_step_pil = None
521- if (i - params .pre_animation_steps ) % params .steps_per_frame == 0 :
522- logger .debug (f"Time: { t :.4f} seconds" )
523- # update_rotoscopers(
524- ROTOSCOPERS .update_rotoscopers (
525- ((i - params .pre_animation_steps ) // params .steps_per_frame + 1 )
526- * params .frame_stride
527- )
528- if params .reset_lr_each_frame :
529- model .set_optim (None )
530-
531- if params .animation_mode == "2D" :
532-
533- next_step_pil = animate_2d (
534- translate_y = params .translate_y ,
535- translate_x = params .translate_x ,
536- rotate_2d = params .rotate_2d ,
537- zoom_x_2d = params .zoom_x_2d ,
538- zoom_y_2d = params .zoom_y_2d ,
539- infill_mode = params .infill_mode ,
540- sampling_mode = params .sampling_mode ,
541- writer = writer ,
542- i = i ,
543- img = img ,
544- t = t , # just here for logging
545- )
546-
547- ###########################
548- elif params .animation_mode == "3D" :
549- try :
550- im
551- except NameError :
552- im = img .decode_image ()
553- with vram_usage_mode ("Optical Flow Loss" ):
554- # zoom_3d -> rename to animate_3d or transform_3d
555- flow , next_step_pil = zoom_3d (
556- img ,
557- (
558- params .translate_x ,
559- params .translate_y ,
560- params .translate_z_3d ,
561- ),
562- params .rotate_3d ,
563- params .field_of_view ,
564- params .near_plane ,
565- params .far_plane ,
566- border_mode = params .infill_mode ,
567- sampling_mode = params .sampling_mode ,
568- stabilize = params .lock_camera ,
569- )
570- freeze_vram_usage ()
571-
572- for optical_flow in optical_flows :
573- optical_flow .set_last_step (im )
574- optical_flow .set_target_flow (flow )
575- optical_flow .set_enabled (True )
576-
577- elif params .animation_mode == "Video Source" :
578-
579- flow_im , next_step_pil = animate_video_source (
580- i = i ,
581- img = img ,
582- video_frames = video_frames ,
583- optical_flows = optical_flows ,
584- base_name = base_name ,
585- pre_animation_steps = params .pre_animation_steps ,
586- frame_stride = params .frame_stride ,
587- steps_per_frame = params .steps_per_frame ,
588- file_namespace = params .file_namespace ,
589- reencode_each_frame = params .reencode_each_frame ,
590- lock_palette = params .lock_palette ,
591- save_every = params .save_every ,
592- infill_mode = params .infill_mode ,
593- sampling_mode = params .sampling_mode ,
594- )
595-
596- if params .animation_mode != "off" :
597- try :
598- for aug in stabilization_augs :
599- aug .set_comp (next_step_pil )
600- aug .set_enabled (True )
601- if last_frame_semantic is not None :
602- last_frame_semantic .set_image (embedder , next_step_pil )
603- last_frame_semantic .set_enabled (True )
604- for aug in init_augs :
605- aug .set_enabled (False )
606- if semantic_init_prompt is not None :
607- semantic_init_prompt .set_enabled (False )
608- except UnboundLocalError :
609- logger .critical (
610- "\n \n -----< PYTTI-TOOLS > ------"
611- "If you are seeing this error, it might mean "
612- "you are using an option that expects you have "
613- "provided an init_image or video_file.\n \n If you "
614- "think you are seeing this message in error, please "
615- "file an issue here: "
616- "https://github.com/pytti-tools/pytti-core/issues/new"
617- "-----< PYTTI-TOOLS > ------\n \n "
618- )
619- raise
368+ pass
0 commit comments