-
Notifications
You must be signed in to change notification settings - Fork 94
Open
Description
最小复现代码:
def layer_forward(self, hidden_states):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
position_ids = torch.arange(hidden_states.shape[1], dtype=torch.int64, device=hidden_states.device)[None,:]
q_idx = position_ids.clone().T
k_idx = position_ids.clone()
mask = torch.where(q_idx > k_idx, -float('inf'), 0)[None, None, :, :].to(hidden_states.dtype)
hidden_states = self.self_attn(hidden_states, attention_mask=mask)
hidden_states = residual.to(hidden_states.device) + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def self_attn_forward(self, hidden_states, attention_mask):
num_heads, embed_dim = self.config.num_attention_heads, self.config.hidden_size
num_kv_heads = self.config.num_key_value_heads
head_dim = embed_dim // num_heads
# query & key & value projection
ques = do_projection(self.q_proj, hidden_states, num_heads, head_dim, head_first=False)
keys = do_projection(self.k_proj, hidden_states, num_kv_heads, head_dim, head_first=False)
vals = do_projection(self.v_proj, hidden_states, num_kv_heads, head_dim, head_first=False)
# position embedding
pos = torch.arange(0, keys.shape[1])
pos = pos[None, :].to(keys.device)
cos, sin = self.rotary_emb(keys, pos)
ques, keys = check_and_apply_qk_rope(ques, keys, cos, sin)
attn_output = self.ring_attention(
query_states=ques,
key_states=keys,
value_states=vals,
attention_mask=attention_mask,
query_length=ques.shape[1],
is_causal=True)
attn_output = attn_output.flatten(2)
attn_output = self.o_proj(attn_output)
return attn_output
class ModelForTraining(Modifier):
def __init__(self, model, save_ckp: str, load_ckp: str, config: str):
self.get_conf(config)
model.forward = types.MethodType(model_forward, model)
model.model.forward = types.MethodType(model_model_forward, model.model)
self.num_layers = len(model.model.layers)
for layer in model.model.layers:
layer.forward = types.MethodType(layer_forward, layer)
ring_attention = create_ring_flash_attention_forward(None, 1)[0]
layer.self_attn.ring_attention = lambda *args, **kwargs: ring_attention(*args, **kwargs)
layer.self_attn.forward = types.MethodType(self_attn_forward, layer.self_attn)
super().__init__(model, save_ckp, load_ckp)
def forward(self, input_ids, labels):
world_size = dist.get_world_size()
rank = dist.get_rank()
cu_seqlens = torch.arange(input_ids.shape[-1] + 1, dtype=torch.int32, device=rank)
update_ring_flash_attn_params(cu_seqlens, None)
input_ids_chunk = torch.chunk(input_ids, world_size, dim=1)[rank]
logits_chunk = self.model(input_ids=input_ids_chunk)
labels_chunk = torch.chunk(labels, world_size, dim=1)[rank]
loss = F.cross_entropy(
logits_chunk.view(-1, logits_chunk.shape[-1]),
labels_chunk.view(-1),
reduction='mean')
return loss我是在2xA100 80G上执行训练的,在60k上下文长度以内能够正常训练LLaMA3-8B,并有以下测试结果:
| 上下文 | 时间 | 内存 |
|---|---|---|
| 10240 | 11.826 | 7161.762 |
| 20480 | 13.819 | 14339.225 |
| 30720 | 16.012 | 19535.596 |
| 40960 | 18.444 | 24731.967 |
| 51200 | 21.275 | 29929.838 |
| 61440 | 23.661 | 35130.709 |
但是超过60K,就会报错:
[rank0]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f2071379897 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f2071329b25 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f2071451718 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f207264fe36 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7f2072653f38 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7f20726595ac in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f207265a31c in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdbbf4 (0x7f20be0ffbf4 in /mnt/hwfile/liwenhao/miniconda3/envs/llava/bin/../lib/libstdc++.so.6)
frame #8: <unknown function> + 0x7dd5 (0x7f20c606edd5 in /lib64/libpthread.so.0)
frame #9: clone + 0x6d (0x7f20c568eead in /lib64/libc.so.6)
我使用的库版本如下:
transformers=4.45
flash_attention=2.5.8
torch=2.3.0+cu121
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels