|
2 | 2 |
|
3 | 3 | Original Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py |
4 | 4 |
|
5 | | -Note: this has been heavily modified to support different input formats (CSV and TFRecords) as well as to demonstrate the different data ingestion methods (feed_dict and QueueRunner). |
| 5 | +Notes: |
| 6 | +- This assumes that you have already [installed Spark, TensorFlow, and TensorFlowOnSpark](https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_Standalone) |
| 7 | +- This code has been heavily modified to support different input formats (CSV and TFRecords) and different data ingestion methods (`InputMode.TENSORFLOW` and `InputMode.SPARK`). |
6 | 8 |
|
7 | | -Please follow [these instructions](https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_YARN) to run this example. |
| 9 | +### Download MNIST data |
| 10 | + |
| 11 | +``` |
| 12 | +mkdir ${TFoS_HOME}/mnist |
| 13 | +pushd ${TFoS_HOME}/mnist |
| 14 | +curl -O "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz" |
| 15 | +curl -O "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz" |
| 16 | +curl -O "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz" |
| 17 | +curl -O "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz" |
| 18 | +popd |
| 19 | +``` |
| 20 | + |
| 21 | +### Convert the MNIST zip files using Spark |
| 22 | + |
| 23 | +``` |
| 24 | +cd ${TFoS_HOME} |
| 25 | +# rm -rf examples/mnist/csv |
| 26 | +${SPARK_HOME}/bin/spark-submit \ |
| 27 | +--master ${MASTER} \ |
| 28 | +${TFoS_HOME}/examples/mnist/mnist_data_setup.py \ |
| 29 | +--output examples/mnist/csv \ |
| 30 | +--format csv |
| 31 | +ls -lR examples/mnist/csv |
| 32 | +``` |
| 33 | + |
| 34 | +### Start Spark Standalone Cluster |
| 35 | + |
| 36 | +``` |
| 37 | +export MASTER=spark://$(hostname):7077 |
| 38 | +export SPARK_WORKER_INSTANCES=2 |
| 39 | +export CORES_PER_WORKER=1 |
| 40 | +export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) |
| 41 | +${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-slave.sh -c $CORES_PER_WORKER -m 3G ${MASTER} |
| 42 | +``` |
| 43 | + |
| 44 | +### Run distributed MNIST training using `InputMode.SPARK` |
| 45 | + |
| 46 | +``` |
| 47 | +# rm -rf mnist_model |
| 48 | +${SPARK_HOME}/bin/spark-submit \ |
| 49 | +--master ${MASTER} \ |
| 50 | +--py-files ${TFoS_HOME}/examples/mnist/spark/mnist_dist.py \ |
| 51 | +--conf spark.cores.max=${TOTAL_CORES} \ |
| 52 | +--conf spark.task.cpus=${CORES_PER_WORKER} \ |
| 53 | +--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \ |
| 54 | +${TFoS_HOME}/examples/mnist/spark/mnist_spark.py \ |
| 55 | +--cluster_size ${SPARK_WORKER_INSTANCES} \ |
| 56 | +--images examples/mnist/csv/train/images \ |
| 57 | +--labels examples/mnist/csv/train/labels \ |
| 58 | +--format csv \ |
| 59 | +--mode train \ |
| 60 | +--model mnist_model |
| 61 | +
|
| 62 | +ls -l mnist_model |
| 63 | +``` |
| 64 | + |
| 65 | +### Run distributed MNIST inference using `InputMode.SPARK` |
| 66 | + |
| 67 | +``` |
| 68 | +# rm -rf predictions |
| 69 | +${SPARK_HOME}/bin/spark-submit \ |
| 70 | +--master ${MASTER} \ |
| 71 | +--py-files ${TFoS_HOME}/examples/mnist/spark/mnist_dist.py \ |
| 72 | +--conf spark.cores.max=${TOTAL_CORES} \ |
| 73 | +--conf spark.task.cpus=${CORES_PER_WORKER} \ |
| 74 | +--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \ |
| 75 | +${TFoS_HOME}/examples/mnist/spark/mnist_spark.py \ |
| 76 | +--cluster_size ${SPARK_WORKER_INSTANCES} \ |
| 77 | +--images examples/mnist/csv/test/images \ |
| 78 | +--labels examples/mnist/csv/test/labels \ |
| 79 | +--mode inference \ |
| 80 | +--format csv \ |
| 81 | +--model mnist_model \ |
| 82 | +--output predictions |
| 83 | +
|
| 84 | +less predictions/part-00000 |
| 85 | +``` |
| 86 | + |
| 87 | +The prediction result should look like: |
| 88 | +``` |
| 89 | +2017-02-10T23:29:17.009563 Label: 7, Prediction: 7 |
| 90 | +2017-02-10T23:29:17.009677 Label: 2, Prediction: 2 |
| 91 | +2017-02-10T23:29:17.009721 Label: 1, Prediction: 1 |
| 92 | +2017-02-10T23:29:17.009761 Label: 0, Prediction: 0 |
| 93 | +2017-02-10T23:29:17.009799 Label: 4, Prediction: 4 |
| 94 | +2017-02-10T23:29:17.009838 Label: 1, Prediction: 1 |
| 95 | +2017-02-10T23:29:17.009876 Label: 4, Prediction: 4 |
| 96 | +2017-02-10T23:29:17.009914 Label: 9, Prediction: 9 |
| 97 | +2017-02-10T23:29:17.009951 Label: 5, Prediction: 6 |
| 98 | +2017-02-10T23:29:17.009989 Label: 9, Prediction: 9 |
| 99 | +2017-02-10T23:29:17.010026 Label: 0, Prediction: 0 |
| 100 | +``` |
| 101 | + |
| 102 | +### Shutdown Spark cluster |
| 103 | + |
| 104 | +``` |
| 105 | +${SPARK_HOME}/sbin/stop-slave.sh; ${SPARK_HOME}/sbin/stop-master.sh |
| 106 | +``` |
0 commit comments