Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions repro.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh
rm -rf outputs/checkpoint/step-30 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh
25 changes: 24 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def dcp_load(
state_dict: dict[str, Any],
checkpoint_id: str,
from_hf: bool,
step: int = -1,
) -> None:
"""Load the checkpoint with dcp.
Args:
Expand All @@ -420,11 +421,21 @@ def dcp_load(
state_dict = self.sd_adapter.from_hf(hf_state_dict)
self.states[MODEL].load_state_dict(state_dict)
else:
before_load = state_dict['tok_embeddings.weight']._local_tensor
dcp.load(state_dict, checkpoint_id=checkpoint_id)
after_load = state_dict['tok_embeddings.weight']._local_tensor
try:
assert torch.equal(before_load, after_load)
# dcp.load(state_dict, checkpoint_id=checkpoint_id)
except:
import fbvscode
fbvscode.set_trace()

# TODO: Since we flatten the model states in state_dict, we need to
# manually call load_state_dict() for the model. Need to fix this.
if MODEL in self.states:
# import fbvscode
# fbvscode.set_trace()
self.states[MODEL].load_state_dict(state_dict)

@torch.no_grad()
Expand Down Expand Up @@ -488,6 +499,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
enable_garbage_collection=True,
)
self._purge_stale_checkpoints()
# import fbvscode
# fbvscode.set_trace()
self.load(step=-1)

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
Expand Down Expand Up @@ -522,6 +536,7 @@ def load(self, step: int = -1) -> bool:

model_only = False
from_hf = False
# torch.distributed.breakpoint()
if not os.path.exists(self.folder):
model_only = self.initial_load_model_only
from_hf = self.initial_load_in_hf
Expand Down Expand Up @@ -576,11 +591,18 @@ def load(self, step: int = -1) -> bool:
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
before_load = states['tok_embeddings.weight']._local_tensor
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
from_hf=from_hf,
)
after_load = states['tok_embeddings.weight']._local_tensor
try:
assert torch.equal(before_load, after_load)
except:
import fbvscode
fbvscode.set_trace()
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -699,7 +721,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
}

states_to_load = self._flattened_model_states_sd(states_to_load)
# states_to_load = self._flattened_model_states_sd(states_to_load)
states_to_load = self._flattened_model_states_sd()

if self.ft_manager:
states_to_load.pop(DATALOADER)
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def _init_backend(cls) -> None:

# Add CuDNN on B200 w/ highest priority
cls.backends = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
# SDPBackend.FLASH_ATTENTION,
# SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
if has_cuda_capability(10, 0):
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
llama3_configs = {
"debugmodel": TransformerModelArgs(
dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000
# dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000
),
"debugmodel_flex_attn": TransformerModelArgs(
dim=256,
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ min_lr_factor = 0.0
local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 30
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand All @@ -55,9 +55,9 @@ pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = false
enable = true
folder = "checkpoint"
interval = 10
interval = 15
last_save_model_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Expand Down
149 changes: 139 additions & 10 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from typing import Any, Generator, Iterable, Optional

import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.protocols.train_spec as train_spec_module
Expand Down Expand Up @@ -154,8 +157,11 @@ def __init__(self, job_config: JobConfig):
logger.info(
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
)
with torch.device("meta"):
with torch.device("cuda"):
# import fbvscode
# fbvscode.set_trace()
model = self.train_spec.model_cls(model_args)
# model = torch.nn.Linear(1024, 1024, device="cuda")

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
Expand Down Expand Up @@ -257,15 +263,19 @@ def __init__(self, job_config: JobConfig):
ensure_pp_loss_visible(parallel_dims, job_config, color)
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
import copy
model.init_weights()
ref_model = copy.deepcopy(model)
ref_model = self.train_spec.parallelize_fn(ref_model, parallel_dims, job_config)
ref_model.train()
self.ref_model_parts = [ref_model]

model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)
model.train()

self.model_parts = [model]



self.ft_manager.maybe_set_all_reduce_hook(self.model_parts)

# initialize device memory monitor and get peak flops for MFU calculation
Expand Down Expand Up @@ -294,6 +304,19 @@ def __init__(self, job_config: JobConfig):
self.model_parts
)
)

self.ref_optimizers = self.train_spec.build_optimizers_fn(
self.ref_model_parts, job_config.optimizer, parallel_dims, self.ft_manager
)
self.ref_lr_schedulers = self.train_spec.build_lr_schedulers_fn(
self.ref_optimizers, job_config.lr_scheduler, job_config.training.steps
)
self.ref_optimizers.register_step_post_hook(
lambda *args, **kwargs: model_converters.post_optimizer_hook(
self.ref_model_parts
)
)

self.metrics_processor.optimizers = self.optimizers
self.metrics_processor.model_parts = self.model_parts

Expand All @@ -320,6 +343,24 @@ def __init__(self, job_config: JobConfig):
ft_manager=self.ft_manager,
)

self.ref_checkpointer = CheckpointManager(
dataloader=self.dataloader,
model_parts=self.ref_model_parts,
optimizers=self.ref_optimizers,
lr_schedulers=self.ref_lr_schedulers,
states={"train_state": self},
checkpoint_config=job_config.checkpoint,
sd_adapter=(
self.train_spec.state_dict_adapter(
model_args, job_config.model.hf_assets_path
)
if self.train_spec.state_dict_adapter
else None
),
base_folder=job_config.job.dump_folder,
ft_manager=self.ft_manager,
)

loss_parallel_enabled = (
parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel
)
Expand Down Expand Up @@ -460,16 +501,32 @@ def forward_backward_step(
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)

import copy
ref_inputs = copy.deepcopy(inputs)
ref_pred = self.ref_model_parts[0](ref_inputs)
ref_loss = self.loss_fn(ref_pred, labels)

try:
assert torch.equal(pred, ref_pred)
assert torch.equal(loss, ref_loss)
except:
import fbvscode
fbvscode.set_trace()

# need to free to before bwd to avoid peaking memory
del pred
del ref_pred
loss.backward()
ref_loss.backward()

return loss
return loss, ref_loss

def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
self.optimizers.zero_grad()
self.ref_optimizers.zero_grad()
# Save the current step learning rate for logging
lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]

Expand All @@ -478,25 +535,73 @@ def train_step(
parallel_dims = self.parallel_dims

accumulated_losses = []
ref_accumulated_losses = []
# If data runs out during gradient accumulation, that
# entire step will not be executed.
for microbatch in range(self.gradient_accumulation_steps):
# import fbvscode
# fbvscode.set_trace()
for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()):
full_param = param.full_tensor()
ref_full_param = ref_param.full_tensor()
try:
assert torch.equal(full_param, ref_full_param)
except:
import fbvscode
fbvscode.set_trace()

input_dict, labels = next(data_iterator)
loss = self.forward_backward_step(input_dict, labels)
loss, ref_loss = self.forward_backward_step(input_dict, labels)
accumulated_losses.append(loss.detach())
ref_accumulated_losses.append(ref_loss.detach())

for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()):
full_param = param.full_tensor()
ref_full_param = ref_param.full_tensor()
full_param_grad = param.grad.full_tensor()
ref_full_param_grad = ref_param.grad.full_tensor()
try:

assert torch.equal(full_param, ref_full_param)
except:
import fbvscode
fbvscode.set_trace()
try:
assert torch.equal(full_param_grad, ref_full_param_grad)
except:
import fbvscode
fbvscode.set_trace()


grad_norm = dist_utils.clip_grad_norm_(
[p for m in self.model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=True,
foreach=False,
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
),
ep_enabled=parallel_dims.ep_enabled,
)
ref_grad_norm = dist_utils.clip_grad_norm_(
[p for m in self.ref_model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=False,
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
),
ep_enabled=parallel_dims.ep_enabled,
)
try:
assert torch.equal(grad_norm, ref_grad_norm)
except:
import fbvscode
fbvscode.set_trace()
self.checkpointer.maybe_wait_for_staging()
self.ref_checkpointer.maybe_wait_for_staging()
self.optimizers.step()
self.lr_schedulers.step()
self.ref_optimizers.step()
self.ref_lr_schedulers.step()

# Reduce the data collected over gradient accumulation steps.
loss = torch.sum(torch.stack(accumulated_losses))
Expand Down Expand Up @@ -539,7 +644,11 @@ def train_step(
def train(self):
job_config = self.job_config

self.checkpointer.load(step=job_config.checkpoint.load_step)
# self.checkpointer.load(step=job_config.checkpoint.load_step)
# self.ref_checkpointer.load(step=job_config.checkpoint.load_step)

# import fbvscode
# fbvscode.set_trace()
logger.info(f"Training starts at step {self.step + 1}")

leaf_folder = (
Expand Down Expand Up @@ -590,6 +699,26 @@ def train(self):
self.checkpointer.save(
self.step, last_step=(self.step == job_config.training.steps)
)
# for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()):
# full_param = param.full_tensor()
# ref_full_param = ref_param.full_tensor()
# try:
# assert torch.equal(full_param, ref_full_param)
# except:
# import fbvscode
# fbvscode.set_trace()
self.checkpointer.load(step=job_config.checkpoint.load_step)
for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()):
full_param = param.full_tensor()
ref_full_param = ref_param.full_tensor()
local_param = param._local_tensor
ref_local_param = ref_param._local_tensor
try:
assert torch.equal(local_param, ref_local_param)
assert torch.equal(full_param, ref_full_param)
except:
import fbvscode
fbvscode.set_trace()

# Run validation if validator is available
if (
Expand Down
Loading