|
| 1 | +# Copyright 2021 The TensorFlow Probability Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================ |
| 15 | +"""Contains the `Einsum` expression and utilities. |
| 16 | +
|
| 17 | +The `Einsum` pattern is used for pattern matching and term rewriting in JAX. |
| 18 | +JAX does not have an underlying einsum primitive; a call to `jnp.einsum` turns |
| 19 | +into its component `dot_general`, `broadcast`, and `transpose` primitives and |
| 20 | +therefore einsums do not directly appear in JAXprs. Autoconj is based on |
| 21 | +rewriting expressions and combining expressions into large einsums and because |
| 22 | +JAX does not have a primitive representation, we need to create our own. |
| 23 | +
|
| 24 | +Along with the `Einsum` pattern we include utilities for manipulating |
| 25 | +einsums. For example, the `compose_einsums` function contains the logic for |
| 26 | +taking two nested einsums and combining them into a single one. These utilities |
| 27 | +are needed for systems such as |
| 28 | +[Autoconj](https://papers.nips.cc/paper/2018/hash/9b89bedda1fc8a2d88c448e361194f02-Abstract.html), |
| 29 | +which aim to create large, monolothic einsums in a program. The functions are |
| 30 | +based on their [implementations in |
| 31 | +Autoconj](https://github.com/google-research/autoconj/blob/master/autoconj/rewrites.py). |
| 32 | +""" |
| 33 | +import collections |
| 34 | +import dataclasses |
| 35 | +import functools |
| 36 | + |
| 37 | +from typing import Any, Dict, Iterator, List, Sequence, Tuple, Union |
| 38 | + |
| 39 | +import jax |
| 40 | +import jax.numpy as jnp |
| 41 | + |
| 42 | +from oryx.experimental.matching import jax_rewrite as jr |
| 43 | +from oryx.experimental.matching import matcher |
| 44 | + |
| 45 | +__all__ = [ |
| 46 | + 'compose_einsums', |
| 47 | + 'Einsum', |
| 48 | + 'einsum_letters', |
| 49 | +] |
| 50 | + |
| 51 | +Bindings = matcher.Bindings |
| 52 | +Continuation = matcher.Continuation |
| 53 | +Expr = matcher.Expr |
| 54 | +Pattern = matcher.Pattern |
| 55 | +Success = matcher.Success |
| 56 | + |
| 57 | +_EINSUM_RANGE = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' |
| 58 | + |
| 59 | + |
| 60 | +def einsum_letters() -> Iterator[str]: |
| 61 | + """Returns an iterator over valid einsum index names.""" |
| 62 | + yield from _EINSUM_RANGE |
| 63 | + |
| 64 | + |
| 65 | +@dataclasses.dataclass(frozen=True) |
| 66 | +class Einsum(jr.JaxExpression): |
| 67 | + """An expression that executes a JAX einsum on its operands. |
| 68 | +
|
| 69 | + JAX offers a `jax.numpy.einsum` function but is executed as a series of JAX |
| 70 | + primitive operations including `dot_general` and `broadcast`. This means that |
| 71 | + when a function with an `Einsum` is traced, the `Einsum` does not explicitly |
| 72 | + show up in the resulting JAXpr. For the purposes of term rewriting, we |
| 73 | + therefore need our own `Einsum` representation that can be constructed from |
| 74 | + JAX primitives using rewrite rules. |
| 75 | +
|
| 76 | + Attributes: |
| 77 | + formula: A string describing the `Einsum`'s operation. See the [NumPy |
| 78 | + documentation]( |
| 79 | + https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) for |
| 80 | + more information. |
| 81 | + operands: The inputs to the Einsum. |
| 82 | + """ |
| 83 | + formula: Union[Pattern, str] |
| 84 | + operands: Union[Pattern, Tuple[Any]] |
| 85 | + |
| 86 | + @functools.lru_cache(None) |
| 87 | + def shape_dtype(self) -> jax.ShapeDtypeStruct: |
| 88 | + """Computes the shape and dtype of the result of this `Einsum`. |
| 89 | +
|
| 90 | + This function traces the JAX execution and does not incur any FLOPs. To |
| 91 | + avoid retracing, however, we are safe to cache the result of this function |
| 92 | + because `Einsum`s are immutable. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + A `jax.ShapeDtypeStruct` object describing the shape and dtype of the |
| 96 | + `Einsum`. |
| 97 | + """ |
| 98 | + # We can trace the evaluation without incurring any FLOPs. |
| 99 | + operand_shape_dtypes = tuple( |
| 100 | + jax.ShapeDtypeStruct(operand.shape, operand.dtype) |
| 101 | + for operand in self.operands) |
| 102 | + |
| 103 | + def _eval_fun(*args): |
| 104 | + return jnp.einsum(self.formula, *args) |
| 105 | + |
| 106 | + return jax.eval_shape(_eval_fun, *operand_shape_dtypes) |
| 107 | + |
| 108 | + @property |
| 109 | + def shape(self) -> Tuple[int]: |
| 110 | + return self.shape_dtype().shape |
| 111 | + |
| 112 | + @property |
| 113 | + def dtype(self) -> jnp.dtype: |
| 114 | + return self.shape_dtype().dtype |
| 115 | + |
| 116 | + # Matching methods |
| 117 | + |
| 118 | + def match(self, expr: Expr, bindings: Bindings, |
| 119 | + succeed: Continuation) -> Success: |
| 120 | + """Matches the formula and operands of an `Einsum`.""" |
| 121 | + if not isinstance(expr, Einsum): |
| 122 | + return |
| 123 | + yield from matcher.matcher((self.operands, self.formula))( |
| 124 | + (expr.operands, expr.formula), bindings, succeed) |
| 125 | + |
| 126 | + # Rules methods |
| 127 | + |
| 128 | + def tree_map(self, fn) -> 'Einsum': |
| 129 | + """Maps a function across the formula and operands of an `Einsum`.""" |
| 130 | + return Einsum(self.formula, tuple(map(fn, self.operands))) |
| 131 | + |
| 132 | + def tree_children(self) -> Iterator[Any]: |
| 133 | + """Returns an iterator over the operands of an `Einsum`.""" |
| 134 | + yield from self.operands |
| 135 | + |
| 136 | + # JAX rewriting methods |
| 137 | + |
| 138 | + def evaluate(self, env: Dict[str, Any]) -> Any: |
| 139 | + """Evaluates an `Einsum` in an environment.""" |
| 140 | + operands = jr.evaluate(self.operands, env) |
| 141 | + return jnp.einsum(self.formula, *operands) |
| 142 | + |
| 143 | + # Builtin methods |
| 144 | + |
| 145 | + def __str__(self) -> str: |
| 146 | + return f'(einsum[{self.formula}] {" ".join(map(str, self.operands))})' |
| 147 | + |
| 148 | + |
| 149 | +def split_einsum_formula(formula: str) -> Tuple[List[str], str]: |
| 150 | + """Splits an einsum formula string into its component axis names.""" |
| 151 | + input_formula, output_formula = formula.split('->') |
| 152 | + return input_formula.split(','), output_formula |
| 153 | + |
| 154 | + |
| 155 | +def reconstitute_einsum_formula(input_formulas: Sequence[str], |
| 156 | + output_formula: str) -> str: |
| 157 | + """Joins einsum input formulas and output formula into a complete formula.""" |
| 158 | + joined_input_formula = ','.join(input_formulas) |
| 159 | + return f'{joined_input_formula}->{output_formula}' |
| 160 | + |
| 161 | + |
| 162 | +def compose_einsums(parent_formula: str, left_args: Tuple[Any], |
| 163 | + child_einsum: Einsum, right_args: Tuple[Any]) -> Einsum: |
| 164 | + """Combines nested einsums into a single einsum. |
| 165 | +
|
| 166 | + Einsums are linear functions and thus the composition of two (or more) einsums |
| 167 | + can be represented as a single one. Composed einsums often come up during the |
| 168 | + term rewriting phase of Autoconj, where a series of linear operations (a |
| 169 | + matrix multiplication followed by a transpose, for example) need to be |
| 170 | + folded together into a single einsum in order to represent a function as a |
| 171 | + sum of einsums. This function takes a composition of einsums (an einsum with |
| 172 | + an einsum as one its arguments) and returns a flattened, single einsum. |
| 173 | +
|
| 174 | + As an example use-case, suppose we have matrices `w, x, y` and `z` along with |
| 175 | + the following `Einsum`s: |
| 176 | + ```python |
| 177 | + child_op = Einsum('ab,bc->ac', (x, y)) |
| 178 | + parent_op = Einsum('ab,bc,cd->ad', (w, child_op, z)) |
| 179 | + ``` |
| 180 | +
|
| 181 | + These two operations can be combined to form the single `Einsum` |
| 182 | + ```python |
| 183 | + combined_op = Einsum('ab,bc,cd,de->ae', (w, x, y, z)) |
| 184 | + ``` |
| 185 | +
|
| 186 | + Implementation based on `_compose_einsums` in |
| 187 | + [Autoconj](https://github.com/google-research/autoconj/blob/master/autoconj/rewrites.py). |
| 188 | +
|
| 189 | + Args: |
| 190 | + parent_formula: The formula of the parent einsum. |
| 191 | + left_args: The sequence of arguments to the left of the child einsum. |
| 192 | + child_einsum: An `Einsum` that is an argument in the `parent_formula`. |
| 193 | + right_args: The sequence of arguments to the right of the child einsum. |
| 194 | +
|
| 195 | + Returns: |
| 196 | + A single un-nested `Einsum` that computes the same quantity as the nested |
| 197 | + einsums. |
| 198 | + """ |
| 199 | + parent_in_formulas, parent_out_formula = split_einsum_formula(parent_formula) |
| 200 | + child_formula, child_args = child_einsum.formula, child_einsum.operands |
| 201 | + child_in_formulas, child_out_formula = split_einsum_formula(child_formula) |
| 202 | + num_left = len(left_args) |
| 203 | + # Number of output dimensions of child einsum should match number of |
| 204 | + # dimensions in parent einsum. |
| 205 | + if len(child_out_formula) != len(parent_in_formulas[num_left]): |
| 206 | + raise ValueError(f'Child output formula {child_out_formula} and ' |
| 207 | + f'parent formula {parent_in_formulas[num_left]} have' |
| 208 | + ' inconsistent size.') |
| 209 | + str_iterator = einsum_letters() |
| 210 | + # Creates a dictionary where each time we access a new element, we generate |
| 211 | + # a new letter. |
| 212 | + subs_map = collections.defaultdict(lambda: next(str_iterator)) |
| 213 | + # Splices out the old input formula |
| 214 | + old_in_formula = parent_in_formulas[num_left] |
| 215 | + parent_in_formulas = ( |
| 216 | + parent_in_formulas[:num_left] + parent_in_formulas[num_left + 1:]) |
| 217 | + # Canonicalizes input and output formulas (optional, for cleanliness) |
| 218 | + parent_in_formulas = [ |
| 219 | + ''.join(subs_map[idx] for idx in subs) for subs in parent_in_formulas |
| 220 | + ] |
| 221 | + out_formula = ''.join(subs_map[idx] for idx in parent_out_formula) |
| 222 | + # Maps child output indices with corresponding parent indices |
| 223 | + subs_map.update((pidx + '_child', subs_map[idx]) |
| 224 | + for pidx, idx in zip(child_out_formula, old_in_formula)) |
| 225 | + # Updates the child input formulas to use parent mappings |
| 226 | + child_in_formulas = [ |
| 227 | + ''.join(subs_map[idx + '_child'] |
| 228 | + for idx in subs) |
| 229 | + for subs in child_in_formulas |
| 230 | + ] |
| 231 | + # Concatenates the formulas and arguments |
| 232 | + new_in_formulas = ( |
| 233 | + parent_in_formulas[:num_left] + child_in_formulas + |
| 234 | + parent_in_formulas[num_left:]) |
| 235 | + new_args = left_args + child_args + right_args |
| 236 | + new_formula = reconstitute_einsum_formula(new_in_formulas, out_formula) |
| 237 | + return Einsum(new_formula, tuple(new_args)) |
0 commit comments