77
88import symforce .symbolic as sf
99from symforce import ops
10- from symforce import python_util
1110from symforce import typing as T
1211from symforce .codegen import Codegen
1312from symforce .codegen import CppConfig
@@ -84,38 +83,10 @@ def prior_factor(
8483 return residual
8584
8685
87- def get_function_code ( codegen : Codegen , cleanup : bool = True ) -> str :
86+ def generate_between_factors ( types : T . Sequence [ T . Type ], output_dir : T . Openable ) -> None :
8887 """
89- Return just the function code from a Codegen object .
88+ Generates between factors for each type in types into output_dir .
9089 """
91- # Codegen
92- data = codegen .generate_function ()
93-
94- # Read
95- assert codegen .name is not None
96- filename = "{}.h" .format (codegen .name )
97- func_code = (data .function_dir / filename ).read_text ()
98-
99- # Cleanup
100- if cleanup :
101- python_util .remove_if_exists (data .output_dir )
102-
103- return func_code
104-
105-
106- def get_filename (codegen : Codegen ) -> str :
107- """
108- Helper to get appropriate filename
109- """
110- assert codegen .name is not None
111- return codegen .name + ".h"
112-
113-
114- def get_between_factors (types : T .Sequence [T .Type ]) -> T .Dict [str , str ]:
115- """
116- Compute
117- """
118- files_dict : T .Dict [str , str ] = {}
11990 for cls in types :
12091 tangent_dim = ops .LieGroupOps .tangent_dim (cls )
12192 between_codegen = Codegen .function (
@@ -125,7 +96,7 @@ def get_between_factors(types: T.Sequence[T.Type]) -> T.Dict[str, str]:
12596 config = CppConfig (),
12697 docstring = get_between_factor_docstring ("a_T_b" ),
12798 ).with_linearization (name = f"between_factor_{ cls .__name__ .lower ()} " , which_args = ["a" , "b" ])
128- files_dict [ get_filename ( between_codegen )] = get_function_code ( between_codegen )
99+ between_codegen . generate_function ( output_dir , skip_directory_nesting = True )
129100
130101 prior_codegen = Codegen .function (
131102 func = prior_factor ,
@@ -134,14 +105,12 @@ def get_between_factors(types: T.Sequence[T.Type]) -> T.Dict[str, str]:
134105 config = CppConfig (),
135106 docstring = get_prior_docstring (),
136107 ).with_linearization (name = f"prior_factor_{ cls .__name__ .lower ()} " , which_args = ["value" ])
137- files_dict [get_filename (prior_codegen )] = get_function_code (prior_codegen )
138-
139- return files_dict
108+ prior_codegen .generate_function (output_dir , skip_directory_nesting = True )
140109
141110
142- def get_pose3_extra_factors ( files_dict : T .Dict [ str , str ] ) -> None :
111+ def generate_pose3_extra_factors ( output_dir : T .Openable ) -> None :
143112 """
144- Generates factors specific to Poses which penalize individual components
113+ Generates factors specific to Poses which penalize individual components into output_dir.
145114
146115 This includes factors for only the position or rotation components of a Pose. This can't be
147116 done by wrapping the other generated functions because we need jacobians with respect to the
@@ -210,53 +179,41 @@ def prior_factor_pose3_position(
210179 config = CppConfig (),
211180 docstring = get_between_factor_docstring ("a_R_b" ),
212181 ).with_linearization (name = "between_factor_pose3_rotation" , which_args = ["a" , "b" ])
182+ between_rotation_codegen .generate_function (output_dir , skip_directory_nesting = True )
213183
214184 between_position_codegen = Codegen .function (
215185 func = between_factor_pose3_position ,
216186 output_names = ["res" ],
217187 config = CppConfig (),
218188 docstring = get_between_factor_docstring ("a_t_b" ),
219189 ).with_linearization (name = "between_factor_pose3_position" , which_args = ["a" , "b" ])
190+ between_position_codegen .generate_function (output_dir , skip_directory_nesting = True )
220191
221192 between_translation_norm_codegen = Codegen .function (
222193 func = between_factor_pose3_translation_norm , output_names = ["res" ], config = CppConfig ()
223194 ).with_linearization (name = "between_factor_pose3_translation_norm" , which_args = ["a" , "b" ])
195+ between_translation_norm_codegen .generate_function (output_dir , skip_directory_nesting = True )
224196
225197 prior_rotation_codegen = Codegen .function (
226198 func = prior_factor_pose3_rotation ,
227199 output_names = ["res" ],
228200 config = CppConfig (),
229201 docstring = get_prior_docstring (),
230202 ).with_linearization (name = "prior_factor_pose3_rotation" , which_args = ["value" ])
203+ prior_rotation_codegen .generate_function (output_dir , skip_directory_nesting = True )
231204
232205 prior_position_codegen = Codegen .function (
233206 func = prior_factor_pose3_position ,
234207 output_names = ["res" ],
235208 config = CppConfig (),
236209 docstring = get_prior_docstring (),
237210 ).with_linearization (name = "prior_factor_pose3_position" , which_args = ["value" ])
238-
239- files_dict [get_filename (between_rotation_codegen )] = get_function_code (between_rotation_codegen )
240- files_dict [get_filename (between_position_codegen )] = get_function_code (between_position_codegen )
241- files_dict [get_filename (between_translation_norm_codegen )] = get_function_code (
242- between_translation_norm_codegen
243- )
244- files_dict [get_filename (prior_rotation_codegen )] = get_function_code (prior_rotation_codegen )
245- files_dict [get_filename (prior_position_codegen )] = get_function_code (prior_position_codegen )
211+ prior_position_codegen .generate_function (output_dir , skip_directory_nesting = True )
246212
247213
248214def generate (output_dir : Path ) -> None :
249215 """
250216 Prior factors and between factors for C++.
251217 """
252- # Compute code
253- files_dict = get_between_factors (types = TYPES )
254- get_pose3_extra_factors (files_dict )
255-
256- # Create output dir
257- factors_dir = output_dir / "factors"
258- factors_dir .mkdir (parents = True )
259-
260- # Write out
261- for filename , code in files_dict .items ():
262- (factors_dir / filename ).write_text (code )
218+ generate_between_factors (types = TYPES , output_dir = output_dir / "factors" )
219+ generate_pose3_extra_factors (output_dir / "factors" )
0 commit comments