Skip to content

Weight decay API maybe unintuitive #241

@albertz

Description

@albertz

I'm not sure if we really recommend it like that anywhere, but I think it's natural to write code like this:

for p in net.parameters():
  p.weight_decay = 0.0001

I noticed, this has several problems:

  • What about auxiliary parameters? You probably don't want weight decay on them. Same as any integer of boolean parameters.
    • I think actually it would be ignored by RETURNN, so maybe it's not a problem? Or we could also just ignore it silently on returnn-common side to allow for such code?
  • Some variables maybe should not be decayed:
    • In LayerNorm, WeightNorm etc, the scale parameter, which is initialized at 1. Any decay should move it towards 1 and not towards 0. (Right?) In Lingvo, you actually find (here) that weight norm is reparameterized as (1 + g) instead of just g, to avoid this problem.
      • We could rewrite any such code to also use such reparameterization. Which is maybe a good thing but maybe not?
      • We could add some additional information, like decay_center or so, and the constraint would not be $w^2$ but $(w-c)^2$ instead, such that any configured weight decay would move it towards the configured center. This would need some extra implementation also on RETURNN side.
      • We could also add some flag Parameter.ignore_weight_decay on returnn-common side, and if that is enabled (via the module such as LayerNorm), it ignores any writes to weight_decay.
    • I'm not sure if a decay on biases is good or not.

Many of the arguments are to actually allow for the simple code above. Or maybe we don't want to allow such simple code? But how exactly would the canonical example of weight decay applied on some generic network look like then?

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions