Skip to content

Commit 00d6e6a

Browse files
authored
[Module] Convert remaining nn.Module classes to Module protocol (#2565)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * __->__ #2565 **Why** It is hard to do the remaining changes, 1) state initialization and 2) sharding spec/local map spec change with some modules being nn.Module. The logic will need several fall back plan. This PR introduce a minimal change to conver all the nn.Module to Module. 1. Don't define Config for most of the Module. 2. Make init_weight() a materialized method with empty logic. While 2 makes it not possible to detect if a Module accidentally forget to define, it is fine beause we are going to change init_weight() in the next PR. **Summary** Convert 35 classes from plain nn.Module to the torchtitan Module protocol across core models (19) and experiments (16). Key design decisions: 1. Module without Config for non-configurable classes: If all constructor args come from the parent module (its Config or runtime), the class inherits Module without defining Config -- just a direct constructor. Config + build() is reserved for classes with independently user-configurable fields. 2. init_weights is a default no-op in base Module: Changed from abstractmethod + raise NotImplementedError to a default pass implementation. Subclasses with learnable parameters override it; all others inherit the no-op. This eliminates boilerplate empty init_weights methods. 3. ModuleContainer for namespace grouping: Added ModuleContainer(Module) in protocols/module.py to replace bare nn.Module() instances used as attribute namespace containers (e.g., self.mid = ModuleContainer() in Flux autoencoder). 4. Container types not converted: nn.ModuleDict/nn.ModuleList subclasses (e.g., SliceableModuleDict) are left as-is. Diamond inheritance with these container types adds complexity for no benefit. Core models: GroupedExperts, TokenChoiceTopKRouter, TokenReorderer, GptOssGroupedExperts, VarlenAttentionWrapper, FlexAttentionWrapper, ScaledDotProductAttentionWrapper, QKNorm, SelfAttention, Modulation, AttnBlock, ResnetBlock, Downsample, Upsample, Encoder, Decoder, DiagonalGaussian, AutoEncoder, FluxEmbedder. Experiments: VLM siglip2 (5), Projector, RL/vLLM attention wrappers (3), RL/vLLM Qwen3 components (4), graph trainer (2), vLLM model wrapper.
1 parent 4801b18 commit 00d6e6a

File tree

16 files changed

+152
-134
lines changed

16 files changed

+152
-134
lines changed

tests/unit_tests/test_module.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,22 @@
1313

1414

1515
class TestModuleInitWeights(unittest.TestCase):
16-
"""Tests for Module.init_weights enforcement.
16+
"""Tests for Module.init_weights behavior.
1717
18-
Module.init_weights uses ``raise NotImplementedError`` because
19-
nn.Module's metaclass is plain ``type`` (not ABCMeta), so
20-
@abstractmethod alone does not prevent instantiation of subclasses
21-
that forget to implement init_weights.
18+
Module.init_weights provides a default no-op implementation so that
19+
subclasses without learnable parameters (or loaded from checkpoints)
20+
do not need to override it.
2221
"""
2322

24-
def test_missing_init_weights_raises_on_call(self):
25-
"""Subclass without init_weights gets NotImplementedError at call time."""
23+
def test_default_init_weights_is_noop(self):
24+
"""Subclass without init_weights gets the default no-op."""
2625

27-
class BadModule(Module):
26+
class SimpleModule(Module):
2827
def __init__(self):
2928
super().__init__()
3029

31-
m = BadModule()
32-
with self.assertRaises(NotImplementedError):
33-
m.init_weights()
30+
m = SimpleModule()
31+
m.init_weights() # should not raise
3432

3533
def test_init_weights_implemented(self):
3634
"""Subclass with init_weights works normally."""
@@ -99,16 +97,15 @@ def test_isinstance_checks(self):
9997
self.assertIsInstance(emb, nn.Module)
10098
self.assertIsInstance(emb, Module)
10199

102-
def test_missing_init_weights_raises(self):
103-
"""Diamond class without init_weights raises on call."""
100+
def test_default_init_weights_noop_diamond(self):
101+
"""Diamond class without init_weights gets the default no-op."""
104102

105-
class BadEmbedding(nn.Embedding, Module):
103+
class SimpleEmbedding(nn.Embedding, Module):
106104
def __init__(self, num_embeddings, embedding_dim):
107105
super().__init__(num_embeddings, embedding_dim)
108106

109-
emb = BadEmbedding(10, 4)
110-
with self.assertRaises(NotImplementedError):
111-
emb.init_weights()
107+
emb = SimpleEmbedding(10, 4)
108+
emb.init_weights() # should not raise
112109

113110
def test_module_hierarchy_is_flat(self):
114111
"""Diamond embedding adds no extra layer to the module tree."""

torchtitan/experiments/graph_trainer/graph_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
end_with_pass,
2929
get_extra_fsdp_pg_name,
3030
)
31+
from torchtitan.protocols.module import Module
3132
from torchtitan.tools.logging import logger
3233

3334

@@ -182,7 +183,7 @@ def wrapper_fn(args, kwargs):
182183
return wrapper_fn
183184

184185

185-
class CompiledModule(torch.nn.Module):
186+
class CompiledModule(Module):
186187
def __init__(
187188
self,
188189
inner: torch.nn.Module,
@@ -225,6 +226,14 @@ def __delattr__(self, name: str) -> None:
225226
else:
226227
super().__delattr__(name)
227228

229+
def init_weights(self, **kwargs) -> None:
230+
# Explicitly delegate to inner model. Without this override,
231+
# Module.init_weights (a no-op) would be found via MRO before
232+
# the overwritten __getattr__ is triggered, silently skipping
233+
# weight initialization.
234+
# This is similar to state_dict, load_state_dict, ...
235+
self.inner.init_weights(**kwargs)
236+
228237
def state_dict(self, *args, **kwargs) -> Any:
229238
return self.inner.state_dict(*args, **kwargs)
230239

torchtitan/experiments/graph_trainer/simple_fsdp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2424
from torch.distributed.tensor.placement_types import _StridedShard, Placement
2525

26+
from torchtitan.protocols.module import Module
27+
2628
_active_parametrization = True
2729

2830

@@ -150,7 +152,7 @@ def _register_parametrization(
150152
module.__class__ = module_cls
151153

152154

153-
class ReplicateComputation(torch.nn.Module):
155+
class ReplicateComputation(Module):
154156
def __init__(
155157
self,
156158
device_mesh: DeviceMesh,

torchtitan/experiments/rl/unified/models/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from torchtitan.experiments.rl.vllm_compat.models.attention import (
1212
VLLMCompatibleFlashAttention,
1313
)
14+
from torchtitan.protocols.module import Module
1415
from vllm.model_executor.layers.attention import Attention
1516

1617
logger = logging.getLogger(__name__)
1718

1819

19-
class VLLMAttention(torch.nn.Module):
20+
class VLLMAttention(Module):
2021
"""Adapter from TorchTitan tensor layout to ``vllm.Attention``.
2122
2223
vLLM's ``Attention`` layer manages KV-cache and paged attention internally,

torchtitan/experiments/rl/unified/models/vllm_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch
1717
import torch.distributed as dist
1818
import torch.distributed.checkpoint as dcp
19-
import torch.nn as nn
2019
from torch.distributed._tensor import DTensor, Replicate
2120
from torch.distributed.checkpoint.state_dict import (
2221
set_model_state_dict,
@@ -29,6 +28,7 @@
2928
replace_with_vllm_attention,
3029
)
3130
from torchtitan.protocols.model_spec import ModelSpec
31+
from torchtitan.protocols.module import Module
3232
from vllm.compilation.decorators import support_torch_compile
3333
from vllm.config import VllmConfig
3434
from vllm.logger import init_logger
@@ -118,7 +118,7 @@ def create_torchtitan_config_from_vllm_config(
118118
"positions": 0,
119119
}
120120
)
121-
class TorchTitanVLLMModelWrapper(nn.Module):
121+
class TorchTitanVLLMModelWrapper(Module):
122122
"""
123123
Generic vLLM-compatible model wrapper for TorchTitan models. Implemented
124124
required interface required by vLLM Engine.

torchtitan/experiments/rl/vllm_compat/models/attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
import torch
1212
from torch.distributed._tensor import DTensor
13+
14+
from torchtitan.protocols.module import Module
1315
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
1416

1517

16-
class VLLMCompatibleFlashAttention(torch.nn.Module):
18+
class VLLMCompatibleFlashAttention(Module):
1719
"""Wrapper around FlashAttention as used by VLLM"""
1820

1921
def __init__(self) -> None:

torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# Import from main torchtitan
2424
from torchtitan.models.qwen3.model import Qwen3Model
2525
from torchtitan.protocols.model import BaseModel
26+
from torchtitan.protocols.module import Module
2627

2728
# Import from local experiment's models
2829
from ..attention import VLLMCompatibleFlashAttention
@@ -82,7 +83,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
8283
)
8384

8485

85-
class VLLMRMSNorm(nn.Module):
86+
class VLLMRMSNorm(Module):
8687
"""
8788
RMSNorm using vLLM's exact Triton kernel for bitwise determinism.
8889
Compatible with PyTorch's nn.RMSNorm interface but uses vLLM's implementation.
@@ -104,7 +105,7 @@ def reset_parameters(self):
104105
nn.init.ones_(self.weight)
105106

106107

107-
class FeedForwardVLLMCompat(nn.Module):
108+
class FeedForwardVLLMCompat(Module):
108109
"""
109110
FeedForward module compatible with vLLM implementation.
110111
Uses merged gate_up projection like vLLM.
@@ -132,13 +133,14 @@ def forward(self, x):
132133
output = self.down_proj(activated)
133134
return output
134135

135-
def init_weights(self, init_std: float):
136-
# Initialize like vLLM
136+
def init_weights(self, **kwargs) -> None:
137+
init_std = kwargs.get("init_std")
138+
assert init_std is not None
137139
nn.init.trunc_normal_(self.gate_up_proj.weight, mean=0.0, std=0.02)
138140
nn.init.trunc_normal_(self.down_proj.weight, mean=0.0, std=init_std)
139141

140142

141-
class Attention(nn.Module):
143+
class Attention(Module):
142144
"""
143145
Multi-head attention module compatible with vLLM.
144146
"""
@@ -172,7 +174,9 @@ def __init__(self, model_args: Qwen3Model.Config):
172174
# Always use vLLM compatible flash attention
173175
self.inner_attention = VLLMCompatibleFlashAttention()
174176

175-
def init_weights(self, init_std: float):
177+
def init_weights(self, **kwargs) -> None:
178+
init_std = kwargs.get("init_std")
179+
assert init_std is not None
176180
for linear in (self.wq, self.wk, self.wv):
177181
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
178182
nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std)
@@ -226,7 +230,7 @@ def forward(
226230
return self.wo(output)
227231

228232

229-
class TransformerBlock(nn.Module):
233+
class TransformerBlock(Module):
230234
"""
231235
TransformerBlock with vLLM-compatible FFN.
232236
"""
@@ -267,11 +271,11 @@ def forward(
267271

268272
return x
269273

270-
def init_weights(self, buffer_device: torch.device):
274+
def init_weights(self, **kwargs) -> None:
271275
for norm in (self.attention_norm, self.ffn_norm):
272276
norm.reset_parameters()
273-
self.attention.init_weights(self.weight_init_std)
274-
self.feed_forward.init_weights(self.weight_init_std)
277+
self.attention.init_weights(init_std=self.weight_init_std)
278+
self.feed_forward.init_weights(init_std=self.weight_init_std)
275279

276280

277281
class Qwen3VLLMCompatModel(BaseModel):
@@ -318,7 +322,7 @@ def init_weights(
318322
nn.init.normal_(self.tok_embeddings.weight)
319323
for layer in self.layers.values():
320324
if layer is not None:
321-
layer.init_weights(buffer_device)
325+
layer.init_weights(buffer_device=buffer_device)
322326
if self.norm is not None:
323327
self.norm.reset_parameters()
324328
final_out_std = self.config.dim**-0.5

torchtitan/experiments/vlm/model/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchtitan.components.tokenizer import BaseTokenizer
1515
from torchtitan.models.common.attention import AttentionMasksType
1616
from torchtitan.models.llama3 import Llama3Model as Llama3
17+
from torchtitan.protocols.module import Module
1718

1819
from .args import Siglip2Config, SpecialTokens
1920
from .siglip2 import VisionTransformer
@@ -34,7 +35,7 @@ def _scatter_img_tokens(h_BSD, tokens_BS, i_NLD, i_mask_NL, img_id):
3435
return h_BSD
3536

3637

37-
class Projector(nn.Module):
38+
class Projector(Module):
3839
"""Project the Encoder embedding to the LLM embedding."""
3940

4041
def __init__(self, in_dim: int, out_dim: int) -> None:
@@ -49,7 +50,7 @@ def forward(self, x_NLD: torch.Tensor):
4950
x_NLD = self.w2(x_NLD)
5051
return x_NLD
5152

52-
def init_weights(self):
53+
def init_weights(self, **kwargs) -> None:
5354
nn.init.xavier_uniform_(self.w1.weight)
5455
if self.w1.bias is not None:
5556
nn.init.zeros_(self.w1.bias)

torchtitan/experiments/vlm/model/siglip2.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_causal_mask_mod,
1919
get_document_mask_mod,
2020
)
21+
from torchtitan.protocols.module import Module
2122

2223
from .args import Siglip2Config
2324

@@ -71,7 +72,7 @@ def resize_positional_embeddings(
7172
return resized_embs_BLD
7273

7374

74-
class VisionEmbeddings(nn.Module):
75+
class VisionEmbeddings(Module):
7576
def __init__(self, args: Siglip2Config):
7677
super().__init__()
7778
self.patch_embedding = nn.Linear(
@@ -81,7 +82,7 @@ def __init__(self, args: Siglip2Config):
8182
self.position_embedding = nn.Embedding(args.n_pos_embs**2, args.dim)
8283
self.n_pos_embs = args.n_pos_embs
8384

84-
def init_weights(self):
85+
def init_weights(self, **kwargs) -> None:
8586
nn.init.trunc_normal_(self.patch_embedding.weight, mean=0.0, std=0.02)
8687
nn.init.normal_(self.position_embedding.weight)
8788

@@ -106,7 +107,7 @@ def forward(self, pixels_NLD: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tens
106107
return embeddings
107108

108109

109-
class Attention(nn.Module):
110+
class Attention(Module):
110111
"""
111112
Multi-head attention module.
112113
@@ -151,12 +152,12 @@ def forward(self, x: torch.Tensor, attention_masks: AttentionMasksType):
151152

152153
return self.out_proj(output)
153154

154-
def init_weights(self):
155+
def init_weights(self, **kwargs) -> None:
155156
for linear in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
156157
nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02)
157158

158159

159-
class FeedForward(nn.Module):
160+
class FeedForward(Module):
160161
def __init__(self, args: Siglip2Config):
161162
super().__init__()
162163
self.fc1 = nn.Linear(args.dim, args.ffn_dim)
@@ -168,12 +169,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
168169
x = self.fc2(x)
169170
return x
170171

171-
def init_weights(self):
172+
def init_weights(self, **kwargs) -> None:
172173
nn.init.trunc_normal_(self.fc1.weight, mean=0.0, std=0.02)
173174
nn.init.trunc_normal_(self.fc2.weight, mean=0.0, std=0.02)
174175

175176

176-
class TransformerLayer(nn.Module):
177+
class TransformerLayer(Module):
177178
def __init__(self, args: Siglip2Config):
178179
super().__init__()
179180
self.layer_norm1 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps)
@@ -188,14 +189,14 @@ def forward(
188189
x = x + self.mlp(self.layer_norm2(x))
189190
return x
190191

191-
def init_weights(self):
192+
def init_weights(self, **kwargs) -> None:
192193
self.layer_norm1.reset_parameters()
193194
self.layer_norm2.reset_parameters()
194195
self.self_attn.init_weights()
195196
self.mlp.init_weights()
196197

197198

198-
class VisionTransformer(nn.Module):
199+
class VisionTransformer(Module):
199200
def __init__(self, args: Siglip2Config):
200201
super().__init__()
201202
self.args = args
@@ -251,7 +252,7 @@ def forward(
251252

252253
return h
253254

254-
def init_weights(self):
255+
def init_weights(self, **kwargs) -> None:
255256
self.embeddings.init_weights()
256257
for layer in self.layers.values():
257258
layer.init_weights()

0 commit comments

Comments
 (0)