Skip to content

Commit a3b640d

Browse files
committed
Merge remote-tracking branch 'upstream/main' into fix-pos-id
2 parents 84a110c + 0187d5f commit a3b640d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2270
-1419
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ datasets >= 3.6.0
33
tensorboard
44
wandb
55
fsspec
6-
tyro
6+
tyro >= 1.0.5
77
tokenizers >= 0.15.0
88
safetensors
99
einops

.github/workflows/integration_test_8gpu_features.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ jobs:
6666
fi
6767
python -m pip install --force-reinstall --pre \
6868
"${TORCH_SPEC}" --index-url ${{ matrix.index-url }}
69+
if [[ "${{ matrix.gpu-arch-type }}" == "cuda" ]]; then
70+
python -m pip install --pre torchcomms --index-url ${{ matrix.index-url }}
71+
fi
6972
end=$(date +%s)
7073
echo "pip install torch took $((end - start)) seconds"
7174

.github/workflows/integration_test_8gpu_torchcomms.yaml

Lines changed: 0 additions & 54 deletions
This file was deleted.

tests/integration_tests/features.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,34 @@ def build_features_test_list() -> list[OverrideDefinitions]:
559559
"Float8 emulation test",
560560
"float8_emulation",
561561
),
562+
OverrideDefinitions(
563+
[
564+
[
565+
"--comm.mode torchcomms",
566+
"--parallelism.context_parallel_degree 2",
567+
"--parallelism.pipeline_parallel_degree 2",
568+
"--compile.enable",
569+
],
570+
],
571+
"FSDP+CP+PP+compile with torchcomms",
572+
"torchcomms_3d_dp+cp+pp+compile",
573+
ngpu=8,
574+
skip_rocm_test=True,
575+
),
576+
OverrideDefinitions(
577+
[
578+
[
579+
"--comm.mode torchcomms",
580+
"--parallelism.tensor_parallel_degree 2",
581+
"--parallelism.pipeline_parallel_degree 2",
582+
"--compile.enable",
583+
],
584+
],
585+
"FSDP+TP+PP+compile with torchcomms",
586+
"torchcomms_3d_dp+tp+pp+compile",
587+
ngpu=8,
588+
skip_rocm_test=True,
589+
),
562590
]
563591

564592
return integration_tests_flavors

tests/integration_tests/flux.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
5353
dump_folder_arg = f"--dump_folder {output_dir}/{test_name}"
5454

5555
# Random init encoder for offline testing
56-
random_init_encoder_arg = "--encoder.test_mode --dataloader.encoder.test_mode"
56+
random_init_arg = "--tokenizer.test_mode --encoder.random_init"
5757
clip_encoder_version_arg = (
5858
"--encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/"
5959
)
6060
t5_encoder_version_arg = (
6161
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
6262
)
63+
t5_tokenizer_path_arg = "--tokenizer.t5_tokenizer_path tests/assets/tokenizer"
64+
clip_tokenizer_path_arg = "--tokenizer.clip_tokenizer_path tests/assets/tokenizer"
6365
hf_assets_path_arg = "--hf_assets_path tests/assets/tokenizer"
6466

6567
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
@@ -78,9 +80,11 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
7880
)
7981

8082
cmd += " " + dump_folder_arg
81-
cmd += " " + random_init_encoder_arg
83+
cmd += " " + random_init_arg
8284
cmd += " " + clip_encoder_version_arg
8385
cmd += " " + t5_encoder_version_arg
86+
cmd += " " + t5_tokenizer_path_arg
87+
cmd += " " + clip_tokenizer_path_arg
8488
cmd += " " + hf_assets_path_arg
8589
if override_arg:
8690
cmd += " " + " ".join(override_arg)

tests/unit_tests/test_configurable.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,30 @@ def __init__(self, config: Config):
200200
self.assertEqual(d2["inner"]["a"], 1)
201201
self.assertEqual(d2["inner"]["b"], 256)
202202

203+
def test_repr_with_unset_init_false(self):
204+
"""repr() must not crash when field(init=False) slots are unset."""
205+
cfg = self.NewStyleComponent.Config(x=10)
206+
# Before build: dim and hidden are unset
207+
r = repr(cfg)
208+
self.assertIn("x=10", r)
209+
self.assertIn("dim=<UNSET>", r)
210+
self.assertIn("hidden=<UNSET>", r)
211+
212+
# After build: all fields set
213+
obj = cfg.build(dim=64, hidden=128)
214+
r2 = repr(obj.config)
215+
self.assertIn("x=10", r2)
216+
self.assertIn("dim=64", r2)
217+
self.assertIn("hidden=128", r2)
218+
self.assertNotIn("UNSET", r2)
219+
220+
def test_repr_no_init_false_fields(self):
221+
"""repr() works normally when there are no field(init=False) fields."""
222+
cfg = self.NoKwargsComponent.Config(x=42)
223+
r = repr(cfg)
224+
self.assertIn("x=42", r)
225+
self.assertNotIn("UNSET", r)
226+
203227
def test_init_false_with_inheritance(self):
204228
"""Child config can redeclare field with default."""
205229

tests/unit_tests/test_dataset_flux.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,21 +65,29 @@ def test_load_dataset(self):
6565
str(256),
6666
"--dataloader.dataset",
6767
dataset_name,
68-
"--dataloader.classifier_free_guidance_prob",
68+
"--dataloader.prompt_dropout_prob",
6969
"0.447",
70-
"--dataloader.encoder.test_mode",
71-
"--encoder.test_mode",
70+
"--tokenizer.test_mode",
71+
"--tokenizer.t5_tokenizer_path",
72+
"tests/assets/tokenizer",
73+
"--tokenizer.clip_tokenizer_path",
74+
"tests/assets/tokenizer",
75+
"--encoder.random_init",
7276
"--encoder.t5_encoder",
7377
"tests/assets/flux_test_encoders/t5-v1_1-xxl",
7478
"--encoder.clip_encoder",
7579
"tests/assets/flux_test_encoders/clip-vit-large-patch14",
7680
]
7781
)
7882

83+
# Build the tokenizer container from config
84+
tokenizer = config.tokenizer.build(tokenizer_path=config.hf_assets_path)
85+
7986
dl = config.dataloader.build(
8087
dp_world_size=world_size,
8188
dp_rank=rank,
8289
local_batch_size=batch_size,
90+
tokenizer=tokenizer,
8391
)
8492

8593
it = iter(dl)
@@ -91,11 +99,11 @@ def test_load_dataset(self):
9199
len(input_data) == 3
92100
) # (clip_encodings, t5_encodings, prompt)
93101
assert labels.shape == (batch_size, 3, 256, 256)
94-
assert input_data["clip_tokens"].shape == (
102+
assert input_data["clip"].shape == (
95103
batch_size,
96104
77,
97105
)
98-
assert input_data["t5_tokens"].shape == (
106+
assert input_data["t5"].shape == (
99107
batch_size,
100108
256,
101109
)
@@ -107,6 +115,7 @@ def test_load_dataset(self):
107115
dp_world_size=world_size,
108116
dp_rank=rank,
109117
local_batch_size=batch_size,
118+
tokenizer=tokenizer,
110119
)
111120
dl_resumed.load_state_dict(state)
112121
it_resumed = iter(dl_resumed)
@@ -119,10 +128,6 @@ def test_load_dataset(self):
119128
torch.manual_seed(i)
120129
input_ids, labels = next(it_resumed)
121130

122-
assert torch.equal(
123-
input_ids["clip_tokens"], expected_input_ids["clip_tokens"]
124-
)
125-
assert torch.equal(
126-
input_ids["t5_tokens"], expected_input_ids["t5_tokens"]
127-
)
131+
assert torch.equal(input_ids["clip"], expected_input_ids["clip"])
132+
assert torch.equal(input_ids["t5"], expected_input_ids["t5"])
128133
assert torch.equal(labels, expected_labels)

tests/unit_tests/test_module.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,22 @@
1313

1414

1515
class TestModuleInitWeights(unittest.TestCase):
16-
"""Tests for Module.init_weights enforcement.
16+
"""Tests for Module.init_weights behavior.
1717
18-
Module.init_weights uses ``raise NotImplementedError`` because
19-
nn.Module's metaclass is plain ``type`` (not ABCMeta), so
20-
@abstractmethod alone does not prevent instantiation of subclasses
21-
that forget to implement init_weights.
18+
Module.init_weights provides a default no-op implementation so that
19+
subclasses without learnable parameters (or loaded from checkpoints)
20+
do not need to override it.
2221
"""
2322

24-
def test_missing_init_weights_raises_on_call(self):
25-
"""Subclass without init_weights gets NotImplementedError at call time."""
23+
def test_default_init_weights_is_noop(self):
24+
"""Subclass without init_weights gets the default no-op."""
2625

27-
class BadModule(Module):
26+
class SimpleModule(Module):
2827
def __init__(self):
2928
super().__init__()
3029

31-
m = BadModule()
32-
with self.assertRaises(NotImplementedError):
33-
m.init_weights()
30+
m = SimpleModule()
31+
m.init_weights() # should not raise
3432

3533
def test_init_weights_implemented(self):
3634
"""Subclass with init_weights works normally."""
@@ -99,16 +97,15 @@ def test_isinstance_checks(self):
9997
self.assertIsInstance(emb, nn.Module)
10098
self.assertIsInstance(emb, Module)
10199

102-
def test_missing_init_weights_raises(self):
103-
"""Diamond class without init_weights raises on call."""
100+
def test_default_init_weights_noop_diamond(self):
101+
"""Diamond class without init_weights gets the default no-op."""
104102

105-
class BadEmbedding(nn.Embedding, Module):
103+
class SimpleEmbedding(nn.Embedding, Module):
106104
def __init__(self, num_embeddings, embedding_dim):
107105
super().__init__(num_embeddings, embedding_dim)
108106

109-
emb = BadEmbedding(10, 4)
110-
with self.assertRaises(NotImplementedError):
111-
emb.init_weights()
107+
emb = SimpleEmbedding(10, 4)
108+
emb.init_weights() # should not raise
112109

113110
def test_module_hierarchy_is_flat(self):
114111
"""Diamond embedding adds no extra layer to the module tree."""

tests/unit_tests/test_rope.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchtitan.models.common.rope import apply_rotary_emb_cos_sin
12+
13+
14+
class TestApplyRotaryEmbCosSin(unittest.TestCase):
15+
def setUp(self):
16+
torch.manual_seed(42)
17+
self.bsz = 2
18+
self.seqlen = 16
19+
self.n_heads = 4
20+
self.head_dim = 64
21+
self.xq = torch.randn(
22+
self.bsz, self.seqlen, self.n_heads, self.head_dim, dtype=torch.bfloat16
23+
)
24+
self.xk = torch.randn(
25+
self.bsz, self.seqlen, self.n_heads, self.head_dim, dtype=torch.bfloat16
26+
)
27+
self.rope_cache = torch.randn(
28+
self.seqlen, self.head_dim * 2, dtype=torch.float32
29+
)
30+
31+
def test_output_dtype_matches_input(self):
32+
xq_out, xk_out = apply_rotary_emb_cos_sin(self.xq, self.xk, self.rope_cache)
33+
self.assertEqual(xq_out.dtype, self.xq.dtype)
34+
self.assertEqual(xk_out.dtype, self.xk.dtype)
35+
36+
def test_output_shape_matches_input(self):
37+
xq_out, xk_out = apply_rotary_emb_cos_sin(self.xq, self.xk, self.rope_cache)
38+
self.assertEqual(xq_out.shape, self.xq.shape)
39+
self.assertEqual(xk_out.shape, self.xk.shape)
40+
41+
def test_computes_in_fp32(self):
42+
"""Output must match a reference computed entirely in float32.
43+
44+
Ensures inductor cannot fuse away the fp32 upcast when compiling
45+
adjacent ops (e.g. q_norm/k_norm) with the RoPE computation.
46+
"""
47+
xq_out, xk_out = apply_rotary_emb_cos_sin(self.xq, self.xk, self.rope_cache)
48+
49+
cos = self.rope_cache[..., : self.head_dim].unsqueeze(0).unsqueeze(2)
50+
sin = self.rope_cache[..., self.head_dim :].unsqueeze(0).unsqueeze(2)
51+
52+
def rotate_half(x):
53+
half = x.shape[-1] // 2
54+
return torch.cat([-x[..., half:], x[..., :half]], dim=-1)
55+
56+
xq_ref = (
57+
(self.xq.float() * cos) + (rotate_half(self.xq.float()) * sin)
58+
).bfloat16()
59+
xk_ref = (
60+
(self.xk.float() * cos) + (rotate_half(self.xk.float()) * sin)
61+
).bfloat16()
62+
63+
self.assertEqual((xq_out - xq_ref).abs().max().item(), 0.0)
64+
self.assertEqual((xk_out - xk_ref).abs().max().item(), 0.0)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)