Skip to content

Conversation

@nisha987
Copy link
Collaborator

Summary

Models trained with different mu values in the FedProx optimizer appear to have identical weights, despite the expectation that different mu values should produce different models. This suggests that the proximal term in the FedProx algorithm is not being applied correctly.

Type of Change (Mandatory)

Specify the type of change being made.

Description (Mandatory)

1. Incorrect Timing of the set_old_weights Call

In the current notebook implementation, set_old_weights is called after gradient computation but right before the optimizer step. This is problematic because:

  • It sets the reference weights (w_old) to be nearly identical to the current weights
  • When the optimizer step is executed, the proximal term mu * (param - w_old_param) becomes effectively zero
  • This nullifies the effect of different mu values, resulting in identical training outcomes

2. Missing Initial w_old Value in Optimizer

The optimizers (FedProxOptimizer and FedProxAdam) don't initialize the w_old parameter in their constructors. If step() is called before set_old_weights, there may be an error accessing the uninitialized w_old
.

Testing

  • Tested locally.

Screenshot 2025-05-19 113316

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a bug in the FedProx MNIST tutorial for PyTorch by correcting the timing of the old weights update in the FedProx optimizers and by ensuring proper initialization of the reference weights. Key changes include:

  • Initializing "w_old" in both FedProxOptimizer and FedProxAdam constructors.
  • Guarding the application of the proximal term in the optimizer step using a new "apply_proximal" flag.
  • Updating the tutorial notebook to set the old weights once at the beginning of training rather than repeatedly per batch.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
openfl/utilities/optimizers/torch/fedprox.py Added initialization for "w_old" and applied conditional proximal updates in both step and adam functions.
openfl-tutorials/experimental/workflow/403_Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb Revised execution counts and repositioned the set_old_weights call to occur only once at training start.

Copy link
Collaborator

@kminhta kminhta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @nisha987 ! Thanks for taking up that long standing issue

"dampening": dampening,
"lr": lr,
"momentum": momentum,
"mu": mu,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General: we should probably rename FedProxOptimizer to FedProxSGD to maintain naming convention, but maybe we can save that for a future PR

@nisha987 nisha987 force-pushed the nshekhaw/fix_mu_pytorch_fedprox branch 2 times, most recently from df1e827 to b5a2cf7 Compare May 20, 2025 06:09
@kminhta
Copy link
Collaborator

kminhta commented May 20, 2025

Thanks @nisha987 This is looking good. I had one more minor comment regarding mu
Can you also rebase and signoff on your commits? https://github.com/securefederatedai/openfl/pull/1634/checks?check_run_id=42564995654

@nisha987 nisha987 force-pushed the nshekhaw/fix_mu_pytorch_fedprox branch from d8b1df1 to 2b95cf3 Compare May 21, 2025 04:47
nisha987 and others added 5 commits May 20, 2025 21:49
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
Signed-off-by: Shekhawat, Nisha <[email protected]>
@nisha987 nisha987 force-pushed the nshekhaw/fix_mu_pytorch_fedprox branch from 2b95cf3 to d432356 Compare May 21, 2025 04:49
@nisha987
Copy link
Collaborator Author

Thanks @nisha987 This is looking good. I had one more minor comment regarding mu Can you also rebase and signoff on your commits? https://github.com/securefederatedai/openfl/pull/1634/checks?check_run_id=42564995654

added signoff to all commits.

raise ValueError(f"Invalid learning rate: {lr}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if mu < 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One liner assert mu >= 0.0, f"FedProx regularizer coefficient must be greater than or equal to 0, got {mu}" is sufficient.
I don't think there are any scenarios where negative mu is used. It is OK to raise exception here instead of a warning.

Comment on lines +23 to +30
IMPORTANT: This optimizer requires a reference to the original (global) model parameters
to calculate the proximal term. These must be set explicitly using the set_old_weights()
method before training begins. The old weights (w_old) must match the order and structure
of the model's parameters. Typically, w_old should be set to the initial global model
parameters received from the aggregator at the beginning of each round.
If mu > 0 and w_old is not set, the optimizer will raise a ValueError.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to record the first value of weights supplied to the optimizer as w_old at the end of iteration? User need not set this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good suggestion. I hadn't considered this.

@nisha987 - the optimizer is supplied with the model parameters when it is first initialized. These parameters are updated during optimizer.step() . This is the global model and it should be possible to also record these weights (before updating) as w_old internally rather than having the user explicitly do it, as @MasterSkepticista suggests

Then we won't risk running into the issue that we saw in the fedprox mnist tutorial

Comment on lines +122 to +126
if mu > 0 and w_old is None:
raise ValueError(
"FedProx requires old weights to be set when mu > 0. "
"Please call set_old_weights() before optimization step."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is assumed to be verified during init, no?

self._validate_old_weights(mu, w_old)

# Apply proximal term when mu != 0
apply_proximal = w_old is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, consider a simpler approach:

apply_proximal should always be true. We guarantee that mu will be >=0.0 during initialization. Which means negative values never appear. As for 0.0 case, it means "no contribution" of the regularizer term. It becomes implied that it makes no effect on the weights.

Comment on lines 189 to 190
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar case: weight_decay can be guaranteed to be >=0.0 during initialization and this check can be avoided

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FedProx: getting same model depsite different mu values

3 participants