Skip to content

Commit e0db217

Browse files
author
ywang
committed
use face correspondence instead of vertex
1 parent 1bad6bd commit e0db217

File tree

1 file changed

+51
-104
lines changed

1 file changed

+51
-104
lines changed

optimize_cage.py

Lines changed: 51 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from matplotlib.colors import Normalize
2525
from matplotlib import cm
2626
import openmesh as om
27+
import pandas as pd
2728

2829

2930
class MyOptions(DeformationOptions):
@@ -87,6 +88,12 @@ def optimize(opt):
8788
target_shape[:, :3].astype(np.float32)).cuda()
8889
target_shape.unsqueeze_(0)
8990

91+
target_faces_arr = target_mesh.face_vertex_indices()
92+
target_faces = target_faces_arr.copy()
93+
target_faces = torch.from_numpy(
94+
target_faces[:, :3].astype(np.int64)).cuda()
95+
target_faces.unsqueeze_(0)
96+
9097
states = torch.load(opt.ckpt)
9198
if "states" in states:
9299
states = states["states"]
@@ -99,7 +106,6 @@ def optimize(opt):
99106
new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked")
100107
orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
101108
logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path))
102-
import pandas as pd
103109
new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None)
104110
orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
105111
orig_label_name = orig_label.iloc[:,5]
@@ -153,40 +159,56 @@ def optimize(opt):
153159
# traceback.print_exc(file=sys.stdout)
154160

155161
# source_points[0] = center_bounding_box(source_points[0])[0]
156-
elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
157-
logger.info("Assuming Faust model")
162+
elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and \
163+
os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")):
164+
logger.info("Could not find {}. Assuming SMPL model".format(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")))
165+
source_shape, source_faces = read_trimesh(opt.source_model)
166+
assert(source_faces.shape[0] == target_faces.shape[1]), \
167+
"opt.model must be a SMPL model with {} faces and {} vertices. Otherwise a correspondence file {} must be present.".format(
168+
source_faces.shape[0], source_shape.shape[0], opt.model.replace(os.path.splitext(opt.model)[1], ".picked"))
169+
# align faces not vertices
158170
orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")
159171
logger.info("Loading picked labels {}".format(orig_label_path))
160-
import pandas as pd
161172
orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None)
162-
orig_label_name = orig_label.iloc[:,5]
163-
source_shape, _ = read_trimesh(opt.source_model)
164-
source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float()
165-
if orig_label.shape[1] == 10:
166-
idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long()
167-
source_points = source_shape[:,idx,:]
168-
target_points = target_shape[:,idx,:]
169-
else:
170-
source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32))
171-
source_points = source_points.unsqueeze(0).cuda()
172-
# find the closest point on the original meshes
173-
source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False)
174-
source_points = source_points.squeeze(2) # B,N,3
175-
idx = idx.squeeze(-1)
176-
target_points = target_shape[:,idx,:]
173+
source_shape = torch.from_numpy(source_shape[None, :, :3]).cuda().float()
174+
source_faces = torch.from_numpy(source_faces[None, :, :3]).cuda().long()
175+
idx = torch.from_numpy(orig_label.iloc[:,1].to_numpy()).long()
176+
source_points = torch.gather(source_shape.unsqueeze(1).expand(-1, idx.numel(), -1, -1), 2, source_faces[:, idx, :, None].expand(-1, -1, -1, 3))
177+
source_points = source_points.mean(dim=-2)
178+
target_points = torch.gather(target_shape.unsqueeze(1).expand(-1, idx.numel(), -1, -1), 2, target_faces[:, idx, :, None].expand(-1, -1, -1, 3))
179+
target_points = target_points.mean(dim=-2)
180+
181+
_, source_center, _ = center_bounding_box(source_shape[0])
182+
source_points -= source_center
183+
elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")):
184+
logger.info("Could not find {}. Assuming SMPL model".format(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")))
185+
source_shape, source_faces = read_trimesh(opt.source_model)
186+
assert(source_faces.shape[0] == target_faces.shape[1]), \
187+
"opt.model must be a SMPL model with {} faces and {} vertices. Otherwise a correspondence file {} must be present.".format(
188+
source_faces.shape[0], source_shape.shape[0], opt.model.replace(os.path.splitext(opt.model)[1], ".picked"))
189+
source_shape, source_faces = read_trimesh(opt.source_model)
177190

178191
_, source_center, _ = center_bounding_box(source_shape[0])
179192
source_points -= source_center
180-
elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]:
181-
logger.info("No correspondence provided, assuming registered Faust models")
182-
# corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda()
183-
corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda()
184-
target_points = torch.index_select(target_shape, 1, corresp_v)
185-
source_points = torch.index_select(shape_v, 1, corresp_v)
193+
194+
source_shape = torch.from_numpy(source_shape[None, :, :3]).cuda().float()
195+
source_faces = torch.from_numpy(source_faces[None, :, :3]).cuda().long()
196+
# select a subset of faces, otherwise optimization is too slow
197+
idx = torch.from_numpy(np.random.permutation(2048)).cuda().long()
198+
source_points = torch.gather(source_shape.unsqueeze(1).expand(-1, source_faces.shape[1], -1, -1), 2, source_faces[:, idx,:, None].expand(-1, -1, -1, 3))
199+
source_points = source_points.mean(dim=-2)
200+
target_points = torch.gather(target_shape.unsqueeze(1).expand(-1, source_faces.shape[1], -1, -1), 2, target_faces[:,idx,: None].expand(-1, -1, -1, 3))
201+
target_points = target_points.mean(dim=-2)
202+
203+
204+
target_points = target_points[:, idx]
205+
source_points = source_points[:, idx]
206+
186207

187208
target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0])
188209
_, _, source_scale = center_bounding_box(shape_v[0])
189-
target_scale_factor = (source_scale/target_scale)[1]
210+
# scale according y axis (body height)
211+
target_scale_factor = (source_scale/target_scale)[0,1]
190212
target_shape *= target_scale_factor
191213
target_points -= target_center
192214
target_points = (target_points*target_scale_factor).detach()
@@ -227,22 +249,16 @@ def optimize(opt):
227249
target_points, cage_v, cage_f, verbose=False)
228250
loss_mvc = torch.mean((weights-weights_ref)**2)
229251
# reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4
230-
reg = 0
252+
reg = torch.tensor(0.0).cuda()
231253
if opt.clap_weight > 0:
232254
reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight
233255
reg = reg.mean()
234256
if opt.mvc_weight > 0:
235257
reg += mvc_reg_loss(weights)*opt.mvc_weight
236258

237-
# weight regularizer with the shape difference
238-
# dist = torch.sum((source_points - target_points)**2, dim=-1)
239-
# weights = torch.exp(-dist)
240-
# reg = reg*weights*0.1
241-
242259
loss = loss_mvc + reg
243260
if (t+1) % 50 == 0:
244-
print("t {}/{} mvc_loss: {} reg: {}".format(t,
245-
opt.nepochs, loss_mvc.item(), reg.item()))
261+
print("t {}/{} mvc_loss: {} reg: {}".format(t, opt.nepochs, loss_mvc.item(), reg.item()))
246262

247263
if loss_mvc.item() < 5e-6:
248264
break
@@ -315,6 +331,7 @@ def test_all(opt, new_cage_shape):
315331
net = networks.FixedSourceDeformer(opt, 3, opt.num_point, bottleneck_size=opt.bottleneck_size,
316332
template_vertices=states["template_vertices"], template_faces=states["template_faces"].cuda(),
317333
source_vertices=states["source_vertices"], source_faces=states["source_faces"]).cuda()
334+
print(net)
318335
load_network(net, states)
319336

320337
source_points = torch.from_numpy(
@@ -346,20 +363,8 @@ def test_all(opt, new_cage_shape):
346363
source_mesh_arr[:] = deformed[0].cpu().detach().numpy()
347364
om.write_mesh(os.path.join(
348365
opt.log_dir, opt.subdir, "template-{}-Sab.obj".format(t_filename)), source_mesh)
349-
# if data["target_face"] is not None and data["target_mesh"] is not None:
350-
# pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sa.ply".format(t_filename)),
351-
# source_mesh[0].detach().cpu(), source_face[b].detach().cpu())
352366
pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sb.ply".format(t_filename)),
353367
data["target_mesh"][b].detach().cpu(), data["target_face"][b].detach().cpu())
354-
# pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab.ply".format(t_filename)),
355-
# deformed[b].detach().cpu(), source_face[b].detach().cpu())
356-
357-
# else:
358-
# save_ply(source_mesh[0].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sa.ply".format(t_filename)))
359-
# save_ply(target_shape[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sb.ply".format(t_filename)),
360-
# normals=data["target_normals"][b].detach().cpu())
361-
# save_ply(deformed[b].detach().cpu(), os.path.join(opt.log_dir, opt.subdir,"template-{}-Sab.ply".format(t_filename)),
362-
# normals=data["target_normals"][b].detach().cpu())
363368

364369
pymesh.save_mesh_raw(
365370
os.path.join(opt.log_dir, opt.subdir, "template-{}-cage1.ply".format(t_filename)),
@@ -369,70 +374,12 @@ def test_all(opt, new_cage_shape):
369374
os.path.join(opt.log_dir, opt.subdir, "template-{}-cage2.ply".format(t_filename)),
370375
outputs["new_cage"][b].detach().cpu(), outputs["cage_face"][b].detach().cpu(),
371376
)
372-
373-
# if opt.opt_lap and deformed.shape[1] == source_mesh.shape[1]:
374-
# deformed = optimize_lap(opt, source_mesh, deformed, source_face)
375-
# for b in range(deformed.shape[0]):
376-
# pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-{}-Sab-optlap.ply".format(t_filename)),
377-
# deformed[b].detach().cpu(), source_face[b].detach().cpu())
378-
379377
if i % 20 == 0:
380378
logger.success("[{}/{}] Done".format(i, len(dataloader)))
381379

382380
dataset.render_result(os.path.join(opt.log_dir, opt.subdir))
383381

384382

385-
def optimize_lap(opt, source_shape, deformed_shape, face):
386-
"""
387-
source_shape (B,N,3)
388-
deformed_shape (B,N,3)
389-
face (B,F,3)
390-
"""
391-
B = deformed_shape.shape[0]
392-
if opt.corres_idx is None:
393-
n_selected = int(source_shape.shape[1] * 0.6)
394-
corresp_v = torch.unique(torch.randint(
395-
0, source_shape.shape[1], (n_selected, ))).view(1, -1, 1).cuda()
396-
else:
397-
corresp_idx = torch.from_numpy(np.loadtxt(
398-
opt.corres_idx, delimiter=",", dtype=np.int64)).cuda()
399-
_, corresp_idx_2 = torch.unbind(corresp_idx, dim=1)
400-
corresp_v = torch.unique(torch.gather(
401-
face, 1, corresp_idx_2.view(1, -1, 1).expand(-1, -1, 3))).view(1, -1, 1)
402-
403-
fixed_points = torch.gather(
404-
deformed_shape, 1, corresp_v.expand(-1, -1, 3)).detach()
405-
406-
deformed_shape = deformed_shape.detach()
407-
deformed_shape.requires_grad_(True)
408-
source_shape = source_shape.detach()
409-
optimizer = torch.optim.Adam([deformed_shape], lr=0.0005, betas=(0.1, 0.1))
410-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
411-
optimizer, factor=0.5, min_lr=1e-6, verbose=True, patience=25)
412-
lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True,
413-
use_norm=True, consistent_topology=True, precompute_L=True)
414-
415-
for t in range(2000):
416-
lap_loss_value = torch.mean(
417-
lap_loss(source_shape, deformed_shape, face=face).view(B, -1))
418-
fixed_points_new = torch.gather(
419-
deformed_shape, 1, corresp_v.expand(-1, -1, 3))
420-
reg_value = torch.mean(
421-
torch.sum((fixed_points - fixed_points_new)**2, dim=-1))
422-
loss = lap_loss_value + reg_value
423-
loss.backward()
424-
if (t+1) % 50 == 0:
425-
print("t {}/{} lap: {} reg: {}".format(t, 2000,
426-
lap_loss_value.item(), reg_value.item()))
427-
if loss < 1e-8:
428-
logger.success("Optimization converged!")
429-
break
430-
optimizer.step()
431-
scheduler.step(loss.item())
432-
433-
return deformed_shape
434-
435-
436383
if __name__ == "__main__":
437384
parser = MyOptions()
438385
opt = parser.parse()

0 commit comments

Comments
 (0)