Skip to content

Activation scales not being stored to state_dict when using multiple GPUs for ResNet model in BNN-PYNQ example #1349

@nickfraser

Description

@nickfraser

As noted in #916, when using the ResNet model with multiple GPUs, it appears the activations scales are not handled correctly, for example running the following (Brevitas version: several including current dev, Torch version: <1.13.1|2.7.1>):

BREVITAS_JIT=1 python bnn_pynq_train.py --network RESNET18_4W4A --gpus 0,1 --epochs 2

During evaluation, there is a large difference between the train and test accuracy (that doesn't occur in the single CPU case):

# Train log:
Epoch: [2][499/500]     Time 0.780 (1.052)      Data 0.003 (0.004)      Loss 0.2890 (0.3099)    Prec@1 42.000 (34.490)  Prec@5 93.000 (85.882)
# Test log:
Test: [99/100]  Model Time 1.618 (1.721)        Loss Time 0.000 (0.000) Loss nan (nan)  Prec@1 14.000 (10.000)  Prec@5 55.000 (50.000)

Furthermore, when trying to evaluate the stored checkpoint as follows:

BREVITAS_JIT=1 python bnn_pynq_train.py --network RESNET18_4W4A --evaluate --resume ./checkpoints/resnet18_4w4a_ddp.tar --strict --gpus <0|0,1>

Results in the following key error:

RuntimeError: Error(s) in loading state_dict for QuantResNet:
        Missing key(s) in state_dict: "relu.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.0.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.0.relu2.act_quant.fused_a
ctivation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.0.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "l
ayer1.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer1.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.0.relu1.act_quant.fused_activation_quant_proxy.tenso
r_quant.scaling_impl.value", "layer2.0.downsample.2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.0.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.0.relu_out.act_qu
ant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer2.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.va
lue", "layer2.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.0.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.0.downsample.2.act_quant.fused_activation_qu
ant_proxy.tensor_quant.scaling_impl.value", "layer3.0.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.0.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.1.relu
1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer3.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scal
ing_impl.value", "layer4.0.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.0.downsample.2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.0.relu2.act_quant.fused_activ
ation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.0.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.1.relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer
4.1.relu2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "layer4.1.relu_out.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value".

Some other considerations:

  • Does this affect every occurrence of using "param_from_stats" with nn.DataParallel?
  • Do we want to move DDP instead, as this seems to be recommended now?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions