1313# limitations under the License.
1414# ==============================================================================
1515"""Tests for dgi."""
16+ from absl .testing import parameterized
1617import tensorflow as tf
18+ import tensorflow .__internal__ .distribute as tfdistribute
19+ import tensorflow .__internal__ .test as tftest
1720import tensorflow_gnn as tfgnn
1821
1922from tensorflow_gnn .runner import orchestration
4245""" % tfgnn .HIDDEN_STATE
4346
4447
45- class DeepGraphInfomaxTest (tf .test .TestCase ):
46-
48+ def _all_eager_distributed_strategy_combinations ():
49+ strategies = [
50+ # MirroredStrategy
51+ tfdistribute .combinations .mirrored_strategy_with_gpu_and_cpu ,
52+ tfdistribute .combinations .mirrored_strategy_with_one_cpu ,
53+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
54+ """ # MultiWorkerMirroredStrategy
55+ tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
56+ tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
57+ # TPUStrategy
58+ tfdistribute.combinations.tpu_strategy,
59+ tfdistribute.combinations.tpu_strategy_one_core,
60+ tfdistribute.combinations.tpu_strategy_packed_var,
61+ # ParameterServerStrategy
62+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
63+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
64+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
65+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
66+ ]
67+ return tftest .combinations .combine (distribution = strategies )
68+
69+
70+ class DeepGraphInfomaxTest (tf .test .TestCase , parameterized .TestCase ):
71+
72+ global_batch_size = 2
4773 gtspec = tfgnn .create_graph_spec_from_schema_pb (tfgnn .parse_schema (SCHEMA ))
48- task = dgi .DeepGraphInfomax ("node" , seed = 8191 )
74+ task = dgi .DeepGraphInfomax (
75+ "node" , global_batch_size = global_batch_size , seed = 8191 )
76+
77+ def get_graph_tensor (self ):
78+ gt = tfgnn .GraphTensor .from_pieces (
79+ node_sets = {
80+ "node" :
81+ tfgnn .NodeSet .from_fields (
82+ features = {
83+ tfgnn .HIDDEN_STATE :
84+ tf .convert_to_tensor ([[1. , 2. , 3. , 4. ],
85+ [11. , 11. , 11. , 11. ],
86+ [19. , 19. , 19. , 19. ]])
87+ },
88+ sizes = tf .convert_to_tensor ([3 ])),
89+ },
90+ edge_sets = {
91+ "edge" :
92+ tfgnn .EdgeSet .from_fields (
93+ sizes = tf .convert_to_tensor ([2 ]),
94+ adjacency = tfgnn .Adjacency .from_indices (
95+ ("node" , tf .convert_to_tensor ([0 , 1 ], dtype = tf .int32 )),
96+ ("node" , tf .convert_to_tensor ([2 , 0 ], dtype = tf .int32 )),
97+ )),
98+ })
99+ return gt
49100
50101 def build_model (self ):
51102 graph = inputs = tf .keras .layers .Input (type_spec = self .gtspec )
@@ -87,12 +138,12 @@ def test_adapt(self):
87138 feature_name = tfgnn .HIDDEN_STATE )(model (gt ))
88139 actual = adapted (gt )
89140
90- self .assertAllClose (actual , expected )
141+ self .assertAllClose (actual , expected , rtol = 1e-04 , atol = 1e-04 )
91142
92143 def test_fit (self ):
93- gt = tfgnn . random_graph_tensor (self .gtspec )
94- ds = tf . data . Dataset . from_tensors ( gt ). repeat ( 8 )
95- ds = ds . batch ( 2 ). map ( tfgnn .GraphTensor .merge_batch_to_components )
144+ ds = tf . data . Dataset . from_tensors (self .get_graph_tensor ()). repeat ( 8 )
145+ ds = ds . batch ( self . global_batch_size ). map (
146+ tfgnn .GraphTensor .merge_batch_to_components )
96147
97148 model = self .task .adapt (self .build_model ())
98149 model .compile ()
@@ -105,12 +156,47 @@ def get_loss():
105156 model .fit (ds )
106157 after = get_loss ()
107158
108- self .assertAllClose (before , 250.42036 , rtol = 1e-04 , atol = 1e-04 )
109- self .assertAllClose (after , 13.18533 , rtol = 1e-04 , atol = 1e-04 )
159+ self .assertAllClose (before , 92.92909 , rtol = 1e-04 , atol = 1e-04 )
160+ self .assertAllClose (after , 4.05084 , rtol = 1e-04 , atol = 1e-04 )
161+
162+ @tfdistribute .combinations .generate (
163+ tftest .combinations .combine (distribution = [
164+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
165+ tfdistribute .combinations .multi_worker_mirrored_2x1_gpu ,
166+ ]))
167+ def test_distributed (self , distribution ):
168+ gt = self .get_graph_tensor ()
169+
170+ def dataset_fn (input_context = None , gt = gt ):
171+ ds = tf .data .Dataset .from_tensors (gt ).repeat (8 )
172+ if input_context :
173+ batch_size = input_context .get_per_replica_batch_size (
174+ self .global_batch_size )
175+ else :
176+ batch_size = self .global_batch_size
177+ ds = ds .batch (batch_size ).map (tfgnn .GraphTensor .merge_batch_to_components )
178+ return ds
179+
180+ with distribution .scope ():
181+ model = self .task .adapt (self .build_model ())
182+ model .compile ()
183+
184+ def get_loss ():
185+ values = model .evaluate (
186+ distribution .distribute_datasets_from_function (dataset_fn ), steps = 4 )
187+ return dict (zip (model .metrics_names , values ))["loss" ]
188+
189+ before = get_loss ()
190+ model .fit (
191+ distribution .distribute_datasets_from_function (dataset_fn ),
192+ steps_per_epoch = 4 )
193+ after = get_loss ()
194+ self .assertAllClose (before , 92.92909 , rtol = 2 , atol = 1 )
195+ self .assertAllClose (after , 4.05084 , rtol = 2 , atol = 1 )
110196
111197 def test_protocol (self ):
112198 self .assertIsInstance (dgi .DeepGraphInfomax , orchestration .Task )
113199
114200
115201if __name__ == "__main__" :
116- tf . test . main ()
202+ tfdistribute . multi_process_runner . test_main ()
0 commit comments