Skip to content

Commit abb0c3f

Browse files
galenmandrewtensorflower-gardener
authored andcommitted
Migrates compute_dp_sgd_privacy to print new privacy statement from compute_dp_sgd_privacy_lib.
PiperOrigin-RevId: 520147633
1 parent 781483d commit abb0c3f

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,48 @@
3131
from absl import app
3232
from absl import flags
3333

34-
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy
34+
from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy_statement
3535

36-
FLAGS = flags.FLAGS
3736

38-
flags.DEFINE_integer('N', None, 'Total number of examples')
39-
flags.DEFINE_integer('batch_size', None, 'Batch size')
40-
flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD')
41-
flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)')
42-
flags.DEFINE_float('delta', 1e-6, 'Target delta')
37+
_NUM_EXAMPLES = flags.DEFINE_integer('N', None, 'Total number of examples.')
38+
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'Batch size.')
39+
_NOISE_MULTIPLIER = flags.DEFINE_float(
40+
'noise_multiplier', None, 'Noise multiplier for DP-SGD.'
41+
)
42+
_NUM_EPOCHS = flags.DEFINE_float(
43+
'epochs', None, 'Number of epochs (may be fractional).'
44+
)
45+
_DELTA = flags.DEFINE_float('delta', 1e-6, 'Target delta.')
46+
_USED_MICROBATCHING = flags.DEFINE_bool(
47+
'used_microbatching',
48+
True,
49+
'Whether microbatching was used (with microbatch size greater than one).',
50+
)
51+
_MAX_EXAMPLES_PER_USER = flags.DEFINE_integer(
52+
'max_examples_per_user',
53+
None,
54+
(
55+
'Maximum number of examples per user, applicable. Used to compute a'
56+
' user-level DP guarantee.'
57+
),
58+
)
59+
60+
flags.mark_flags_as_required(['N', 'batch_size', 'noise_multiplier', 'epochs'])
4361

4462

4563
def main(argv):
4664
del argv # argv is not used.
4765

48-
assert FLAGS.N is not None, 'Flag N is missing.'
49-
assert FLAGS.batch_size is not None, 'Flag batch_size is missing.'
50-
assert FLAGS.noise_multiplier is not None, 'Flag noise_multiplier is missing.'
51-
assert FLAGS.epochs is not None, 'Flag epochs is missing.'
52-
compute_dp_sgd_privacy(FLAGS.N, FLAGS.batch_size, FLAGS.noise_multiplier,
53-
FLAGS.epochs, FLAGS.delta)
66+
statement = compute_dp_sgd_privacy_statement(
67+
_NUM_EXAMPLES.value,
68+
_BATCH_SIZE.value,
69+
_NUM_EPOCHS.value,
70+
_NOISE_MULTIPLIER.value,
71+
_DELTA.value,
72+
_USED_MICROBATCHING.value,
73+
_MAX_EXAMPLES_PER_USER.value,
74+
)
75+
print(statement)
5476

5577

5678
if __name__ == '__main__':

0 commit comments

Comments
 (0)