diff --git a/openfl-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb b/openfl-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb index 349832dc89..45410f8719 100644 --- a/openfl-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb +++ b/openfl-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -210,22 +210,9 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Aggregator step \"start\" registered\n", - "Collaborator step \"aggregated_model_validation\" registered\n", - "Collaborator step \"train\" registered\n", - "Collaborator step \"local_model_validation\" registered\n", - "Aggregator step \"join\" registered\n", - "Aggregator step \"end\" registered\n" - ] - } - ], + "outputs": [], "source": [ "class FederatedFlow(FLSpec):\n", " def __init__(self, model=None, optimizer=None, rounds=10, **kwargs):\n", @@ -273,13 +260,18 @@ "\n", " self.model.train()\n", " self.optimizer = get_optimizer(self.model)\n", + " \n", + " # Set old weights ONCE at the beginning of training\n", + " # This sets the reference weights to the global model weights \n", + " # received from the aggregator, implementing FedProx correctly\n", + " self.optimizer.set_old_weights([p.clone().detach() for p in self.model.parameters()])\n", + " \n", " for batch_idx, (data, target) in enumerate(self.train_loader):\n", " self.optimizer.zero_grad()\n", " output = self.model(data)\n", " loss = F.cross_entropy(output, target)\n", " loss.backward()\n", " \n", - " self.optimizer.set_old_weights([p.clone().detach() for p in self.model.parameters()])\n", " self.optimizer.step()\n", "\n", " if (len(data) * batch_idx) / len(self.train_loader.dataset) >= log_threshold:\n", @@ -335,266 +327,9 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Calling start\n", - "\u001b[94mPerforming initialization for model\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator0, model: 140162497619616\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 4.6833, Accuracy: 171/2500 (7%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.889274\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.279191\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.994200\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator0\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.7548, Accuracy: 1929/2500 (77%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.7716000080108643\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator1, model: 140158910463952\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 4.7259, Accuracy: 173/2500 (7%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.675623\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.068585\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.687561\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator1\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.6366, Accuracy: 2004/2500 (80%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator1, Accuracy: 0.8015999794006348\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator2, model: 140162497661872\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 4.6549, Accuracy: 215/2500 (9%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.879489\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.325507\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.968176\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator2\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.7462, Accuracy: 1901/2500 (76%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator2, Accuracy: 0.7603999972343445\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator3, model: 140162498346528\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 4.7129, Accuracy: 193/2500 (8%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.720635\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.061211\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.762026\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator3\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.6378, Accuracy: 1992/2500 (80%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator3, Accuracy: 0.7968000173568726\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling join\n", - "\u001b[94mAverage aggregated model accuracy = 0.07520000264048576\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mAverage training loss = 0.8529909627063148\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mAverage local model validation values = 0.782600000500679\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator0, model: 140158910552480\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.6740, Accuracy: 1996/2500 (80%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.974921\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.633429\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.591566\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator0\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.3951, Accuracy: 2214/2500 (89%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.8855999708175659\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator1, model: 140162497608672\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.6877, Accuracy: 1981/2500 (79%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.824028\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.515538\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.410188\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator1\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.4668, Accuracy: 2166/2500 (87%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator1, Accuracy: 0.8664000034332275\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator2, model: 140162498107328\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.6919, Accuracy: 1981/2500 (79%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.025180\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.616896\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.483282\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator2\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.4406, Accuracy: 2163/2500 (87%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator2, Accuracy: 0.8651999831199646\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator3, model: 140162498345664\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.6698, Accuracy: 2000/2500 (80%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.725868\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.450241\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.388554\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator3\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.4106, Accuracy: 2211/2500 (88%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator3, Accuracy: 0.8844000101089478\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling join\n", - "\u001b[94mAverage aggregated model accuracy = 0.7958000004291534\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mAverage training loss = 0.4683974838455107\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mAverage local model validation values = 0.8753999918699265\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator0, model: 140162648091376\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.3590, Accuracy: 2230/2500 (89%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.406638\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.313662\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.326520\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator0\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.2096, Accuracy: 2338/2500 (94%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.9351999759674072\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator1, model: 140162646717344\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.3773, Accuracy: 2228/2500 (89%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.392126\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.228912\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.200197\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator1\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.2601, Accuracy: 2317/2500 (93%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator1, Accuracy: 0.926800012588501\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator2, model: 140162498503728\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.3683, Accuracy: 2240/2500 (90%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.583415\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.407979\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.299050\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator2\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.2664, Accuracy: 2305/2500 (92%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator2, Accuracy: 0.921999990940094\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling aggregated_model_validation\n", - "\u001b[94mPerforming aggregated model validation for collaborator collaborator3, model: 140162497621488\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.3626, Accuracy: 2226/2500 (89%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling train\n", - "\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.371595\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.234668\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.177007\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling local_model_validation\n", - "\u001b[94mPerforming local model validation for collaborator collaborator3\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94m\n", - "Test set: Avg. loss: 0.2615, Accuracy: 2305/2500 (92%)\n", - "\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator3, Accuracy: 0.921999990940094\u001b[0m\u001b[94m\n", - "\u001b[0mShould transfer from local_model_validation to join\n", - "\n", - "Calling join\n", - "\u001b[94mAverage aggregated model accuracy = 0.8924000114202499\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mAverage training loss = 0.250693485351252\u001b[0m\u001b[94m\n", - "\u001b[0m\u001b[94mAverage local model validation values = 0.926499992609024\u001b[0m\u001b[94m\n", - "\u001b[0m\n", - "Calling end\n", - "\u001b[94mFlow ended\u001b[0m\u001b[94m\n", - "\u001b[0m" - ] - } - ], + "outputs": [], "source": [ "model = Net()\n", "flflow = FederatedFlow(model, get_optimizer(model), rounds=3, checkpoint=False)\n", @@ -605,7 +340,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "env_name", "language": "python", "name": "python3" }, @@ -619,7 +354,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.11.12" } }, "nbformat": 4, diff --git a/openfl/utilities/optimizers/torch/fedprox.py b/openfl/utilities/optimizers/torch/fedprox.py old mode 100644 new mode 100755 index 2475260e32..d86ffec453 --- a/openfl/utilities/optimizers/torch/fedprox.py +++ b/openfl/utilities/optimizers/torch/fedprox.py @@ -20,6 +20,14 @@ class FedProxOptimizer(Optimizer): It introduces a proximal term to the federated averaging algorithm to reduce the impact of devices with outlying updates. + IMPORTANT: This optimizer requires a reference to the original (global) model parameters + to calculate the proximal term. These must be set explicitly using the set_old_weights() + method before training begins. The old weights (w_old) must match the order and structure + of the model's parameters. Typically, w_old should be set to the initial global model + parameters received from the aggregator at the beginning of each round. + + If mu > 0 and w_old is not set, the optimizer will raise a ValueError. + Paper: https://arxiv.org/pdf/1812.06127.pdf Attributes: @@ -67,7 +75,14 @@ def __init__( if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if mu < 0.0: - raise ValueError(f"Invalid mu value: {mu}") + import warnings + + warnings.warn( + f"Negative mu value ({mu}) will cause the proximal term to reward " + f"deviations from global weights, which may be counterintuitive.", + UserWarning, + stacklevel=2, + ) defaults = { "dampening": dampening, "lr": lr, @@ -75,6 +90,7 @@ def __init__( "mu": mu, "nesterov": nesterov, "weight_decay": weight_decay, + "w_old": None, # Initialize w_old as None } if nesterov and (momentum <= 0 or dampening != 0): @@ -93,6 +109,47 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault("nesterov", False) + def _validate_old_weights(self, mu, w_old): + """Validate old weights for FedProx regularization. + + Args: + mu: Proximal term coefficient + w_old: Old weights reference + + Raises: + ValueError: If mu > 0 and w_old is None + """ + if mu > 0 and w_old is None: + raise ValueError( + "FedProx requires old weights to be set when mu > 0. " + "Please call set_old_weights() before optimization step." + ) + + def _apply_momentum(self, p, d_p, momentum, dampening, nesterov): + """Apply momentum to gradient. + + Args: + p: Parameter + d_p: Gradient + momentum: Momentum factor + dampening: Dampening factor + nesterov: Whether to use Nesterov momentum + + Returns: + Modified gradient + """ + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.clone(d_p).detach() + else: + buf = param_state["momentum_buffer"] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + return d_p + @torch.no_grad() def step(self, closure=None): """Perform a single optimization step. @@ -115,25 +172,33 @@ def step(self, closure=None): nesterov = group["nesterov"] mu = group["mu"] w_old = group["w_old"] - for p, w_old_p in zip(group["params"], w_old): + + # Validate old weights for FedProx + self._validate_old_weights(mu, w_old) + + # Apply proximal term when mu != 0 + apply_proximal = w_old is not None + + for i, p in enumerate(group["params"]): if p.grad is None: continue + d_p = p.grad + + # Apply weight decay if weight_decay != 0: d_p = d_p.add(p, alpha=weight_decay) + + # Apply momentum if momentum != 0: - param_state = self.state[p] - if "momentum_buffer" not in param_state: - buf = param_state["momentum_buffer"] = torch.clone(d_p).detach() - else: - buf = param_state["momentum_buffer"] - buf.mul_(momentum).add_(d_p, alpha=1 - dampening) - if nesterov: - d_p = d_p.add(buf, alpha=momentum) - else: - d_p = buf - if w_old is not None: + d_p = self._apply_momentum(p, d_p, momentum, dampening, nesterov) + + # Apply proximal term + if apply_proximal: + w_old_p = w_old[i] d_p.add_(p - w_old_p, alpha=mu) + + # Apply gradient step p.add_(d_p, alpha=-group["lr"]) return loss @@ -141,8 +206,19 @@ def step(self, closure=None): def set_old_weights(self, old_weights): """Set the global weights parameter to `old_weights` value. + This method must be called before training begins to set the reference point for + calculating the proximal term in FedProx. Typically, this should be set to the + initial global model parameters received from the aggregator at the beginning + of each federated learning round. + + If mu > 0 and this method is not called, the optimizer will raise a ValueError + during the optimization step. + Args: - old_weights: The old weights to be set. + old_weights: List of parameter tensors representing the global model weights. + Must match the order and structure of the model's parameters + being optimized (typically obtained by calling + [p.clone().detach() for p in model.parameters()]). """ for param_group in self.param_groups: param_group["w_old"] = old_weights @@ -153,12 +229,19 @@ class FedProxAdam(Optimizer): Implements the FedProx optimization algorithm with Adam optimizer. + IMPORTANT: This optimizer requires a reference to the original (global) model parameters + to calculate the proximal term. These must be set explicitly using the set_old_weights() + method before training begins. The old weights (w_old) must match the order and structure + of the model's parameters. Typically, w_old should be set to the initial global model + parameters received from the aggregator at the beginning of each round. + + If mu > 0 and w_old is not set, the optimizer will raise a ValueError. + Attributes: params: Parameters to be stored for optimization. mu: Proximal term coefficient. lr: Learning rate. - betas: Coefficients used for computing running averages of gradient - and its square. + betas: Coefficients used for computing running averages of gradient and its square. eps: Value for computational stability. weight_decay: Weight decay (L2 penalty). amsgrad: Whether to use the AMSGrad variant of this algorithm. @@ -204,7 +287,14 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if mu < 0.0: - raise ValueError(f"Invalid mu value: {mu}") + import warnings + + warnings.warn( + f"Negative mu value ({mu}) will cause the proximal term to reward " + f"deviations from global weights, which may be counterintuitive.", + UserWarning, + stacklevel=2, + ) defaults = { "lr": lr, "betas": betas, @@ -212,6 +302,7 @@ def __init__( "weight_decay": weight_decay, "amsgrad": amsgrad, "mu": mu, + "w_old": None, # Initialize w_old as None } super().__init__(params, defaults) @@ -224,8 +315,19 @@ def __setstate__(self, state): def set_old_weights(self, old_weights): """Set the global weights parameter to `old_weights` value. + This method must be called before training begins to set the reference point for + calculating the proximal term in FedProx. Typically, this should be set to the + initial global model parameters received from the aggregator at the beginning + of each federated learning round. + + If mu > 0 and this method is not called, the optimizer will raise a ValueError + during the optimization step. + Args: - old_weights: The old weights to be set. + old_weights: List of parameter tensors representing the global model weights. + Must match the order and structure of the model's parameters + being optimized (typically obtained by calling + [p.clone().detach() for p in model.parameters()]). """ for param_group in self.param_groups: param_group["w_old"] = old_weights @@ -310,6 +412,88 @@ def step(self, closure=None): ) return loss + def _validate_old_weights(self, mu, w_old): + """Validate old weights for FedProx regularization. + + Args: + mu: Proximal term coefficient + w_old: Old weights reference + + Raises: + ValueError: If mu > 0 and w_old is None + """ + if mu > 0 and w_old is None: + raise ValueError( + "FedProx requires old weights to be set when mu > 0. " + "Please call set_old_weights() before optimization step.", + ) + + def _apply_proximal_term(self, grad, param, w_old_p, mu): + """Apply proximal term to gradient. + + Args: + grad: Gradient + param: Parameter + w_old_p: Old weight parameter + mu: Proximal term coefficient + + Returns: + Modified gradient + """ + return grad.add(param - w_old_p, alpha=mu) + + def _compute_adam_step( + self, + param, + grad, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + step, + amsgrad, + beta1, + beta2, + lr, + weight_decay, + eps, + ): + """Compute Adam optimization step. + + Args: + param: Parameter + grad: Gradient + exp_avg: Exponential moving average + exp_avg_sq: Exponential moving average squared + max_exp_avg_sq: Maximum exponential moving average squared + step: Step count + amsgrad: Whether to use AMSGrad + beta1: Beta1 coefficient + beta2: Beta2 coefficient + lr: Learning rate + weight_decay: Weight decay + eps: Epsilon value + """ + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + param.addcdiv_(exp_avg, denom, value=-step_size) + def adam( self, params, @@ -348,31 +532,32 @@ def adam( mu (float): Proximal term coefficient. w_old: The old weights. """ + # Validate old weights for FedProx + self._validate_old_weights(mu, w_old) + + # Apply proximal term when mu != 0 + apply_proximal = w_old is not None + for i, param in enumerate(params): - w_old_p = w_old[i] grad = grads[i] - grad.add_(param - w_old_p, alpha=mu) - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step = state_steps[i] - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) + + # Apply proximal term if needed + if apply_proximal: + w_old_p = w_old[i] + grad = self._apply_proximal_term(grad, param, w_old_p, mu) + + # Apply Adam optimization steps + self._compute_adam_step( + param, + grad, + exp_avgs[i], + exp_avg_sqs[i], + max_exp_avg_sqs[i] if amsgrad else None, + state_steps[i], + amsgrad, + beta1, + beta2, + lr, + weight_decay, + eps, + )