33# This source code is under the Apache 2.0 license found in the LICENSE file.
44# ----------------------------------------------------------------------------
55
6+ import copy
7+ import functools
8+ import inspect
69from pathlib import Path
710
811import symforce .symbolic as sf
912from symforce import ops
1013from symforce import typing as T
1114from symforce .codegen import Codegen
1215from symforce .codegen import CppConfig
16+ from symforce .typing_util import get_type
1317
1418TYPES = (sf .Rot2 , sf .Rot3 , sf .V3 , sf .Pose2 , sf .Pose3 )
1519
@@ -28,7 +32,8 @@ def get_between_factor_docstring(between_argument_name: str) -> str:
2832 Args:
2933 sqrt_info: Square root information matrix to whiten residual. This can be computed from
3034 a covariance matrix as the cholesky decomposition of the inverse. In the case
31- of a diagonal it will contain 1/sigma values. Must match the tangent dim.
35+ of a diagonal it will contain 1/sigma values. Pass in a single scalar if
36+ isotropic, a vector if diagonal, and a square matrix otherwise.
3237 """ .format (
3338 a_T_b = between_argument_name
3439 )
@@ -47,7 +52,8 @@ def get_prior_docstring() -> str:
4752 Args:
4853 sqrt_info: Square root information matrix to whiten residual. This can be computed from
4954 a covariance matrix as the cholesky decomposition of the inverse. In the case
50- of a diagonal it will contain 1/sigma values. Must match the tangent dim.
55+ of a diagonal it will contain 1/sigma values. Pass in a single scalar if
56+ isotropic, a vector if diagonal, and a square matrix otherwise.
5157 """
5258
5359
@@ -83,29 +89,194 @@ def prior_factor(
8389 return residual
8490
8591
92+ def modify_argument (
93+ core_func : T .Callable , arg_to_modify : str , new_arg_type : T .Type , modification : T .Callable
94+ ) -> T .Callable :
95+ """
96+ Returns a wrapper which applies modification to the arg_to_modify parameter before forwarding
97+ it and any other arguments on to core_func. Also, sets the type annotation of arg_to_modify
98+ to new_arg_type in the returned function.
99+
100+ If arg_to_modify is not a parameter of core_func, then a Value Error will be raised.
101+ """
102+ try :
103+ arg_index = list (inspect .signature (core_func ).parameters ).index (arg_to_modify )
104+ except ValueError as error :
105+ raise ValueError (f"{ arg_to_modify } is not an argument of { core_func } " ) from error
106+
107+ @functools .wraps (core_func )
108+ def wrapper (* args : T .Any , ** kwargs : T .Any ) -> T .Any :
109+ args_list = list (args )
110+ if arg_index < len (args ):
111+ # Then arg_to_modify was passed in args
112+ args_list [arg_index ] = modification (args [arg_index ])
113+ else :
114+ # arg_to_modify should have been passed in kwargs
115+ try :
116+ kwargs [arg_to_modify ] = modification (kwargs [arg_to_modify ])
117+ except KeyError as error :
118+ raise TypeError (f"{ wrapper } missing required argument { arg_to_modify } " ) from error
119+
120+ return core_func (* args_list , ** kwargs )
121+
122+ wrapper .__annotations__ = dict (wrapper .__annotations__ , ** {arg_to_modify : new_arg_type })
123+
124+ return wrapper
125+
126+
127+ def is_not_fixed_size_square_matrix (type_t : T .Type ) -> bool :
128+ return (
129+ not issubclass (type_t , sf .Matrix )
130+ or type_t == sf .Matrix
131+ or type_t .SHAPE [0 ] != type_t .SHAPE [1 ]
132+ )
133+
134+
135+ def _get_sqrt_info_dim (func : T .Callable ) -> int :
136+ """
137+ Raises ValueError if func does not have a parameter named sqrt_info annotated as a fixed
138+ sized square symbolic matrix.
139+ Returns the matrix dimension of sqrt_info type annotation of func.
140+ """
141+ if "sqrt_info" not in func .__annotations__ :
142+ raise ValueError (
143+ "sqrt_info missing annotation. Either add one or explicitly pass in expected number of dimensions"
144+ )
145+ sqrt_info_type = func .__annotations__ ["sqrt_info" ]
146+ if is_not_fixed_size_square_matrix (sqrt_info_type ):
147+ raise ValueError (
148+ f"""Expected sqrt_info to be annotated as a fixed size square matrix. Instead
149+ found { sqrt_info_type } . Either fix annotation or explicitly pass in expected number
150+ of dimensions of sqrt_info."""
151+ )
152+ return sqrt_info_type .SHAPE [0 ]
153+
154+
155+ def isotropic_sqrt_info_wrapper (func : T .Callable ) -> T .Callable :
156+ """
157+ Wraps func, except instead of taking a square matrix for the sqrt_info argument, it takes a
158+ scalar and passes that scalar times the identity matrix in for the value of sqrt_info.
159+
160+ Raises ValueError if func does not have an argument named sqrt_info annotated as a fixed size
161+ symbolic square matrix.
162+ """
163+ sqrt_info_dim = _get_sqrt_info_dim (func )
164+
165+ return modify_argument (
166+ func ,
167+ arg_to_modify = "sqrt_info" ,
168+ new_arg_type = T .Scalar ,
169+ modification = lambda sqrt_info : sqrt_info * sf .M .eye (sqrt_info_dim , sqrt_info_dim ),
170+ )
171+
172+
173+ def diagonal_sqrt_info_wrapper (func : T .Callable ) -> T .Callable :
174+ """
175+ Wraps func, except instead of taking a square matrix for the sqrt_info argument, it takes a
176+ vector representing the diagonal and passes the corresponding diagonal matrix in for the
177+ value of sqrt_info.
178+
179+ Raises ValueError if func does not have an argument named sqrt_info annotated as a fixed size
180+ symbolic square matrix.
181+ """
182+ sqrt_info_dim = _get_sqrt_info_dim (func )
183+
184+ return modify_argument (
185+ func ,
186+ arg_to_modify = "sqrt_info" ,
187+ new_arg_type = type (sf .M (sqrt_info_dim , 1 )),
188+ modification = sf .M .diag ,
189+ )
190+
191+
192+ def override_annotations (func : T .Callable , input_types : T .Sequence [T .ElementOrType ]) -> T .Callable :
193+ """
194+ Returns copy of func which is the same except its parameters are annotated with input_types.
195+
196+ Raises a ValueError if the length of input_types does not match parameter count of func.
197+ """
198+ parameters = inspect .signature (func ).parameters
199+ if len (parameters ) != len (input_types ):
200+ raise ValueError (
201+ f"{ func } has { len (parameters )} inputs, but input_types has length { len (input_types )} "
202+ )
203+ new_func = copy .copy (func )
204+ new_func .__annotations__ = dict (
205+ func .__annotations__ , ** {param : get_type (tp ) for param , tp in zip (parameters , input_types )}
206+ )
207+ return new_func
208+
209+
210+ def generate_with_alternate_sqrt_infos (
211+ output_dir : T .Openable ,
212+ func : T .Callable ,
213+ name : str ,
214+ which_args : T .Sequence [str ],
215+ input_types : T .Sequence [T .ElementOrType ] = None ,
216+ output_names : T .Sequence [str ] = None ,
217+ docstring : str = None ,
218+ ) -> None :
219+ """
220+ Generates func with linearization into output_dir / name, along with two overloads of func:
221+ one which takes a single scalar representing an isotropic matrix for the sqrt_info argument,
222+ and another which instead takes a vector representing a diagonal matrix for the same argument.
223+
224+ A common header located at output_dir / name.h will re-export each of these overloads.
225+
226+ As usual, if func does not have concrete type annotations, then input_types must be passed in
227+ (the same as you would if calling Codegen.function direction on func).
228+ """
229+
230+ common_header = Path (output_dir , name + ".h" )
231+ common_header .parent .mkdir (exist_ok = True , parents = True )
232+
233+ annotated_func = func if input_types is None else override_annotations (func , input_types )
234+
235+ for func_variant , variant_name in [
236+ (diagonal_sqrt_info_wrapper (annotated_func ), f"{ name } _diagonal" ),
237+ (isotropic_sqrt_info_wrapper (annotated_func ), f"{ name } _isotropic" ),
238+ (annotated_func , f"{ name } _square" ),
239+ ]:
240+ Codegen .function (
241+ func = func_variant ,
242+ output_names = output_names ,
243+ config = CppConfig (),
244+ docstring = docstring ,
245+ ).with_linearization (name = name , which_args = which_args ).generate_function (
246+ Path (output_dir , name ),
247+ skip_directory_nesting = True ,
248+ generated_file_name = variant_name + ".h" ,
249+ )
250+
251+ with common_header .open ("a" ) as f :
252+ f .write (f'#include "./{ name } /{ variant_name } .h"\n ' )
253+
254+
86255def generate_between_factors (types : T .Sequence [T .Type ], output_dir : T .Openable ) -> None :
87256 """
88257 Generates between factors for each type in types into output_dir.
89258 """
90259 for cls in types :
91260 tangent_dim = ops .LieGroupOps .tangent_dim (cls )
92- between_codegen = Codegen .function (
261+ generate_with_alternate_sqrt_infos (
262+ output_dir ,
93263 func = between_factor ,
264+ name = f"between_factor_{ cls .__name__ .lower ()} " ,
265+ which_args = ["a" , "b" ],
94266 input_types = [cls , cls , cls , sf .M (tangent_dim , tangent_dim ), sf .Symbol ],
95267 output_names = ["res" ],
96- config = CppConfig (),
97268 docstring = get_between_factor_docstring ("a_T_b" ),
98- ).with_linearization (name = f"between_factor_{ cls .__name__ .lower ()} " , which_args = ["a" , "b" ])
99- between_codegen .generate_function (output_dir , skip_directory_nesting = True )
269+ )
100270
101- prior_codegen = Codegen .function (
271+ generate_with_alternate_sqrt_infos (
272+ output_dir ,
102273 func = prior_factor ,
274+ name = f"prior_factor_{ cls .__name__ .lower ()} " ,
275+ which_args = ["value" ],
103276 input_types = [cls , cls , sf .M (tangent_dim , tangent_dim ), sf .Symbol ],
104277 output_names = ["res" ],
105- config = CppConfig (),
106278 docstring = get_prior_docstring (),
107- ).with_linearization (name = f"prior_factor_{ cls .__name__ .lower ()} " , which_args = ["value" ])
108- prior_codegen .generate_function (output_dir , skip_directory_nesting = True )
279+ )
109280
110281
111282def generate_pose3_extra_factors (output_dir : T .Openable ) -> None :
@@ -173,42 +344,46 @@ def prior_factor_pose3_position(
173344 ) -> sf .Matrix :
174345 return prior_factor (value .t , prior , sqrt_info , epsilon )
175346
176- between_rotation_codegen = Codegen .function (
347+ generate_with_alternate_sqrt_infos (
348+ output_dir ,
177349 func = between_factor_pose3_rotation ,
350+ name = "between_factor_pose3_rotation" ,
351+ which_args = ["a" , "b" ],
178352 output_names = ["res" ],
179- config = CppConfig (),
180353 docstring = get_between_factor_docstring ("a_R_b" ),
181- ).with_linearization (name = "between_factor_pose3_rotation" , which_args = ["a" , "b" ])
182- between_rotation_codegen .generate_function (output_dir , skip_directory_nesting = True )
354+ )
183355
184- between_position_codegen = Codegen .function (
356+ generate_with_alternate_sqrt_infos (
357+ output_dir ,
185358 func = between_factor_pose3_position ,
359+ name = "between_factor_pose3_position" ,
360+ which_args = ["a" , "b" ],
186361 output_names = ["res" ],
187- config = CppConfig (),
188362 docstring = get_between_factor_docstring ("a_t_b" ),
189- ).with_linearization (name = "between_factor_pose3_position" , which_args = ["a" , "b" ])
190- between_position_codegen .generate_function (output_dir , skip_directory_nesting = True )
363+ )
191364
192365 between_translation_norm_codegen = Codegen .function (
193366 func = between_factor_pose3_translation_norm , output_names = ["res" ], config = CppConfig ()
194367 ).with_linearization (name = "between_factor_pose3_translation_norm" , which_args = ["a" , "b" ])
195368 between_translation_norm_codegen .generate_function (output_dir , skip_directory_nesting = True )
196369
197- prior_rotation_codegen = Codegen .function (
370+ generate_with_alternate_sqrt_infos (
371+ output_dir ,
198372 func = prior_factor_pose3_rotation ,
373+ name = "prior_factor_pose3_rotation" ,
199374 output_names = ["res" ],
200- config = CppConfig () ,
375+ which_args = [ "value" ] ,
201376 docstring = get_prior_docstring (),
202- ).with_linearization (name = "prior_factor_pose3_rotation" , which_args = ["value" ])
203- prior_rotation_codegen .generate_function (output_dir , skip_directory_nesting = True )
377+ )
204378
205- prior_position_codegen = Codegen .function (
379+ generate_with_alternate_sqrt_infos (
380+ output_dir ,
206381 func = prior_factor_pose3_position ,
382+ name = "prior_factor_pose3_position" ,
207383 output_names = ["res" ],
208- config = CppConfig () ,
384+ which_args = [ "value" ] ,
209385 docstring = get_prior_docstring (),
210- ).with_linearization (name = "prior_factor_pose3_position" , which_args = ["value" ])
211- prior_position_codegen .generate_function (output_dir , skip_directory_nesting = True )
386+ )
212387
213388
214389def generate (output_dir : Path ) -> None :
0 commit comments