|
16 | 16 |
|
17 | 17 |
|
18 | 18 | class JaxInterpreter(torch.fx.Interpreter):
|
19 |
| - """Experimental.""" |
| 19 | + """An `fx.Interpreter` that executes a PyTorch FX graph using JAX. |
| 20 | +
|
| 21 | + This interpreter traverses an FX graph and replaces PyTorch operations with |
| 22 | + their corresponding JAX implementations from the `torchax` operator registry. |
| 23 | + It is a key component in the process of exporting PyTorch models to JAX and |
| 24 | + StableHLO. |
| 25 | + """ |
20 | 26 |
|
21 | 27 | def __init__(self, graph_module):
|
22 | 28 | super().__init__(graph_module)
|
@@ -74,11 +80,24 @@ def _extract_states_from_exported_program(exported_model):
|
74 | 80 |
|
75 | 81 |
|
76 | 82 | def exported_program_to_jax(exported_program, export_raw: bool = False):
|
77 |
| - """returns a pytree of jax arrays(state), and |
| 83 | + """Converts a `torch.export.ExportedProgram` to a JAX-compatible function and state. |
| 84 | +
|
| 85 | + This function takes a PyTorch `ExportedProgram`, runs the necessary |
| 86 | + decompositions, and returns a JAX-compatible function and the model's state |
| 87 | + (parameters and buffers) as JAX arrays. |
| 88 | +
|
| 89 | + **Arguments:** |
78 | 90 |
|
79 |
| - a callable(func) that is jax function. |
| 91 | + * `exported_program` (`torch.export.ExportedProgram`): The PyTorch |
| 92 | + `ExportedProgram` to convert. |
| 93 | + * `export_raw` (`bool`, optional): If `True`, returns the raw states and |
| 94 | + function without converting them to JAX arrays. Defaults to `False`. |
80 | 95 |
|
81 |
| - func(state, input) would be how you call it. |
| 96 | + **Returns:** |
| 97 | +
|
| 98 | + A tuple containing: |
| 99 | + * A pytree of JAX arrays representing the model's state. |
| 100 | + * A JAX-callable function that takes the state and inputs as arguments. |
82 | 101 | """
|
83 | 102 | if torch.__version__ >= '2.2':
|
84 | 103 | # torch version 2.1 didn't expose this yet
|
@@ -115,8 +134,19 @@ def func(states, inputs):
|
115 | 134 |
|
116 | 135 |
|
117 | 136 | def extract_avals(exported):
|
118 |
| - """Return JAX Abstract Value shapes for all input parameters of the exported |
119 |
| - program. This supports dynamic batch dimensions, including with constraints. |
| 137 | + """Returns JAX abstract values (`ShapeDtypeStruct`) for all input parameters of the exported program. |
| 138 | +
|
| 139 | + This function supports dynamic batch dimensions, including those with |
| 140 | + constraints. |
| 141 | +
|
| 142 | + **Arguments:** |
| 143 | +
|
| 144 | + * `exported` (`torch.export.ExportedProgram`): The exported PyTorch program. |
| 145 | +
|
| 146 | + **Returns:** |
| 147 | +
|
| 148 | + A list of `jax.ShapeDtypeStruct` objects representing the abstract values of |
| 149 | + the input parameters. |
120 | 150 | """
|
121 | 151 |
|
122 | 152 | def _to_aval(arg_meta, symbolic_shapes):
|
@@ -232,12 +262,24 @@ def _build_symbolic_shape(sym, constraint, free_symbols):
|
232 | 262 |
|
233 | 263 |
|
234 | 264 | def exported_program_to_stablehlo(exported_program):
|
235 |
| - """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo |
| 265 | + """Converts a `torch.export.ExportedProgram` to StableHLO. |
| 266 | +
|
| 267 | + This function serves as a replacement for |
| 268 | + `torch_xla.stablehlo.exported_program_to_stablehlo`. It supports dynamic |
| 269 | + dimension sizes and generates explicit checks for Dynamo guards in the IR |
| 270 | + using `shape_assertion` custom calls. |
| 271 | +
|
| 272 | + **Arguments:** |
| 273 | +
|
| 274 | + * `exported_program` (`torch.export.ExportedProgram`): The exported PyTorch |
| 275 | + program. |
236 | 276 |
|
237 |
| - Convert a program exported via torch.export to StableHLO. |
| 277 | + **Returns:** |
238 | 278 |
|
239 |
| - This supports dynamic dimension sizes and generates explicit checks for |
240 |
| - dynamo guards in the IR using shape_assertion custom_call ops. |
| 279 | + A tuple containing: |
| 280 | + * The model's state (weights) as a pytree of JAX arrays. |
| 281 | + * A `jax.export.Exported` object containing the StableHLO representation of |
| 282 | + the model. |
241 | 283 | """
|
242 | 284 | weights, func = exported_program_to_jax(exported_program)
|
243 | 285 | jax_avals = extract_avals(exported_program)
|
|
0 commit comments