@@ -37,14 +37,17 @@ def _make_functional_regularized_model(distance_config):
3737 def _make_unregularized_model (inputs , num_classes ):
3838 """Makes standard 1 layer MLP with logistic regression."""
3939 x = tf .keras .layers .Dense (16 , activation = 'relu' )(inputs )
40- return tf .keras .Model (inputs , outputs = tf .keras .layers .Dense (num_classes )(x ))
40+ model = tf .keras .Model (inputs , tf .keras .layers .Dense (num_classes )(x ))
41+ return model
4142
4243 # Each example has 4 features and 2 neighbors, each with an edge weight.
4344 inputs = (tf .keras .Input (shape = (4 ,), dtype = tf .float32 , name = 'features' ),
4445 tf .keras .Input (shape = (2 , 4 ), dtype = tf .float32 , name = 'neighbors' ),
4546 tf .keras .Input (
4647 shape = (2 , 1 ), dtype = tf .float32 , name = 'neighbor_weights' ))
4748 features , neighbors , neighbor_weights = inputs
49+ neighbors = tf .reshape (neighbors , (- 1 ,) + tuple (features .shape [1 :]))
50+ neighbor_weights = tf .reshape (neighbor_weights , [- 1 , 1 ])
4851 unregularized_model = _make_unregularized_model (features , 3 )
4952 logits = unregularized_model (features )
5053 model = tf .keras .Model (inputs = inputs , outputs = logits )
0 commit comments