You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Module] Convert remaining nn.Module classes to Module protocol (#2565)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* __->__ #2565
**Why**
It is hard to do the remaining changes, 1) state initialization and 2)
sharding spec/local map spec change with some modules being nn.Module.
The logic will need several fall back plan.
This PR introduce a minimal change to conver all the nn.Module to
Module.
1. Don't define Config for most of the Module.
2. Make init_weight() a materialized method with empty logic.
While 2 makes it not possible to detect if a Module accidentally forget
to define, it is fine beause we are going to change init_weight() in the
next PR.
**Summary**
Convert 35 classes from plain nn.Module to the torchtitan Module
protocol
across core models (19) and experiments (16).
Key design decisions:
1. Module without Config for non-configurable classes: If all
constructor
args come from the parent module (its Config or runtime), the class
inherits Module without defining Config -- just a direct constructor.
Config + build() is reserved for classes with independently
user-configurable fields.
2. init_weights is a default no-op in base Module: Changed from
abstractmethod + raise NotImplementedError to a default pass
implementation. Subclasses with learnable parameters override it;
all others inherit the no-op. This eliminates boilerplate empty
init_weights methods.
3. ModuleContainer for namespace grouping: Added ModuleContainer(Module)
in protocols/module.py to replace bare nn.Module() instances used as
attribute namespace containers (e.g., self.mid = ModuleContainer()
in Flux autoencoder).
4. Container types not converted: nn.ModuleDict/nn.ModuleList subclasses
(e.g., SliceableModuleDict) are left as-is. Diamond inheritance with
these container types adds complexity for no benefit.
Core models: GroupedExperts, TokenChoiceTopKRouter, TokenReorderer,
GptOssGroupedExperts, VarlenAttentionWrapper, FlexAttentionWrapper,
ScaledDotProductAttentionWrapper, QKNorm, SelfAttention, Modulation,
AttnBlock, ResnetBlock, Downsample, Upsample, Encoder, Decoder,
DiagonalGaussian, AutoEncoder, FluxEmbedder.
Experiments: VLM siglip2 (5), Projector, RL/vLLM attention wrappers (3),
RL/vLLM Qwen3 components (4), graph trainer (2), vLLM model wrapper.
0 commit comments