Skip to content

Commit a42a36d

Browse files
authored
(torchx/aws_batch) Add user tag and runopt. Make list API only return torchx jobs by filtering on job tag (#672)
1 parent a3e15a0 commit a42a36d

File tree

3 files changed

+101
-67
lines changed

3 files changed

+101
-67
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
https://docs.aws.amazon.com/AmazonECR/latest/userguide/getting-started-cli.html#cli-create-repository
3535
for how to create a image repository.
3636
"""
37-
37+
import getpass
3838
import re
3939
import threading
4040
from dataclasses import dataclass
@@ -79,6 +79,11 @@
7979
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
8080
from typing_extensions import TypedDict
8181

82+
TAG_TORCHX_VER = "torchx.pytorch.org/version"
83+
TAG_TORCHX_APPNAME = "torchx.pytorch.org/app-name"
84+
TAG_TORCHX_USER = "torchx.pytorch.org/user"
85+
86+
8287
if TYPE_CHECKING:
8388
from docker import DockerClient
8489

@@ -244,6 +249,7 @@ def _local_session() -> "boto3.session.Session":
244249

245250
class AWSBatchOpts(TypedDict, total=False):
246251
queue: str
252+
user: str
247253
image_repo: Optional[str]
248254
share_id: Optional[str]
249255
priority: Optional[int]
@@ -417,8 +423,9 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
417423
],
418424
},
419425
"tags": {
420-
"torchx.pytorch.org/version": torchx.__version__,
421-
"torchx.pytorch.org/app-name": app.name,
426+
TAG_TORCHX_VER: torchx.__version__,
427+
TAG_TORCHX_APPNAME: app.name,
428+
TAG_TORCHX_USER: cfg.get("user"),
422429
},
423430
},
424431
**(
@@ -455,6 +462,12 @@ def _cancel_existing(self, app_id: str) -> None:
455462
def _run_opts(self) -> runopts:
456463
opts = runopts()
457464
opts.add("queue", type_=str, help="queue to schedule job in", required=True)
465+
opts.add(
466+
"user",
467+
type_=str,
468+
default=getpass.getuser(),
469+
help="The username to tag the job with. `getpass.getuser()` if not specified.",
470+
)
458471
opts.add(
459472
"share_id",
460473
type_=str,
@@ -582,36 +595,50 @@ def list(self) -> List[ListAppResponse]:
582595
for resp in self._client.get_paginator("describe_job_queues").paginate():
583596
queue_names = [queue["jobQueueName"] for queue in resp["jobQueues"]]
584597
for qn in queue_names:
585-
apps_in_queue = self._list_by_queue(qn)
586-
all_apps += [
587-
ListAppResponse(
588-
app_id=f"{qn}:{app['jobName']}",
589-
state=JOB_STATE[app["status"]],
590-
)
591-
for app in apps_in_queue
592-
]
598+
all_apps.extend(self._list_by_queue(qn))
593599
return all_apps
594600

595-
def _list_by_queue(self, queue_name: str) -> List[Dict[str, Any]]:
596-
# By default only running jobs are listed by batch/boto client's list_jobs API
601+
def _list_by_queue(self, queue_name: str) -> List[ListAppResponse]:
602+
# By default, only running jobs are listed by batch/boto client's list_jobs API
597603
# When 'filters' parameter is specified, jobs with all statuses are listed
598604
# So use AFTER_CREATED_AT filter to list jobs in all statuses
599605
# milli_seconds_after_epoch can later be used to list jobs by timeframe
600-
milli_seconds_after_epoch = "1"
601-
job_summary_list = []
606+
MS_AFTER_EPOCH = "1"
607+
EVERY_STATUS = {"name": "AFTER_CREATED_AT", "values": [MS_AFTER_EPOCH]}
608+
609+
jobs = []
602610
for resp in self._client.get_paginator("list_jobs").paginate(
603611
jobQueue=queue_name,
604-
filters=[
605-
{
606-
"name": "AFTER_CREATED_AT",
607-
"values": [
608-
milli_seconds_after_epoch,
609-
],
610-
},
611-
],
612+
filters=[EVERY_STATUS],
613+
# describe-jobs API can take up to 100 jobIds
614+
PaginationConfig={"MaxItems": 100},
612615
):
613-
job_summary_list.extend(resp["jobSummaryList"])
614-
return job_summary_list
616+
617+
# torchx.pytorch.org/version tag is used to filter torchx jobs
618+
# list_jobs() API only returns a job summary which does not include the job's tag
619+
# so we need to call the describe_jobs API.
620+
# Ideally batch lets us pass tags as a filter to list_jobs API
621+
# but this is currently not supported
622+
job_ids = [js["jobId"] for js in resp["jobSummaryList"]]
623+
for jobdesc in self._get_torchx_submitted_jobs(job_ids):
624+
jobs.append(
625+
ListAppResponse(
626+
app_id=f"{queue_name}:{jobdesc['jobName']}",
627+
state=JOB_STATE[jobdesc["status"]],
628+
)
629+
)
630+
631+
return jobs
632+
633+
def _get_torchx_submitted_jobs(self, job_ids: List[str]) -> List[Dict[str, Any]]:
634+
if not job_ids:
635+
return []
636+
637+
return [
638+
jobdesc
639+
for jobdesc in self._client.describe_jobs(jobs=job_ids)["jobs"]
640+
if TAG_TORCHX_VER in jobdesc["tags"]
641+
]
615642

616643
def _stream_events(
617644
self,

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -94,41 +94,50 @@ def test_create_scheduler(self) -> None:
9494
self.assertIsInstance(scheduler, AWSBatchScheduler)
9595

9696
def test_submit_dryrun_with_share_id(self) -> None:
97-
scheduler = create_scheduler("test")
9897
app = _test_app()
9998
cfg = AWSBatchOpts({"queue": "testqueue", "share_id": "fooshare"})
100-
info = scheduler._submit_dryrun(app, cfg)
99+
info = create_scheduler("test").submit_dryrun(app, cfg)
101100

102101
req = info.request
103102
job_def = req.job_def
104103
self.assertEqual(req.share_id, "fooshare")
105104
# must be set for jobs submitted to a queue with scheduling policy
106105
self.assertEqual(job_def["schedulingPriority"], 0)
107106

108-
def test_submit_dryrun_with_priority(self) -> None:
109-
scheduler = create_scheduler("test")
110-
app = _test_app()
111-
107+
def test_submit_dryrun_with_priority_but_not_share_id(self) -> None:
112108
with self.assertRaisesRegex(ValueError, "config value.*priority.*share_id"):
113109
cfg = AWSBatchOpts({"queue": "testqueue", "priority": 42})
114-
info = scheduler._submit_dryrun(app, cfg)
110+
create_scheduler("test").submit_dryrun(_test_app(), cfg)
115111

116-
cfg = AWSBatchOpts(
117-
{"queue": "testqueue", "share_id": "fooshare", "priority": 42}
118-
)
119-
info = scheduler._submit_dryrun(app, cfg)
112+
def test_submit_dryrun_with_priority(self) -> None:
113+
cfg = AWSBatchOpts({"queue": "testqueue", "share_id": "foo", "priority": 42})
114+
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
120115

121116
req = info.request
122117
job_def = req.job_def
123-
self.assertEqual(req.share_id, "fooshare")
118+
self.assertEqual(req.share_id, "foo")
124119
self.assertEqual(job_def["schedulingPriority"], 42)
125120

121+
@patch(
122+
"torchx.schedulers.aws_batch_scheduler.getpass.getuser", return_value="testuser"
123+
)
124+
def test_submit_dryrun_tags(self, _) -> None:
125+
# intentionally not specifying user in cfg to test default
126+
cfg = AWSBatchOpts({"queue": "ignored_in_test"})
127+
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
128+
self.assertEqual(
129+
{
130+
"torchx.pytorch.org/version": torchx.__version__,
131+
"torchx.pytorch.org/app-name": "test",
132+
"torchx.pytorch.org/user": "testuser",
133+
},
134+
info.request.job_def["tags"],
135+
)
136+
126137
@mock_rand()
127138
def test_submit_dryrun(self) -> None:
128-
scheduler = create_scheduler("test")
129-
app = _test_app()
130-
cfg = AWSBatchOpts({"queue": "testqueue"})
131-
info = scheduler._submit_dryrun(app, cfg)
139+
cfg = AWSBatchOpts({"queue": "testqueue", "user": "testuser"})
140+
info = create_scheduler("test").submit_dryrun(_test_app(), cfg)
132141

133142
req = info.request
134143
self.assertEqual(req.share_id, None)
@@ -248,6 +257,7 @@ def test_submit_dryrun(self) -> None:
248257
"tags": {
249258
"torchx.pytorch.org/version": torchx.__version__,
250259
"torchx.pytorch.org/app-name": "test",
260+
"torchx.pytorch.org/user": "testuser",
251261
},
252262
},
253263
)
@@ -440,13 +450,13 @@ def _mock_scheduler(self) -> AWSBatchScheduler:
440450
{
441451
"jobArn": "arn:aws:batch:us-west-2:495572122715:job/6afc27d7-3559-43ca-89fd-1007b6bf2546",
442452
"jobId": "6afc27d7-3559-43ca-89fd-1007b6bf2546",
443-
"jobName": "echo-v1r560pmwn5t3c",
453+
"jobName": "app-name-42",
444454
"createdAt": 1643949940162,
445455
"status": "SUCCEEDED",
446456
"stoppedAt": 1643950324125,
447457
"container": {"exitCode": 0},
448458
"nodeProperties": {"numNodes": 2},
449-
"jobDefinition": "arn:aws:batch:us-west-2:495572122715:job-definition/echo-v1r560pmwn5t3c:1",
459+
"jobDefinition": "arn:aws:batch:us-west-2:495572122715:job-definition/app-name-42:1",
450460
}
451461
]
452462
}
@@ -568,11 +578,7 @@ def _mock_scheduler(self) -> AWSBatchScheduler:
568578
def test_submit(self) -> None:
569579
scheduler = self._mock_scheduler()
570580
app = _test_app()
571-
cfg = AWSBatchOpts(
572-
{
573-
"queue": "testqueue",
574-
}
575-
)
581+
cfg = AWSBatchOpts({"queue": "testqueue"})
576582

577583
info = scheduler._submit_dryrun(app, cfg)
578584
id = scheduler.schedule(info)
@@ -610,34 +616,35 @@ def test_describe(self) -> None:
610616
def test_list(self) -> None:
611617
scheduler = self._mock_scheduler()
612618
expected_apps = [
613-
ListAppResponse(
614-
app_id="torchx:echo-v1r560pmwn5t3c", state=AppState.SUCCEEDED
615-
)
619+
ListAppResponse(app_id="torchx:app-name-42", state=AppState.SUCCEEDED)
616620
]
617621
apps = scheduler.list()
618-
self.assertEqual(apps, expected_apps)
622+
self.assertEqual(expected_apps, apps)
623+
624+
def test_list_no_jobs(self) -> None:
625+
scheduler = AWSBatchScheduler("test", client=MagicMock())
626+
scheduler._client.get_paginator.side_effect = MockPaginator(
627+
describe_job_queues=[
628+
{
629+
"jobQueues": [
630+
{"jobQueueName": "torchx", "state": "ENABLED"},
631+
],
632+
}
633+
],
634+
list_jobs=[{"jobSummaryList": []}],
635+
)
636+
637+
self.assertEqual([], scheduler.list())
619638

620639
def test_log_iter(self) -> None:
621640
scheduler = self._mock_scheduler()
622641
logs = scheduler.log_iter("testqueue:app-name-42", "echo", k=1, regex="foo.*")
623-
self.assertEqual(
624-
list(logs),
625-
[
626-
"foo\n",
627-
"foobar\n",
628-
],
629-
)
642+
self.assertEqual(list(logs), ["foo\n", "foobar\n"])
630643

631644
def test_log_iter_running_job(self) -> None:
632645
scheduler = self._mock_scheduler_running_job()
633646
logs = scheduler.log_iter("testqueue:app-name-42", "echo", k=1, regex="foo.*")
634-
self.assertEqual(
635-
[
636-
"foo\n",
637-
"foobar\n",
638-
],
639-
list(logs),
640-
)
647+
self.assertEqual(["foo\n", "foobar\n"], list(logs))
641648

642649
def test_local_session(self) -> None:
643650
a: object = _local_session()

torchx/specs/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
794794

795795
def cfg_from_str(self, cfg_str: str) -> Dict[str, CfgVal]:
796796
"""
797-
Parses scheduler ``runcfg`` from a string literal and returns
797+
Parses scheduler ``cfg`` from a string literal and returns
798798
a cfg map where the cfg values have been cast into the appropriate
799799
types as specified by this runopts object. Unknown keys are ignored
800800
and not returned in the resulting map.

0 commit comments

Comments
 (0)