|
21 | 21 | num_ps = 1 |
22 | 22 |
|
23 | 23 | parser = argparse.ArgumentParser() |
24 | | -parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100) |
25 | | -parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=0) |
26 | | -parser.add_argument("-f", "--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr") |
27 | | -parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format") |
28 | | -parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format") |
29 | | -parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model") |
30 | | -parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors) |
31 | | -parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions") |
32 | | -parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1) |
33 | | -parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000) |
34 | | -parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true") |
35 | | -parser.add_argument("-X", "--mode", help="train|inference", default="train") |
36 | | -parser.add_argument("-c", "--rdma", help="use rdma connection", default=False) |
37 | | -parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally. |
| 24 | +parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100) |
| 25 | +parser.add_argument("--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors) |
| 26 | +parser.add_argument("--driver_ps_nodes", help="""run tensorflow PS node on driver locally. |
38 | 27 | You will need to set cluster_size = num_executors + num_ps""", default=False) |
| 28 | +parser.add_argument("--epochs", help="number of epochs", type=int, default=1) |
| 29 | +parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr") |
| 30 | +parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format") |
| 31 | +parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format") |
| 32 | +parser.add_argument("--mode", help="train|inference", default="train") |
| 33 | +parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model") |
| 34 | +parser.add_argument("--num_ps", help="number of ps nodes", default=1) |
| 35 | +parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions") |
| 36 | +parser.add_argument("--rdma", help="use rdma connection", default=False) |
| 37 | +parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1) |
| 38 | +parser.add_argument("--shuffle_size", help="size of shuffle buffer", type=int, default=1000) |
| 39 | +parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000) |
| 40 | +parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") |
39 | 41 | args = parser.parse_args() |
40 | 42 | print("args:", args) |
41 | 43 |
|
42 | 44 |
|
43 | 45 | print("{0} ===== Start".format(datetime.now().isoformat())) |
44 | | -cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard, |
| 46 | +cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, args.num_ps, args.tensorboard, |
45 | 47 | TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes) |
46 | 48 | cluster.shutdown() |
47 | 49 |
|
|
0 commit comments