Skip to content

Commit 37098d8

Browse files
authored
Merge pull request #135 from pytti-tools/test
merge test for release
2 parents 520c29e + 0c94a36 commit 37098d8

File tree

5 files changed

+558
-321
lines changed

5 files changed

+558
-321
lines changed

src/pytti/ImageGuide.py

Lines changed: 31 additions & 282 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ def __init__(
7777
embedder: nn.Module,
7878
optimizer: optim.Optimizer = None,
7979
lr: float = None,
80-
null_update=True,
80+
# null_update=True,
8181
params=None,
8282
writer=None,
83-
OUTPATH=None,
84-
base_name=None,
8583
fig=None,
8684
axs=None,
85+
base_name=None,
86+
OUTPATH=None, # <<<<<<<<<<<<<<
8787
#####################
8888
video_frames=None, # # only need this to pass to animate_video_source
8989
optical_flows=None,
@@ -108,7 +108,7 @@ def __init__(
108108
self.optimizer = optimizer
109109
self.dataframe = []
110110

111-
self.null_update = null_update
111+
# self.null_update = null_update
112112
self.params = params
113113
self.writer = writer
114114
self.OUTPATH = OUTPATH
@@ -117,13 +117,13 @@ def __init__(
117117
self.axs = axs
118118
self.video_frames = video_frames
119119
self.optical_flows = optical_flows
120-
if stabilization_augs is None:
121-
stabilization_augs = []
120+
# if stabilization_augs is None:
121+
# stabilization_augs = []
122122
self.stabilization_augs = stabilization_augs
123123
self.last_frame_semantic = last_frame_semantic
124124
self.semantic_init_prompt = semantic_init_prompt
125-
if init_augs is None:
126-
init_augs = []
125+
# if init_augs is None:
126+
# init_augs = []
127127
self.init_augs = init_augs
128128

129129
def run_steps(
@@ -152,8 +152,26 @@ def run_steps(
152152
# and here we can check if the DirectImageGuide was
153153
# initialized with a renderer or not, and call self.renderer.update()
154154
# if appropriate
155-
if not self.null_update:
156-
self.update(i + i_offset, i + skipped_steps)
155+
# if not self.null_update:
156+
# self.update(i + i_offset, i + skipped_steps)
157+
self.update(
158+
model=self,
159+
img=self.image_rep,
160+
i=i + i_offset,
161+
stage_i=i + skipped_steps,
162+
params=self.params,
163+
writer=self.writer,
164+
fig=self.fig,
165+
axs=self.axs,
166+
base_name=self.base_name,
167+
optical_flows=self.optical_flows,
168+
video_frames=self.video_frames,
169+
stabilization_augs=self.stabilization_augs,
170+
last_frame_semantic=self.last_frame_semantic,
171+
embedder=self.embedder,
172+
init_augs=self.init_augs,
173+
semantic_init_prompt=self.semantic_init_prompt,
174+
)
157175
losses = self.train(
158176
i + skipped_steps,
159177
prompts,
@@ -343,277 +361,8 @@ def train(
343361

344362
return {"TOTAL": float(total_loss)}
345363

346-
def report_out(
347-
self,
348-
i,
349-
stage_i,
350-
# model,
351-
writer,
352-
fig, # default to None...
353-
axs, # default to None...
354-
clear_every,
355-
display_every,
356-
approximate_vram_usage,
357-
display_scale,
358-
show_graphs,
359-
show_palette,
360-
):
361-
model = self
362-
img = self.image_rep # pretty sure this is right
363-
# DM: I bet this could be abstracted out into a report_out() function or whatever
364-
if clear_every > 0 and i > 0 and i % clear_every == 0:
365-
display.clear_output()
366-
367-
if display_every > 0 and i % display_every == 0:
368-
logger.debug(f"Step {i} losses:")
369-
if model.dataframe:
370-
rec = model.dataframe[0].iloc[-1]
371-
logger.debug(rec)
372-
if writer is not None:
373-
for k, v in rec.iteritems():
374-
writer.add_scalar(
375-
tag=f"losses/{k}", scalar_value=v, global_step=i
376-
)
377-
378-
# does this VRAM stuff even do anything?
379-
if approximate_vram_usage:
380-
logger.debug("VRAM Usage:")
381-
print_vram_usage() # update this function to use logger
382-
# update this stuff to use/rely on tensorboard
383-
display_width = int(img.image_shape[0] * display_scale)
384-
display_height = int(img.image_shape[1] * display_scale)
385-
if stage_i > 0 and show_graphs:
386-
model.plot_losses(axs)
387-
im = img.decode_image()
388-
sidebyside = make_hbox(
389-
im.resize((display_width, display_height), Image.LANCZOS),
390-
fig,
391-
)
392-
display.display(sidebyside)
393-
else:
394-
im = img.decode_image()
395-
display.display(
396-
im.resize((display_width, display_height), Image.LANCZOS)
397-
)
398-
logger.debug(PixelImage)
399-
logger.debug(type(PixelImage))
400-
if show_palette and isinstance(img, PixelImage):
401-
logger.debug("Palette:")
402-
display.display(img.render_pallet())
403-
404-
def save_out(
405-
self,
406-
i,
407-
# img,
408-
writer,
409-
OUTPATH,
410-
base_name,
411-
save_every,
412-
file_namespace,
413-
backups,
414-
):
415-
img = self.image_rep
416-
# save
417-
# if i > 0 and save_every > 0 and i % save_every == 0:
418-
if i > 0 and save_every > 0 and (i + 1) % save_every == 0:
419-
im = (
420-
img.decode_image()
421-
) # let's turn this into a property so decoding is cheap
422-
# n = i // save_every
423-
n = (i + 1) // save_every
424-
filename = f"{OUTPATH}/{file_namespace}/{base_name}_{n}.png"
425-
logger.debug(filename)
426-
im.save(filename)
427-
428-
im_np = np.array(im)
429-
if writer is not None:
430-
writer.add_image(
431-
tag="pytti output",
432-
# img_tensor=filename, # thought this would work?
433-
img_tensor=im_np,
434-
global_step=i,
435-
dataformats="HWC", # this was the key
436-
)
437-
438-
if backups > 0:
439-
filename = f"backup/{file_namespace}/{base_name}_{n}.bak"
440-
torch.save(img.state_dict(), filename)
441-
if n > backups:
442-
443-
# YOOOOOOO let's not start shell processes unnecessarily
444-
# and then execute commands using string interpolation.
445-
# Replace this with a pythonic folder removal, then see
446-
# if we can't deprecate the folder removal entirely. What
447-
# is the purpose of "backups" here? Just use the frames that
448-
# are being written to disk.
449-
subprocess.run(
450-
[
451-
"rm",
452-
f"backup/{file_namespace}/{base_name}_{n-backups}.bak",
453-
]
454-
)
455-
456-
def update(
457-
self,
458-
# params,
459-
# move to class
460-
i,
461-
stage_i,
462-
):
364+
def update(self, model, img, i, stage_i, *args, **kwargs):
463365
"""
464-
Orchestrates animation transformations and reporting
366+
update hook called ever step
465367
"""
466-
# logger.debug("model.update called")
467-
468-
# ... I have regrets.
469-
params = self.params
470-
writer = self.writer
471-
OUTPATH = self.OUTPATH
472-
base_name = self.base_name
473-
fig = self.fig
474-
axs = self.axs
475-
video_frames = self.video_frames
476-
optical_flows = self.optical_flows
477-
stabilization_augs = self.stabilization_augs
478-
last_frame_semantic = self.last_frame_semantic
479-
semantic_init_prompt = self.semantic_init_prompt
480-
init_augs = self.init_augs
481-
482-
model = self
483-
img = self.image_rep
484-
embedder = self.embedder
485-
486-
model.report_out(
487-
i=i,
488-
stage_i=stage_i,
489-
# model=model,
490-
writer=writer,
491-
fig=fig, # default to None...
492-
axs=axs, # default to None...
493-
clear_every=params.clear_every,
494-
display_every=params.display_every,
495-
approximate_vram_usage=params.approximate_vram_usage,
496-
display_scale=params.display_scale,
497-
show_graphs=params.show_graphs,
498-
show_palette=params.show_palette,
499-
)
500-
501-
model.save_out(
502-
i=i,
503-
# img=img,
504-
writer=writer,
505-
OUTPATH=OUTPATH,
506-
base_name=base_name,
507-
save_every=params.save_every,
508-
file_namespace=params.file_namespace,
509-
backups=params.backups,
510-
)
511-
512-
# animate
513-
################
514-
## TO DO: attach T as a class attribute
515-
t = (i - params.pre_animation_steps) / (
516-
params.steps_per_frame * params.frames_per_second
517-
)
518-
set_t(t) # this won't need to be a thing with `t`` attached to the class
519-
if i >= params.pre_animation_steps:
520-
# next_step_pil = None
521-
if (i - params.pre_animation_steps) % params.steps_per_frame == 0:
522-
logger.debug(f"Time: {t:.4f} seconds")
523-
# update_rotoscopers(
524-
ROTOSCOPERS.update_rotoscopers(
525-
((i - params.pre_animation_steps) // params.steps_per_frame + 1)
526-
* params.frame_stride
527-
)
528-
if params.reset_lr_each_frame:
529-
model.set_optim(None)
530-
531-
if params.animation_mode == "2D":
532-
533-
next_step_pil = animate_2d(
534-
translate_y=params.translate_y,
535-
translate_x=params.translate_x,
536-
rotate_2d=params.rotate_2d,
537-
zoom_x_2d=params.zoom_x_2d,
538-
zoom_y_2d=params.zoom_y_2d,
539-
infill_mode=params.infill_mode,
540-
sampling_mode=params.sampling_mode,
541-
writer=writer,
542-
i=i,
543-
img=img,
544-
t=t, # just here for logging
545-
)
546-
547-
###########################
548-
elif params.animation_mode == "3D":
549-
try:
550-
im
551-
except NameError:
552-
im = img.decode_image()
553-
with vram_usage_mode("Optical Flow Loss"):
554-
# zoom_3d -> rename to animate_3d or transform_3d
555-
flow, next_step_pil = zoom_3d(
556-
img,
557-
(
558-
params.translate_x,
559-
params.translate_y,
560-
params.translate_z_3d,
561-
),
562-
params.rotate_3d,
563-
params.field_of_view,
564-
params.near_plane,
565-
params.far_plane,
566-
border_mode=params.infill_mode,
567-
sampling_mode=params.sampling_mode,
568-
stabilize=params.lock_camera,
569-
)
570-
freeze_vram_usage()
571-
572-
for optical_flow in optical_flows:
573-
optical_flow.set_last_step(im)
574-
optical_flow.set_target_flow(flow)
575-
optical_flow.set_enabled(True)
576-
577-
elif params.animation_mode == "Video Source":
578-
579-
flow_im, next_step_pil = animate_video_source(
580-
i=i,
581-
img=img,
582-
video_frames=video_frames,
583-
optical_flows=optical_flows,
584-
base_name=base_name,
585-
pre_animation_steps=params.pre_animation_steps,
586-
frame_stride=params.frame_stride,
587-
steps_per_frame=params.steps_per_frame,
588-
file_namespace=params.file_namespace,
589-
reencode_each_frame=params.reencode_each_frame,
590-
lock_palette=params.lock_palette,
591-
save_every=params.save_every,
592-
infill_mode=params.infill_mode,
593-
sampling_mode=params.sampling_mode,
594-
)
595-
596-
if params.animation_mode != "off":
597-
try:
598-
for aug in stabilization_augs:
599-
aug.set_comp(next_step_pil)
600-
aug.set_enabled(True)
601-
if last_frame_semantic is not None:
602-
last_frame_semantic.set_image(embedder, next_step_pil)
603-
last_frame_semantic.set_enabled(True)
604-
for aug in init_augs:
605-
aug.set_enabled(False)
606-
if semantic_init_prompt is not None:
607-
semantic_init_prompt.set_enabled(False)
608-
except UnboundLocalError:
609-
logger.critical(
610-
"\n\n-----< PYTTI-TOOLS > ------"
611-
"If you are seeing this error, it might mean "
612-
"you are using an option that expects you have "
613-
"provided an init_image or video_file.\n\nIf you "
614-
"think you are seeing this message in error, please "
615-
"file an issue here: "
616-
"https://github.com/pytti-tools/pytti-core/issues/new"
617-
"-----< PYTTI-TOOLS > ------\n\n"
618-
)
619-
raise
368+
pass

0 commit comments

Comments
 (0)