All code runs on Python 3.6.7 using PyTorch version 1.1.0.
In addition, you will need to install
- torchvision
- torchtext
- numpy
- pandas
- Download the data from link and extract it to the current directory. Then you can get two files:
train.csvandtest.csv. - Modify the data path in
process_text.pyand executeprocess_text.py.
- Download the data from link and extract it to the current directory.
- Modify the data path in
process_tiny_magenet.pyand executeprocess_tiny_magenet.py.
There are two main scripts:
train.shfor training using S-SGD, Local SGD and VRL-SGD.plot_all.shfor plotting figure.
--lrlearning rate--modelmodel name, model:lenet5,text_cnn,mlp.--data-setdataset name, model:mnist,DB_Pedia,tiny_imagenet.--epochsthe number of epochs for running.--gpu-numthe number of GPUs.--batch-sizebatch size for each machine.-rresume the training.localwhether to communicate periodically.--periodthe communication period. If--localis not set, then it will always be 1.--cluster-dataeach worker only accesses a sub of data.--vrlwhether to execute the VRLSGD algorithm.
We recommend performing 2 epoch SGD to initialize the weights. If not, the -r parameter cannot be used. After the initialization is completed, modify the file name, for example, change the file lenet5.pth to lenet5_init.pth.
# S-SGD
python main.py --lr 0.005 --model lenet5 --dataset mnist --epochs 100 --st 0 -s 1 --gpu-num 8 -r --port 6632 --cluster-data
# Local-SGD
python main.py --lr 0.005 --model lenet5 --dataset mnist --epochs 100 --st 0 -s 1 --gpu-num 8 -r --port 6633 --cluster-data --local --period 20
# VRL-SGD
python main.py --lr 0.005 --model lenet5 --dataset mnist --epochs 100 --st 0 -s 1 --gpu-num 8 -r --port 6634 --cluster-data --local --period 20 --vrl
# S-SGD
python main.py --lr 0.005 --model lenet5 --dataset mnist --epochs 100 --st 0 -s 1 --gpu-num 8 -r --port 6632
# Local-SGD
python main.py --lr 0.005 --model lenet5 --dataset mnist --epochs 100 --st 0 -s 1 --gpu-num 8 -r --port 6633 --local --period 20
# VRL-SGD
python main.py --lr 0.005 --model lenet5 --dataset mnist --epochs 100 --st 0 -s 1 --gpu-num 8 -r --port 6634 --local --period 20 --vrl
# S-SGD
python main.py --lr 0.01 --model text_cnn --dataset DB_Pedia --epochs 100 --st 0 -s 1 --gpu-num 8 --port 6632 --batch-size 512 -r --cluster-data
# Local-SGD
python main.py --lr 0.01 --model text_cnn --dataset DB_Pedia --epochs 100 --st 0 -s 1 --gpu-num 8 --port 6632 --batch-size 512 -r --cluster-data --local --period 50
# VRL-SGD
python main.py --lr 0.01 --model text_cnn --dataset DB_Pedia --epochs 100 --st 0 -s 1 --gpu-num 8 --port 6632 --batch-size 512 -r --cluster-data --local --period 50 --vrl
# S-SGD
python main.py --lr 0.01 --model text_cnn --dataset DB_Pedia --epochs 100 --st 0 -s 1 --gpu-num 8 --port 6632 --batch-size 512 -r
# Local-SGD
python main.py --lr 0.01 --model text_cnn --dataset DB_Pedia --epochs 100 --st 0 -s 1 --gpu-num 8 --port 6632 --batch-size 512 -r --local --period 50
# VRL-SGD
python main.py --lr 0.01 --model text_cnn --dataset DB_Pedia --epochs 100 --st 0 -s 1 --gpu-num 8 --port 6632 --batch-size 512 -r --local --period 50 --vrl
# S-SGD
python main.py --lr 0.025 --model mlp --dataset tiny_imagenet --epochs 300 -s 1 --gpu-num 8 --port 6632 --batch-size 256 -r --cluster-data
# Local-SGD
python main.py --lr 0.025 --model mlp --dataset tiny_imagenet --epochs 300 -s 1 --gpu-num 8 --port 6633 --batch-size 256 -r --local --period 20 --cluster-data
# VRL-SGD
python main.py --lr 0.025 --model mlp --dataset tiny_imagenet --epochs 300 -s 1 --gpu-num 8 --port 6634 --batch-size 256 -r --local --period 20 --vrl--cluster-data
# S-SGD
python main.py --lr 0.025 --model mlp --dataset tiny_imagenet --epochs 300 -s 1 --gpu-num 8 --port 6632 --batch-size 256 -r
# Local-SGD
python main.py --lr 0.025 --model mlp --dataset tiny_imagenet --epochs 300 -s 1 --gpu-num 8 --port 6633 --batch-size 256 -r --local --period 20
# VRL-SGD
python main.py --lr 0.025 --model mlp --dataset tiny_imagenet --epochs 300 -s 1 --gpu-num 8 --port 6634 --batch-size 256 -r --local --period 20 --vrl