2424from matplotlib .colors import Normalize
2525from matplotlib import cm
2626import openmesh as om
27+ import pandas as pd
2728
2829
2930class MyOptions (DeformationOptions ):
@@ -87,6 +88,12 @@ def optimize(opt):
8788 target_shape [:, :3 ].astype (np .float32 )).cuda ()
8889 target_shape .unsqueeze_ (0 )
8990
91+ target_faces_arr = target_mesh .face_vertex_indices ()
92+ target_faces = target_faces_arr .copy ()
93+ target_faces = torch .from_numpy (
94+ target_faces [:, :3 ].astype (np .int64 )).cuda ()
95+ target_faces .unsqueeze_ (0 )
96+
9097 states = torch .load (opt .ckpt )
9198 if "states" in states :
9299 states = states ["states" ]
@@ -99,7 +106,6 @@ def optimize(opt):
99106 new_label_path = opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" )
100107 orig_label_path = opt .source_model .replace (os .path .splitext (opt .source_model )[1 ], ".picked" )
101108 logger .info ("Loading picked labels {} and {}" .format (orig_label_path , new_label_path ))
102- import pandas as pd
103109 new_label = pd .read_csv (new_label_path , delimiter = " " ,skiprows = 1 , header = None )
104110 orig_label = pd .read_csv (orig_label_path , delimiter = " " ,skiprows = 1 , header = None )
105111 orig_label_name = orig_label .iloc [:,5 ]
@@ -153,40 +159,56 @@ def optimize(opt):
153159 # traceback.print_exc(file=sys.stdout)
154160
155161 # source_points[0] = center_bounding_box(source_points[0])[0]
156- elif not os .path .isfile (opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" )) and os .path .isfile (opt .source_model .replace (os .path .splitext (opt .source_model )[1 ], ".picked" )):
157- logger .info ("Assuming Faust model" )
162+ elif not os .path .isfile (opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" )) and \
163+ os .path .isfile (opt .source_model .replace (os .path .splitext (opt .source_model )[1 ], ".picked" )):
164+ logger .info ("Could not find {}. Assuming SMPL model" .format (opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" )))
165+ source_shape , source_faces = read_trimesh (opt .source_model )
166+ assert (source_faces .shape [0 ] == target_faces .shape [1 ]), \
167+ "opt.model must be a SMPL model with {} faces and {} vertices. Otherwise a correspondence file {} must be present." .format (
168+ source_faces .shape [0 ], source_shape .shape [0 ], opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" ))
169+ # align faces not vertices
158170 orig_label_path = opt .source_model .replace (os .path .splitext (opt .source_model )[1 ], ".picked" )
159171 logger .info ("Loading picked labels {}" .format (orig_label_path ))
160- import pandas as pd
161172 orig_label = pd .read_csv (orig_label_path , delimiter = " " ,skiprows = 1 , header = None )
162- orig_label_name = orig_label .iloc [:,5 ]
163- source_shape , _ = read_trimesh (opt .source_model )
164- source_shape = torch .from_numpy (source_shape [None , :,:3 ]).cuda ().float ()
165- if orig_label .shape [1 ] == 10 :
166- idx = torch .from_numpy (orig_label .iloc [:,9 ].to_numpy ()).long ()
167- source_points = source_shape [:,idx ,:]
168- target_points = target_shape [:,idx ,:]
169- else :
170- source_points = torch .from_numpy (orig_label .iloc [:,6 :9 ].to_numpy ().astype (np .float32 ))
171- source_points = source_points .unsqueeze (0 ).cuda ()
172- # find the closest point on the original meshes
173- source_points , idx , _ = faiss_knn (1 , source_points , source_shape , NCHW = False )
174- source_points = source_points .squeeze (2 ) # B,N,3
175- idx = idx .squeeze (- 1 )
176- target_points = target_shape [:,idx ,:]
173+ source_shape = torch .from_numpy (source_shape [None , :, :3 ]).cuda ().float ()
174+ source_faces = torch .from_numpy (source_faces [None , :, :3 ]).cuda ().long ()
175+ idx = torch .from_numpy (orig_label .iloc [:,1 ].to_numpy ()).long ()
176+ source_points = torch .gather (source_shape .unsqueeze (1 ).expand (- 1 , idx .numel (), - 1 , - 1 ), 2 , source_faces [:, idx , :, None ].expand (- 1 , - 1 , - 1 , 3 ))
177+ source_points = source_points .mean (dim = - 2 )
178+ target_points = torch .gather (target_shape .unsqueeze (1 ).expand (- 1 , idx .numel (), - 1 , - 1 ), 2 , target_faces [:, idx , :, None ].expand (- 1 , - 1 , - 1 , 3 ))
179+ target_points = target_points .mean (dim = - 2 )
180+
181+ _ , source_center , _ = center_bounding_box (source_shape [0 ])
182+ source_points -= source_center
183+ elif not os .path .isfile (opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" )):
184+ logger .info ("Could not find {}. Assuming SMPL model" .format (opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" )))
185+ source_shape , source_faces = read_trimesh (opt .source_model )
186+ assert (source_faces .shape [0 ] == target_faces .shape [1 ]), \
187+ "opt.model must be a SMPL model with {} faces and {} vertices. Otherwise a correspondence file {} must be present." .format (
188+ source_faces .shape [0 ], source_shape .shape [0 ], opt .model .replace (os .path .splitext (opt .model )[1 ], ".picked" ))
189+ source_shape , source_faces = read_trimesh (opt .source_model )
177190
178191 _ , source_center , _ = center_bounding_box (source_shape [0 ])
179192 source_points -= source_center
180- elif opt .corres_idx is None and target_shape .shape [1 ] == shape_v .shape [1 ]:
181- logger .info ("No correspondence provided, assuming registered Faust models" )
182- # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
183- corresp_v = torch .unique (torch .randint (0 , shape_v .shape [1 ], (4800 ,))).cuda ()
184- target_points = torch .index_select (target_shape , 1 , corresp_v )
185- source_points = torch .index_select (shape_v , 1 , corresp_v )
193+
194+ source_shape = torch .from_numpy (source_shape [None , :, :3 ]).cuda ().float ()
195+ source_faces = torch .from_numpy (source_faces [None , :, :3 ]).cuda ().long ()
196+ # select a subset of faces, otherwise optimization is too slow
197+ idx = torch .from_numpy (np .random .permutation (2048 )).cuda ().long ()
198+ source_points = torch .gather (source_shape .unsqueeze (1 ).expand (- 1 , source_faces .shape [1 ], - 1 , - 1 ), 2 , source_faces [:, idx ,:, None ].expand (- 1 , - 1 , - 1 , 3 ))
199+ source_points = source_points .mean (dim = - 2 )
200+ target_points = torch .gather (target_shape .unsqueeze (1 ).expand (- 1 , source_faces .shape [1 ], - 1 , - 1 ), 2 , target_faces [:,idx ,: None ].expand (- 1 , - 1 , - 1 , 3 ))
201+ target_points = target_points .mean (dim = - 2 )
202+
203+
204+ target_points = target_points [:, idx ]
205+ source_points = source_points [:, idx ]
206+
186207
187208 target_shape [0 ], target_center , target_scale = center_bounding_box (target_shape [0 ])
188209 _ , _ , source_scale = center_bounding_box (shape_v [0 ])
189- target_scale_factor = (source_scale / target_scale )[1 ]
210+ # scale according y axis (body height)
211+ target_scale_factor = (source_scale / target_scale )[0 ,1 ]
190212 target_shape *= target_scale_factor
191213 target_points -= target_center
192214 target_points = (target_points * target_scale_factor ).detach ()
@@ -227,22 +249,16 @@ def optimize(opt):
227249 target_points , cage_v , cage_f , verbose = False )
228250 loss_mvc = torch .mean ((weights - weights_ref )** 2 )
229251 # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
230- reg = 0
252+ reg = torch . tensor ( 0.0 ). cuda ()
231253 if opt .clap_weight > 0 :
232254 reg = lap_loss (cage_init , cage_v , face = cage_f )* opt .clap_weight
233255 reg = reg .mean ()
234256 if opt .mvc_weight > 0 :
235257 reg += mvc_reg_loss (weights )* opt .mvc_weight
236258
237- # weight regularizer with the shape difference
238- # dist = torch.sum((source_points - target_points)**2, dim=-1)
239- # weights = torch.exp(-dist)
240- # reg = reg*weights*0.1
241-
242259 loss = loss_mvc + reg
243260 if (t + 1 ) % 50 == 0 :
244- print ("t {}/{} mvc_loss: {} reg: {}" .format (t ,
245- opt .nepochs , loss_mvc .item (), reg .item ()))
261+ print ("t {}/{} mvc_loss: {} reg: {}" .format (t , opt .nepochs , loss_mvc .item (), reg .item ()))
246262
247263 if loss_mvc .item () < 5e-6 :
248264 break
@@ -315,6 +331,7 @@ def test_all(opt, new_cage_shape):
315331 net = networks .FixedSourceDeformer (opt , 3 , opt .num_point , bottleneck_size = opt .bottleneck_size ,
316332 template_vertices = states ["template_vertices" ], template_faces = states ["template_faces" ].cuda (),
317333 source_vertices = states ["source_vertices" ], source_faces = states ["source_faces" ]).cuda ()
334+ print (net )
318335 load_network (net , states )
319336
320337 source_points = torch .from_numpy (
@@ -346,20 +363,8 @@ def test_all(opt, new_cage_shape):
346363 source_mesh_arr [:] = deformed [0 ].cpu ().detach ().numpy ()
347364 om .write_mesh (os .path .join (
348365 opt .log_dir , opt .subdir , "template-{}-Sab.obj" .format (t_filename )), source_mesh )
349- # if data["target_face"] is not None and data["target_mesh"] is not None:
350- # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)),
351- # source_mesh[0].detach().cpu(), source_face[b].detach().cpu())
352366 pymesh .save_mesh_raw (os .path .join (opt .log_dir , opt .subdir , "template-{}-Sb.ply" .format (t_filename )),
353367 data ["target_mesh" ][b ].detach ().cpu (), data ["target_face" ][b ].detach ().cpu ())
354- # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)),
355- # deformed[b].detach().cpu(), source_face[b].detach().cpu())
356-
357- # else:
358- # save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename)))
359- # save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)),
360- # normals=data["target_normals"][b].detach().cpu())
361- # save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)),
362- # normals=data["target_normals"][b].detach().cpu())
363368
364369 pymesh .save_mesh_raw (
365370 os .path .join (opt .log_dir , opt .subdir , "template-{}-cage1.ply" .format (t_filename )),
@@ -369,70 +374,12 @@ def test_all(opt, new_cage_shape):
369374 os .path .join (opt .log_dir , opt .subdir , "template-{}-cage2.ply" .format (t_filename )),
370375 outputs ["new_cage" ][b ].detach ().cpu (), outputs ["cage_face" ][b ].detach ().cpu (),
371376 )
372-
373- # if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]:
374- # deformed = optimize_lap(opt, source_mesh, deformed, source_face)
375- # for b in range(deformed.shape[0]):
376- # pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)),
377- # deformed[b].detach().cpu(), source_face[b].detach().cpu())
378-
379377 if i % 20 == 0 :
380378 logger .success ("[{}/{}] Done" .format (i , len (dataloader )))
381379
382380 dataset .render_result (os .path .join (opt .log_dir , opt .subdir ))
383381
384382
385- def optimize_lap (opt , source_shape , deformed_shape , face ):
386- """
387- source_shape (B,N,3)
388- deformed_shape (B,N,3)
389- face (B,F,3)
390- """
391- B = deformed_shape .shape [0 ]
392- if opt .corres_idx is None :
393- n_selected = int (source_shape .shape [1 ] * 0.6 )
394- corresp_v = torch .unique (torch .randint (
395- 0 , source_shape .shape [1 ], (n_selected , ))).view (1 , - 1 , 1 ).cuda ()
396- else :
397- corresp_idx = torch .from_numpy (np .loadtxt (
398- opt .corres_idx , delimiter = "," , dtype = np .int64 )).cuda ()
399- _ , corresp_idx_2 = torch .unbind (corresp_idx , dim = 1 )
400- corresp_v = torch .unique (torch .gather (
401- face , 1 , corresp_idx_2 .view (1 , - 1 , 1 ).expand (- 1 , - 1 , 3 ))).view (1 , - 1 , 1 )
402-
403- fixed_points = torch .gather (
404- deformed_shape , 1 , corresp_v .expand (- 1 , - 1 , 3 )).detach ()
405-
406- deformed_shape = deformed_shape .detach ()
407- deformed_shape .requires_grad_ (True )
408- source_shape = source_shape .detach ()
409- optimizer = torch .optim .Adam ([deformed_shape ], lr = 0.0005 , betas = (0.1 , 0.1 ))
410- scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
411- optimizer , factor = 0.5 , min_lr = 1e-6 , verbose = True , patience = 25 )
412- lap_loss = MeshLaplacianLoss (torch .nn .MSELoss (reduction = "none" ), use_cot = True ,
413- use_norm = True , consistent_topology = True , precompute_L = True )
414-
415- for t in range (2000 ):
416- lap_loss_value = torch .mean (
417- lap_loss (source_shape , deformed_shape , face = face ).view (B , - 1 ))
418- fixed_points_new = torch .gather (
419- deformed_shape , 1 , corresp_v .expand (- 1 , - 1 , 3 ))
420- reg_value = torch .mean (
421- torch .sum ((fixed_points - fixed_points_new )** 2 , dim = - 1 ))
422- loss = lap_loss_value + reg_value
423- loss .backward ()
424- if (t + 1 ) % 50 == 0 :
425- print ("t {}/{} lap: {} reg: {}" .format (t , 2000 ,
426- lap_loss_value .item (), reg_value .item ()))
427- if loss < 1e-8 :
428- logger .success ("Optimization converged!" )
429- break
430- optimizer .step ()
431- scheduler .step (loss .item ())
432-
433- return deformed_shape
434-
435-
436383if __name__ == "__main__" :
437384 parser = MyOptions ()
438385 opt = parser .parse ()
0 commit comments