Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -177,7 +177,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -210,22 +210,9 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Aggregator step \"start\" registered\n",
"Collaborator step \"aggregated_model_validation\" registered\n",
"Collaborator step \"train\" registered\n",
"Collaborator step \"local_model_validation\" registered\n",
"Aggregator step \"join\" registered\n",
"Aggregator step \"end\" registered\n"
]
}
],
"outputs": [],
"source": [
"class FederatedFlow(FLSpec):\n",
" def __init__(self, model=None, optimizer=None, rounds=10, **kwargs):\n",
Expand Down Expand Up @@ -273,13 +260,18 @@
"\n",
" self.model.train()\n",
" self.optimizer = get_optimizer(self.model)\n",
" \n",
" # Set old weights ONCE at the beginning of training\n",
" # This sets the reference weights to the global model weights \n",
" # received from the aggregator, implementing FedProx correctly\n",
" self.optimizer.set_old_weights([p.clone().detach() for p in self.model.parameters()])\n",
" \n",
" for batch_idx, (data, target) in enumerate(self.train_loader):\n",
" self.optimizer.zero_grad()\n",
" output = self.model(data)\n",
" loss = F.cross_entropy(output, target)\n",
" loss.backward()\n",
" \n",
" self.optimizer.set_old_weights([p.clone().detach() for p in self.model.parameters()])\n",
" self.optimizer.step()\n",
"\n",
" if (len(data) * batch_idx) / len(self.train_loader.dataset) >= log_threshold:\n",
Expand Down Expand Up @@ -335,266 +327,9 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Calling start\n",
"\u001b[94mPerforming initialization for model\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator0, model: 140162497619616\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 4.6833, Accuracy: 171/2500 (7%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.889274\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.279191\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.994200\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator0\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.7548, Accuracy: 1929/2500 (77%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.7716000080108643\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator1, model: 140158910463952\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 4.7259, Accuracy: 173/2500 (7%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.675623\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.068585\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.687561\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator1\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.6366, Accuracy: 2004/2500 (80%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator1, Accuracy: 0.8015999794006348\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator2, model: 140162497661872\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 4.6549, Accuracy: 215/2500 (9%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.879489\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.325507\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.968176\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator2\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.7462, Accuracy: 1901/2500 (76%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator2, Accuracy: 0.7603999972343445\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator3, model: 140162498346528\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 4.7129, Accuracy: 193/2500 (8%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.720635\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 1.061211\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.762026\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator3\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.6378, Accuracy: 1992/2500 (80%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator3, Accuracy: 0.7968000173568726\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling join\n",
"\u001b[94mAverage aggregated model accuracy = 0.07520000264048576\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mAverage training loss = 0.8529909627063148\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mAverage local model validation values = 0.782600000500679\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator0, model: 140158910552480\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.6740, Accuracy: 1996/2500 (80%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.974921\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.633429\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.591566\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator0\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.3951, Accuracy: 2214/2500 (89%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.8855999708175659\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator1, model: 140162497608672\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.6877, Accuracy: 1981/2500 (79%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.824028\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.515538\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.410188\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator1\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.4668, Accuracy: 2166/2500 (87%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator1, Accuracy: 0.8664000034332275\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator2, model: 140162498107328\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.6919, Accuracy: 1981/2500 (79%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 1.025180\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.616896\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.483282\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator2\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.4406, Accuracy: 2163/2500 (87%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator2, Accuracy: 0.8651999831199646\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator3, model: 140162498345664\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.6698, Accuracy: 2000/2500 (80%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.725868\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.450241\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.388554\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator3\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.4106, Accuracy: 2211/2500 (88%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator3, Accuracy: 0.8844000101089478\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling join\n",
"\u001b[94mAverage aggregated model accuracy = 0.7958000004291534\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mAverage training loss = 0.4683974838455107\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mAverage local model validation values = 0.8753999918699265\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator0, model: 140162648091376\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.3590, Accuracy: 2230/2500 (89%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.406638\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.313662\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.326520\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator0\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.2096, Accuracy: 2338/2500 (94%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.9351999759674072\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator1, model: 140162646717344\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.3773, Accuracy: 2228/2500 (89%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.392126\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.228912\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.200197\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator1\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.2601, Accuracy: 2317/2500 (93%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator1, Accuracy: 0.926800012588501\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator2, model: 140162498503728\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.3683, Accuracy: 2240/2500 (90%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.583415\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.407979\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.299050\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator2\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.2664, Accuracy: 2305/2500 (92%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator2, Accuracy: 0.921999990940094\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling aggregated_model_validation\n",
"\u001b[94mPerforming aggregated model validation for collaborator collaborator3, model: 140162497621488\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.3626, Accuracy: 2226/2500 (89%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling train\n",
"\u001b[94mTrain Epoch: [4096/15000 (27%)]\tLoss: 0.371595\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [8192/15000 (53%)]\tLoss: 0.234668\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mTrain Epoch: [11264/15000 (73%)]\tLoss: 0.177007\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling local_model_validation\n",
"\u001b[94mPerforming local model validation for collaborator collaborator3\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94m\n",
"Test set: Avg. loss: 0.2615, Accuracy: 2305/2500 (92%)\n",
"\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mDone with local model validation for collaborator collaborator3, Accuracy: 0.921999990940094\u001b[0m\u001b[94m\n",
"\u001b[0mShould transfer from local_model_validation to join\n",
"\n",
"Calling join\n",
"\u001b[94mAverage aggregated model accuracy = 0.8924000114202499\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mAverage training loss = 0.250693485351252\u001b[0m\u001b[94m\n",
"\u001b[0m\u001b[94mAverage local model validation values = 0.926499992609024\u001b[0m\u001b[94m\n",
"\u001b[0m\n",
"Calling end\n",
"\u001b[94mFlow ended\u001b[0m\u001b[94m\n",
"\u001b[0m"
]
}
],
"outputs": [],
"source": [
"model = Net()\n",
"flflow = FederatedFlow(model, get_optimizer(model), rounds=3, checkpoint=False)\n",
Expand All @@ -605,7 +340,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "env_name",
"language": "python",
"name": "python3"
},
Expand All @@ -619,7 +354,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.11.12"
}
},
"nbformat": 4,
Expand Down
Loading
Loading