Skip to content

Commit 344fa98

Browse files
Add diagonal & isotropic sqrt_infos in geo factors
Often times the square root information matrix is a diagonal or isotropic matrix. However, in our generated geo factors we always assume they are full dense matrices leading to many more operations than needed. This commit modifies `geo_factors_codegen.py` to create 3 variants for each geo factor, one which assumes `sqrt_info` is a square matrix (the existing behavior), one which assumes it is a diagonal matrix, and one which assumes it is an isotropic matrix. These variants are all generated with the same name & are distinguished only by their signatures (thus this only works for C++, which is all we currently generate factors for anyway). The square version takes a square eigen matrix, the diagonal version a eigen vector whose entries are those of the diagonal of `sqrt_info`, and a single scalar for the isotropic version. The header that the user imports remains the same, except now rather than containing the implementation directly, it re-imports the headers for the 3 different versions (which are stores in sub-directories). Topic: sqrt_info_variants Relative: geo_factors_use_skip_directory_nesting
1 parent f9a6f2d commit 344fa98

File tree

1 file changed

+201
-26
lines changed

1 file changed

+201
-26
lines changed

symforce/codegen/geo_factors_codegen.py

Lines changed: 201 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
# This source code is under the Apache 2.0 license found in the LICENSE file.
44
# ----------------------------------------------------------------------------
55

6+
import copy
7+
import functools
8+
import inspect
69
from pathlib import Path
710

811
import symforce.symbolic as sf
912
from symforce import ops
1013
from symforce import typing as T
1114
from symforce.codegen import Codegen
1215
from symforce.codegen import CppConfig
16+
from symforce.typing_util import get_type
1317

1418
TYPES = (sf.Rot2, sf.Rot3, sf.V3, sf.Pose2, sf.Pose3)
1519

@@ -28,7 +32,8 @@ def get_between_factor_docstring(between_argument_name: str) -> str:
2832
Args:
2933
sqrt_info: Square root information matrix to whiten residual. This can be computed from
3034
a covariance matrix as the cholesky decomposition of the inverse. In the case
31-
of a diagonal it will contain 1/sigma values. Must match the tangent dim.
35+
of a diagonal it will contain 1/sigma values. Pass in a single scalar if
36+
isotropic, a vector if diagonal, and a square matrix otherwise.
3237
""".format(
3338
a_T_b=between_argument_name
3439
)
@@ -47,7 +52,8 @@ def get_prior_docstring() -> str:
4752
Args:
4853
sqrt_info: Square root information matrix to whiten residual. This can be computed from
4954
a covariance matrix as the cholesky decomposition of the inverse. In the case
50-
of a diagonal it will contain 1/sigma values. Must match the tangent dim.
55+
of a diagonal it will contain 1/sigma values. Pass in a single scalar if
56+
isotropic, a vector if diagonal, and a square matrix otherwise.
5157
"""
5258

5359

@@ -83,29 +89,194 @@ def prior_factor(
8389
return residual
8490

8591

92+
def modify_argument(
93+
core_func: T.Callable, arg_to_modify: str, new_arg_type: T.Type, modification: T.Callable
94+
) -> T.Callable:
95+
"""
96+
Returns a wrapper which applies modification to the arg_to_modify parameter before forwarding
97+
it and any other arguments on to core_func. Also, sets the type annotation of arg_to_modify
98+
to new_arg_type in the returned function.
99+
100+
If arg_to_modify is not a parameter of core_func, then a Value Error will be raised.
101+
"""
102+
try:
103+
arg_index = list(inspect.signature(core_func).parameters).index(arg_to_modify)
104+
except ValueError as error:
105+
raise ValueError(f"{arg_to_modify} is not an argument of {core_func}") from error
106+
107+
@functools.wraps(core_func)
108+
def wrapper(*args: T.Any, **kwargs: T.Any) -> T.Any:
109+
args_list = list(args)
110+
if arg_index < len(args):
111+
# Then arg_to_modify was passed in args
112+
args_list[arg_index] = modification(args[arg_index])
113+
else:
114+
# arg_to_modify should have been passed in kwargs
115+
try:
116+
kwargs[arg_to_modify] = modification(kwargs[arg_to_modify])
117+
except KeyError as error:
118+
raise TypeError(f"{wrapper} missing required argument {arg_to_modify}") from error
119+
120+
return core_func(*args_list, **kwargs)
121+
122+
wrapper.__annotations__ = dict(wrapper.__annotations__, **{arg_to_modify: new_arg_type})
123+
124+
return wrapper
125+
126+
127+
def is_not_fixed_size_square_matrix(type_t: T.Type) -> bool:
128+
return (
129+
not issubclass(type_t, sf.Matrix)
130+
or type_t == sf.Matrix
131+
or type_t.SHAPE[0] != type_t.SHAPE[1]
132+
)
133+
134+
135+
def _get_sqrt_info_dim(func: T.Callable) -> int:
136+
"""
137+
Raises ValueError if func does not have a parameter named sqrt_info annotated as a fixed
138+
sized square symbolic matrix.
139+
Returns the matrix dimension of sqrt_info type annotation of func.
140+
"""
141+
if "sqrt_info" not in func.__annotations__:
142+
raise ValueError(
143+
"sqrt_info missing annotation. Either add one or explicitly pass in expected number of dimensions"
144+
)
145+
sqrt_info_type = func.__annotations__["sqrt_info"]
146+
if is_not_fixed_size_square_matrix(sqrt_info_type):
147+
raise ValueError(
148+
f"""Expected sqrt_info to be annotated as a fixed size square matrix. Instead
149+
found {sqrt_info_type}. Either fix annotation or explicitly pass in expected number
150+
of dimensions of sqrt_info."""
151+
)
152+
return sqrt_info_type.SHAPE[0]
153+
154+
155+
def isotropic_sqrt_info_wrapper(func: T.Callable) -> T.Callable:
156+
"""
157+
Wraps func, except instead of taking a square matrix for the sqrt_info argument, it takes a
158+
scalar and passes that scalar times the identity matrix in for the value of sqrt_info.
159+
160+
Raises ValueError if func does not have an argument named sqrt_info annotated as a fixed size
161+
symbolic square matrix.
162+
"""
163+
sqrt_info_dim = _get_sqrt_info_dim(func)
164+
165+
return modify_argument(
166+
func,
167+
arg_to_modify="sqrt_info",
168+
new_arg_type=T.Scalar,
169+
modification=lambda sqrt_info: sqrt_info * sf.M.eye(sqrt_info_dim, sqrt_info_dim),
170+
)
171+
172+
173+
def diagonal_sqrt_info_wrapper(func: T.Callable) -> T.Callable:
174+
"""
175+
Wraps func, except instead of taking a square matrix for the sqrt_info argument, it takes a
176+
vector representing the diagonal and passes the corresponding diagonal matrix in for the
177+
value of sqrt_info.
178+
179+
Raises ValueError if func does not have an argument named sqrt_info annotated as a fixed size
180+
symbolic square matrix.
181+
"""
182+
sqrt_info_dim = _get_sqrt_info_dim(func)
183+
184+
return modify_argument(
185+
func,
186+
arg_to_modify="sqrt_info",
187+
new_arg_type=type(sf.M(sqrt_info_dim, 1)),
188+
modification=sf.M.diag,
189+
)
190+
191+
192+
def override_annotations(func: T.Callable, input_types: T.Sequence[T.ElementOrType]) -> T.Callable:
193+
"""
194+
Returns copy of func which is the same except its parameters are annotated with input_types.
195+
196+
Raises a ValueError if the length of input_types does not match parameter count of func.
197+
"""
198+
parameters = inspect.signature(func).parameters
199+
if len(parameters) != len(input_types):
200+
raise ValueError(
201+
f"{func} has {len(parameters)} inputs, but input_types has length {len(input_types)}"
202+
)
203+
new_func = copy.copy(func)
204+
new_func.__annotations__ = dict(
205+
func.__annotations__, **{param: get_type(tp) for param, tp in zip(parameters, input_types)}
206+
)
207+
return new_func
208+
209+
210+
def generate_with_alternate_sqrt_infos(
211+
output_dir: T.Openable,
212+
func: T.Callable,
213+
name: str,
214+
which_args: T.Sequence[str],
215+
input_types: T.Sequence[T.ElementOrType] = None,
216+
output_names: T.Sequence[str] = None,
217+
docstring: str = None,
218+
) -> None:
219+
"""
220+
Generates func with linearization into output_dir / name, along with two overloads of func:
221+
one which takes a single scalar representing an isotropic matrix for the sqrt_info argument,
222+
and another which instead takes a vector representing a diagonal matrix for the same argument.
223+
224+
A common header located at output_dir / name.h will re-export each of these overloads.
225+
226+
As usual, if func does not have concrete type annotations, then input_types must be passed in
227+
(the same as you would if calling Codegen.function direction on func).
228+
"""
229+
230+
common_header = Path(output_dir, name + ".h")
231+
common_header.parent.mkdir(exist_ok=True, parents=True)
232+
233+
annotated_func = func if input_types is None else override_annotations(func, input_types)
234+
235+
for func_variant, variant_name in [
236+
(diagonal_sqrt_info_wrapper(annotated_func), f"{name}_diagonal"),
237+
(isotropic_sqrt_info_wrapper(annotated_func), f"{name}_isotropic"),
238+
(annotated_func, f"{name}_square"),
239+
]:
240+
Codegen.function(
241+
func=func_variant,
242+
output_names=output_names,
243+
config=CppConfig(),
244+
docstring=docstring,
245+
).with_linearization(name=name, which_args=which_args).generate_function(
246+
Path(output_dir, name),
247+
skip_directory_nesting=True,
248+
generated_file_name=variant_name + ".h",
249+
)
250+
251+
with common_header.open("a") as f:
252+
f.write(f'#include "./{name}/{variant_name}.h"\n')
253+
254+
86255
def generate_between_factors(types: T.Sequence[T.Type], output_dir: T.Openable) -> None:
87256
"""
88257
Generates between factors for each type in types into output_dir.
89258
"""
90259
for cls in types:
91260
tangent_dim = ops.LieGroupOps.tangent_dim(cls)
92-
between_codegen = Codegen.function(
261+
generate_with_alternate_sqrt_infos(
262+
output_dir,
93263
func=between_factor,
264+
name=f"between_factor_{cls.__name__.lower()}",
265+
which_args=["a", "b"],
94266
input_types=[cls, cls, cls, sf.M(tangent_dim, tangent_dim), sf.Symbol],
95267
output_names=["res"],
96-
config=CppConfig(),
97268
docstring=get_between_factor_docstring("a_T_b"),
98-
).with_linearization(name=f"between_factor_{cls.__name__.lower()}", which_args=["a", "b"])
99-
between_codegen.generate_function(output_dir, skip_directory_nesting=True)
269+
)
100270

101-
prior_codegen = Codegen.function(
271+
generate_with_alternate_sqrt_infos(
272+
output_dir,
102273
func=prior_factor,
274+
name=f"prior_factor_{cls.__name__.lower()}",
275+
which_args=["value"],
103276
input_types=[cls, cls, sf.M(tangent_dim, tangent_dim), sf.Symbol],
104277
output_names=["res"],
105-
config=CppConfig(),
106278
docstring=get_prior_docstring(),
107-
).with_linearization(name=f"prior_factor_{cls.__name__.lower()}", which_args=["value"])
108-
prior_codegen.generate_function(output_dir, skip_directory_nesting=True)
279+
)
109280

110281

111282
def generate_pose3_extra_factors(output_dir: T.Openable) -> None:
@@ -173,42 +344,46 @@ def prior_factor_pose3_position(
173344
) -> sf.Matrix:
174345
return prior_factor(value.t, prior, sqrt_info, epsilon)
175346

176-
between_rotation_codegen = Codegen.function(
347+
generate_with_alternate_sqrt_infos(
348+
output_dir,
177349
func=between_factor_pose3_rotation,
350+
name="between_factor_pose3_rotation",
351+
which_args=["a", "b"],
178352
output_names=["res"],
179-
config=CppConfig(),
180353
docstring=get_between_factor_docstring("a_R_b"),
181-
).with_linearization(name="between_factor_pose3_rotation", which_args=["a", "b"])
182-
between_rotation_codegen.generate_function(output_dir, skip_directory_nesting=True)
354+
)
183355

184-
between_position_codegen = Codegen.function(
356+
generate_with_alternate_sqrt_infos(
357+
output_dir,
185358
func=between_factor_pose3_position,
359+
name="between_factor_pose3_position",
360+
which_args=["a", "b"],
186361
output_names=["res"],
187-
config=CppConfig(),
188362
docstring=get_between_factor_docstring("a_t_b"),
189-
).with_linearization(name="between_factor_pose3_position", which_args=["a", "b"])
190-
between_position_codegen.generate_function(output_dir, skip_directory_nesting=True)
363+
)
191364

192365
between_translation_norm_codegen = Codegen.function(
193366
func=between_factor_pose3_translation_norm, output_names=["res"], config=CppConfig()
194367
).with_linearization(name="between_factor_pose3_translation_norm", which_args=["a", "b"])
195368
between_translation_norm_codegen.generate_function(output_dir, skip_directory_nesting=True)
196369

197-
prior_rotation_codegen = Codegen.function(
370+
generate_with_alternate_sqrt_infos(
371+
output_dir,
198372
func=prior_factor_pose3_rotation,
373+
name="prior_factor_pose3_rotation",
199374
output_names=["res"],
200-
config=CppConfig(),
375+
which_args=["value"],
201376
docstring=get_prior_docstring(),
202-
).with_linearization(name="prior_factor_pose3_rotation", which_args=["value"])
203-
prior_rotation_codegen.generate_function(output_dir, skip_directory_nesting=True)
377+
)
204378

205-
prior_position_codegen = Codegen.function(
379+
generate_with_alternate_sqrt_infos(
380+
output_dir,
206381
func=prior_factor_pose3_position,
382+
name="prior_factor_pose3_position",
207383
output_names=["res"],
208-
config=CppConfig(),
384+
which_args=["value"],
209385
docstring=get_prior_docstring(),
210-
).with_linearization(name="prior_factor_pose3_position", which_args=["value"])
211-
prior_position_codegen.generate_function(output_dir, skip_directory_nesting=True)
386+
)
212387

213388

214389
def generate(output_dir: Path) -> None:

0 commit comments

Comments
 (0)