Skip to content

Commit 34a11a2

Browse files
authored
Merge branch 'tensorflow:main' into frighterafix#1384
2 parents 917ba41 + d85f921 commit 34a11a2

38 files changed

+1017
-2278
lines changed

discussion/lazybones.pdf

-134 KB
Binary file not shown.

discussion/lazybones.py

Lines changed: 0 additions & 493 deletions
This file was deleted.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
# [internal] load pytype.bzl (pytype_strict_library)
16+
# [internal] load strict.bzl
17+
18+
licenses(["notice"])
19+
20+
package(default_visibility = ["//visibility:public"])
21+
22+
# pytype_strict
23+
py_library(
24+
name = "einsum",
25+
srcs = ["einsum.py"],
26+
srcs_version = "PY3",
27+
deps = [
28+
# jax dep,
29+
"//oryx/experimental/matching:jax_rewrite",
30+
"//oryx/experimental/matching:matcher",
31+
],
32+
)
33+
34+
# py_strict
35+
py_test(
36+
name = "einsum_test",
37+
srcs = ["einsum_test.py"],
38+
python_version = "PY3",
39+
srcs_version = "PY3",
40+
deps = [
41+
":einsum",
42+
# absl/testing:absltest dep,
43+
# jax dep,
44+
# numpy dep,
45+
"//oryx/experimental/matching:jax_rewrite",
46+
"//oryx/experimental/matching:matcher",
47+
"//oryx/experimental/matching:rules",
48+
"//oryx/internal:test_util",
49+
],
50+
)
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Contains the `Einsum` expression and utilities.
16+
17+
The `Einsum` pattern is used for pattern matching and term rewriting in JAX.
18+
JAX does not have an underlying einsum primitive; a call to `jnp.einsum` turns
19+
into its component `dot_general`, `broadcast`, and `transpose` primitives and
20+
therefore einsums do not directly appear in JAXprs. Autoconj is based on
21+
rewriting expressions and combining expressions into large einsums and because
22+
JAX does not have a primitive representation, we need to create our own.
23+
24+
Along with the `Einsum` pattern we include utilities for manipulating
25+
einsums. For example, the `compose_einsums` function contains the logic for
26+
taking two nested einsums and combining them into a single one. These utilities
27+
are needed for systems such as
28+
[Autoconj](https://papers.nips.cc/paper/2018/hash/9b89bedda1fc8a2d88c448e361194f02-Abstract.html),
29+
which aim to create large, monolothic einsums in a program. The functions are
30+
based on their [implementations in
31+
Autoconj](https://github.com/google-research/autoconj/blob/master/autoconj/rewrites.py).
32+
"""
33+
import collections
34+
import dataclasses
35+
import functools
36+
37+
from typing import Any, Dict, Iterator, List, Sequence, Tuple, Union
38+
39+
import jax
40+
import jax.numpy as jnp
41+
42+
from oryx.experimental.matching import jax_rewrite as jr
43+
from oryx.experimental.matching import matcher
44+
45+
__all__ = [
46+
'compose_einsums',
47+
'Einsum',
48+
'einsum_letters',
49+
]
50+
51+
Bindings = matcher.Bindings
52+
Continuation = matcher.Continuation
53+
Expr = matcher.Expr
54+
Pattern = matcher.Pattern
55+
Success = matcher.Success
56+
57+
_EINSUM_RANGE = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
58+
59+
60+
def einsum_letters() -> Iterator[str]:
61+
"""Returns an iterator over valid einsum index names."""
62+
yield from _EINSUM_RANGE
63+
64+
65+
@dataclasses.dataclass(frozen=True)
66+
class Einsum(jr.JaxExpression):
67+
"""An expression that executes a JAX einsum on its operands.
68+
69+
JAX offers a `jax.numpy.einsum` function but is executed as a series of JAX
70+
primitive operations including `dot_general` and `broadcast`. This means that
71+
when a function with an `Einsum` is traced, the `Einsum` does not explicitly
72+
show up in the resulting JAXpr. For the purposes of term rewriting, we
73+
therefore need our own `Einsum` representation that can be constructed from
74+
JAX primitives using rewrite rules.
75+
76+
Attributes:
77+
formula: A string describing the `Einsum`'s operation. See the [NumPy
78+
documentation](
79+
https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) for
80+
more information.
81+
operands: The inputs to the Einsum.
82+
"""
83+
formula: Union[Pattern, str]
84+
operands: Union[Pattern, Tuple[Any]]
85+
86+
@functools.lru_cache(None)
87+
def shape_dtype(self) -> jax.ShapeDtypeStruct:
88+
"""Computes the shape and dtype of the result of this `Einsum`.
89+
90+
This function traces the JAX execution and does not incur any FLOPs. To
91+
avoid retracing, however, we are safe to cache the result of this function
92+
because `Einsum`s are immutable.
93+
94+
Returns:
95+
A `jax.ShapeDtypeStruct` object describing the shape and dtype of the
96+
`Einsum`.
97+
"""
98+
# We can trace the evaluation without incurring any FLOPs.
99+
operand_shape_dtypes = tuple(
100+
jax.ShapeDtypeStruct(operand.shape, operand.dtype)
101+
for operand in self.operands)
102+
103+
def _eval_fun(*args):
104+
return jnp.einsum(self.formula, *args)
105+
106+
return jax.eval_shape(_eval_fun, *operand_shape_dtypes)
107+
108+
@property
109+
def shape(self) -> Tuple[int]:
110+
return self.shape_dtype().shape
111+
112+
@property
113+
def dtype(self) -> jnp.dtype:
114+
return self.shape_dtype().dtype
115+
116+
# Matching methods
117+
118+
def match(self, expr: Expr, bindings: Bindings,
119+
succeed: Continuation) -> Success:
120+
"""Matches the formula and operands of an `Einsum`."""
121+
if not isinstance(expr, Einsum):
122+
return
123+
yield from matcher.matcher((self.operands, self.formula))(
124+
(expr.operands, expr.formula), bindings, succeed)
125+
126+
# Rules methods
127+
128+
def tree_map(self, fn) -> 'Einsum':
129+
"""Maps a function across the formula and operands of an `Einsum`."""
130+
return Einsum(self.formula, tuple(map(fn, self.operands)))
131+
132+
def tree_children(self) -> Iterator[Any]:
133+
"""Returns an iterator over the operands of an `Einsum`."""
134+
yield from self.operands
135+
136+
# JAX rewriting methods
137+
138+
def evaluate(self, env: Dict[str, Any]) -> Any:
139+
"""Evaluates an `Einsum` in an environment."""
140+
operands = jr.evaluate(self.operands, env)
141+
return jnp.einsum(self.formula, *operands)
142+
143+
# Builtin methods
144+
145+
def __str__(self) -> str:
146+
return f'(einsum[{self.formula}] {" ".join(map(str, self.operands))})'
147+
148+
149+
def split_einsum_formula(formula: str) -> Tuple[List[str], str]:
150+
"""Splits an einsum formula string into its component axis names."""
151+
input_formula, output_formula = formula.split('->')
152+
return input_formula.split(','), output_formula
153+
154+
155+
def reconstitute_einsum_formula(input_formulas: Sequence[str],
156+
output_formula: str) -> str:
157+
"""Joins einsum input formulas and output formula into a complete formula."""
158+
joined_input_formula = ','.join(input_formulas)
159+
return f'{joined_input_formula}->{output_formula}'
160+
161+
162+
def compose_einsums(parent_formula: str, left_args: Tuple[Any],
163+
child_einsum: Einsum, right_args: Tuple[Any]) -> Einsum:
164+
"""Combines nested einsums into a single einsum.
165+
166+
Einsums are linear functions and thus the composition of two (or more) einsums
167+
can be represented as a single one. Composed einsums often come up during the
168+
term rewriting phase of Autoconj, where a series of linear operations (a
169+
matrix multiplication followed by a transpose, for example) need to be
170+
folded together into a single einsum in order to represent a function as a
171+
sum of einsums. This function takes a composition of einsums (an einsum with
172+
an einsum as one its arguments) and returns a flattened, single einsum.
173+
174+
As an example use-case, suppose we have matrices `w, x, y` and `z` along with
175+
the following `Einsum`s:
176+
```python
177+
child_op = Einsum('ab,bc->ac', (x, y))
178+
parent_op = Einsum('ab,bc,cd->ad', (w, child_op, z))
179+
```
180+
181+
These two operations can be combined to form the single `Einsum`
182+
```python
183+
combined_op = Einsum('ab,bc,cd,de->ae', (w, x, y, z))
184+
```
185+
186+
Implementation based on `_compose_einsums` in
187+
[Autoconj](https://github.com/google-research/autoconj/blob/master/autoconj/rewrites.py).
188+
189+
Args:
190+
parent_formula: The formula of the parent einsum.
191+
left_args: The sequence of arguments to the left of the child einsum.
192+
child_einsum: An `Einsum` that is an argument in the `parent_formula`.
193+
right_args: The sequence of arguments to the right of the child einsum.
194+
195+
Returns:
196+
A single un-nested `Einsum` that computes the same quantity as the nested
197+
einsums.
198+
"""
199+
parent_in_formulas, parent_out_formula = split_einsum_formula(parent_formula)
200+
child_formula, child_args = child_einsum.formula, child_einsum.operands
201+
child_in_formulas, child_out_formula = split_einsum_formula(child_formula)
202+
num_left = len(left_args)
203+
# Number of output dimensions of child einsum should match number of
204+
# dimensions in parent einsum.
205+
if len(child_out_formula) != len(parent_in_formulas[num_left]):
206+
raise ValueError(f'Child output formula {child_out_formula} and '
207+
f'parent formula {parent_in_formulas[num_left]} have'
208+
' inconsistent size.')
209+
str_iterator = einsum_letters()
210+
# Creates a dictionary where each time we access a new element, we generate
211+
# a new letter.
212+
subs_map = collections.defaultdict(lambda: next(str_iterator))
213+
# Splices out the old input formula
214+
old_in_formula = parent_in_formulas[num_left]
215+
parent_in_formulas = (
216+
parent_in_formulas[:num_left] + parent_in_formulas[num_left + 1:])
217+
# Canonicalizes input and output formulas (optional, for cleanliness)
218+
parent_in_formulas = [
219+
''.join(subs_map[idx] for idx in subs) for subs in parent_in_formulas
220+
]
221+
out_formula = ''.join(subs_map[idx] for idx in parent_out_formula)
222+
# Maps child output indices with corresponding parent indices
223+
subs_map.update((pidx + '_child', subs_map[idx])
224+
for pidx, idx in zip(child_out_formula, old_in_formula))
225+
# Updates the child input formulas to use parent mappings
226+
child_in_formulas = [
227+
''.join(subs_map[idx + '_child']
228+
for idx in subs)
229+
for subs in child_in_formulas
230+
]
231+
# Concatenates the formulas and arguments
232+
new_in_formulas = (
233+
parent_in_formulas[:num_left] + child_in_formulas +
234+
parent_in_formulas[num_left:])
235+
new_args = left_args + child_args + right_args
236+
new_formula = reconstitute_einsum_formula(new_in_formulas, out_formula)
237+
return Einsum(new_formula, tuple(new_args))
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Tests for tensorflow_probability.spinoffs.oryx.experimental.autoconj.einsum."""
16+
from absl.testing import absltest
17+
18+
from jax import random
19+
import jax.numpy as jnp
20+
21+
import numpy as np
22+
23+
from oryx.experimental.autoconj import einsum
24+
from oryx.experimental.matching import jax_rewrite as jr
25+
from oryx.experimental.matching import matcher
26+
from oryx.experimental.matching import rules
27+
from oryx.internal import test_util
28+
29+
Var = matcher.Var
30+
Segment = matcher.Segment
31+
JaxVar = jr.JaxVar
32+
Einsum = einsum.Einsum
33+
34+
35+
class EinsumTest(test_util.TestCase):
36+
37+
def test_can_match_einsum_components(self):
38+
x = JaxVar('x', (5,), jnp.float32)
39+
op = Einsum('a,a->', (x, x))
40+
pattern = Einsum(Var('formula'), (matcher.Segment('args'),))
41+
self.assertDictEqual(
42+
matcher.match(pattern, op), {
43+
'formula': 'a,a->',
44+
'args': (x, x)
45+
})
46+
47+
def test_can_replace_einsum_operands(self):
48+
x = JaxVar('x', (5,), jnp.float32)
49+
y = JaxVar('y', (5,), jnp.float32)
50+
z = JaxVar('y', (5,), jnp.float32)
51+
op = Einsum('a,a->', (x, y))
52+
pattern = Einsum(Var('formula'), (matcher.Segment('args'),))
53+
def replace_with_z(formula, args):
54+
del args
55+
return Einsum(formula, (z, z))
56+
replace_rule = rules.make_rule(pattern, replace_with_z)
57+
replaced_op = replace_rule(op)
58+
self.assertEqual(replaced_op, Einsum('a,a->', (z, z)))
59+
60+
def test_einsum_correctly_infers_shape_and_dtype(self):
61+
x = JaxVar('x', (5, 2), jnp.float32)
62+
y = JaxVar('y', (2, 3), jnp.float32)
63+
op = Einsum('ab,bc->ac', (x, y))
64+
self.assertEqual(op.dtype, jnp.float32)
65+
self.assertTupleEqual(op.shape, (5, 3))
66+
67+
def test_einsum_evaluates_to_correct_value(self):
68+
x = JaxVar('x', (5, 2), jnp.float32)
69+
y = JaxVar('y', (2, 3), jnp.float32)
70+
op = Einsum('ab,bc->ac', (x, y))
71+
x_val = jnp.arange(10.).reshape((5, 2))
72+
y_val = jnp.arange(6.).reshape((2, 3))
73+
np.testing.assert_allclose(
74+
op.evaluate(dict(x=x_val, y=y_val)),
75+
jnp.einsum('ab,bc->ac', x_val, y_val))
76+
77+
78+
class EinsumOperationsTest(test_util.TestCase):
79+
80+
def test_can_compose_nested_einsums_to_make_single_einsum(self):
81+
w = JaxVar('w', (4, 5), jnp.float32)
82+
x = JaxVar('x', (5, 2), jnp.float32)
83+
y = JaxVar('y', (2, 3), jnp.float32)
84+
z = JaxVar('z', (3, 1), jnp.float32)
85+
86+
child_op = Einsum('ab,bc->ac', (x, y))
87+
parent_op = Einsum('ab,bc,cd->ad', (w, child_op, z))
88+
89+
single_op = einsum.compose_einsums(parent_op.formula, (w,), child_op, (z,))
90+
self.assertEqual(single_op.dtype, parent_op.dtype)
91+
self.assertTupleEqual(single_op.shape, parent_op.shape)
92+
93+
keys = random.split(random.PRNGKey(0), 4)
94+
95+
env = {
96+
a.name: random.normal(key, a.shape, dtype=a.dtype)
97+
for key, a in zip(keys, [w, x, y, z])
98+
}
99+
np.testing.assert_allclose(parent_op.evaluate(env), single_op.evaluate(env),
100+
rtol=1e-6, atol=1e-6)
101+
102+
103+
if __name__ == '__main__':
104+
absltest.main()

0 commit comments

Comments
 (0)