Skip to content

Commit 60c2df2

Browse files
authored
[0.9.1][bugfix] pd proxy enhance usability (#2538)
### What this PR does / why we need it? pd proxy enhance usability,add prefiller-hosts-num、prefiller-ports-inc、decoder-hosts-num、decoder-ports-inc paramter in pd proxy Signed-off-by: liziyu <[email protected]>
1 parent 4e578f5 commit 60c2df2

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

examples/disaggregate_prefill_v1/load_balance_proxy_server_example.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,31 @@ def parse_args():
184184
type=str,
185185
nargs="+",
186186
default=["localhost"])
187+
parser.add_argument("--prefiller-hosts-num",
188+
type=int,
189+
nargs="+",
190+
default=None)
187191
parser.add_argument("--prefiller-ports",
188192
type=int,
189193
nargs="+",
190194
default=[8001])
195+
parser.add_argument("--prefiller-ports-inc",
196+
type=int,
197+
nargs="+",
198+
default=None)
191199
parser.add_argument("--decoder-hosts",
192200
type=str,
193201
nargs="+",
194202
default=["localhost"])
203+
parser.add_argument("--decoder-hosts-num",
204+
type=int,
205+
nargs="+",
206+
default=None)
195207
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
208+
parser.add_argument("--decoder-ports-inc",
209+
type=int,
210+
nargs="+",
211+
default=None)
196212
parser.add_argument("--max-retries",
197213
type=int,
198214
default=3,
@@ -209,6 +225,49 @@ def parse_args():
209225
if len(args.decoder_hosts) != len(args.decoder_ports):
210226
raise ValueError(
211227
"Number of decoder hosts must match number of decoder ports")
228+
if args.prefiller_hosts_num is not None and (len(args.prefiller_hosts_num)
229+
!= len(args.prefiller_hosts)):
230+
raise ValueError(
231+
"Number of prefiller hosts num must match number of prefiller hosts"
232+
)
233+
if args.prefiller_ports_inc is not None and (len(args.prefiller_ports_inc)
234+
!= len(args.prefiller_ports)):
235+
raise ValueError(
236+
"Number of prefiller ports inc must match number of prefiller ports"
237+
)
238+
if args.decoder_hosts_num is not None and (len(args.decoder_hosts_num) !=
239+
len(args.decoder_hosts)):
240+
raise ValueError(
241+
"Number of decoder hosts num must match number of decoder hosts")
242+
if args.decoder_ports_inc is not None and (len(args.decoder_ports_inc) !=
243+
len(args.decoder_ports)):
244+
raise ValueError(
245+
"Number of decoder ports inc must match number of decoder ports")
246+
247+
if args.prefiller_hosts_num is not None:
248+
args.prefiller_hosts = [
249+
host for host, num in zip(args.prefiller_hosts,
250+
args.prefiller_hosts_num)
251+
for _ in range(num)
252+
]
253+
if args.prefiller_ports_inc is not None:
254+
args.prefiller_ports = [(int(port) + i) for port, inc in zip(
255+
args.prefiller_ports, args.prefiller_ports_inc)
256+
for i in range(inc)]
257+
258+
if args.decoder_hosts_num is not None:
259+
args.decoder_hosts = [
260+
host
261+
for host, num in zip(args.decoder_hosts, args.decoder_hosts_num)
262+
for _ in range(num)
263+
]
264+
if args.decoder_ports_inc is not None:
265+
args.decoder_ports = [
266+
(int(port) + i)
267+
for port, inc in zip(args.decoder_ports, args.decoder_ports_inc)
268+
for i in range(inc)
269+
]
270+
212271
args.prefiller_instances = list(
213272
zip(args.prefiller_hosts, args.prefiller_ports))
214273
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))

0 commit comments

Comments
 (0)