Skip to content

Commit 1993aab

Browse files
authored
Add DDP component support in GCP Batch
Differential Revision: D42080776 Pull Request resolved: #669
1 parent cdc9a76 commit 1993aab

File tree

5 files changed

+58
-20
lines changed

5 files changed

+58
-20
lines changed

dev-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ boto3==1.20.24
55
captum>=0.4.0
66
flake8==3.9.0
77
fsspec[s3]==2022.1.0
8-
google-api-core>=2.0.1
9-
google-cloud-batch>=0.3.1
8+
google-api-core
9+
google-cloud-batch>=0.5.0
1010
google-cloud-logging>=3.0.0
1111
google-cloud-runtimeconfig>=0.33.2
1212
hydra-core

scripts/component_integration_tests.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main() -> None:
5151
torchx_image = "dummy_image"
5252
dryrun = False
5353

54-
if scheduler in ("kubernetes", "local_docker", "aws_batch", "lsf"):
54+
if scheduler in ("kubernetes", "local_docker", "aws_batch", "lsf", "gcp_batch"):
5555
try:
5656
build = build_and_push_image()
5757
torchx_image = build.torchx_image
@@ -95,6 +95,13 @@ def main() -> None:
9595
"queue": "torchx",
9696
},
9797
},
98+
"gcp_batch": {
99+
"providers": [
100+
component_provider,
101+
],
102+
"image": torchx_image,
103+
"cfg": {},
104+
},
98105
"ray": {
99106
"providers": [
100107
component_provider,

scripts/gcpbatchint.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ set -ex
99

1010
torchx runopts gcp_batch
1111

12-
APP_ID="$(torchx run --wait --scheduler gcp_batch utils.echo --msg hello)"
12+
APP_ID="$(torchx run --wait --scheduler gcp_batch dist.ddp -j 2x2 --max_retries 3 --script torchx/components/integration_tests/test/dummy_app.py)"
1313
torchx status "$APP_ID"
1414

1515
torchx list -s gcp_batch
@@ -19,3 +19,12 @@ then
1919
echo "expected $APP_ID to be listed"
2020
exit 1
2121
fi
22+
23+
torchx log "$APP_ID"
24+
EXPECTED_MSG="hi from main"
25+
LINES="$(torchx log "$APP_ID" | grep -c "$EXPECTED_MSG")"
26+
if [ "$LINES" -ne 4 ]
27+
then
28+
echo "expected 4 log lines with msg $EXPECTED_MSG"
29+
exit 1
30+
fi

torchx/schedulers/gcp_batch_scheduler.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ def _app_to_job(self, app: AppDef) -> "batch_v1.Job":
182182
img_root="",
183183
app_id=name,
184184
replica_id=str(0),
185-
# TODO set value for rank0_env: TORCHX_RANK0_HOST is a place holder for now
186-
rank0_env=("TORCHX_RANK0_HOST"),
185+
rank0_env=("BATCH_MAIN_NODE_HOSTNAME"),
187186
)
188187
role_dict = values.apply(role)
189188
role_dict.env["TORCHX_ROLE_IDX"] = str(role_idx)
@@ -195,14 +194,12 @@ def _app_to_job(self, app: AppDef) -> "batch_v1.Job":
195194
if cpu <= 0:
196195
cpu = 1
197196
MILLI = 1000
198-
# pyre-ignore [8] : pyre gets confused even when types on both sides of = are int
199197
res.cpu_milli = cpu * MILLI
200198
memMB = resource.memMB
201199
if memMB < 0:
202200
raise ValueError(
203201
f"memMB should to be set to a positive value, got {memMB}"
204202
)
205-
# pyre-ignore [8] : pyre gets confused even when types on both sides of = are int
206203
res.memory_mib = memMB
207204

208205
# TODO support named resources
@@ -226,24 +223,40 @@ def _app_to_job(self, app: AppDef) -> "batch_v1.Job":
226223
)
227224
print(f"Using GPUs of type: {machineType}")
228225

226+
# Configure host firewall rules to accept ingress communication
227+
config_network_runnable = batch_v1.Runnable(
228+
script=batch_v1.Runnable.Script(
229+
text="/sbin/iptables -A INPUT -j ACCEPT"
230+
)
231+
)
232+
229233
runnable = batch_v1.Runnable(
230234
container=batch_v1.Runnable.Container(
231235
image_uri=role_dict.image,
232236
commands=[role_dict.entrypoint] + role_dict.args,
233237
entrypoint="",
238+
# Configure docker to use the host network stack to communicate with containers/other hosts in the same network
239+
options="--net host",
234240
)
235241
)
236242

237243
ts = batch_v1.TaskSpec(
238-
runnables=[runnable],
244+
runnables=[config_network_runnable, runnable],
239245
environment=batch_v1.Environment(variables=role_dict.env),
240246
max_retry_count=role_dict.max_retries,
241247
compute_resource=res,
242248
)
243249

250+
task_env = [
251+
batch_v1.Environment(variables={"TORCHX_REPLICA_IDX": str(i)})
252+
for i in range(role_dict.num_replicas)
253+
]
254+
244255
tg = batch_v1.TaskGroup(
245256
task_spec=ts,
246257
task_count=role_dict.num_replicas,
258+
task_count_per_node=1,
259+
task_environments=task_env,
247260
require_hosts_file=True,
248261
)
249262
taskGroups.append(tg)
@@ -338,37 +351,34 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
338351
return None
339352

340353
gpu = 0
341-
# pyre-fixme [16]: Pyre doesn't properly infer job field types
342354
if len(job.allocation_policy.instances) != 0:
343355
gpu_type = job.allocation_policy.instances[0].policy.machine_type
344356
gpu = GPU_TYPE_TO_COUNT[gpu_type]
345357

346358
roles = {}
347-
# pyre-fixme [16]: Pyre doesn't properly infer job field types
348359
for tg in job.task_groups:
349360
env = tg.task_spec.environment.variables
350361
role = env["TORCHX_ROLE_NAME"]
351-
container = tg.task_spec.runnables[0].container
362+
container = tg.task_spec.runnables[1].container
352363
roles[role] = Role(
353364
name=role,
354365
num_replicas=tg.task_count,
355366
image=container.image_uri,
356367
entrypoint=container.commands[0],
357-
args=container.commands[1:],
368+
args=list(container.commands[1:]),
358369
resource=Resource(
359370
cpu=int(tg.task_spec.compute_resource.cpu_milli / 1000),
360371
memMB=tg.task_spec.compute_resource.memory_mib,
361372
gpu=gpu,
362373
),
363-
env=env,
374+
env=dict(env),
364375
max_retries=tg.task_spec.max_retry_count,
365376
)
366377

367378
# Map job -> DescribeAppResponse
368379
# TODO map role/replica status
369380
desc = DescribeAppResponse(
370381
app_id=app_id,
371-
# pyre-fixme [16]: Pyre doesn't properly infer job field types
372382
state=JOB_STATE[job.status.state.name],
373383
roles=list(roles.values()),
374384
)

torchx/schedulers/test/gcp_batch_scheduler_test.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def test_submit_dryrun(self) -> None:
8181
env["TORCHX_ROLE_NAME"] = "trainer"
8282
env["FOO"] = "bar"
8383
res = batch_v1.ComputeResource()
84-
# pyre-ignore [8] : pyre gets confused even when types on both sides of = are int
8584
res.cpu_milli = 2000
86-
# pyre-ignore [8] : pyre gets confused even when types on both sides of = are int
8785
res.memory_mib = 3000
8886
allocationPolicy = batch_v1.AllocationPolicy(
8987
instances=[
@@ -95,6 +93,9 @@ def test_submit_dryrun(self) -> None:
9593
)
9694
],
9795
)
96+
preRunnable = batch_v1.Runnable(
97+
script=batch_v1.Runnable.Script(text="/sbin/iptables -A INPUT -j ACCEPT")
98+
)
9899
runnable = batch_v1.Runnable(
99100
container=batch_v1.Runnable.Container(
100101
image_uri="pytorch/torchx:latest",
@@ -105,12 +106,13 @@ def test_submit_dryrun(self) -> None:
105106
"--app-id",
106107
"app-name-42",
107108
"--rank0_env",
108-
"TORCHX_RANK0_HOST",
109+
"BATCH_MAIN_NODE_HOSTNAME",
109110
],
111+
options="--net host",
110112
)
111113
)
112114
ts = batch_v1.TaskSpec(
113-
runnables=[runnable],
115+
runnables=[preRunnable, runnable],
114116
environment=batch_v1.Environment(variables=env),
115117
max_retry_count=3,
116118
compute_resource=res,
@@ -119,6 +121,10 @@ def test_submit_dryrun(self) -> None:
119121
tg = batch_v1.TaskGroup(
120122
task_spec=ts,
121123
task_count=1,
124+
task_count_per_node=1,
125+
task_environments=[
126+
batch_v1.Environment(variables={"TORCHX_REPLICA_IDX": "0"})
127+
],
122128
require_hosts_file=True,
123129
)
124130
taskGroups.append(tg)
@@ -261,13 +267,19 @@ def _mock_scheduler(self) -> GCPBatchScheduler:
261267
batch_v1.TaskGroup(
262268
task_spec=batch_v1.TaskSpec(
263269
runnables=[
270+
batch_v1.Runnable(
271+
script=batch_v1.Runnable.Script(
272+
text="/sbin/iptables -A INPUT -j ACCEPT"
273+
)
274+
),
264275
batch_v1.Runnable(
265276
container=batch_v1.Runnable.Container(
266277
image_uri="ghcr.io/pytorch/torchx:0.3.0dev0",
267278
commands=["python"] + ["-c", 'print("hello ")'],
268279
entrypoint="",
280+
options="--net host",
269281
)
270-
)
282+
),
271283
],
272284
compute_resource=batch_v1.ComputeResource(
273285
cpu_milli=8000,

0 commit comments

Comments
 (0)