|
26 | 26 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100) |
27 | 27 | parser.add_argument("--epochs", help="number of epochs", type=int, default=1) |
28 | 28 | parser.add_argument("--export_dir", help="HDFS path to export saved_model", default="mnist_export") |
29 | | -parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="csv") |
| 29 | +parser.add_argument("--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv") |
30 | 30 | parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format") |
31 | 31 | parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format") |
32 | 32 | parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model") |
@@ -56,22 +56,22 @@ def toNumpy(bytestr): |
56 | 56 | return (image, label) |
57 | 57 |
|
58 | 58 | dataRDD = images.map(lambda x: toNumpy(bytes(x[0]))) |
59 | | -else: |
60 | | - if args.format == "csv": |
61 | | - images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')]) |
62 | | - labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')]) |
63 | | - else: # args.format == "pickle": |
64 | | - images = sc.pickleFile(args.images) |
65 | | - labels = sc.pickleFile(args.labels) |
| 59 | +else: # "csv" |
66 | 60 | print("zipping images and labels") |
| 61 | + # If partitions of images/labels don't match, you can use the following code: |
| 62 | + # images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')]).zipWithIndex().map(lambda x: (x[1], x[0])) |
| 63 | + # labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')]).zipWithIndex().map(lambda x: (x[1], x[0])) |
| 64 | + # dataRDD = images.join(labels).map(lambda x: (x[1][0], x[1][1])) |
| 65 | + images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')]) |
| 66 | + labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')]) |
67 | 67 | dataRDD = images.zip(labels) |
68 | 68 |
|
69 | 69 | cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model) |
70 | 70 | if args.mode == "train": |
71 | 71 | cluster.train(dataRDD, args.epochs) |
72 | | -else: |
73 | | - labelRDD = cluster.inference(dataRDD) |
74 | | - labelRDD.saveAsTextFile(args.output) |
| 72 | +else: # inference |
| 73 | + predRDD = cluster.inference(dataRDD) |
| 74 | + predRDD.saveAsTextFile(args.output) |
75 | 75 |
|
76 | 76 | cluster.shutdown(grace_secs=30) |
77 | 77 |
|
|
0 commit comments