diff --git a/dist_run.py b/dist_run.py index fc580ea2a..3fbb857c7 100644 --- a/dist_run.py +++ b/dist_run.py @@ -273,13 +273,11 @@ def main(args): pp_rank = pp_mesh.get_local_rank() tp_group = tp_mesh.get_group() pp_group = pp_mesh.get_group() - pp_group_size = pp_group.size() - tp_group_size = tp_group.size() - logger.info(f"{pp_group_size=}, {tp_group_size=}") + logger.info(f"{pp_degree=}, {tp_degree=}") # Convenience variables first_pp_rank = 0 - last_pp_rank = pp_group_size - 1 + last_pp_rank = pp_degree - 1 # Assuming same number of GPUs per node device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") @@ -297,18 +295,22 @@ def main(args): if rank == 0: logger.info(f"Model: {model}") - mbs = 1 # number of micro-batches - mb_size = 4 # micro-batch size - batch_size = mbs * mb_size # total batch size - + # Batch size. Since we push batches dynamically through the pipeline rather + # than chunking them, this is effectively micro-batch size in pipeline + # sense. Thus it is interchangeable with micro-batch size below. + batch_size = 4 seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension # Setup KV caches (after model distribution) - # TODO: the setting below only works for 1 micro-batch case. To support - # multiple micro-batches, we need the KV cache in the model to be aware of - # the number of micro-batches and the current micro-batch index. - model.setup_caches(mb_size, seqlen_prefill) + # The number of cache lanes is the same as the maximum number of + # micro-batches that can be "in flight" in parallel -- imagine each + # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. + # When decoding is done for certain micro-batches, we can reuse the KV cache + # lanes. + # TODO: bump up the lane count + pipeline_lanes = 1 + model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") @@ -317,7 +319,7 @@ def main(args): model.to(device) logger.info( - f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" ) # info on stage size and params @@ -330,17 +332,16 @@ def main(args): # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen input_pos = torch.arange(seqlen_prefill, device=device) - model.setup_input_pos(input_pos) model.eval() # Helper function to get example inputs and outputs for the stages. def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: - mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) + mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) activation = torch.rand( - mb_size, seqlen, dim, device=device, dtype=model_dtype + batch_size, seqlen, dim, device=device, dtype=model_dtype ) logits = torch.rand( - mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype + batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype ) example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) example_outputs = (logits if pp_rank == last_pp_rank else activation,) @@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: output_args=example_outputs, group=pp_group, ) - # create schedule - prefill_schedule = ScheduleGPipe(prefill_stage, mbs) + + # Create schedule + # Number of micro-batches for the schedule is 1, because each step() call we + # only push 1 micro-batch into the pipeline. But we can continuously push + # new micro-batches into the pipeline as they arrive, achieving same + # pipelining effect. + prefiller = ScheduleGPipe(prefill_stage, 1) prompt = [ "What is a computer?", @@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: s = set(prompt_lengths) assert len(s) == 1, f"prompt_lengths should be the same, got {s}" - # with CUDATrackTime() as timer: # Need these global ids due to the API definition of dist.send and recv first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) @@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: num_tokens = 40 # Prefill phase - # Run context input through pipeline, in 1 step - with torch.no_grad(): + # Run context input through pipeline + # TODO: we need to pass `input_pos` and `cache_lane` to each stage. + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} + with torch.no_grad(), CUDATrackTime() as timer: if pp_rank == first_pp_rank: - output = prefill_schedule.step(padded_sequence) + output = prefiller.step(padded_sequence, **kwargs) elif pp_rank == last_pp_rank: - output = prefill_schedule.step() + output = prefiller.step(**kwargs) else: # middle pp ranks - prefill_schedule.step() + prefiller.step(**kwargs) + + logger.info( + f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) # Decode the output -- first generated token if pp_rank == last_pp_rank: @@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # seqlen = 1 now seqlen_decode = 1 input_pos = torch.tensor([prompt_lengths[0]], device=device) - model.setup_input_pos(input_pos) # Create decode stage logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") @@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: group=pp_group, ) # create schedule - decode_schedule = ScheduleGPipe(decode_stage, mbs) + decorder = ScheduleGPipe(decode_stage, 1) # Decoding - with torch.no_grad(): + with torch.no_grad(), CUDATrackTime() as timer: for step in range(num_tokens - 1): + kwargs = {"input_pos": input_pos, "cache_lane": lane} # sendrecv between last and first ranks, only if: # first_pp_rank != last_pp_rank. if pp_rank == last_pp_rank and pp_rank != first_pp_rank: @@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Run data through pipeline if pp_rank == first_pp_rank: - output = decode_schedule.step(new_token) + output = decorder.step(new_token, **kwargs) elif pp_rank == last_pp_rank: - output = decode_schedule.step() + output = decorder.step(**kwargs) else: # middle pp ranks - decode_schedule.step() + decorder.step(**kwargs) # Decode the output if pp_rank == last_pp_rank: @@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: ) # decode_results[i][0] input_pos += 1 - model.setup_input_pos(input_pos) + + logger.info( + f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) # Display the decoding results diff --git a/torchchat/export.py b/torchchat/export.py index affb8b871..263c3815a 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -152,9 +152,9 @@ def __init__(self, attention: Attention): self.wo = attention.wo max_batch_size, n_heads, max_seq_length, head_dim = ( - attention.kv_cache.k_cache.shape + attention.kv_cache[0].k_cache.shape ) - cache_dtype = attention.kv_cache.k_cache.dtype + cache_dtype = attention.kv_cache[0].k_cache.dtype self.kv_cache = CustomKVCache( max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype ) diff --git a/torchchat/model.py b/torchchat/model.py index aaa72cb2a..228b97c3d 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 - def setup_caches(self, max_batch_size, max_seq_length): + def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): if ( self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size @@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length): # parallelism may have been applied there and the `n_local_heads`` # value being adjusted. b.attention.setup_cache( - max_batch_size, max_seq_length, + max_batch_size, max_seq_length, cache_lanes=cache_lanes ) freqs_cis = precompute_freqs_cis( @@ -653,22 +653,15 @@ def distribute(self, device_mesh: DeviceMesh): ColwiseParallel(output_layouts=Replicate()), ) - # This is a temporary solution to pass input_pos to non-0 pipeline stages - # TODO: make `step()` function of dist.pipelining accept args for non-0 stages - def setup_input_pos(self, input_pos: Tensor) -> None: - self._input_pos = input_pos - - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" - # TODO: find a better way to pass input_pos to non-0 pipeline stages - input_pos = input_pos if input_pos is not None else self._input_pos mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] if self.tok_embeddings: x = self.tok_embeddings(x) for _, layer in self.layers.items(): - x = layer(x, input_pos, freqs_cis, mask) + x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane) if self.norm: x = self.norm(x) @@ -691,7 +684,7 @@ def distribute(self, device_mesh: DeviceMesh): self.feed_forward.distribute(device_mesh) def forward( - self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0 ) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) out = h + self.feed_forward(self.ffn_norm(h)) @@ -723,15 +716,16 @@ def __init__(self, config: TransformerArgs): self.dim = config.dim self._register_load_state_dict_pre_hook(self.load_hook) - def setup_cache(self, max_batch_size, max_seq_length): + def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1): n_local_heads = self.n_local_heads # If TP is enabled, the heads would be divided and assigned to different ranks if hasattr(self, "tp_degree"): n_local_heads = self.n_local_heads // self.tp_degree - self.kv_cache = KVCache( - max_batch_size, max_seq_length, n_local_heads, self.head_dim - ) + self.kv_cache = nn.ModuleList([ + KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim) + for _ in range(cache_lanes) + ]) def load_hook(self, state_dict, prefix, *args): # if prefix + "wq.weight" in state_dict: @@ -784,6 +778,7 @@ def forward( freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, + cache_lane: int = 0, ) -> Tensor: bsz, seqlen, _ = x.shape @@ -809,7 +804,7 @@ def forward( q, k, v = (x.transpose(1, 2) for x in (q, k, v)) if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) + k, v = self.kv_cache[cache_lane].update(input_pos, k, v) k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)