1313# limitations under the License.
1414# ==============================================================================
1515"""Tests for dgi."""
16+ import os
17+
18+ from absl .testing import parameterized
1619import tensorflow as tf
20+ import tensorflow .__internal__ .distribute as tfdistribute
21+ import tensorflow .__internal__ .test as tftest
1722import tensorflow_gnn as tfgnn
1823
1924from tensorflow_gnn .runner import orchestration
4247""" % tfgnn .HIDDEN_STATE
4348
4449
45- class DeepGraphInfomaxTest (tf .test .TestCase ):
46-
50+ def _all_eager_distributed_strategy_combinations ():
51+ strategies = [
52+ # MirroredStrategy
53+ tfdistribute .combinations .mirrored_strategy_with_gpu_and_cpu ,
54+ tfdistribute .combinations .mirrored_strategy_with_one_cpu ,
55+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
56+ """ # MultiWorkerMirroredStrategy
57+ tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
58+ tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
59+ # TPUStrategy
60+ tfdistribute.combinations.tpu_strategy,
61+ tfdistribute.combinations.tpu_strategy_one_core,
62+ tfdistribute.combinations.tpu_strategy_packed_var,
63+ # ParameterServerStrategy
64+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
65+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
66+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
67+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
68+ ]
69+ return tftest .combinations .combine (distribution = strategies )
70+
71+
72+ class DeepGraphInfomaxTest (tf .test .TestCase , parameterized .TestCase ):
73+
74+ global_batch_size = 2
4775 gtspec = tfgnn .create_graph_spec_from_schema_pb (tfgnn .parse_schema (SCHEMA ))
48- task = dgi .DeepGraphInfomax ("node" , seed = 8191 )
76+ seed = 8191
77+ task = dgi .DeepGraphInfomax (
78+ "node" , global_batch_size = global_batch_size , seed = seed )
79+
80+ def get_graph_tensor (self ):
81+ gt = tfgnn .GraphTensor .from_pieces (
82+ node_sets = {
83+ "node" :
84+ tfgnn .NodeSet .from_fields (
85+ features = {
86+ tfgnn .HIDDEN_STATE :
87+ tf .convert_to_tensor ([[1. , 2. , 3. , 4. ],
88+ [11. , 11. , 11. , 11. ],
89+ [19. , 19. , 19. , 19. ]])
90+ },
91+ sizes = tf .convert_to_tensor ([3 ])),
92+ },
93+ edge_sets = {
94+ "edge" :
95+ tfgnn .EdgeSet .from_fields (
96+ sizes = tf .convert_to_tensor ([2 ]),
97+ adjacency = tfgnn .Adjacency .from_indices (
98+ ("node" , tf .convert_to_tensor ([0 , 1 ], dtype = tf .int32 )),
99+ ("node" , tf .convert_to_tensor ([2 , 0 ], dtype = tf .int32 )),
100+ )),
101+ })
102+ return gt
49103
50104 def build_model (self ):
51105 graph = inputs = tf .keras .layers .Input (type_spec = self .gtspec )
@@ -56,7 +110,9 @@ def build_model(self):
56110 "edge" ,
57111 tfgnn .TARGET ,
58112 feature_name = tfgnn .HIDDEN_STATE )
59- messages = tf .keras .layers .Dense (16 )(values )
113+ messages = tf .keras .layers .Dense (
114+ 8 , kernel_initializer = tf .constant_initializer (1. ))(
115+ values )
60116
61117 pooled = tfgnn .pool_edges_to_node (
62118 graph ,
@@ -67,7 +123,9 @@ def build_model(self):
67123 h_old = graph .node_sets ["node" ].features [tfgnn .HIDDEN_STATE ]
68124
69125 h_next = tf .keras .layers .Concatenate ()((pooled , h_old ))
70- h_next = tf .keras .layers .Dense (8 )(h_next )
126+ h_next = tf .keras .layers .Dense (
127+ 4 , kernel_initializer = tf .constant_initializer (1. ))(
128+ h_next )
71129
72130 graph = graph .replace_features (
73131 node_sets = {"node" : {
@@ -87,30 +145,71 @@ def test_adapt(self):
87145 feature_name = tfgnn .HIDDEN_STATE )(model (gt ))
88146 actual = adapted (gt )
89147
90- self .assertAllClose (actual , expected )
148+ self .assertAllClose (actual , expected , rtol = 1e-04 , atol = 1e-04 )
91149
92150 def test_fit (self ):
93- gt = tfgnn . random_graph_tensor (self .gtspec )
94- ds = tf . data . Dataset . from_tensors ( gt ). repeat ( 8 )
95- ds = ds . batch ( 2 ). map ( tfgnn .GraphTensor .merge_batch_to_components )
151+ ds = tf . data . Dataset . from_tensors (self .get_graph_tensor ()). repeat ( 8 )
152+ ds = ds . batch ( self . global_batch_size ). map (
153+ tfgnn .GraphTensor .merge_batch_to_components )
96154
155+ tf .random .set_seed (self .seed )
97156 model = self .task .adapt (self .build_model ())
98157 model .compile ()
99158
100159 def get_loss ():
160+ tf .random .set_seed (self .seed )
101161 values = model .evaluate (ds )
102162 return dict (zip (model .metrics_names , values ))["loss" ]
103163
104164 before = get_loss ()
105165 model .fit (ds )
106166 after = get_loss ()
167+ self .assertAllClose (before , 21754138.0 , rtol = 1e-04 , atol = 1e-04 )
168+ self .assertAllClose (after , 16268301.0 , rtol = 1e-04 , atol = 1e-04 )
169+
170+ @tfdistribute .combinations .generate (
171+ tftest .combinations .combine (distribution = [
172+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
173+ tfdistribute .combinations .multi_worker_mirrored_2x1_gpu ,
174+ ]))
175+ def test_distributed (self , distribution ):
176+ gt = self .get_graph_tensor ()
177+
178+ def dataset_fn (input_context = None , gt = gt ):
179+ ds = tf .data .Dataset .from_tensors (gt ).repeat (8 )
180+ if input_context :
181+ batch_size = input_context .get_per_replica_batch_size (
182+ self .global_batch_size )
183+ else :
184+ batch_size = self .global_batch_size
185+ ds = ds .batch (batch_size ).map (tfgnn .GraphTensor .merge_batch_to_components )
186+ return ds
187+
188+ with distribution .scope ():
189+ tf .random .set_seed (self .seed )
190+ model = self .task .adapt (self .build_model ())
191+ model .compile ()
192+
193+ def get_loss ():
194+ tf .random .set_seed (self .seed )
195+ values = model .evaluate (
196+ distribution .distribute_datasets_from_function (dataset_fn ), steps = 4 )
197+ return dict (zip (model .metrics_names , values ))["loss" ]
198+
199+ before = get_loss ()
200+ model .fit (
201+ distribution .distribute_datasets_from_function (dataset_fn ),
202+ steps_per_epoch = 4 )
203+ after = get_loss ()
204+ self .assertAllClose (before , 21754138.0 , rtol = 1e-04 , atol = 1e-04 )
205+ self .assertAllClose (after , 16268301.0 , rtol = 1e-04 , atol = 1e-04 )
107206
108- self .assertAllClose ( before , 250.42036 , rtol = 1e-04 , atol = 1e-04 )
109- self . assertAllClose ( after , 13.18533 , rtol = 1e-04 , atol = 1e-04 )
207+ export_dir = os . path . join ( self .get_temp_dir (), "dropout-model" )
208+ model . save ( export_dir )
110209
111210 def test_protocol (self ):
112211 self .assertIsInstance (dgi .DeepGraphInfomax , orchestration .Task )
113212
114213
115214if __name__ == "__main__" :
116- tf . test . main ()
215+ tfdistribute . multi_process_runner . test_main ()
0 commit comments