|
31 | 31 | from absl import app |
32 | 32 | from absl import flags |
33 | 33 |
|
34 | | -from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy |
| 34 | +from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import compute_dp_sgd_privacy_statement |
35 | 35 |
|
36 | | -FLAGS = flags.FLAGS |
37 | 36 |
|
38 | | -flags.DEFINE_integer('N', None, 'Total number of examples') |
39 | | -flags.DEFINE_integer('batch_size', None, 'Batch size') |
40 | | -flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD') |
41 | | -flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)') |
42 | | -flags.DEFINE_float('delta', 1e-6, 'Target delta') |
| 37 | +_NUM_EXAMPLES = flags.DEFINE_integer('N', None, 'Total number of examples.') |
| 38 | +_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'Batch size.') |
| 39 | +_NOISE_MULTIPLIER = flags.DEFINE_float( |
| 40 | + 'noise_multiplier', None, 'Noise multiplier for DP-SGD.' |
| 41 | +) |
| 42 | +_NUM_EPOCHS = flags.DEFINE_float( |
| 43 | + 'epochs', None, 'Number of epochs (may be fractional).' |
| 44 | +) |
| 45 | +_DELTA = flags.DEFINE_float('delta', 1e-6, 'Target delta.') |
| 46 | +_USED_MICROBATCHING = flags.DEFINE_bool( |
| 47 | + 'used_microbatching', |
| 48 | + True, |
| 49 | + 'Whether microbatching was used (with microbatch size greater than one).', |
| 50 | +) |
| 51 | +_MAX_EXAMPLES_PER_USER = flags.DEFINE_integer( |
| 52 | + 'max_examples_per_user', |
| 53 | + None, |
| 54 | + ( |
| 55 | + 'Maximum number of examples per user, applicable. Used to compute a' |
| 56 | + ' user-level DP guarantee.' |
| 57 | + ), |
| 58 | +) |
| 59 | + |
| 60 | +flags.mark_flags_as_required(['N', 'batch_size', 'noise_multiplier', 'epochs']) |
43 | 61 |
|
44 | 62 |
|
45 | 63 | def main(argv): |
46 | 64 | del argv # argv is not used. |
47 | 65 |
|
48 | | - assert FLAGS.N is not None, 'Flag N is missing.' |
49 | | - assert FLAGS.batch_size is not None, 'Flag batch_size is missing.' |
50 | | - assert FLAGS.noise_multiplier is not None, 'Flag noise_multiplier is missing.' |
51 | | - assert FLAGS.epochs is not None, 'Flag epochs is missing.' |
52 | | - compute_dp_sgd_privacy(FLAGS.N, FLAGS.batch_size, FLAGS.noise_multiplier, |
53 | | - FLAGS.epochs, FLAGS.delta) |
| 66 | + statement = compute_dp_sgd_privacy_statement( |
| 67 | + _NUM_EXAMPLES.value, |
| 68 | + _BATCH_SIZE.value, |
| 69 | + _NUM_EPOCHS.value, |
| 70 | + _NOISE_MULTIPLIER.value, |
| 71 | + _DELTA.value, |
| 72 | + _USED_MICROBATCHING.value, |
| 73 | + _MAX_EXAMPLES_PER_USER.value, |
| 74 | + ) |
| 75 | + print(statement) |
54 | 76 |
|
55 | 77 |
|
56 | 78 | if __name__ == '__main__': |
|
0 commit comments