Skip to content

Commit dc70d90

Browse files
authored
(torchx/runopt) Allow runopt type to be builtin list[str] and dict[str,str]
Differential Revision: D78767495 Pull Request resolved: #1093
1 parent 4adf7f6 commit dc70d90

File tree

4 files changed

+104
-35
lines changed

4 files changed

+104
-35
lines changed

torchx/runner/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,14 @@ def dump(
278278
continue
279279

280280
# serialize list elements with `;` delimiter (consistent with torchx cli)
281-
if opt.opt_type == List[str]:
281+
if opt.is_type_list_of_str:
282282
# deal with empty or None default lists
283283
if opt.default:
284284
# pyre-ignore[6] opt.default type checked already as List[str]
285285
val = ";".join(opt.default)
286286
else:
287287
val = _NONE
288-
elif opt.opt_type == Dict[str, str]:
288+
elif opt.is_type_dict_of_str:
289289
# deal with empty or None default lists
290290
if opt.default:
291291
# pyre-ignore[16] opt.default type checked already as Dict[str, str]
@@ -536,26 +536,26 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
536536
# this also handles empty or None lists
537537
cfg[name] = None
538538
else:
539-
runopt = runopts.get(name)
539+
opt = runopts.get(name)
540540

541-
if runopt is None:
541+
if opt is None:
542542
log.warning(
543543
f"`{name} = {value}` was declared in the [{section}] section "
544544
f" of the config file but is not a runopt of `{scheduler}` scheduler."
545545
f" Remove the entry from the config file to no longer see this warning"
546546
)
547547
else:
548-
if runopt.opt_type is bool:
548+
if opt.opt_type is bool:
549549
# need to handle bool specially since str -> bool is based on
550550
# str emptiness not value (e.g. bool("False") == True)
551551
cfg[name] = config.getboolean(section, name)
552-
elif runopt.opt_type is List[str]:
552+
elif opt.is_type_list_of_str:
553553
cfg[name] = value.split(";")
554-
elif runopt.opt_type is Dict[str, str]:
554+
elif opt.is_type_dict_of_str:
555555
cfg[name] = {
556556
s.split(":", 1)[0]: s.split(":", 1)[1]
557557
for s in value.replace(",", ";").split(";")
558558
}
559559
else:
560560
# pyre-ignore[29]
561-
cfg[name] = runopt.opt_type(value)
561+
cfg[name] = opt.opt_type(value)

torchx/runner/test/config_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,34 @@ def _run_opts(self) -> runopts:
9595
)
9696
opts.add(
9797
"l",
98-
type_=List[str],
98+
type_=list[str],
9999
default=["a", "b", "c"],
100100
help="a list option",
101101
)
102102
opts.add(
103-
"l_none",
103+
"l_typing",
104104
type_=List[str],
105+
default=["a", "b", "c"],
106+
help="a typing.List option",
107+
)
108+
opts.add(
109+
"l_none",
110+
type_=list[str],
105111
default=None,
106112
help="a None list option",
107113
)
108114
opts.add(
109115
"d",
110-
type_=Dict[str, str],
116+
type_=dict[str, str],
111117
default={"foo": "bar"},
112118
help="a dict option",
113119
)
120+
opts.add(
121+
"d_typing",
122+
type_=Dict[str, str],
123+
default={"foo": "bar"},
124+
help="a typing.Dict option",
125+
)
114126
opts.add(
115127
"d_none",
116128
type_=Dict[str, str],
@@ -151,6 +163,10 @@ def _run_opts(self) -> runopts:
151163
[test]
152164
s = my_default
153165
i = 100
166+
l = abc;def
167+
l_typing = ghi;jkl
168+
d = a:b,c:d
169+
d_typing = e:f,g:h
154170
"""
155171

156172
_MY_CONFIG2 = """#
@@ -387,6 +403,10 @@ def test_apply_dirs(self, _) -> None:
387403
self.assertEqual("runtime_value", cfg.get("s"))
388404
self.assertEqual(100, cfg.get("i"))
389405
self.assertEqual(1.2, cfg.get("f"))
406+
self.assertEqual({"a": "b", "c": "d"}, cfg.get("d"))
407+
self.assertEqual({"e": "f", "g": "h"}, cfg.get("d_typing"))
408+
self.assertEqual(["abc", "def"], cfg.get("l"))
409+
self.assertEqual(["ghi", "jkl"], cfg.get("l_typing"))
390410

391411
def test_dump_invalid_scheduler(self) -> None:
392412
with self.assertRaises(ValueError):
@@ -460,7 +480,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None:
460480

461481
# all runopts in the TestScheduler have defaults, just check against those
462482
for opt_name, opt in TestScheduler("test").run_opts():
463-
self.assertEqual(cfg.get(opt_name), opt.default)
483+
self.assertEqual(opt.default, cfg.get(opt_name))
464484

465485
def test_dump_and_load_all_registered_schedulers(self) -> None:
466486
# dump all the runopts for all registered schedulers

torchx/specs/api.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,60 @@ class runopt:
789789
is_required: bool
790790
help: str
791791

792+
@property
793+
def is_type_list_of_str(self) -> bool:
794+
"""
795+
Checks if the option type is a list of strings.
796+
797+
Returns:
798+
bool: True if the option type is either List[str] or list[str], False otherwise.
799+
"""
800+
return self.opt_type in (List[str], list[str])
801+
802+
@property
803+
def is_type_dict_of_str(self) -> bool:
804+
"""
805+
Checks if the option type is a dict of string keys to string values.
806+
807+
Returns:
808+
bool: True if the option type is either Dict[str, str] or dict[str, str], False otherwise.
809+
"""
810+
return self.opt_type in (Dict[str, str], dict[str, str])
811+
812+
def cast_to_type(self, value: str) -> CfgVal:
813+
"""Casts the given `value` (in its string representation) to the type of this run option.
814+
Below are the cast rules for each option type and value literal:
815+
816+
1. opt_type=str, value="foo" -> "foo"
817+
1. opt_type=bool, value="True"/"False" -> True/False
818+
1. opt_type=int, value="1" -> 1
819+
1. opt_type=float, value="1.1" -> 1.1
820+
1. opt_type=list[str]/List[str], value="a,b,c" or value="a;b;c" -> ["a", "b", "c"]
821+
1. opt_type=dict[str,str]/Dict[str,str],
822+
value="key1:val1,key2:val2" or value="key1:val1;key2:val2" -> {"key1": "val1", "key2": "val2"}
823+
824+
NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
825+
at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
826+
primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/pytorch/torchx/pull/855
827+
"""
828+
829+
if self.opt_type is None:
830+
raise ValueError("runopt's opt_type cannot be `None`")
831+
elif self.opt_type == bool:
832+
return value.lower() == "true"
833+
elif self.opt_type in (List[str], list[str]):
834+
# lists may be ; or , delimited
835+
# also deal with trailing "," by removing empty strings
836+
return [v for v in value.replace(";", ",").split(",") if v]
837+
elif self.opt_type in (Dict[str, str], dict[str, str]):
838+
return {
839+
s.split(":", 1)[0]: s.split(":", 1)[1]
840+
for s in value.replace(";", ",").split(",")
841+
}
842+
else:
843+
assert self.opt_type in (str, int, float)
844+
return self.opt_type(value)
845+
792846

793847
class runopts:
794848
"""
@@ -948,27 +1002,11 @@ def cfg_from_str(self, cfg_str: str) -> Dict[str, CfgVal]:
9481002
9491003
"""
9501004

951-
def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
952-
if opt_type == bool:
953-
return value.lower() == "true"
954-
elif opt_type in (List[str], list[str]):
955-
# lists may be ; or , delimited
956-
# also deal with trailing "," by removing empty strings
957-
return [v for v in value.replace(";", ",").split(",") if v]
958-
elif opt_type in (Dict[str, str], dict[str, str]):
959-
return {
960-
s.split(":", 1)[0]: s.split(":", 1)[1]
961-
for s in value.replace(";", ",").split(",")
962-
}
963-
else:
964-
# pyre-ignore[19, 6] type won't be dict here as we handled it above
965-
return opt_type(value)
966-
9671005
cfg: Dict[str, CfgVal] = {}
9681006
for key, val in to_dict(cfg_str).items():
969-
runopt_ = self.get(key)
970-
if runopt_:
971-
cfg[key] = _cast_to_type(val, runopt_.opt_type)
1007+
opt = self.get(key)
1008+
if opt:
1009+
cfg[key] = opt.cast_to_type(val)
9721010
else:
9731011
logger.warning(
9741012
f"{YELLOW_BOLD}Unknown run option passed to scheduler: {key}={val}{RESET}"
@@ -982,16 +1020,16 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
9821020
cfg: Dict[str, CfgVal] = {}
9831021
cfg_dict = json.loads(json_repr)
9841022
for key, val in cfg_dict.items():
985-
runopt_ = self.get(key)
986-
if runopt_:
1023+
opt = self.get(key)
1024+
if opt:
9871025
# Optional runopt cfg values default their value to None,
9881026
# but use `_type` to specify their type when provided.
9891027
# Make sure not to treat None's as lists/dictionaries
9901028
if val is None:
9911029
cfg[key] = val
992-
elif runopt_.opt_type == List[str]:
1030+
elif opt.is_type_list_of_str:
9931031
cfg[key] = [str(v) for v in val]
994-
elif runopt_.opt_type == Dict[str, str]:
1032+
elif opt.is_type_dict_of_str:
9951033
cfg[key] = {str(k): str(v) for k, v in val.items()}
9961034
else:
9971035
cfg[key] = val

torchx/specs/test/api_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
RetryPolicy,
3939
Role,
4040
RoleStatus,
41+
runopt,
4142
runopts,
4243
)
4344

@@ -437,6 +438,16 @@ def test_valid_values(self) -> None:
437438
self.assertTrue(cfg.get("preemptible"))
438439
self.assertIsNone(cfg.get("unknown"))
439440

441+
def test_runopt_cast_to_type_typing_list(self) -> None:
442+
opt = runopt(default="", opt_type=List[str], is_required=False, help="help")
443+
self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c"))
444+
self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi"))
445+
446+
def test_runopt_cast_to_type_builtin_list(self) -> None:
447+
opt = runopt(default="", opt_type=list[str], is_required=False, help="help")
448+
self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c"))
449+
self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi"))
450+
440451
def test_runopts_add(self) -> None:
441452
"""
442453
tests for various add option variations

0 commit comments

Comments
 (0)