Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit acf6247

Browse files
author
Mesh TensorFlow Team
committed
Internal
PiperOrigin-RevId: 395547787
1 parent 69bd9c7 commit acf6247

File tree

4 files changed

+99
-3
lines changed

4 files changed

+99
-3
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Install dependencies
1515
run: |
1616
pip install tf-nightly mock pytest
17-
pip install -e .[auto_mtf,transformer]
17+
pip install -e .[test,auto_mtf,transformer]
1818
- name: Test with pytest
1919
run: pytest
2020
# The below step just reports the success or failure of tests as a "commit status".

mesh_tensorflow/layers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,39 @@ def conv3d_transpose_with_blocks(
947947
variable_dtype, name)
948948

949949

950+
def corr(x, dim, epsilon=1e-20, name="pearson_correlation"):
951+
"""Compute correlation along dimension dim, equiv to tfp.stats.correlation.
952+
953+
It treats the dim Dimension as the random event axis, and all the other dims
954+
as the sample axis. Pearson correlation is computed between random events in
955+
dim Dimension, and marginalized over the other dims.
956+
957+
Example usage:
958+
inputs = tf.random_normal([batch, channels])
959+
mtf_inputs = mtf.import_tf_tensor(
960+
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
961+
correlation = corr(mtf_inputs, dim=channels_dim)
962+
963+
Args:
964+
x: a mtf.Tensor whose shape contains dim.
965+
dim: a mtf.Dimension.
966+
epsilon: a small floating point number for numerical stability.
967+
name: a string used for tf.variable_scope.
968+
969+
Returns:
970+
a mtf.Tensor with the shape of [dim, dim].
971+
"""
972+
with tf.variable_scope(name):
973+
mean = mtf.reduce_mean(x, output_shape=[dim])
974+
dim_name = dim.name
975+
x1 = mtf.rename_dimension(x - mean, dim_name, f"{dim_name}_1")
976+
x2 = mtf.rename_dimension(x - mean, dim_name, f"{dim_name}_2")
977+
variance = lambda z: mtf.sqrt( # pylint: disable=g-long-lambda
978+
mtf.reduce_sum(mtf.square(z), output_shape=z.shape.dims[-1:])) + epsilon
979+
v1, v2 = variance(x1), variance(x2)
980+
return mtf.matmul(x1, x2) / mtf.matmul(v1, v2)
981+
982+
950983
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
951984
"""Layer normalization over dimension dim.
952985

mesh_tensorflow/layers_test.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
from __future__ import print_function
2121

2222
from absl.testing import parameterized
23-
2423
import mesh_tensorflow as mtf
2524
from mesh_tensorflow import test_utils
2625
import mock
2726
import numpy as np
28-
2927
import tensorflow.compat.v1 as tf
28+
import tensorflow_probability as tfp
29+
3030
from tensorflow.python.framework import test_util # pylint:disable=g-direct-tensorflow-import
3131

3232

@@ -85,6 +85,68 @@ def testDense(self, units, use_bias, new_dim_name):
8585

8686
self.assertEqual(actual.shape, expected.shape)
8787

88+
@test_util.run_in_graph_and_eager_modes()
89+
def testCorr2DInput(self):
90+
batch = 4
91+
channels = 3
92+
inputs = tf.random_normal([batch, channels])
93+
94+
graph = mtf.Graph()
95+
mesh = mtf.Mesh(graph, "my_mesh")
96+
batch_dim = mtf.Dimension("batch", batch)
97+
channels_dim = mtf.Dimension("channels", channels)
98+
99+
mtf_inputs = mtf.import_tf_tensor(
100+
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
101+
mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim)
102+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
103+
shape=[], layout={}, devices=[""])
104+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
105+
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
106+
107+
expected_outputs = tfp.stats.correlation(
108+
inputs, sample_axis=0, event_axis=1)
109+
tf_group = lowering.copy_masters_to_slices()
110+
init = tf.global_variables_initializer()
111+
self.evaluate(init)
112+
self.evaluate(tf_group)
113+
actual, expected = self.evaluate([actual_outputs, expected_outputs])
114+
115+
self.assertEqual(actual.shape, expected.shape)
116+
self.assertAllClose(actual, expected)
117+
118+
@test_util.run_in_graph_and_eager_modes()
119+
def testCorr3DInput(self):
120+
batch = 4
121+
sequence = 5
122+
channels = 3
123+
inputs = tf.random_normal([batch, sequence, channels])
124+
125+
graph = mtf.Graph()
126+
mesh = mtf.Mesh(graph, "my_mesh")
127+
batch_dim = mtf.Dimension("batch", batch)
128+
seq_dim = mtf.Dimension("seq", sequence)
129+
channels_dim = mtf.Dimension("channels", channels)
130+
131+
mtf_inputs = mtf.import_tf_tensor(
132+
mesh, inputs, shape=mtf.Shape([batch_dim, seq_dim, channels_dim]))
133+
mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim)
134+
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
135+
shape=[], layout={}, devices=[""])
136+
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
137+
actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
138+
139+
expected_outputs = tfp.stats.correlation(
140+
inputs, sample_axis=[0, 1], event_axis=2)
141+
tf_group = lowering.copy_masters_to_slices()
142+
init = tf.global_variables_initializer()
143+
self.evaluate(init)
144+
self.evaluate(tf_group)
145+
actual, expected = self.evaluate([actual_outputs, expected_outputs])
146+
147+
self.assertEqual(actual.shape, expected.shape)
148+
self.assertAllClose(actual, expected)
149+
88150
@test_util.run_in_graph_and_eager_modes()
89151
def testLayerNorm(self):
90152
batch = 2

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'auto_mtf': ['ortools'],
2828
'tensorflow': ['tensorflow>=1.15.0'],
2929
'transformer': ['tensorflow-datasets', 'scipy'],
30+
'test': ['tensorflow_probability']
3031
},
3132
tests_require=[
3233
'ortools',

0 commit comments

Comments
 (0)