-
Notifications
You must be signed in to change notification settings - Fork 234
Fix fedprox MNIST tutorial for pytorch #1634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Fix fedprox MNIST tutorial for pytorch #1634
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a bug in the FedProx MNIST tutorial for PyTorch by correcting the timing of the old weights update in the FedProx optimizers and by ensuring proper initialization of the reference weights. Key changes include:
- Initializing "w_old" in both FedProxOptimizer and FedProxAdam constructors.
- Guarding the application of the proximal term in the optimizer step using a new "apply_proximal" flag.
- Updating the tutorial notebook to set the old weights once at the beginning of training rather than repeatedly per batch.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| openfl/utilities/optimizers/torch/fedprox.py | Added initialization for "w_old" and applied conditional proximal updates in both step and adam functions. |
| openfl-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb | Revised execution counts and repositioned the set_old_weights call to occur only once at training start. |
...-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb
Show resolved
Hide resolved
kminhta
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch @nisha987 ! Thanks for taking up that long standing issue
| "dampening": dampening, | ||
| "lr": lr, | ||
| "momentum": momentum, | ||
| "mu": mu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General: we should probably rename FedProxOptimizer to FedProxSGD to maintain naming convention, but maybe we can save that for a future PR
df1e827 to
b5a2cf7
Compare
|
Thanks @nisha987 This is looking good. I had one more minor comment regarding |
d8b1df1 to
2b95cf3
Compare
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
2b95cf3 to
d432356
Compare
added signoff to all commits. |
| raise ValueError(f"Invalid learning rate: {lr}") | ||
| if weight_decay < 0.0: | ||
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
| if mu < 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One liner assert mu >= 0.0, f"FedProx regularizer coefficient must be greater than or equal to 0, got {mu}" is sufficient.
I don't think there are any scenarios where negative mu is used. It is OK to raise exception here instead of a warning.
| IMPORTANT: This optimizer requires a reference to the original (global) model parameters | ||
| to calculate the proximal term. These must be set explicitly using the set_old_weights() | ||
| method before training begins. The old weights (w_old) must match the order and structure | ||
| of the model's parameters. Typically, w_old should be set to the initial global model | ||
| parameters received from the aggregator at the beginning of each round. | ||
| If mu > 0 and w_old is not set, the optimizer will raise a ValueError. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to record the first value of weights supplied to the optimizer as w_old at the end of iteration? User need not set this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good suggestion. I hadn't considered this.
@nisha987 - the optimizer is supplied with the model parameters when it is first initialized. These parameters are updated during optimizer.step() . This is the global model and it should be possible to also record these weights (before updating) as w_old internally rather than having the user explicitly do it, as @MasterSkepticista suggests
Then we won't risk running into the issue that we saw in the fedprox mnist tutorial
| if mu > 0 and w_old is None: | ||
| raise ValueError( | ||
| "FedProx requires old weights to be set when mu > 0. " | ||
| "Please call set_old_weights() before optimization step." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is assumed to be verified during init, no?
| self._validate_old_weights(mu, w_old) | ||
|
|
||
| # Apply proximal term when mu != 0 | ||
| apply_proximal = w_old is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, consider a simpler approach:
apply_proximal should always be true. We guarantee that mu will be >=0.0 during initialization. Which means negative values never appear. As for 0.0 case, it means "no contribution" of the regularizer term. It becomes implied that it makes no effect on the weights.
| if weight_decay != 0: | ||
| d_p = d_p.add(p, alpha=weight_decay) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar case: weight_decay can be guaranteed to be >=0.0 during initialization and this check can be avoided
Summary
Models trained with different mu values in the FedProx optimizer appear to have identical weights, despite the expectation that different mu values should produce different models. This suggests that the proximal term in the FedProx algorithm is not being applied correctly.
Type of Change (Mandatory)
Specify the type of change being made.
Description (Mandatory)
1. Incorrect Timing of the
set_old_weightsCallIn the current notebook implementation,
set_old_weightsis called after gradient computation but right before the optimizer step. This is problematic because:w_old) to be nearly identical to the current weightsmu * (param - w_old_param)becomes effectively zero2. Missing Initial
w_oldValue in OptimizerThe optimizers (
FedProxOptimizerandFedProxAdam) don't initialize thew_oldparameter in their constructors. Ifstep()is called beforeset_old_weights, there may be an error accessing the uninitializedw_old.
Testing