Skip to content

[Module] Refactor init_weights to config-based param_init system#2633

Open
fegin wants to merge 17 commits intogh/fegin/103/basefrom
gh/fegin/103/head
Open

[Module] Refactor init_weights to config-based param_init system#2633
fegin wants to merge 17 commits intogh/fegin/103/basefrom
gh/fegin/103/head

Conversation

@fegin
Copy link
Copy Markdown
Contributor

@fegin fegin commented Mar 19, 2026

Stack from ghstack (oldest at bottom):

Motiviation and Design decision

Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., init_trunc_normal, init_zeros) to customize how parameters are initialized. Parameters are identified by their FQNs, and init_by_regex lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does init_states recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via init_self_buffers, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

Summary

  • Remove init_weights.
  • init_states auto-recurses the module tree, then calls init_self_parameters and init_self_buffers on each module.
  • param_init (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own param_init delegate up the parent chain, so init logic is defined once at the model config level.
  • make_decoder_param_init provides shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
  • Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

Verification

with loss_compare.py and change the code to make the initialization order exact the same as the main branch, we can get llama3 and qwen3 loss parity

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 100   │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 100   │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

However, the current code doesn't include the fixed order because it will make the code ugly. It is expected to have a different loss if the initialization order changes.

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 19, 2026
fegin added a commit that referenced this pull request Mar 19, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.


ghstack-source-id: 024899c
Pull-Request: #2633
@fegin fegin marked this pull request as draft March 19, 2026 18:04
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 19, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

ghstack-source-id: 11120b6
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 19, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

ghstack-source-id: 6871155
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 19, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

ghstack-source-id: 803f279
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 19, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: e6704e8
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 20, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 440e49d
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 20, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 1bcce2b
Pull-Request: #2633
@fegin fegin changed the title [WIP][Module] Refactor init_weights to config-based param_init system [Module] Refactor init_weights to config-based param_init system Mar 20, 2026
@fegin fegin marked this pull request as ready for review March 20, 2026 02:33
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 20, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 884689c
Pull-Request: #2633
except AttributeError:
# field(init=False) not yet set, ignore this field.
continue
if callable(val) and not dataclasses.is_dataclass(val):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, can a dataclass ever be callable?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, https://fburl.com/code/8t7xhgna is making Config.param_init a config, not a function, so that it should still be serializable. I think we should do this too, to make it reproducible (given a code version).

Right now parallelize_fn and other functions in ModelSpec are not serializable, which should be fixed later.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config is a dictionary that maps from reg to function. So from serializability perspective, the requirement is the same?

buffer_device = buffer_device or self.freqs_cis.device
if self.rope is not None:
self.rope.init_weights(buffer_device=buffer_device)
# RoPE's _init_self_buffers was already called by auto-recursion
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this, but it's possible that it is set on meta device, when buffer_device here is cuda? Shall we add an assertion for device?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure if I understand the question. If meta device is used, then everything is meta device. buffer_device is only used when trainer parallelizes the model and tries to re-initialize the states again. The comment here just want to explain the order, not about the device though.

Assertion is good, but I want to understand if I miss something.

# Type alias for simple parameter initializers: (param) -> Any
# Uses Any return type because nn.init.* functions return Tensor,
# but the return value is always ignored by the dispatch layer.
ParamInitializer = Callable[[nn.Parameter], Any]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's only used in common/param_init.py, can we put it there for now?

Copy link
Copy Markdown
Contributor Author

@fegin fegin Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will be inconsistent. We will have ParamInitializer in common/param_init.py but NamedParamInitializer being here. Ideally, we should put both in common/param_init.py. But I cannot do this due to the recursive import.

# and the prefix resets (FQNs are relative to the owner).
if self.param_init is not None:
param_init = self.param_init
param_prefix = ""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this.
E.g. let's say I created a new module to plug in each layer, for which I want to use depth init.

  • If I provide a config.param_init to this module, I would lose the layer id because of this line
  • Of course, I can change the overall decoder init. But if that's the case, what's the point of supporting param_init on each module? -- we could always call get_default_decoder_init and modify the dictionary.

Comment on lines +71 to +72
if self.param_init is not None:
instance.param_init = self.param_init
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do this in constructor, curious why we need to do it here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we don't have a common Module.__init__, so users who define a new Module needs to remember to do the assignment. Or an alternative is to define a common Module.init. I thought about the later case, but that requires some reworks of module inheritance hierarchy and users now have to call super.__init__(cfg, *args, *kwargs). The later constraint is okay guess but the former is something need sometime to verify. If we do want to change this, we can later do another round of change to the Module hierarchy.

def init_states(
self,
*,
param_init: NamedParamInitializer | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems different from https://fburl.com/code/ia9w7w5i

  • over there, if the param is missing then rely on the parent
  • here, if sending the parent param_init to all children. I think this is leaking information from global to local modules, so should be avoided in general.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure though. Parent modules already hold the param_init, is that considered as leaking from global to a local? I didn't do the parent walk because unlike the internal code, Node, that has parent, we don't have this implementation in our Module. So I deliberately didn't include that implementation. Adding parent will also cause circular dependency and make GC slower. But the later concern should not be a real concern as we never GC modules. Just provide this minor issue.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 23, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: ccd5db5
Pull-Request: #2633
@fegin
Copy link
Copy Markdown
Contributor Author

fegin commented Mar 23, 2026

I change the parameter init to ancestor-lookup style. But the comment reflects an issue brought by the circular reference. So the original concern of parent reference is legit. This needs some careful handle. But the code looks cleaner for each inherited init_state.

The typing of param_init is still not changed. As long as there is an unnamed function, which the internal code also does, serialization is not sound. DepthScaledTruncNormal and RegexInitializer are serializable if we correctly implement to_dict and from_dict. But I think this rule stays even if we change param_init to config. Correct me if I'm wrong.

@tianyu-l

[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 23, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 4427985
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 23, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: ccfa49b
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 24, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 8ec94ee
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 24, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: d3cf966
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 25, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 8006560
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 25, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: 8e8a34a
Pull-Request: #2633
@tianyu-l
Copy link
Copy Markdown
Contributor

@claude review this pull request. In particular, please comment on where is best to expand general transformer layer config into per-layer configs. I think in general there are three options:

  1. There is no per-layer config. The TransformerBlock constructor has the logic to create submodules. This is the current behavior before this PR.
  2. The model config class has a method to expand into per-layer config. This is the behavior in the current commit of the PR.
  3. Alternatively, the model constructor only accepts already-expanded per-layer config. The (optional) logic of "creating the per-layer config based on a general config using some per-model logic" can be put outside the model code, e.g. in config_registry.py.

@claude
Copy link
Copy Markdown

claude bot commented Mar 26, 2026

Claude finished @tianyu-l's task in 2m 50s —— View job


PR Review: Refactor init_weights to config-based param_init system

  • Read core protocol changes (protocols/module.py, protocols/model.py)
  • Read param_init infrastructure (models/common/param_init.py, config/configurable.py)
  • Read decoder changes (models/common/decoder.py)
  • Read model-specific changes (llama3, llama4, qwen3, deepseek_v3, flux, gpt_oss)
  • Read model files (model.py for each)
  • Read test changes
  • Read experiment changes
  • Provide overall review and analysis of per-layer config expansion approaches

Overall Assessment

This is a substantial and well-motivated refactor that centralizes parameter initialization into a declarative, config-based system. The design cleanly separates "what to initialize" (config-level param_init dicts) from "how to recurse" (the init_states mechanism). The elimination of init_weights from individual modules and the removal of init_mean/init_std from Linear.Config/Embedding.Config are good simplifications.

The numeric verification (bitwise-identical loss for Llama3 and Qwen3 over 100 steps) is reassuring, though only 2 of the 6+ models were verified. The Flux and GPT-OSS models have quite different init patterns, so numeric parity for those would strengthen confidence.


Analysis: Where to Expand Per-Layer Config

This is the core design question. Here is my analysis of the three options:

Option 1: No per-layer config; TransformerBlock.__init__ has the construction logic.
This is the status quo on main. Layer-level decisions (MoE interleaving, iRoPE use_rope=False, depth-scaled init) are computed inside TransformerBlock.__init__ at build time. The upside is simplicity — one template config, no expansion pass. The downside is that the config object doesn't fully describe what was built: you can't serialize or inspect the per-layer config after construction, and param_init must embed PerLayer / depth-dependent logic that gets resolved lazily. This also means the constructor contains structural branching logic (e.g., if moe_enabled: building either FFN or MoE) mixed with parameter concerns.

Option 2 (current PR): The model config class has an expand() method that produces per-layer configs.
Decoder.Config.expand() deep-copies the template layer config N times, calls _expand_layer() for structural changes (MoE interleaving, iRoPE), and resolves PerLayer markers. This happens between update_from_config() and build(), so by the time the model constructor runs, each layer has a fully-resolved, self-contained config. The model constructor simply iterates config.layers and calls layer_config.build(...).

Advantages:

  • Inspectability: After expand(), the full per-layer config tree is materialized. You could serialize it, compare it, or log it.
  • Clean constructors: TransformerBlock.__init__ receives a config that already has the correct feed_forward=None or moe=None — no structural branching needed in the constructor itself (though the current code still has both paths in __init__ for backwards compat, e.g. Llama4TransformerBlock.__init__ still checks self.moe_enabled).
  • Composability: _expand_layer() is a small, overridable hook. Each model only overrides what's different (Llama3 is identity, Llama4 handles iRoPE + MoE interleaving, DeepSeek handles dense vs MoE layers).
  • param_init resolution: PerLayer markers are resolved to concrete dicts during expand(), so by build time every sub-config has a plain dict[str, Callable]. No lazy resolution needed at init time.

Disadvantages:

  • Deep-copy cost: N deep copies of a potentially large config tree. For a 126-layer 405B model, this is 126 copies. The cost is trivial compared to model construction, but the memory footprint of the config list is non-zero.
  • Config list in config: Decoder.Config.layers: list | None = None is an unusual pattern — it's a field(init=False)-like slot populated by expand(). The to_dict() method explicitly skips lists to avoid serializing this.
  • Dual-path construction: Decoder.__init__ has both if config.layers is not None and else paths (decoder.py:128-137), which is a code smell. If expand() is always called, the fallback path should probably be removed or guarded with an assertion.

Option 3: The model constructor only accepts already-expanded per-layer configs; expansion logic lives in config_registry.py or a standalone function.
This pushes the expansion entirely outside the model code. The model config would always have a layers: list[TransformerBlock.Config] (required), and the config_registry.py would call an expansion utility to produce it.

Advantages:

  • Strict separation: Model code never knows about expansion. It receives a list and iterates it.
  • Flexibility: Different expansion strategies (or no expansion at all, e.g., manually specifying each layer) can be plugged in without touching model code.
  • No expand() on config: The config class is purely declarative data with no behavior.

Disadvantages:

  • Loss of encapsulation: Model-specific expansion logic (e.g., Llama4's iRoPE + MoE interleaving) would live in config_registry.py rather than in the model's config class. This breaks the principle that model-specific knowledge should be co-located with the model. The config_registry.py files would grow substantially.
  • Template convenience: You lose the "specify one template, get N layers" ergonomic. Every config registration would need an explicit expansion call.
  • Boilerplate: Each model's config_registry.py would need to import and call an expansion function, adding repetitive setup code.

My recommendation: Option 2 is the best fit for this codebase, with a minor refinement. The expand() method on the config class is the right place for per-model structural expansion because:

  1. It co-locates model-specific knowledge (which layers are MoE, which use RoPE) with the model definition, respecting the CLAUDE.md guideline "Model-specific code → the specific model folder."
  2. It gives you a clean materialized config tree before build(), enabling inspection and serialization.
  3. The _expand_layer() hook pattern scales well — Llama3 (identity), Llama4 (iRoPE + MoE interleaving), DeepSeek (dense-to-MoE transition), Qwen3 (MoE flag) all override it cleanly.

The refinement I'd suggest: consider making expand() mandatory (remove the fallback path in Decoder.__init__ that builds from config.layer directly). This eliminates the dual-path construction and makes the contract explicit — build() always operates on expanded configs.


Specific Code Review Findings

1. Dual-path construction in Decoder.__init__ should be consolidated

torchtitan/models/common/decoder.py:128-137:

if config.layers is not None:
    for i, layer_config in enumerate(config.layers):
        self.layers[str(i)] = layer_config.build(...)
else:
    for layer_id in range(config.n_layers):
        self.layers[str(layer_id)] = config.layer.build(...)

If expand() is always called before build(), the else branch is dead code. If it's intentionally kept for backwards compatibility (e.g., experiments that don't call expand()), this should have a comment or emit a warning.

2. Llama4TransformerBlock.__init__ still has structural branching that duplicates _expand_layer

torchtitan/models/llama4/model.py:80-107: The constructor still computes self.moe_enabled and does if moe_enabled: branching, even though _expand_layer() already resolves feed_forward=None or moe=None. After expansion, exactly one of config.feed_forward / config.moe is None, so the constructor could simply check which is not None. The same applies to Qwen3TransformerBlock.__init__ and DeepSeekV3TransformerBlock.__init__.

This is a consequence of still supporting the un-expanded path. If expand() becomes mandatory, these constructors simplify significantly.

3. _expand_layer in Llama4 duplicates runtime logic

torchtitan/models/llama4/model.py:144-161: The _expand_layer method recomputes moe_enabled and iRoPE decisions that are also computed in __init__. If expansion becomes the canonical path, the __init__ logic should defer to the config rather than recomputing.

4. PerLayer + lambdas prevent config serialization

The PerLayer(lambda layer_id: {...}) pattern used throughout (e.g., llama3/__init__.py:43-48) captures closures that aren't serializable. The to_dict() method in configurable.py:85-97 works around this by skipping callables, PerLayer, and lists. This means the expanded config can't be round-tripped through serialization. tianyu-l's review comment about making param_init a config rather than a function is well-taken — this is a real limitation for reproducibility.

5. to_dict() is overly aggressive about skipping

torchtitan/config/configurable.py:85-106: The method skips all lists (isinstance(val, list)), any dict with callable values, and anything with a .resolve attribute. This is very broad — it would skip legitimate list-typed config fields that happen to exist. A more targeted approach (e.g., only skip layers or fields of known non-serializable types) would be safer.

6. Flux expand() is very different from decoder expand()

torchtitan/models/flux/model/model.py:140-227: Flux's expand() manually walks every sub-config and assigns param_init using dataclasses.replace. This is verbose but works. The contrast with the decoder pattern (which uses PerLayer + resolve_per_layer) is jarring. It would be worth a comment explaining why Flux uses a different pattern (DiT-style init is role-based rather than depth-based, so PerLayer isn't the right abstraction).

7. VLM experiment _set_param_init_recursive is a useful escape hatch

torchtitan/experiments/vlm/model/model.py:55-74: This function walks the tree and sets _param_init on any Module that doesn't have one. It's a pragmatic solution for experiment modules that don't use the config system. Worth noting in the docstring that this is an experiment-only pattern.

8. Weight tying + skip_param_init interaction

torchtitan/models/llama3/model.py:137-149: The init_states override re-ties weights before calling super().init_states(). The tok_embeddings config uses skip_param_init and the output config has a real initializer. This means output.weight gets initialized, and tok_embeddings.weight (which is the same tensor) is skipped. The ordering dependency here is subtle — if auto-recursion visits tok_embeddings after output, the skip is correct; if before, the skip runs first and then output's init overwrites. The current BFS order in init_states (queue.pop() pops from the end, so it's LIFO / DFS on children) means output comes after tok_embeddings in iteration, which is correct. But this ordering is fragile — it depends on module registration order in __init__. Consider adding an assertion or comment about this dependency.

9. Missing expand() override for Llama3

torchtitan/models/llama3/model.py: Llama3Model.Config doesn't override expand() at all — it inherits from Decoder.Config directly. But Decoder.Config.expand() calls self._expand_layer() which is a no-op for Llama3 (identity return). This is clean, but the expand() base implementation resolves PerLayer markers for all models, which is good.

10. buffer_device assertion in Decoder._init_self_buffers

torchtitan/models/common/decoder.py:156-159:

assert buffer_device is None or buffer_device.type != "meta", (...)

This is an internal invariant check (programmer error), so assert is correct here. But buffer_device is None would pass the assertion — is that intentional? If buffer_device=None reaches this point (e.g., init_states called without buffer_device and freqs_cis is on meta device), the assertion passes but the RoPE rebuild below would create buffers on an unintended device.


Summary

This is a well-designed refactor that makes parameter initialization more declarative and configurable. The expand() pattern (Option 2) is the right architectural choice — it co-locates model-specific config expansion with the model definition while providing clean separation between config resolution and module construction.

Key areas for improvement:

  1. Make expand() mandatory and remove the fallback construction path
  2. Eliminate duplicated structural logic between _expand_layer() and TransformerBlock.__init__
  3. Address serialization limitations of PerLayer lambdas (longer term)
  4. Add safety comments around weight-tying + init ordering assumptions

@tianyu-l
Copy link
Copy Markdown
Contributor

@claude Let me clarify.

For option 2, yes we could just remove the logic in TransformerBlock.__init__() and only use expand() in the config.

For option 3, I didn't mean putting the expansion logic in config_registry.py; instead I meant that we can put it in model/__init__.py (e.g. in the model_registry(flavor) function) where we construct model config.

  • This still couples model with expansion logic. Each model has its own expansion logic.
  • The logic is flexible and optional. One could have written expanded logic directly.
  • I would argue that this provides better encapsulation between model core logic, because what each model looks like is completely determined by the expanded per-layer config, and it doesn't matter how this config is provided, via expansion or not. Putting in each model subclass would prescribe how users can provide the final per-layer configs.
  • I disagree with what you said "This breaks the principle that model-specific knowledge should be co-located with the model.", because in the first place although model.py holds model-specific components, it doesn't know how to construct the model combining these model-specific components and common components. The actual model composition logic has largely been existing in model/__init__.py. So I would say option 3 (moving expansion to model/__init__.py) gives better encapsulation, in that model.py now only holds model-specific components that's not defined for other models (thus not in common/), and model/__init__.py holds the composition logic.
  • Also if we go with option 3, we can remove classes such as Llama4TransformerBlock because the base class TransformerBlock can hold general logic.

@claude
Copy link
Copy Markdown

claude bot commented Mar 26, 2026

Claude finished @tianyu-l's task in 1m 47s —— View job


  • Read model __init__.py files to understand current composition logic
  • Read model .py files to understand TransformerBlock subclasses
  • Read Decoder and TransformerBlock to understand current expand() design
  • Provide updated analysis of Option 2 vs clarified Option 3

Revised Analysis: Option 2 vs Clarified Option 3

Thank you for the clarification. The distinction is meaningful — placing expansion logic in model/__init__.py alongside the existing composition logic (where model_registry() and config dicts already live) is quite different from putting it in config_registry.py. Let me revise my analysis.

Understanding the Current Boundaries

Looking at the code as it stands in this PR, the responsibilities are already split:

File Responsibility
model/__init__.py Config construction (component wiring, init patterns, hyperparameters), model_registry()
model/model.py Model-specific components (subclassed TransformerBlock, subclassed Decoder), _expand_layer() (on Config)
common/decoder.py Shared Decoder.__init__, shared expand() loop, shared forward()

The key observation is that the model-specific TransformerBlock subclasses exist primarily to handle the structural branching (FFN vs MoE, iRoPE vs RoPE). Looking at the actual code:

  • Llama3TransformerBlock.__init__ — just builds attention + FFN + norms. No branching. Could be the base TransformerBlock.
  • Llama4TransformerBlock.__init__ (model.py:80-111) — computes moe_enabled and attn_use_rope, branches on both. But _expand_layer (model.py:144-161) already resolves these same decisions onto the config.
  • Qwen3TransformerBlock.__init__ (model.py:37-51) — branches on moe_enabled. But _expand_layer (model.py:87-93) already sets feed_forward=None or moe=None.
  • DeepSeekV3TransformerBlock.__init__ (model.py:179-193) — branches on layer_id >= n_dense_layers. But _expand_layer (model.py:280-286) already resolves this.

This duplication is exactly the problem. Both _expand_layer and __init__ compute the same structural decisions. The question is: which should be the canonical path?

Revised Assessment of Option 3

With the clarification that expansion lives in model/__init__.py, Option 3 becomes:

# In model/__init__.py (e.g., llama4/__init__.py)

def _expand_layers(config: Decoder.Config) -> list[TransformerBlock.Config]:
    """Llama4-specific layer expansion: iRoPE + MoE interleaving."""
    layers = []
    for layer_id in range(config.n_layers):
        layer_cfg = copy.deepcopy(config.layer)
        resolve_per_layer(layer_cfg, layer_id)
        # iRoPE
        if layer_cfg.every_n_layers_nope and layer_id % layer_cfg.every_n_layers_nope == 0:
            layer_cfg = replace(layer_cfg, attention=replace(layer_cfg.attention, use_rope=False))
        # MoE interleaving
        if (layer_id + 1) % layer_cfg.interleave_moe_layer_step == 0:
            layer_cfg = replace(layer_cfg, feed_forward=None)
        else:
            layer_cfg = replace(layer_cfg, moe=None)
        layers.append(layer_cfg)
    return layers

def model_registry(flavor: str) -> ModelSpec:
    config = llama4_configs[flavor]
    config.layers = _expand_layers(config)  # expansion happens here
    return ModelSpec(model=config, ...)

And the TransformerBlock.__init__ becomes:

# In common/decoder.py — ONE implementation, no subclasses needed
class TransformerBlock(Module):
    def __init__(self, config, *, layer_id, dim, n_layers):
        super().__init__()
        self.attention = config.attention.build(dim=dim)
        self.attention_norm = config.attention_norm.build(normalized_shape=dim)
        self.ffn_norm = config.ffn_norm.build(normalized_shape=dim)
        # Config already has exactly one of feed_forward/moe set
        if config.feed_forward is not None:
            self.feed_forward = config.feed_forward.build(dim=dim)
        if config.moe is not None:
            self.moe = config.moe.build(dim=dim)

Where I Think Option 3 (Clarified) Is Stronger

1. Eliminates the duplicated structural logic.
This is the strongest argument. Right now, Llama4TransformerBlock.__init__ and Llama4Model.Config._expand_layer compute the same decisions. With Option 3, the expansion logic in __init__.py is the single source of truth, and the constructor just reads what's in the config. No duplication.

2. Eliminates model-specific TransformerBlock subclasses.
Looking at what each subclass actually contributes:

  • Llama3TransformerBlock — adds nothing over a base class that handles FFN-only.
  • Qwen3TransformerBlock — adds moe_enabled flag and branching. With expanded configs, the base class handles this.
  • DeepSeekV3TransformerBlock — adds n_dense_layers and branching. With expanded configs, the base class handles this.
  • Llama4TransformerBlock — adds iRoPE + MoE interleaving. With expanded configs, the base class handles this.

All four model-specific forward() implementations are structurally identical: x + attn(norm(x)); x + ffn_or_moe(norm(x)). A single base TransformerBlock with if self.moe is not None / else self.feed_forward covers all of them. This aligns with the CLAUDE.md guideline: "Don't create per-model wrappers for the same functionality."

3. Better separation of "what" vs "how".
Your point about encapsulation is well-taken. model.py defines components (what a layer looks like when built); __init__.py defines composition (how to assemble N layers with the right structural variations). The TransformerBlock doesn't need to know why it has an MoE — it just builds whatever the config says.

4. model/__init__.py already holds composition logic.
This is factually correct. Looking at llama4/__init__.py, it already decides interleave steps, iRoPE periods, MoE hidden dims, etc. — these are composition decisions. The expansion function is a natural extension of this existing role.

Where Option 2 Still Has Merit

1. Discoverability. With Option 2, the expansion is co-located with the config class definition (Llama4Model.Config._expand_layer). A reader looking at the Config class immediately sees how layers differ. With Option 3, you have to look in __init__.py for a separate function.

2. The expand() loop is shared. In Option 2, Decoder.Config.expand() provides the deep-copy + resolve_per_layer loop, and models only override _expand_layer(). In Option 3, each model's __init__.py must implement (or call a helper for) the full expansion loop, including deep-copy and resolve_per_layer. There's a risk of inconsistency if models implement this differently. This can be mitigated with a shared utility function, though.

3. Flux already uses expand() on Config. FluxModel.Config.expand() (flux/model/model.py:140-227) uses the Config method approach heavily. Moving decoder expansion to __init__.py would create two patterns: Flux does it on Config, decoders do it in __init__.py. Though one could argue Flux should be updated too if Option 3 is chosen.

Updated Recommendation

I now agree that Option 3 (with expansion in model/__init__.py) is the better design, primarily because:

  1. It eliminates the duplicated structural logic between _expand_layer() and TransformerBlock.__init__() — this is the most compelling practical benefit.
  2. It enables removing all four model-specific TransformerBlock subclasses (Llama3, Llama4, Qwen3, DeepSeek), which are currently near-identical. This directly satisfies the CLAUDE.md principle "unify across models."
  3. It follows the existing responsibility split where __init__.py already handles composition.
  4. The TransformerBlock.forward() can become a single, shared implementation.

The refinements I'd suggest for Option 3:

  • Provide a shared utility function (in common/) for the deep-copy + resolve_per_layer loop, so models don't each reimplement the mechanics.
  • For Llama3, which has no structural variation, the expansion could be as simple as calling the shared utility with no per-layer modifications.
  • The TransformerBlock.Config would still carry model-specific fields (every_n_layers_nope, n_dense_layers, etc.) as "hints" for the expansion function, or those fields could move to the model-level Config (e.g., Llama4Model.Config.every_n_layers_nope) since they're composition-level concerns, not layer-level concerns.

One caveat: for the forward() method, the Llama4 model currently distinguishes self.moe_enabled for the FFN/MoE branch. With expanded configs, this could be simplified to checking attribute existence (hasattr(self, 'moe')) or checking config.moe is not None at build time and setting a flag. Either way, the forward method unifies cleanly.


[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 26, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: c721621
Pull-Request: #2633
[ghstack-poisoned]
fegin added a commit that referenced this pull request Mar 26, 2026
**NOT READY TO REVIEW**
**NOT READY TO REVIEW**

**Motiviation and Design decision**

Parameter initialization should be configurable without modifying module code.  Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.

In theory, we could just apply param_init at the root module and be done — no recursion needed. So why   does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.

**Summary**
- Remove `init_weights`.
- `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module.
- `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level.
- `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.
- Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns.

**Verification**

with loss_compare.py

  ┌─────────────┬───────┬───────────────────┐
  │    Model    │ Steps │      Result       │
  ├─────────────┼───────┼───────────────────┤
  │ Llama3      │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ Qwen3       │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ DeepSeek V3 │ 10    │ Bitwise identical │
  ├─────────────┼───────┼───────────────────┤
  │ GPT-OSS     │ 10    │ Bitwise identical │
  └─────────────┴───────┴───────────────────┘

ghstack-source-id: cff1807
Pull-Request: #2633
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants