Skip to content

感谢您的工作,我有一个问题想问您,也许是潜在的bug #75

@wenhaoli-xmu

Description

@wenhaoli-xmu

最小复现代码:

def layer_forward(self, hidden_states):
    residual = hidden_states
    hidden_states = self.input_layernorm(hidden_states)

    position_ids = torch.arange(hidden_states.shape[1], dtype=torch.int64, device=hidden_states.device)[None,:]
    q_idx = position_ids.clone().T
    k_idx = position_ids.clone()
    mask = torch.where(q_idx > k_idx, -float('inf'), 0)[None, None, :, :].to(hidden_states.dtype)

    hidden_states = self.self_attn(hidden_states, attention_mask=mask)
    hidden_states = residual.to(hidden_states.device) + hidden_states
    
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states

    return hidden_states


def self_attn_forward(self, hidden_states, attention_mask):

    num_heads, embed_dim = self.config.num_attention_heads, self.config.hidden_size
    num_kv_heads = self.config.num_key_value_heads
    head_dim = embed_dim // num_heads


    # query & key & value projection
    ques = do_projection(self.q_proj, hidden_states, num_heads, head_dim, head_first=False)
    keys = do_projection(self.k_proj, hidden_states, num_kv_heads, head_dim, head_first=False)
    vals = do_projection(self.v_proj, hidden_states, num_kv_heads, head_dim, head_first=False)

    # position embedding
    pos = torch.arange(0, keys.shape[1])
    pos = pos[None, :].to(keys.device)
    cos, sin = self.rotary_emb(keys, pos)
    ques, keys = check_and_apply_qk_rope(ques, keys, cos, sin)

    attn_output = self.ring_attention(
        query_states=ques,
        key_states=keys,
        value_states=vals,
        attention_mask=attention_mask,
        query_length=ques.shape[1],
        is_causal=True)

    attn_output = attn_output.flatten(2)
    attn_output = self.o_proj(attn_output)

    return attn_output


class ModelForTraining(Modifier):
    def __init__(self, model, save_ckp: str, load_ckp: str, config: str):
        self.get_conf(config)
        
        model.forward = types.MethodType(model_forward, model)
        model.model.forward = types.MethodType(model_model_forward, model.model)
        self.num_layers = len(model.model.layers)

        for layer in model.model.layers:
            layer.forward = types.MethodType(layer_forward, layer)
            ring_attention = create_ring_flash_attention_forward(None, 1)[0]
            layer.self_attn.ring_attention = lambda *args, **kwargs: ring_attention(*args, **kwargs)
            layer.self_attn.forward = types.MethodType(self_attn_forward, layer.self_attn)

        super().__init__(model, save_ckp, load_ckp)

    def forward(self, input_ids, labels):
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        cu_seqlens = torch.arange(input_ids.shape[-1] + 1, dtype=torch.int32, device=rank)
        update_ring_flash_attn_params(cu_seqlens, None)

        input_ids_chunk = torch.chunk(input_ids, world_size, dim=1)[rank]
        logits_chunk = self.model(input_ids=input_ids_chunk)
        labels_chunk = torch.chunk(labels, world_size, dim=1)[rank]

        loss = F.cross_entropy(
            logits_chunk.view(-1, logits_chunk.shape[-1]), 
            labels_chunk.view(-1),
            reduction='mean')
        
        return loss

我是在2xA100 80G上执行训练的,在60k上下文长度以内能够正常训练LLaMA3-8B,并有以下测试结果:

上下文 时间 内存
10240 11.826 7161.762
20480 13.819 14339.225
30720 16.012 19535.596
40960 18.444 24731.967
51200 21.275 29929.838
61440 23.661 35130.709

但是超过60K,就会报错:

[rank0]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f2071379897 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f2071329b25 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f2071451718 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f207264fe36 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7f2072653f38 in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7f20726595ac in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f207265a31c in /mnt/petrelfs/liwenhao/miniconda3/envs/llava/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdbbf4 (0x7f20be0ffbf4 in /mnt/hwfile/liwenhao/miniconda3/envs/llava/bin/../lib/libstdc++.so.6)
frame #8: <unknown function> + 0x7dd5 (0x7f20c606edd5 in /lib64/libpthread.so.0)
frame #9: clone + 0x6d (0x7f20c568eead in /lib64/libc.so.6)

我使用的库版本如下:
transformers=4.45
flash_attention=2.5.8
torch=2.3.0+cu121

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions