As noted in #916, when using the ResNet model with multiple GPUs, it appears the activations scales are not handled correctly, for example running the following (Brevitas version: several including current dev, Torch version: <1.13.1|2.7.1>):
BREVITAS_JIT=1 python bnn_pynq_train.py --network RESNET18_4W4A --gpus 0,1 --epochs 2
During evaluation, there is a large difference between the train and test accuracy (that doesn't occur in the single CPU case):
# Train log:
Epoch: [2][499/500] Time 0.780 (1.052) Data 0.003 (0.004) Loss 0.2890 (0.3099) Prec@1 42.000 (34.490) Prec@5 93.000 (85.882)
# Test log:
Test: [99/100] Model Time 1.618 (1.721) Loss Time 0.000 (0.000) Loss nan (nan) Prec@1 14.000 (10.000) Prec@5 55.000 (50.000)
Furthermore, when trying to evaluate the stored checkpoint as follows:
BREVITAS_JIT=1 python bnn_pynq_train.py --network RESNET18_4W4A --evaluate --resume ./checkpoints/resnet18_4w4a_ddp.tar --strict --gpus <0|0,1>
Results in the following key error:
RuntimeError: Error(s) in loading state_dict for QuantResNet:
Missing key(s) in state_dict: "relu.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.0.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.0.relu2.act_quant.fused_a
ctivation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.0.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "l
ayer1.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.0.relu1.act_quant.fused_activation_quant_proxy.tenso
r_quant.scaling_impl.value", "layer2.0.downsample.2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.0.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.0.relu_out.act_qu
ant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.va
lue", "layer2.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.0.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.0.downsample.2.act_quant.fused_activation_qu
ant_proxy.tensor_quant.scaling_impl.value", "layer3.0.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.0.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.1.relu
1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scal
ing_impl.value", "layer4.0.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.0.downsample.2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.0.relu2.act_quant.fused_activ
ation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.0.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer
4.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value".
Some other considerations:
As noted in #916, when using the ResNet model with multiple GPUs, it appears the activations scales are not handled correctly, for example running the following (Brevitas version: several including current
dev, Torch version:<1.13.1|2.7.1>):During evaluation, there is a large difference between the train and test accuracy (that doesn't occur in the single CPU case):
Furthermore, when trying to evaluate the stored checkpoint as follows:
Results in the following key error:
Some other considerations:
nn.DataParallel?