Skip to content

Commit a0c9082

Browse files
committed
merge ba_helper into one
1 parent 273eacb commit a0c9082

File tree

5 files changed

+114
-130
lines changed

5 files changed

+114
-130
lines changed

ba_example.py

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,110 @@
11
from time import perf_counter
2-
import torch
2+
33
import pypose as pp
4+
import torch
5+
import torch.nn as nn
46

5-
from ba_helpers import Reproj, least_square_error
6-
from datapipes.bal_loader import get_problem, read_bal_data
7-
from bae.sparse.py_ops import *
7+
from datapipes.bal_loader import get_problem
8+
from bae.autograd.function import TrackingTensor, map_transform
89
from bae.optim import LM
9-
from bae.utils.pysolvers import PCG, CuDSS
10+
from bae.utils.ba import rotate_quat
11+
from bae.utils.pysolvers import PCG
1012

11-
# TARGET_DATASET = "ladybug"
12-
# TARGET_PROBLEM = "problem-1723-156502-pre"
13-
# TARGET_PROBLEM = "problem-49-7776-pre"
14-
# TARGET_PROBLEM = "problem-1695-155710-pre"
15-
# TARGET_PROBLEM = "problem-969-105826-pre"
1613
TARGET_DATASET = "trafalgar"
1714
TARGET_PROBLEM = "problem-257-65132-pre"
15+
# other options:
16+
# TARGET_DATASET = "ladybug"
17+
# TARGET_PROBLEM = "problem-1723-156502-pre"
1818
# TARGET_DATASET = "dubrovnik"
1919
# TARGET_PROBLEM = "problem-356-226730-pre"
2020

21-
22-
23-
DEVICE = 'cuda'
21+
DEVICE = "cuda"
2422
OPTIMIZE_INTRINSICS = True
25-
26-
USE_QUATERNIONS = True
27-
28-
file_name = f'{TARGET_DATASET}.{TARGET_PROBLEM}'
29-
dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS)
30-
31-
if OPTIMIZE_INTRINSICS:
32-
NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9
33-
else:
34-
NUM_CAMERA_PARAMS = 7 if USE_QUATERNIONS else 6
35-
36-
print(f'Fetched {TARGET_PROBLEM} from {TARGET_DATASET}')
37-
38-
trimmed_dataset = dataset
39-
trimmed_dataset = {k: v.to(DEVICE) for k, v in trimmed_dataset.items() if type(v) == torch.Tensor}
40-
41-
input = {
42-
"points_2d": trimmed_dataset['points_2d'],
43-
"camera_indices": trimmed_dataset['camera_index_of_observations'],
44-
"point_indices": trimmed_dataset['point_index_of_observations']
45-
}
46-
47-
model = Reproj(
48-
trimmed_dataset['camera_params'][:, :NUM_CAMERA_PARAMS].clone(),
49-
trimmed_dataset['points_3d'].clone()
50-
).to(DEVICE)
51-
strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4)
52-
solver = PCG(tol=1e-4, maxiter=250) # or CuDSS()
53-
optimizer = LM(model, strategy=strategy, solver=solver, reject=30)
54-
55-
print('Loss:', least_square_error(
56-
model.pose,
57-
model.points_3d,
58-
trimmed_dataset['camera_index_of_observations'],
59-
trimmed_dataset['point_index_of_observations'],
60-
trimmed_dataset['points_2d'],
61-
).item())
62-
63-
print("Initial loss", optimizer.model.loss(input, None).item())
64-
65-
start = perf_counter()
66-
for idx in range(20):
67-
loss = optimizer.step(input)
68-
print('Iteration', idx, 'loss', loss.item(), 'time', perf_counter() - start)
69-
70-
torch.cuda.synchronize()
71-
end = perf_counter()
72-
print('Time', end - start)
73-
74-
print('Ending loss:', least_square_error(
75-
model.pose,
76-
model.points_3d,
77-
trimmed_dataset['camera_index_of_observations'],
78-
trimmed_dataset['point_index_of_observations'],
79-
trimmed_dataset['points_2d'],
80-
).item())
23+
NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7
24+
25+
26+
@map_transform
27+
def project(points, camera_params):
28+
projection = rotate_quat(points, camera_params[..., :7])
29+
projection = -projection[..., :2] / projection[..., [2]]
30+
31+
f = camera_params[..., [-3]]
32+
k1 = camera_params[..., [-2]]
33+
k2 = camera_params[..., [-1]]
34+
35+
n = torch.sum(projection**2, axis=-1, keepdim=True)
36+
r = 1 + k1 * n + k2 * n**2
37+
return projection * r * f
38+
39+
40+
class Residual(nn.Module):
41+
def __init__(self, camera_params, points):
42+
super().__init__()
43+
self.pose = nn.Parameter(TrackingTensor(camera_params))
44+
self.points = nn.Parameter(TrackingTensor(points))
45+
self.pose.trim_SE3_grad = True
46+
47+
def forward(self, observes, cidx, pidx):
48+
points_proj = project(self.points[pidx], self.pose[cidx])
49+
return points_proj - observes
50+
51+
52+
def least_square_error(camera_params, points, cidx, pidx, observes):
53+
model = Residual(camera_params, points)
54+
loss = model(observes, cidx, pidx)
55+
return torch.sum(loss**2, dim=-1).mean()
56+
57+
58+
def main():
59+
dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET)
60+
print(f"Fetched {TARGET_PROBLEM} from {TARGET_DATASET}")
61+
62+
dataset = {
63+
key: value.to(DEVICE)
64+
for key, value in dataset.items()
65+
if isinstance(value, torch.Tensor)
66+
}
67+
input = {
68+
"observes": dataset["points_2d"],
69+
"cidx": dataset["camera_index_of_observations"],
70+
"pidx": dataset["point_index_of_observations"],
71+
}
72+
73+
model = Residual(
74+
dataset["camera_params"][:, :NUM_CAMERA_PARAMS].clone(),
75+
dataset["points_3d"].clone(),
76+
).to(DEVICE)
77+
strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4)
78+
solver = PCG(tol=1e-4, maxiter=250)
79+
optimizer = LM(model, strategy=strategy, solver=solver, reject=30)
80+
81+
print('Loss:', least_square_error(
82+
model.pose,
83+
model.points,
84+
dataset["camera_index_of_observations"],
85+
dataset["point_index_of_observations"],
86+
dataset["points_2d"],
87+
).item())
88+
89+
print("Initial loss", optimizer.model.loss(input, None).item())
90+
91+
start = perf_counter()
92+
for idx in range(20):
93+
loss = optimizer.step(input)
94+
print("Iteration", idx, "loss", loss.item(), "time", perf_counter() - start)
95+
96+
torch.cuda.synchronize()
97+
end = perf_counter()
98+
print("Time", end - start)
99+
100+
print('Ending loss:', least_square_error(
101+
model.pose,
102+
model.points,
103+
dataset["camera_index_of_observations"],
104+
dataset["point_index_of_observations"],
105+
dataset["points_2d"],
106+
).item())
107+
108+
109+
if __name__ == "__main__":
110+
main()

ba_helpers.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

datapipes/bal_io.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _rotvec_to_quat_xyzw(rotvec: torch.Tensor) -> torch.Tensor:
1717
return torch.cat([xyz, cos_half], dim=-1)
1818

1919

20-
def read_bal_data(file_name: str, use_quat: bool = False) -> dict:
20+
def read_bal_data(file_name: str, use_quat: bool = True) -> dict:
2121
"""
2222
Read a Bundle Adjustment in the Large dataset problem text file.
2323
@@ -29,8 +29,8 @@ def read_bal_data(file_name: str, use_quat: bool = False) -> dict:
2929
3030
Each camera has 9 parameters: Rodrigues rotvec (3), translation (3), f, k1, k2.
3131
This loader outputs either:
32-
- use_quat=False: [tx, ty, tz, rx, ry, rz, f, k1, k2] (9)
3332
- use_quat=True: [tx, ty, tz, qx, qy, qz, qw, f, k1, k2] (10)
33+
- use_quat=False: [tx, ty, tz, rx, ry, rz, f, k1, k2] (9)
3434
"""
3535
with open(file_name, "r") as file:
3636
n_cameras, n_points, n_observations = map(int, file.readline().split())
@@ -70,4 +70,3 @@ def read_bal_data(file_name: str, use_quat: bool = False) -> dict:
7070
"camera_index_of_observations": camera_indices,
7171
"point_index_of_observations": point_indices,
7272
}
73-

datapipes/bal_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _ensure_problem_available(dataset: str, problem_name: str, cache_dir: Path)
8383
os.replace(tmp_path, txt_path)
8484
return txt_path
8585

86-
def get_problem(problem_name, dataset, cache_dir='bal_data', use_quat=False):
86+
def get_problem(problem_name, dataset, cache_dir='bal_data', use_quat=True):
8787
cache_path = Path(cache_dir)
8888
print(f"Preparing data for {dataset}...")
8989
_validate_dataset(dataset)

tests/autograd/test_bal_jacobian.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
if str(_REPO_ROOT) not in sys.path:
2020
sys.path.insert(0, str(_REPO_ROOT))
2121

22-
from ba_helpers import Reproj, project, least_square_error # noqa: E402
22+
from ba_example import Residual, project, least_square_error # noqa: E402
2323
from bae.autograd.function import TrackingTensor, map_transform
2424
import bae.autograd.graph as autograd_graph # noqa: E402
2525
from bae.optim import LM # noqa: E402
@@ -234,11 +234,11 @@ def _final_bal_per_pixel_error(
234234
point_idx: torch.Tensor,
235235
) -> float:
236236
input = {
237-
"points_2d": points_2d,
238-
"camera_indices": camera_idx,
239-
"point_indices": point_idx,
237+
"observes": points_2d,
238+
"cidx": camera_idx,
239+
"pidx": point_idx,
240240
}
241-
model = Reproj(camera_params.clone(), points_3d.clone())
241+
model = Residual(camera_params.clone(), points_3d.clone())
242242
strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4)
243243
solver = PCG(tol=1e-4, maxiter=250)
244244
optimizer = LM(model, strategy=strategy, solver=solver, reject=30)
@@ -248,7 +248,7 @@ def _final_bal_per_pixel_error(
248248

249249
return least_square_error(
250250
model.pose,
251-
model.points_3d,
251+
model.points,
252252
camera_idx,
253253
point_idx,
254254
points_2d,
@@ -284,16 +284,16 @@ def test_bal_jacobian_structure_no_empty_columns(
284284
camera_idx = camera_idx.to(device=device)
285285
point_idx = point_idx.to(device=device)
286286

287-
model = Reproj(camera_params.clone(), points_3d.clone()).to(device)
287+
model = Residual(camera_params.clone(), points_3d.clone()).to(device)
288288
residual = model(points_2d, camera_idx, point_idx)
289289
n_obs = int(points_2d.shape[0])
290290

291-
J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points_3d])
291+
J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points])
292292
assert J_cam.layout == torch.sparse_bsr
293293
assert J_pts.layout == torch.sparse_bsr
294294

295295
n_cams = model.pose.shape[0]
296-
n_pts = model.points_3d.shape[0]
296+
n_pts = model.points.shape[0]
297297

298298
assert J_cam.shape == (n_obs * 2, n_cams * 9)
299299
assert J_pts.shape == (n_obs * 2, n_pts * 3)
@@ -414,7 +414,7 @@ def _final_bal_per_pixel_error_fixed_first_camera_cat(
414414
camera_all = torch.cat([camera_se3_all, model.intrinsics.tensor()], dim=-1)
415415
return least_square_error(
416416
camera_all,
417-
model.points_3d.tensor(),
417+
model.points.tensor(),
418418
camera_idx,
419419
point_idx,
420420
points_2d,
@@ -603,9 +603,9 @@ def test_bal_jacobian_structure_assert_failed_when_missing_observation_appearanc
603603
n_pts=n_pts,
604604
)
605605

606-
model = Reproj(camera_params.clone(), points_3d.clone()).to(device)
606+
model = Residual(camera_params.clone(), points_3d.clone()).to(device)
607607
residual = model(points_2d, camera_idx2, point_idx2)
608-
J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points_3d])
608+
J_cam, J_pts = autograd_graph.jacobian(residual, [model.pose, model.points])
609609

610610
with pytest.raises(AssertionError):
611611
_assert_bal_correctness_criteria(

0 commit comments

Comments
 (0)