|
20 | 20 | from __future__ import print_function |
21 | 21 |
|
22 | 22 | from absl.testing import parameterized |
23 | | - |
24 | 23 | import mesh_tensorflow as mtf |
25 | 24 | from mesh_tensorflow import test_utils |
26 | 25 | import mock |
27 | 26 | import numpy as np |
28 | | - |
29 | 27 | import tensorflow.compat.v1 as tf |
| 28 | +import tensorflow_probability as tfp |
| 29 | + |
30 | 30 | from tensorflow.python.framework import test_util # pylint:disable=g-direct-tensorflow-import |
31 | 31 |
|
32 | 32 |
|
@@ -85,6 +85,68 @@ def testDense(self, units, use_bias, new_dim_name): |
85 | 85 |
|
86 | 86 | self.assertEqual(actual.shape, expected.shape) |
87 | 87 |
|
| 88 | + @test_util.run_in_graph_and_eager_modes() |
| 89 | + def testCorr2DInput(self): |
| 90 | + batch = 4 |
| 91 | + channels = 3 |
| 92 | + inputs = tf.random_normal([batch, channels]) |
| 93 | + |
| 94 | + graph = mtf.Graph() |
| 95 | + mesh = mtf.Mesh(graph, "my_mesh") |
| 96 | + batch_dim = mtf.Dimension("batch", batch) |
| 97 | + channels_dim = mtf.Dimension("channels", channels) |
| 98 | + |
| 99 | + mtf_inputs = mtf.import_tf_tensor( |
| 100 | + mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) |
| 101 | + mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim) |
| 102 | + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( |
| 103 | + shape=[], layout={}, devices=[""]) |
| 104 | + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) |
| 105 | + actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) |
| 106 | + |
| 107 | + expected_outputs = tfp.stats.correlation( |
| 108 | + inputs, sample_axis=0, event_axis=1) |
| 109 | + tf_group = lowering.copy_masters_to_slices() |
| 110 | + init = tf.global_variables_initializer() |
| 111 | + self.evaluate(init) |
| 112 | + self.evaluate(tf_group) |
| 113 | + actual, expected = self.evaluate([actual_outputs, expected_outputs]) |
| 114 | + |
| 115 | + self.assertEqual(actual.shape, expected.shape) |
| 116 | + self.assertAllClose(actual, expected) |
| 117 | + |
| 118 | + @test_util.run_in_graph_and_eager_modes() |
| 119 | + def testCorr3DInput(self): |
| 120 | + batch = 4 |
| 121 | + sequence = 5 |
| 122 | + channels = 3 |
| 123 | + inputs = tf.random_normal([batch, sequence, channels]) |
| 124 | + |
| 125 | + graph = mtf.Graph() |
| 126 | + mesh = mtf.Mesh(graph, "my_mesh") |
| 127 | + batch_dim = mtf.Dimension("batch", batch) |
| 128 | + seq_dim = mtf.Dimension("seq", sequence) |
| 129 | + channels_dim = mtf.Dimension("channels", channels) |
| 130 | + |
| 131 | + mtf_inputs = mtf.import_tf_tensor( |
| 132 | + mesh, inputs, shape=mtf.Shape([batch_dim, seq_dim, channels_dim])) |
| 133 | + mtf_outputs = mtf.layers.corr(mtf_inputs, dim=channels_dim) |
| 134 | + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( |
| 135 | + shape=[], layout={}, devices=[""]) |
| 136 | + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) |
| 137 | + actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) |
| 138 | + |
| 139 | + expected_outputs = tfp.stats.correlation( |
| 140 | + inputs, sample_axis=[0, 1], event_axis=2) |
| 141 | + tf_group = lowering.copy_masters_to_slices() |
| 142 | + init = tf.global_variables_initializer() |
| 143 | + self.evaluate(init) |
| 144 | + self.evaluate(tf_group) |
| 145 | + actual, expected = self.evaluate([actual_outputs, expected_outputs]) |
| 146 | + |
| 147 | + self.assertEqual(actual.shape, expected.shape) |
| 148 | + self.assertAllClose(actual, expected) |
| 149 | + |
88 | 150 | @test_util.run_in_graph_and_eager_modes() |
89 | 151 | def testLayerNorm(self): |
90 | 152 | batch = 2 |
|
0 commit comments