Skip to content

Commit 6b41f31

Browse files
c00wpytorchmergebot
authored andcommitted
config: Support str env variables (pytorch#145980)
Summary: This allows us to use environment variables to set string values. We've added tests for the specific functionality implemented here. Note that we already accidentally started setting up configs to use this, so we're just adding the feature. Additionally, we're not fully validating the underlying type when we set the value (and in general, it's more difficult than we would like to do this). Let me know if people feel strongly, and we can add a PR to do this. Pull Request resolved: pytorch#145980 Approved by: https://github.com/yushangdi, https://github.com/oulgen
1 parent a9ed7bd commit 6b41f31

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

test/test_utils_config_module.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
os.environ["ENV_TRUE"] = "1"
88
os.environ["ENV_FALSE"] = "0"
9+
os.environ["ENV_STR"] = "1234"
10+
os.environ["ENV_STR_EMPTY"] = ""
911

1012
from typing import Optional
1113

@@ -100,6 +102,12 @@ def test_env_name_semantics(self):
100102
config.e_env_force = False
101103
self.assertTrue(config.e_env_force)
102104

105+
def test_env_name_string_semantics(self):
106+
self.assertEqual(config.e_env_default_str, "1234")
107+
self.assertEqual(config.e_env_default_str_empty, "")
108+
config.e_env_default_str = "override"
109+
self.assertEqual(config.e_env_default_str, "override")
110+
103111
def test_multi_env(self):
104112
self.assertTrue(config2.e_env_default_multi)
105113
self.assertTrue(config2.e_env_force_multi)
@@ -129,6 +137,8 @@ def test_save_config(self):
129137
"e_jk_false": False,
130138
"e_env_default": True,
131139
"e_env_default_FALSE": False,
140+
"e_env_default_str": "1234",
141+
"e_env_default_str_empty": "",
132142
"e_env_force": True,
133143
"e_optional": True,
134144
},
@@ -161,6 +171,8 @@ def test_save_config_portable(self):
161171
"e_jk_false": False,
162172
"e_env_default": True,
163173
"e_env_default_FALSE": False,
174+
"e_env_default_str": "1234",
175+
"e_env_default_str_empty": "",
164176
"e_env_force": True,
165177
"e_optional": True,
166178
},
@@ -180,6 +192,8 @@ def test_codegen_config(self):
180192
"""torch.testing._internal.fake_config_module.e_bool = False
181193
torch.testing._internal.fake_config_module.e_env_default = True
182194
torch.testing._internal.fake_config_module.e_env_default_FALSE = False
195+
torch.testing._internal.fake_config_module.e_env_default_str = '1234'
196+
torch.testing._internal.fake_config_module.e_env_default_str_empty = ''
183197
torch.testing._internal.fake_config_module.e_env_force = True
184198
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
185199
)
@@ -202,7 +216,7 @@ def test_codegen_config_function(self):
202216
)
203217

204218
def test_get_hash(self):
205-
hash_value = b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
219+
hash_value = b"\x87\xf7\xc6\x1di\x7f\x96-\x85\xdc\x04\xd5\xd0\xf6\x1c\x87"
206220
self.assertEqual(
207221
config.get_hash(),
208222
hash_value,
@@ -259,6 +273,8 @@ def test_dict_copy_semantics(self):
259273
"e_jk_false": False,
260274
"e_env_default": True,
261275
"e_env_default_FALSE": False,
276+
"e_env_default_str": "1234",
277+
"e_env_default_str_empty": "",
262278
"e_env_force": True,
263279
"e_optional": True,
264280
},
@@ -288,6 +304,8 @@ def test_dict_copy_semantics(self):
288304
"e_jk_false": False,
289305
"e_env_default": True,
290306
"e_env_default_FALSE": False,
307+
"e_env_default_str": "1234",
308+
"e_env_default_str_empty": "",
291309
"e_env_force": True,
292310
"e_optional": True,
293311
},
@@ -317,6 +335,8 @@ def test_dict_copy_semantics(self):
317335
"e_jk_false": False,
318336
"e_env_default": True,
319337
"e_env_default_FALSE": False,
338+
"e_env_default_str": "1234",
339+
"e_env_default_str_empty": "",
320340
"e_env_force": True,
321341
"e_optional": True,
322342
},

torch/testing/_internal/fake_config_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
e_jk_false: bool = Config(justknob="does_not_exist", default=False)
2525
e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False)
2626
e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True)
27+
e_env_default_str: bool = Config(env_name_default="ENV_STR", default="default")
28+
e_env_default_str_empty: bool = Config(
29+
env_name_default="ENV_STR_EMPTY", default="default"
30+
)
2731
e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False)
2832
e_aliased_bool: bool = Config(
2933
alias="torch.testing._internal.fake_config_module2.e_aliasing_bool"

torch/utils/_config_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,13 @@ def Config(
159159
)
160160

161161

162-
def _read_env_variable(name: str) -> Optional[bool]:
162+
def _read_env_variable(name: str) -> Optional[Union[bool, str]]:
163163
value = os.environ.get(name)
164164
if value == "1":
165165
return True
166166
if value == "0":
167167
return False
168-
return None
168+
return value
169169

170170

171171
def install_config_module(module: ModuleType) -> None:

0 commit comments

Comments
 (0)