Skip to content

Commit f9a6f2d

Browse files
Simplify geo_factors_codegen w/ skip_dir_nesting
Previously, `geo_factors_codegen.py` was using an ad-hoc system of generating all the C++ factors into a `factors` folder. It did this by calling `Codegen.generate_function` to generate the code into a temporary file, read the file into a string, then re-wrote the string into the desired location. I assume this was to avoid all the fluff that's generated by default with `Codegen.generate_function`. However, there is the `skip_directory_nesting` optional argument for `Codegen.generate_function` which does precisely that (I think the code in this file might have been written before that option was added). So, to reduce confusion (such as the confusion I faced when I first started looking at this file) and complexity, I rewrote the code to instead use the `skip_directory_nesting` argument. Topic: geo_factors_use_skip_directory_nesting
1 parent 4c36881 commit f9a6f2d

File tree

1 file changed

+13
-56
lines changed

1 file changed

+13
-56
lines changed

symforce/codegen/geo_factors_codegen.py

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import symforce.symbolic as sf
99
from symforce import ops
10-
from symforce import python_util
1110
from symforce import typing as T
1211
from symforce.codegen import Codegen
1312
from symforce.codegen import CppConfig
@@ -84,38 +83,10 @@ def prior_factor(
8483
return residual
8584

8685

87-
def get_function_code(codegen: Codegen, cleanup: bool = True) -> str:
86+
def generate_between_factors(types: T.Sequence[T.Type], output_dir: T.Openable) -> None:
8887
"""
89-
Return just the function code from a Codegen object.
88+
Generates between factors for each type in types into output_dir.
9089
"""
91-
# Codegen
92-
data = codegen.generate_function()
93-
94-
# Read
95-
assert codegen.name is not None
96-
filename = "{}.h".format(codegen.name)
97-
func_code = (data.function_dir / filename).read_text()
98-
99-
# Cleanup
100-
if cleanup:
101-
python_util.remove_if_exists(data.output_dir)
102-
103-
return func_code
104-
105-
106-
def get_filename(codegen: Codegen) -> str:
107-
"""
108-
Helper to get appropriate filename
109-
"""
110-
assert codegen.name is not None
111-
return codegen.name + ".h"
112-
113-
114-
def get_between_factors(types: T.Sequence[T.Type]) -> T.Dict[str, str]:
115-
"""
116-
Compute
117-
"""
118-
files_dict: T.Dict[str, str] = {}
11990
for cls in types:
12091
tangent_dim = ops.LieGroupOps.tangent_dim(cls)
12192
between_codegen = Codegen.function(
@@ -125,7 +96,7 @@ def get_between_factors(types: T.Sequence[T.Type]) -> T.Dict[str, str]:
12596
config=CppConfig(),
12697
docstring=get_between_factor_docstring("a_T_b"),
12798
).with_linearization(name=f"between_factor_{cls.__name__.lower()}", which_args=["a", "b"])
128-
files_dict[get_filename(between_codegen)] = get_function_code(between_codegen)
99+
between_codegen.generate_function(output_dir, skip_directory_nesting=True)
129100

130101
prior_codegen = Codegen.function(
131102
func=prior_factor,
@@ -134,14 +105,12 @@ def get_between_factors(types: T.Sequence[T.Type]) -> T.Dict[str, str]:
134105
config=CppConfig(),
135106
docstring=get_prior_docstring(),
136107
).with_linearization(name=f"prior_factor_{cls.__name__.lower()}", which_args=["value"])
137-
files_dict[get_filename(prior_codegen)] = get_function_code(prior_codegen)
138-
139-
return files_dict
108+
prior_codegen.generate_function(output_dir, skip_directory_nesting=True)
140109

141110

142-
def get_pose3_extra_factors(files_dict: T.Dict[str, str]) -> None:
111+
def generate_pose3_extra_factors(output_dir: T.Openable) -> None:
143112
"""
144-
Generates factors specific to Poses which penalize individual components
113+
Generates factors specific to Poses which penalize individual components into output_dir.
145114
146115
This includes factors for only the position or rotation components of a Pose. This can't be
147116
done by wrapping the other generated functions because we need jacobians with respect to the
@@ -210,53 +179,41 @@ def prior_factor_pose3_position(
210179
config=CppConfig(),
211180
docstring=get_between_factor_docstring("a_R_b"),
212181
).with_linearization(name="between_factor_pose3_rotation", which_args=["a", "b"])
182+
between_rotation_codegen.generate_function(output_dir, skip_directory_nesting=True)
213183

214184
between_position_codegen = Codegen.function(
215185
func=between_factor_pose3_position,
216186
output_names=["res"],
217187
config=CppConfig(),
218188
docstring=get_between_factor_docstring("a_t_b"),
219189
).with_linearization(name="between_factor_pose3_position", which_args=["a", "b"])
190+
between_position_codegen.generate_function(output_dir, skip_directory_nesting=True)
220191

221192
between_translation_norm_codegen = Codegen.function(
222193
func=between_factor_pose3_translation_norm, output_names=["res"], config=CppConfig()
223194
).with_linearization(name="between_factor_pose3_translation_norm", which_args=["a", "b"])
195+
between_translation_norm_codegen.generate_function(output_dir, skip_directory_nesting=True)
224196

225197
prior_rotation_codegen = Codegen.function(
226198
func=prior_factor_pose3_rotation,
227199
output_names=["res"],
228200
config=CppConfig(),
229201
docstring=get_prior_docstring(),
230202
).with_linearization(name="prior_factor_pose3_rotation", which_args=["value"])
203+
prior_rotation_codegen.generate_function(output_dir, skip_directory_nesting=True)
231204

232205
prior_position_codegen = Codegen.function(
233206
func=prior_factor_pose3_position,
234207
output_names=["res"],
235208
config=CppConfig(),
236209
docstring=get_prior_docstring(),
237210
).with_linearization(name="prior_factor_pose3_position", which_args=["value"])
238-
239-
files_dict[get_filename(between_rotation_codegen)] = get_function_code(between_rotation_codegen)
240-
files_dict[get_filename(between_position_codegen)] = get_function_code(between_position_codegen)
241-
files_dict[get_filename(between_translation_norm_codegen)] = get_function_code(
242-
between_translation_norm_codegen
243-
)
244-
files_dict[get_filename(prior_rotation_codegen)] = get_function_code(prior_rotation_codegen)
245-
files_dict[get_filename(prior_position_codegen)] = get_function_code(prior_position_codegen)
211+
prior_position_codegen.generate_function(output_dir, skip_directory_nesting=True)
246212

247213

248214
def generate(output_dir: Path) -> None:
249215
"""
250216
Prior factors and between factors for C++.
251217
"""
252-
# Compute code
253-
files_dict = get_between_factors(types=TYPES)
254-
get_pose3_extra_factors(files_dict)
255-
256-
# Create output dir
257-
factors_dir = output_dir / "factors"
258-
factors_dir.mkdir(parents=True)
259-
260-
# Write out
261-
for filename, code in files_dict.items():
262-
(factors_dir / filename).write_text(code)
218+
generate_between_factors(types=TYPES, output_dir=output_dir / "factors")
219+
generate_pose3_extra_factors(output_dir / "factors")

0 commit comments

Comments
 (0)