[Module] Refactor init_weights to config-based param_init system#2633
[Module] Refactor init_weights to config-based param_init system#2633fegin wants to merge 17 commits intogh/fegin/103/basefrom
Conversation
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. ghstack-source-id: 024899c Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. ghstack-source-id: 11120b6 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. ghstack-source-id: 6871155 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. ghstack-source-id: 803f279 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: e6704e8 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 440e49d Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 1bcce2b Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 884689c Pull-Request: #2633
| except AttributeError: | ||
| # field(init=False) not yet set, ignore this field. | ||
| continue | ||
| if callable(val) and not dataclasses.is_dataclass(val): |
There was a problem hiding this comment.
oh, can a dataclass ever be callable?
There was a problem hiding this comment.
btw, https://fburl.com/code/8t7xhgna is making Config.param_init a config, not a function, so that it should still be serializable. I think we should do this too, to make it reproducible (given a code version).
Right now parallelize_fn and other functions in ModelSpec are not serializable, which should be fixed later.
There was a problem hiding this comment.
The config is a dictionary that maps from reg to function. So from serializability perspective, the requirement is the same?
| buffer_device = buffer_device or self.freqs_cis.device | ||
| if self.rope is not None: | ||
| self.rope.init_weights(buffer_device=buffer_device) | ||
| # RoPE's _init_self_buffers was already called by auto-recursion |
There was a problem hiding this comment.
Not sure about this, but it's possible that it is set on meta device, when buffer_device here is cuda? Shall we add an assertion for device?
There was a problem hiding this comment.
Not quite sure if I understand the question. If meta device is used, then everything is meta device. buffer_device is only used when trainer parallelizes the model and tries to re-initialize the states again. The comment here just want to explain the order, not about the device though.
Assertion is good, but I want to understand if I miss something.
torchtitan/protocols/module.py
Outdated
| # Type alias for simple parameter initializers: (param) -> Any | ||
| # Uses Any return type because nn.init.* functions return Tensor, | ||
| # but the return value is always ignored by the dispatch layer. | ||
| ParamInitializer = Callable[[nn.Parameter], Any] |
There was a problem hiding this comment.
If it's only used in common/param_init.py, can we put it there for now?
There was a problem hiding this comment.
That will be inconsistent. We will have ParamInitializer in common/param_init.py but NamedParamInitializer being here. Ideally, we should put both in common/param_init.py. But I cannot do this due to the recursive import.
torchtitan/protocols/module.py
Outdated
| # and the prefix resets (FQNs are relative to the owner). | ||
| if self.param_init is not None: | ||
| param_init = self.param_init | ||
| param_prefix = "" |
There was a problem hiding this comment.
Not sure about this.
E.g. let's say I created a new module to plug in each layer, for which I want to use depth init.
- If I provide a
config.param_initto this module, I would lose the layer id because of this line - Of course, I can change the overall decoder init. But if that's the case, what's the point of supporting
param_initon each module? -- we could always callget_default_decoder_initand modify the dictionary.
torchtitan/protocols/module.py
Outdated
| if self.param_init is not None: | ||
| instance.param_init = self.param_init |
There was a problem hiding this comment.
we could do this in constructor, curious why we need to do it here
There was a problem hiding this comment.
Because we don't have a common Module.__init__, so users who define a new Module needs to remember to do the assignment. Or an alternative is to define a common Module.init. I thought about the later case, but that requires some reworks of module inheritance hierarchy and users now have to call super.__init__(cfg, *args, *kwargs). The later constraint is okay guess but the former is something need sometime to verify. If we do want to change this, we can later do another round of change to the Module hierarchy.
torchtitan/protocols/module.py
Outdated
| def init_states( | ||
| self, | ||
| *, | ||
| param_init: NamedParamInitializer | None = None, |
There was a problem hiding this comment.
This seems different from https://fburl.com/code/ia9w7w5i
- over there, if the param is missing then rely on the parent
- here, if sending the parent param_init to all children. I think this is leaking information from global to local modules, so should be avoided in general.
There was a problem hiding this comment.
I'm not sure though. Parent modules already hold the param_init, is that considered as leaking from global to a local? I didn't do the parent walk because unlike the internal code, Node, that has parent, we don't have this implementation in our Module. So I deliberately didn't include that implementation. Adding parent will also cause circular dependency and make GC slower. But the later concern should not be a real concern as we never GC modules. Just provide this minor issue.
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: ccd5db5 Pull-Request: #2633
|
I change the parameter init to ancestor-lookup style. But the comment reflects an issue brought by the circular reference. So the original concern of parent reference is legit. This needs some careful handle. But the code looks cleaner for each inherited The typing of param_init is still not changed. As long as there is an unnamed function, which the internal code also does, serialization is not sound. |
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 4427985 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: ccfa49b Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 8ec94ee Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: d3cf966 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 8006560 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: 8e8a34a Pull-Request: #2633
|
@claude review this pull request. In particular, please comment on where is best to expand general transformer layer config into per-layer configs. I think in general there are three options:
|
|
Claude finished @tianyu-l's task in 2m 50s —— View job PR Review: Refactor init_weights to config-based param_init system
Overall AssessmentThis is a substantial and well-motivated refactor that centralizes parameter initialization into a declarative, config-based system. The design cleanly separates "what to initialize" (config-level The numeric verification (bitwise-identical loss for Llama3 and Qwen3 over 100 steps) is reassuring, though only 2 of the 6+ models were verified. The Flux and GPT-OSS models have quite different init patterns, so numeric parity for those would strengthen confidence. Analysis: Where to Expand Per-Layer ConfigThis is the core design question. Here is my analysis of the three options: Option 1: No per-layer config; Option 2 (current PR): The model config class has an Advantages:
Disadvantages:
Option 3: The model constructor only accepts already-expanded per-layer configs; expansion logic lives in Advantages:
Disadvantages:
My recommendation: Option 2 is the best fit for this codebase, with a minor refinement. The
The refinement I'd suggest: consider making Specific Code Review Findings1. Dual-path construction in
|
|
@claude Let me clarify. For option 2, yes we could just remove the logic in For option 3, I didn't mean putting the expansion logic in
|
|
Claude finished @tianyu-l's task in 1m 47s —— View job
Revised Analysis: Option 2 vs Clarified Option 3Thank you for the clarification. The distinction is meaningful — placing expansion logic in Understanding the Current BoundariesLooking at the code as it stands in this PR, the responsibilities are already split:
The key observation is that the model-specific
This duplication is exactly the problem. Both Revised Assessment of Option 3With the clarification that expansion lives in # In model/__init__.py (e.g., llama4/__init__.py)
def _expand_layers(config: Decoder.Config) -> list[TransformerBlock.Config]:
"""Llama4-specific layer expansion: iRoPE + MoE interleaving."""
layers = []
for layer_id in range(config.n_layers):
layer_cfg = copy.deepcopy(config.layer)
resolve_per_layer(layer_cfg, layer_id)
# iRoPE
if layer_cfg.every_n_layers_nope and layer_id % layer_cfg.every_n_layers_nope == 0:
layer_cfg = replace(layer_cfg, attention=replace(layer_cfg.attention, use_rope=False))
# MoE interleaving
if (layer_id + 1) % layer_cfg.interleave_moe_layer_step == 0:
layer_cfg = replace(layer_cfg, feed_forward=None)
else:
layer_cfg = replace(layer_cfg, moe=None)
layers.append(layer_cfg)
return layers
def model_registry(flavor: str) -> ModelSpec:
config = llama4_configs[flavor]
config.layers = _expand_layers(config) # expansion happens here
return ModelSpec(model=config, ...)And the # In common/decoder.py — ONE implementation, no subclasses needed
class TransformerBlock(Module):
def __init__(self, config, *, layer_id, dim, n_layers):
super().__init__()
self.attention = config.attention.build(dim=dim)
self.attention_norm = config.attention_norm.build(normalized_shape=dim)
self.ffn_norm = config.ffn_norm.build(normalized_shape=dim)
# Config already has exactly one of feed_forward/moe set
if config.feed_forward is not None:
self.feed_forward = config.feed_forward.build(dim=dim)
if config.moe is not None:
self.moe = config.moe.build(dim=dim)Where I Think Option 3 (Clarified) Is Stronger1. Eliminates the duplicated structural logic. 2. Eliminates model-specific TransformerBlock subclasses.
All four model-specific 3. Better separation of "what" vs "how". 4. Where Option 2 Still Has Merit1. Discoverability. With Option 2, the expansion is co-located with the config class definition ( 2. The 3. Flux already uses expand() on Config. Updated RecommendationI now agree that Option 3 (with expansion in
The refinements I'd suggest for Option 3:
One caveat: for the |
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: c721621 Pull-Request: #2633
**NOT READY TO REVIEW** **NOT READY TO REVIEW** **Motiviation and Design decision** Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g., `init_trunc_normal`, `init_zeros`) to customize how parameters are initialized. Parameters are identified by their FQNs, and `init_by_regex` lets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init. In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does `init_states` recurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic via `init_self_buffers`, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass. **Summary** - Remove `init_weights`. - `init_states` auto-recurses the module tree, then calls `init_self_parameters` and `init_self_buffers` on each module. - `param_init` (a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their own `param_init` delegate up the parent chain, so init logic is defined once at the model config level. - `make_decoder_param_init provides` shared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge. - Removes init_mean/init_std fields from Linear.Config and Embedding.Config — init parameters are now expressed entirely through param_init patterns. **Verification** with loss_compare.py ┌─────────────┬───────┬───────────────────┐ │ Model │ Steps │ Result │ ├─────────────┼───────┼───────────────────┤ │ Llama3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ Qwen3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ DeepSeek V3 │ 10 │ Bitwise identical │ ├─────────────┼───────┼───────────────────┤ │ GPT-OSS │ 10 │ Bitwise identical │ └─────────────┴───────┴───────────────────┘ ghstack-source-id: cff1807 Pull-Request: #2633
Stack from ghstack (oldest at bottom):
Motiviation and Design decision
Parameter initialization should be configurable without modifying module code. Users reuse provided helpers (e.g.,
init_trunc_normal,init_zeros) to customize how parameters are initialized. Parameters are identified by their FQNs, andinit_by_regexlets users map regex patterns to initializers — so a single declarative dict at the model config level controls all parameter init.In theory, we could just apply param_init at the root module and be done — no recursion needed. So why does
init_statesrecurse? The main motivation is that buffer initialization is usually inherently coupled with module internals and not user-configurable. Since buffers require per-module logic viainit_self_buffers, we provide a recursive init_states that handles both: it walks the module tree, initializing parameters via the configurable param_init and buffers via module-specific overrides, in a single pass.Summary
init_weights.init_statesauto-recurses the module tree, then callsinit_self_parametersandinit_self_bufferson each module.param_init(a NamedInitializer callable on Module.Config) uses regex-matched FQN patterns to map parameters to initializers. Child modules without their ownparam_initdelegate up the parent chain, so init logic is defined once at the model config level.make_decoder_param_init providesshared patterns for decoder-based model, reused by Llama3, Llama4, Qwen3, and DeepSeek V3. Model-specific extensions (Flux DiT-style, GPT-OSS MoE biases) are composed via dict merge.Verification
with loss_compare.py and change the code to make the initialization order exact the same as the main branch, we can get llama3 and qwen3 loss parity
However, the current code doesn't include the fixed order because it will make the code ugly. It is expected to have a different loss if the initialization order changes.