|
| 1 | +import contextlib |
| 2 | +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union |
| 3 | +import torch |
| 4 | +from torch._higher_order_ops.utils import ( |
| 5 | + materialize_as_graph, |
| 6 | + check_input_alias_and_mutation_return_outputs, |
| 7 | + # _maybe_reenter_make_fx, |
| 8 | +) |
| 9 | + |
| 10 | +_TEST_EXPORT = False |
| 11 | + |
| 12 | + |
| 13 | +@contextlib.contextmanager |
| 14 | +def enable_code_export_control_flow(): |
| 15 | + """Enables the code meant to be exported.""" |
| 16 | + global _TEST_EXPORT |
| 17 | + old = _TEST_EXPORT |
| 18 | + _TEST_EXPORT = True |
| 19 | + try: |
| 20 | + yield |
| 21 | + finally: |
| 22 | + _TEST_EXPORT = old |
| 23 | + |
| 24 | + |
| 25 | +def is_exporting() -> bool: |
| 26 | + """ |
| 27 | + Returns :func:`torch.compiler.is_exporting` or |
| 28 | + :func:`torch.compiler.is_compiling`. |
| 29 | + Changes ``_TEST_EXPORT`` to make it trigger. |
| 30 | + """ |
| 31 | + return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling() |
| 32 | + |
| 33 | + |
| 34 | +def _loop_for_fn(n_iter, body_fn, reduction_dim, args): |
| 35 | + """ |
| 36 | + Python implementation of the loop. |
| 37 | +
|
| 38 | + :param n_iter: number of iteration |
| 39 | + :param body_fn: function implementing the body |
| 40 | + :param reduction_dim: dimension used to reduce the list produced by the loop |
| 41 | + :param args: arguments to the loop body |
| 42 | + :return: results |
| 43 | + """ |
| 44 | + res = [] |
| 45 | + for i in torch.arange(n_iter, dtype=n_iter.dtype): |
| 46 | + r = body_fn(i, *args) |
| 47 | + if isinstance(r, tuple): |
| 48 | + assert not res or len(r) == len(res[-1]), ( |
| 49 | + f"Unexpected number of results {len(r)} for function {body_fn}, " |
| 50 | + f"expected {len(res[-1])}" |
| 51 | + ) |
| 52 | + res.append(r) |
| 53 | + else: |
| 54 | + assert isinstance(r, torch.Tensor), ( |
| 55 | + f"Unexpected type {r} for function {body_fn}, " |
| 56 | + f"it must be a tuple or a Tensor." |
| 57 | + ) |
| 58 | + assert not res or len(res[-1]) == 1, ( |
| 59 | + f"Unexpected number of results {len(r)} for function {body_fn}, " |
| 60 | + f"expected {len(res[-1])}" |
| 61 | + ) |
| 62 | + res.append((r,)) |
| 63 | + |
| 64 | + if not res: |
| 65 | + return torch.empty(tuple(), dtype=torch.float32, device=args[0].device) |
| 66 | + if len(res) == 1: |
| 67 | + final = res[0] |
| 68 | + else: |
| 69 | + n_res = len(res[0]) |
| 70 | + final = [ |
| 71 | + torch.cat( |
| 72 | + [r[i] for r in res], |
| 73 | + dim=( |
| 74 | + 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i] |
| 75 | + ), |
| 76 | + ) |
| 77 | + for i in range(n_res) |
| 78 | + ] |
| 79 | + return tuple(final) if len(final) > 1 else final[0] |
| 80 | + |
| 81 | + |
| 82 | +def make_custom_loop_for( |
| 83 | + n_iter: torch.Tensor, |
| 84 | + body_fn: Callable, |
| 85 | + reduction_dim: Optional[Sequence[int]], |
| 86 | + args: Sequence[torch.Tensor], |
| 87 | + body_gm: Optional[torch.fx.GraphModule] = None, |
| 88 | + body_mutated_inputs: Optional[List[Any]] = None, |
| 89 | + body_outputs: Optional[List[Any]] = None, |
| 90 | +) -> Tuple[str, torch.library.CustomOpDef]: |
| 91 | + """ |
| 92 | + Defines a custom operator for a loop in order to avoid |
| 93 | + :func:`torch.export.export` digging into it. |
| 94 | + It registers the custom op and a custom conversion |
| 95 | + to ONNX. |
| 96 | +
|
| 97 | + :param n_iter: number of iterations defined by a tensor of no dimension |
| 98 | + :param body_fn: the loop body defined as a function |
| 99 | + :param reduction_dim: dimension used to concatenated the results |
| 100 | + :param args: list of tensors, input to the body |
| 101 | + :param body_gm: torch.fx.GraphModule equivalent to *body_gm* |
| 102 | + :param body_mutated_inputs: inputs to *body_gm* |
| 103 | + :param body_outputs: outputs to *body_gm* |
| 104 | + :return: a name and the custom op definition, the name |
| 105 | + is used to cache the custom op |
| 106 | + """ |
| 107 | + assert body_gm is not None, "body_gm cannot be None" |
| 108 | + assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None" |
| 109 | + assert body_outputs is not None, "body_outputs cannot be None" |
| 110 | + |
| 111 | + srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs) |
| 112 | + sred = "x".join(map(str, reduction_dim)) if reduction_dim else "" |
| 113 | + full_name = ( |
| 114 | + body_fn.__qualname__.replace("<locals>", "L") |
| 115 | + .replace("<lambda>", "l") |
| 116 | + .replace(".", "_") |
| 117 | + ) |
| 118 | + name = f"loop_for_onnx_{full_name}_{srank}_{sred}" |
| 119 | + |
| 120 | + schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor" |
| 121 | + if len(body_outputs) > 1: |
| 122 | + schema += "[]" |
| 123 | + custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn) |
| 124 | + custom_def.register_kernel("cpu")(body_fn) |
| 125 | + |
| 126 | + custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: ( |
| 127 | + tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0]) |
| 128 | + ) |
| 129 | + return name, custom_def |
| 130 | + |
| 131 | + |
| 132 | +def loop_for( |
| 133 | + n_iter: Union[torch.SymInt, torch.Tensor], |
| 134 | + body_fn: Callable[..., Tuple[torch.Tensor]], |
| 135 | + args: Sequence[torch.Tensor], |
| 136 | + reduction_dim: Optional[Sequence[int]] = None, |
| 137 | +) -> Tuple[torch.Tensor, ...]: |
| 138 | + """ |
| 139 | + High operators used to easily export a loop in ONNX. |
| 140 | + Does not fully work with :func:`torch.export.export`, |
| 141 | + it does replaces a custom op with a loop operator afterwards. |
| 142 | + Every iteration produces tensors, all of them are gathered |
| 143 | + into lists, all these lists are concatenated into tensors. |
| 144 | +
|
| 145 | + :param n_iter: number of iterations, it can be fixed on |
| 146 | + variable, in that case it should a tensor with no dimension |
| 147 | + :param body_fn: function body, takes only tensors and returns |
| 148 | + only tensors, the first argument is the iteration number |
| 149 | + in a tensor with no dimension, all the others |
| 150 | + are not changed during the loop |
| 151 | + :param args: the available tensors at every loop |
| 152 | + :param reduction_dim: the loop aggregated the results into list, |
| 153 | + one of each output, each of them is concatenated into one |
| 154 | + tensor along one dimension, by default, it is the first |
| 155 | + dimension, but it can be defined otherwise |
| 156 | + """ |
| 157 | + assert args, "The function should have at least one arg." |
| 158 | + assert ( |
| 159 | + isinstance(n_iter, torch.Tensor) |
| 160 | + and n_iter.dtype == torch.int64 |
| 161 | + and len(n_iter.shape) == 0 |
| 162 | + ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}." |
| 163 | + if is_exporting(): |
| 164 | + from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER |
| 165 | + |
| 166 | + # tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer |
| 167 | + root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root |
| 168 | + # graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph |
| 169 | + |
| 170 | + body_gm: torch.fx.GraphModule = materialize_as_graph( |
| 171 | + body_fn, (torch.tensor(0, dtype=torch.int64), *args) |
| 172 | + ) |
| 173 | + ( |
| 174 | + _1, |
| 175 | + _2, |
| 176 | + _3, |
| 177 | + body_mutated_inputs, |
| 178 | + body_outputs, |
| 179 | + ) = check_input_alias_and_mutation_return_outputs(body_gm) |
| 180 | + name, _custom_ops = make_custom_loop_for( |
| 181 | + n_iter, |
| 182 | + body_fn, |
| 183 | + reduction_dim, |
| 184 | + args, |
| 185 | + body_gm=body_gm, |
| 186 | + body_mutated_inputs=body_mutated_inputs, |
| 187 | + body_outputs=body_outputs, |
| 188 | + ) |
| 189 | + root.register_module(name, body_gm) |
| 190 | + # body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args) |
| 191 | + return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args) |
| 192 | + |
| 193 | + return _loop_for_fn(n_iter, body_fn, reduction_dim, args) |
| 194 | + |
| 195 | + |
| 196 | +""" |
| 197 | + proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph) |
| 198 | + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) |
| 199 | +
|
| 200 | + args = (cond_graph, body_graph, carried_inputs, additional_inputs) |
| 201 | +
|
| 202 | + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) |
| 203 | +
|
| 204 | + out_proxy = proxy_mode.tracer.create_proxy( |
| 205 | + "call_function", op, proxy_args, {}, name=op._name |
| 206 | + ) |
| 207 | +
|
| 208 | + out = op( |
| 209 | + cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs |
| 210 | + ) |
| 211 | + return track_tensor_tree( |
| 212 | + out, out_proxy, constant=None, tracer=proxy_mode.tracer |
| 213 | + ) |
| 214 | +""" |
0 commit comments