|
34 | 34 | from vllm.utils import get_open_port
|
35 | 35 |
|
36 | 36 |
|
| 37 | +def parse_args(): |
| 38 | + import argparse |
| 39 | + parser = argparse.ArgumentParser(description="Data Parallel Inference") |
| 40 | + parser.add_argument("--model", |
| 41 | + type=str, |
| 42 | + default="ibm-research/PowerMoE-3b", |
| 43 | + help="Model name or path") |
| 44 | + parser.add_argument("--dp-size", |
| 45 | + type=int, |
| 46 | + default=2, |
| 47 | + help="Data parallel size") |
| 48 | + parser.add_argument("--tp-size", |
| 49 | + type=int, |
| 50 | + default=2, |
| 51 | + help="Tensor parallel size") |
| 52 | + parser.add_argument("--node-size", |
| 53 | + type=int, |
| 54 | + default=1, |
| 55 | + help="Total number of nodes") |
| 56 | + parser.add_argument("--node-rank", |
| 57 | + type=int, |
| 58 | + default=0, |
| 59 | + help="Rank of the current node") |
| 60 | + parser.add_argument("--master-addr", |
| 61 | + type=str, |
| 62 | + default="", |
| 63 | + help="Master node IP address") |
| 64 | + parser.add_argument("--master-port", |
| 65 | + type=int, |
| 66 | + default=0, |
| 67 | + help="Master node port") |
| 68 | + return parser.parse_args() |
| 69 | + |
| 70 | + |
37 | 71 | def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
38 | 72 | dp_master_port, GPUs_per_dp_rank):
|
39 | 73 | os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
@@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
95 | 129 |
|
96 | 130 |
|
97 | 131 | if __name__ == "__main__":
|
98 |
| - import argparse |
99 |
| - parser = argparse.ArgumentParser(description="Data Parallel Inference") |
100 |
| - parser.add_argument("--model", |
101 |
| - type=str, |
102 |
| - default="ibm-research/PowerMoE-3b", |
103 |
| - help="Model name or path") |
104 |
| - parser.add_argument("--dp-size", |
105 |
| - type=int, |
106 |
| - default=2, |
107 |
| - help="Data parallel size") |
108 |
| - parser.add_argument("--tp-size", |
109 |
| - type=int, |
110 |
| - default=2, |
111 |
| - help="Tensor parallel size") |
112 |
| - parser.add_argument("--node-size", |
113 |
| - type=int, |
114 |
| - default=1, |
115 |
| - help="Total number of nodes") |
116 |
| - parser.add_argument("--node-rank", |
117 |
| - type=int, |
118 |
| - default=0, |
119 |
| - help="Rank of the current node") |
120 |
| - parser.add_argument("--master-addr", |
121 |
| - type=str, |
122 |
| - default="", |
123 |
| - help="Master node IP address") |
124 |
| - parser.add_argument("--master-port", |
125 |
| - type=int, |
126 |
| - default=0, |
127 |
| - help="Master node port") |
128 |
| - args = parser.parse_args() |
| 132 | + |
| 133 | + args = parse_args() |
129 | 134 |
|
130 | 135 | dp_size = args.dp_size
|
131 | 136 | tp_size = args.tp_size
|
|
0 commit comments