|
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
| 6 | +import math |
6 | 7 | import os |
7 | | -from typing import Callable |
| 8 | +from typing import Callable, Optional |
8 | 9 |
|
9 | 10 | from torch.distributed.pipelining.schedules import ( |
10 | 11 | _PipelineSchedule, |
|
25 | 26 | # TODO: It's unclear if this API is general enough to be used by other models. |
26 | 27 | # If not, we should move it to a Transformer-specific directory. |
27 | 28 | def generate_split_points( |
28 | | - pipeline_parallel_schedule: str, pp_dim: int, num_layers: int |
| 29 | + schedule_str: str, |
| 30 | + layers_per_stage: Optional[int], |
| 31 | + pp_dim: int, |
| 32 | + num_layers: int, |
| 33 | + input_weight: int = 1, |
| 34 | + output_weight: int = 1, |
29 | 35 | ) -> list[str]: |
30 | 36 | """ |
31 | | - Generate a default split point based on the number of layers and |
32 | | - pipeline parallel dimension. |
| 37 | + Generate a list of split points based on the number of layers and |
| 38 | + pipeline parallel dimension, ensuring the first and last stages have the least layers. |
33 | 39 |
|
34 | 40 | Args: |
35 | | - job_config (JobConfig): The job configuration. |
| 41 | + schedule_str (str): The string of the schedule name. |
| 42 | + layers_per_stage (int): The number of layers per stage. |
36 | 43 | pp_dim (int): The pipeline parallel dimension. |
37 | 44 | num_layers (int): The number of layers in the model. |
| 45 | + input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation. |
38 | 46 |
|
39 | 47 | Returns: |
40 | 48 | list[str]: A list of split point FQNs. |
41 | 49 | """ |
42 | 50 |
|
43 | | - schedule_class = get_schedule_class(pipeline_parallel_schedule) |
44 | | - if issubclass(schedule_class, PipelineScheduleSingle): |
45 | | - num_stages_per_rank = 1 |
46 | | - elif issubclass(schedule_class, PipelineScheduleMulti): |
47 | | - # Multi-stage schedules support more than 2 stages per rank, but this is the default if |
48 | | - # no pipeline split is specified |
49 | | - num_stages_per_rank = 2 |
| 51 | + schedule_class = get_schedule_class(schedule_str) |
| 52 | + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) |
| 53 | + num_stages_per_rank = 1 if is_single_stage_schedule else 2 |
| 54 | + |
| 55 | + if layers_per_stage is not None: |
| 56 | + total_stages = math.ceil(num_layers / layers_per_stage) |
| 57 | + if total_stages % pp_dim != 0: |
| 58 | + raise ValueError( |
| 59 | + f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})." |
| 60 | + f"Each rank should have the same number of stages. " |
| 61 | + ) |
| 62 | + num_stages_per_rank = total_stages // pp_dim |
| 63 | + |
| 64 | + if is_single_stage_schedule and num_stages_per_rank != 1: |
| 65 | + raise ValueError( |
| 66 | + f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules." |
| 67 | + ) |
| 68 | + elif not is_single_stage_schedule and num_stages_per_rank < 2: |
| 69 | + raise ValueError( |
| 70 | + f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules." |
| 71 | + ) |
50 | 72 | else: |
51 | | - raise ValueError(f"Unsupported pipeline schedule: {pipeline_parallel_schedule}") |
52 | | - total_stages = pp_dim * num_stages_per_rank |
53 | | - if total_stages > num_layers: |
54 | | - raise ValueError("Total stages cannot be greater than the number of layers") |
55 | | - |
56 | | - base_interval = num_layers // total_stages |
57 | | - extra_layers = num_layers % total_stages |
58 | | - |
59 | | - splits = [] |
60 | | - current_layer = 0 |
61 | | - for i in range(total_stages - 1): |
62 | | - if i == 0: |
63 | | - current_layer += base_interval |
64 | | - else: |
65 | | - # Middle stages get an extra layer if there are any remaining |
66 | | - if extra_layers > 0: |
67 | | - current_layer += base_interval + 1 |
68 | | - extra_layers -= 1 |
69 | | - else: |
70 | | - current_layer += base_interval |
71 | | - splits.append("layers." + str(current_layer)) |
| 73 | + total_stages = pp_dim * num_stages_per_rank |
| 74 | + if total_stages > num_layers: |
| 75 | + raise ValueError("Total stages cannot be greater than the number of layers") |
| 76 | + |
| 77 | + # Calculate effective number of layers including input and output weights |
| 78 | + effective_num_layers = num_layers + input_weight + output_weight |
| 79 | + base_layers_per_stage = effective_num_layers // total_stages |
| 80 | + |
| 81 | + splits = [""] * (total_stages - 1) |
| 82 | + current_layer_index = 0 |
| 83 | + |
| 84 | + # First stage |
| 85 | + layers_on_first_stage = max(0, base_layers_per_stage - input_weight) |
| 86 | + current_layer_index += layers_on_first_stage |
| 87 | + splits[0] = "layers." + str(current_layer_index) |
| 88 | + |
| 89 | + # Last stage |
| 90 | + layers_on_last_stage = max(0, base_layers_per_stage - output_weight) |
| 91 | + splits[-1] = "layers." + str(num_layers - layers_on_last_stage) |
| 92 | + |
| 93 | + # Middle stages |
| 94 | + remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1 |
| 95 | + middle_stages = len(splits) - 2 |
| 96 | + layers_per_middle_stage = remaining_layers // middle_stages |
| 97 | + # split remainder evenly across middle stages |
| 98 | + remainder = remaining_layers % middle_stages |
| 99 | + |
| 100 | + for i in range(1, middle_stages + 1): |
| 101 | + current_layer_index += layers_per_middle_stage |
| 102 | + if remainder > 0: |
| 103 | + current_layer_index += 1 |
| 104 | + remainder -= 1 |
| 105 | + splits[i] = "layers." + str(current_layer_index) |
| 106 | + |
72 | 107 | logger.info( |
73 | 108 | f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} " |
74 | 109 | "This may be sub-optimal as the number of layers per stage may be unbalanced." |
|
0 commit comments