2121`compute_gradient_norms()` function).
2222"""
2323
24+ from typing import Union , Iterable , Text , TypeAlias
25+
2426import tensorflow as tf
2527from tensorflow_privacy .privacy .fast_gradient_clipping import gradient_clipping_utils
28+ from tensorflow_privacy .privacy .fast_gradient_clipping import layer_registry as lr
29+
30+ InputTensor : TypeAlias = Union [
31+ tf .Tensor , Iterable [tf .Tensor ], dict [Text , tf .Tensor ]
32+ ]
2633
2734
28- def get_registry_generator_fn (tape , layer_registry ):
35+ def get_registry_generator_fn (
36+ tape : tf .GradientTape , layer_registry : lr .LayerRegistry
37+ ):
2938 """Creates the generator function for `compute_gradient_norms()`."""
3039 if layer_registry is None :
3140 # Needed for backwards compatibility.
@@ -53,7 +62,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
5362 return registry_generator_fn
5463
5564
56- def compute_gradient_norms (input_model , x_batch , y_batch , layer_registry ):
65+ def compute_gradient_norms (
66+ input_model : tf .keras .Model ,
67+ x_batch : InputTensor ,
68+ y_batch : tf .Tensor ,
69+ layer_registry : lr .LayerRegistry ,
70+ ):
5771 """Computes the per-example loss gradient norms for given data.
5872
5973 Applies a variant of the approach given in
@@ -62,7 +76,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
6276 Args:
6377 input_model: The `tf.keras.Model` from which to obtain the layers from. The
6478 loss of the model *must* be a scalar loss.
65- x_batch: A `tf.Tensor ` representing a batch of inputs to the model. The
79+ x_batch: An `InputTensor ` representing a batch of inputs to the model. The
6680 first axis must be the batch dimension.
6781 y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
6882 must be the batch dimension. The number of examples should match the
@@ -106,7 +120,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
106120 return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
107121
108122
109- def compute_clip_weights (l2_norm_clip , gradient_norms ):
123+ def compute_clip_weights (l2_norm_clip : float , gradient_norms : tf . Tensor ):
110124 """Computes the per-example loss/clip weights for clipping.
111125
112126 When the sum of the per-example losses is replaced a weighted sum, where
@@ -132,7 +146,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
132146
133147
134148def compute_pred_and_clipped_gradients (
135- input_model , x_batch , y_batch , l2_norm_clip , layer_registry
149+ input_model : tf .keras .Model ,
150+ x_batch : InputTensor ,
151+ y_batch : tf .Tensor ,
152+ l2_norm_clip : float ,
153+ layer_registry : lr .LayerRegistry ,
136154):
137155 """Computes the per-example predictions and per-example clipped loss gradient.
138156
@@ -147,7 +165,7 @@ def compute_pred_and_clipped_gradients(
147165
148166 Args:
149167 input_model: The `tf.keras.Model` from which to obtain the layers from.
150- x_batch: A `tf.Tensor ` representing a batch of inputs to the model. The
168+ x_batch: An `InputTensor ` representing a batch of inputs to the model. The
151169 first axis must be the batch dimension.
152170 y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
153171 must be the batch dimension. The number of examples should match the
0 commit comments