|
| 1 | +# pytest: skip |
| 2 | +import os |
| 3 | +import sys |
| 4 | + |
| 5 | +from absl.testing import parameterized |
| 6 | +from tensorflow.python.distribute import multi_process_lib |
| 7 | +import multiprocessing |
| 8 | +import tensorflow as tf |
| 9 | +from tensorflow.python.framework import constant_op |
| 10 | + |
| 11 | +from tensorflow.python.training import server_lib |
| 12 | + |
| 13 | +from tensorflow_recommenders_addons import dynamic_embedding as de |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +from tensorflow.python.compat import v2_compat |
| 17 | +from tensorflow.python.distribute import multi_process_runner |
| 18 | +from tensorflow.python.distribute import multi_worker_test_base |
| 19 | +from tensorflow.python.distribute import parameter_server_strategy_v2 |
| 20 | +from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib |
| 21 | + |
| 22 | +from tensorflow.python.eager import test |
| 23 | +from packaging import version |
| 24 | +from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib |
| 25 | +from tensorflow.python.eager import def_function |
| 26 | +from tensorflow.python.ops import variables |
| 27 | + |
| 28 | +if version.parse(tf.__version__) >= version.parse("2.16"): |
| 29 | + from tf_keras import layers |
| 30 | + from tf_keras import Sequential |
| 31 | + from tf_keras.optimizers import Adam |
| 32 | +else: |
| 33 | + from tensorflow.python.keras import layers |
| 34 | + from tensorflow.python.keras import Sequential |
| 35 | + try: |
| 36 | + from tensorflow.keras.optimizers import Adam |
| 37 | + except: |
| 38 | + from tensorflow.keras.optimizers.legacy import Adam |
| 39 | + |
| 40 | + |
| 41 | +def create_multi_process_cluster(cluster_spec, |
| 42 | + rpc_layer='grpc', |
| 43 | + stream_output=False, |
| 44 | + collective_leader=None): |
| 45 | + |
| 46 | + cluster = multi_worker_test_base.MultiProcessCluster( |
| 47 | + cluster_resolver_lib.SimpleClusterResolver( |
| 48 | + server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer), |
| 49 | + stream_output=stream_output, |
| 50 | + collective_leader=collective_leader) |
| 51 | + cluster.start() |
| 52 | + return cluster |
| 53 | + |
| 54 | + |
| 55 | +class ParameterServerStrategyV2Test(test.TestCase): |
| 56 | + |
| 57 | + @classmethod |
| 58 | + def setUpClass(cls): |
| 59 | + super(ParameterServerStrategyV2Test, cls).setUpClass() |
| 60 | + cluster_spec = { |
| 61 | + "worker": ["localhost:2223", "localhost:2224"], |
| 62 | + "ps": ["localhost:2222"] |
| 63 | + } |
| 64 | + cls.cluster = create_multi_process_cluster(cluster_spec) |
| 65 | + cls.cluster_resolver = cls.cluster.cluster_resolver |
| 66 | + # cls.strategy = DEParameterServerStrategy(cls.cluster_resolver) |
| 67 | + cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( |
| 68 | + cls.cluster_resolver) |
| 69 | + cls.coordinator = coordinator_lib.ClusterCoordinator(cls.strategy) |
| 70 | + |
| 71 | + @classmethod |
| 72 | + def tearDownClass(cls): |
| 73 | + super(ParameterServerStrategyV2Test, cls).tearDownClass() |
| 74 | + cls.cluster.stop() |
| 75 | + |
| 76 | + #@parameterized.parameters(True, False) |
| 77 | + def testPerWorkerVariableCreation(self): |
| 78 | + var_dtype = tf.dtypes.float32 |
| 79 | + var_name = 'var' |
| 80 | + shape = [1] #if define_shape else None |
| 81 | + |
| 82 | + # with self.strategy.scope(): |
| 83 | + var = variables.Variable(initial_value=[0.0], |
| 84 | + shape=shape, |
| 85 | + dtype=var_dtype, |
| 86 | + name=var_name, |
| 87 | + per_worker_de_variable=True) |
| 88 | + |
| 89 | + # Use per-worker variable as a capture |
| 90 | + @def_function.function |
| 91 | + def worker_fn(): |
| 92 | + var.assign_add(constant_op.constant([1.0])) |
| 93 | + return var |
| 94 | + |
| 95 | + num_closures = 10 |
| 96 | + for ix in range(num_closures): |
| 97 | + self.coordinator.schedule(worker_fn) |
| 98 | + # Read the PWV many times to ensure result is up-to-date |
| 99 | + self.coordinator.join() |
| 100 | + result_sum = sum(var.read_all()).numpy() |
| 101 | + self.assertEqual(result_sum, ix + 1) |
| 102 | + |
| 103 | + for _ in range(num_closures): |
| 104 | + self.coordinator.schedule(worker_fn) |
| 105 | + self.coordinator.join() |
| 106 | + |
| 107 | + # Verify placement of variables |
| 108 | + devices = [wv._get_values().device for wv in var._per_worker_vars._values] |
| 109 | + expected_devices = [ |
| 110 | + f'/job:worker/replica:0/task:{ix}/device:CPU:0' |
| 111 | + for ix in range(self.strategy._num_workers) |
| 112 | + ] # pylint: disable=protected-access |
| 113 | + self.assertAllEqual(devices, expected_devices) |
| 114 | + |
| 115 | + result_sum = sum(var.read_all()).numpy() |
| 116 | + self.assertEqual(result_sum, num_closures * 2) |
| 117 | + |
| 118 | + def testKerasFit(self): |
| 119 | + embed_dim = 8 |
| 120 | + with self.strategy.scope(): |
| 121 | + model = Sequential([ |
| 122 | + layers.Input(shape=(1,), dtype=tf.int32), |
| 123 | + de.keras.layers.Embedding(embed_dim, key_dtype=tf.int32), |
| 124 | + layers.Flatten(), |
| 125 | + layers.Dense(1, activation='sigmoid') |
| 126 | + ]) |
| 127 | + optimizer = Adam(1E-3) |
| 128 | + optimizer = de.DynamicEmbeddingOptimizer(optimizer) |
| 129 | + model.compile(loss='binary_crossentropy', |
| 130 | + optimizer=optimizer, |
| 131 | + metrics=['accuracy']) |
| 132 | + |
| 133 | + ids = np.random.randint(0, 100, size=(64 * 2, 1)) |
| 134 | + labels = np.random.randint(0, 2, size=(64 * 2, 1)) |
| 135 | + |
| 136 | + def dataset_fn(input_context): |
| 137 | + global_batch_size = 32 |
| 138 | + batch_size = input_context.get_per_replica_batch_size(global_batch_size) |
| 139 | + dataset = tf.data.Dataset.from_tensor_slices((ids, labels)) |
| 140 | + dataset = dataset.shard(input_context.num_input_pipelines, |
| 141 | + input_context.input_pipeline_id) |
| 142 | + dataset = dataset.batch(batch_size).repeat() |
| 143 | + return dataset |
| 144 | + |
| 145 | + dataset = self.strategy.distribute_datasets_from_function(dataset_fn) |
| 146 | + |
| 147 | + history = model.fit(dataset, epochs=1, steps_per_epoch=len(ids) // 64) |
| 148 | + self.assertIn('loss', history.history) |
| 149 | + |
| 150 | + |
| 151 | +# borrow from multi_process_lib._set_spawn_exe_path and modify it for tf_recommenders_addons |
| 152 | +def custom_set_spawn_exe_path(): |
| 153 | + if sys.argv[0].endswith('.py'): |
| 154 | + |
| 155 | + def guess_path(package_root): |
| 156 | + # If all we have is a python module path, we'll need to make a guess for |
| 157 | + # the actual executable path. |
| 158 | + if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]: |
| 159 | + package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)] |
| 160 | + binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1) |
| 161 | + print(f"package_root_base {package_root_base} binary {binary}") |
| 162 | + possible_path = os.path.join(package_root_base, package_root, binary) |
| 163 | + print('Guessed test binary path: %s', possible_path) |
| 164 | + if os.access(possible_path, os.X_OK): |
| 165 | + return possible_path |
| 166 | + return None |
| 167 | + |
| 168 | + path = guess_path('tf_recommenders_addons') |
| 169 | + if path is None: |
| 170 | + print('Cannot determine binary path. sys.argv[0]=%s os.environ=%s', |
| 171 | + sys.argv[0], os.environ) |
| 172 | + raise RuntimeError('Cannot determine binary path') |
| 173 | + sys.argv[0] = path |
| 174 | + # Note that this sets the executable for *all* contexts. |
| 175 | + multiprocessing.get_context().set_executable(sys.argv[0]) |
| 176 | + |
| 177 | + |
| 178 | +# This is not for pytest |
| 179 | +# bazel test //tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests:parameter_server_bzl |
| 180 | +if __name__ == "__main__": |
| 181 | + multi_process_lib._set_spawn_exe_path = custom_set_spawn_exe_path |
| 182 | + v2_compat.enable_v2_behavior() |
| 183 | + multi_process_runner.test_main() |
0 commit comments