Skip to content

Commit 004162a

Browse files
authored
minor fix (#1494)
1 parent 2429e0b commit 004162a

File tree

4 files changed

+17
-16
lines changed

4 files changed

+17
-16
lines changed

torchtitan/experiments/flux/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def parallelize_encoders(
143143
fully_shard(t5_model.hf_module, **fsdp_config)
144144

145145
if parallel_dims.dp_replicate_enabled:
146-
logger.info("Applied FSDP to the T5 encoder model")
146+
logger.info("Applied HSDP to the T5 encoder model")
147147
else:
148148
logger.info("Applied FSDP to the T5 encoder model")
149149

torchtitan/experiments/flux/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def denoise(
172172
_, latent_channels, latent_height, latent_width = latents.shape
173173

174174
# create denoising schedule
175-
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
175+
timesteps = get_schedule(denoising_steps, latent_height * latent_width, shift=True)
176176

177177
# create positional encodings
178178
POSITION_DIM = 3

torchtitan/experiments/flux/train.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,20 @@ def forward_backward_step(
126126
# Patchify: Convert latent into a sequence of patches
127127
latents = pack_latents(latents)
128128

129-
latent_noise_pred = model(
130-
img=latents,
131-
img_ids=latent_pos_enc,
132-
txt=t5_encodings,
133-
txt_ids=text_pos_enc,
134-
y=clip_encodings,
135-
timesteps=timesteps,
136-
)
129+
with self.maybe_enable_amp:
130+
latent_noise_pred = model(
131+
img=latents,
132+
img_ids=latent_pos_enc,
133+
txt=t5_encodings,
134+
txt_ids=text_pos_enc,
135+
y=clip_encodings,
136+
timesteps=timesteps,
137+
)
137138

138-
# Convert sequence of patches to latent shape
139-
pred = unpack_latents(latent_noise_pred, latent_height, latent_width)
140-
target = noise - labels
141-
loss = self.loss_fn(pred, target)
139+
# Convert sequence of patches to latent shape
140+
pred = unpack_latents(latent_noise_pred, latent_height, latent_width)
141+
target = noise - labels
142+
loss = self.loss_fn(pred, target)
142143
# pred.shape=(bs, seq_len, vocab_size)
143144
# need to free to before bwd to avoid peaking memory
144145
del (pred, noise, target)

torchtitan/tools/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ def check_if_feature_in_pytorch(
166166
# notify users to check if the pull request is included in their pytorch
167167
logger.warning(
168168
"Detected that the pytorch is built from source. Please make sure the PR "
169-
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
169+
f"({pull_request}) is included in pytorch for correct {feature_name}."
170170
)
171171
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
172172
logger.warning(
173173
f"Detected that the pytorch version {torch.__version__} is older than "
174174
f"{min_nightly_version}. Please upgrade a newer version to include the "
175-
f"change in ({pull_request_link}) for correct {feature_name}."
175+
f"change in ({pull_request}) for correct {feature_name}."
176176
)

0 commit comments

Comments
 (0)