Skip to content

Commit e9cf74e

Browse files
d4l3kfacebook-github-bot
authored andcommitted
torchx: use importlib_metadata for Python 3.10+ syntax (#623)
Summary: importlib.metadata in 3.12 breaks compatibility with the dict style interface. This switches TorchX to use importlib_metadata for all versions and switches the code to use the 3.10+ select style interface instead of dict. This avoids having to pin importlib_metadata<5 such as in pytorch/tutorials#2091 Pull Request resolved: #623 Test Plan: Unit tests on both importlib_metadata 5.0 and 4.1.3 Reviewed By: priyaramani Differential Revision: D40561638 Pulled By: d4l3k fbshipit-source-id: 95144406c0e3dcbe203ada3ff3236f7384ab2a5c
1 parent f9fa2fe commit e9cf74e

File tree

5 files changed

+24
-35
lines changed

5 files changed

+24
-35
lines changed

dev-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ classy-vision>=0.6.0
77
flake8==3.9.0
88
fsspec[s3]==2022.1.0
99
hydra-core
10-
importlib-metadata<5.0
1110
ipython
1211
kfp==1.8.9
1312
moto==3.0.2

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pyre-extensions
22
docstring-parser==0.8.1
3+
importlib-metadata
34
pyyaml
45
docker
56
filelock

setup.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ def get_nightly_version():
8282
"kubernetes": ["kubernetes>=11"],
8383
"ray": ["ray>=1.12.1"],
8484
"dev": dev_reqs,
85-
':python_version < "3.8"': [
86-
"importlib-metadata",
87-
],
8885
},
8986
# PyPI package information.
9087
classifiers=[

torchx/util/entrypoints.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
try:
8-
from importlib import metadata
9-
from importlib.metadata import EntryPoint
10-
except ImportError:
11-
import importlib_metadata as metadata
12-
from importlib_metadata import EntryPoint
137
from typing import Any, Dict, Optional
148

9+
import importlib_metadata as metadata
10+
from importlib_metadata import EntryPoint
11+
1512

1613
# pyre-ignore-all-errors[3, 2]
1714
def load(group: str, name: str, default=None):
@@ -31,18 +28,13 @@ def load(group: str, name: str, default=None):
3128
raises an error.
3229
"""
3330

34-
entrypoints = metadata.entry_points()
31+
entrypoints = metadata.entry_points().select(group=group)
3532

36-
if group not in entrypoints and default:
33+
if name not in entrypoints.names and default is not None:
3734
return default
3835

39-
eps: Dict[str, EntryPoint] = {ep.name: ep for ep in entrypoints[group]}
40-
41-
if name not in eps and default:
42-
return default
43-
else:
44-
ep = eps[name]
45-
return ep.load()
36+
ep = entrypoints[name]
37+
return ep.load()
4638

4739

4840
def _defer_load_ep(ep: EntryPoint) -> object:
@@ -75,12 +67,12 @@ def load_group(
7567
7668
"""
7769

78-
entrypoints = metadata.entry_points()
70+
entrypoints = metadata.entry_points().select(group=group)
7971

80-
if group not in entrypoints:
72+
if len(entrypoints) == 0:
8173
return default
8274

8375
eps = {}
84-
for ep in entrypoints[group]:
76+
for ep in entrypoints:
8577
eps[ep.name] = _defer_load_ep(ep)
8678
return eps

torchx/util/test/entrypoints_test.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8-
9-
try:
10-
from importlib.metadata import EntryPoint
11-
except ImportError:
12-
from importlib_metadata import EntryPoint
138
from configparser import ConfigParser
14-
from typing import Dict, List
9+
from typing import List
1510
from unittest.mock import MagicMock, patch
1611

12+
from importlib_metadata import EntryPoint, EntryPoints
13+
1714
from torchx.util.entrypoints import load, load_group
1815

1916

@@ -61,12 +58,12 @@ def barbaz() -> str:
6158
[ep.grp.missing.mod.test]
6259
baz = torchx.util.test.entrypoints_test.missing_module
6360
"""
64-
_ENTRY_POINTS: Dict[str, List[EntryPoint]] = {
65-
"entrypoints.test": EntryPoint_from_text(_EP_TXT),
66-
"ep.grp.test": EntryPoint_from_text(_EP_GRP_TXT),
67-
"ep.grp.missing.attr.test": EntryPoint_from_text(_EP_GRP_IGN_ATTR_TXT),
68-
"ep.grp.missing.mod.test": EntryPoint_from_text(_EP_GRP_IGN_MOD_TXT),
69-
}
61+
_ENTRY_POINTS: EntryPoints = EntryPoints(
62+
EntryPoint_from_text(_EP_TXT)
63+
+ EntryPoint_from_text(_EP_GRP_TXT)
64+
+ EntryPoint_from_text(_EP_GRP_IGN_ATTR_TXT)
65+
+ EntryPoint_from_text(_EP_GRP_IGN_MOD_TXT)
66+
)
7067

7168
_METADATA_EPS: str = "torchx.util.entrypoints.metadata.entry_points"
7269

@@ -77,6 +74,9 @@ def test_load(self, mock_md_eps: MagicMock) -> None:
7774
print(type(load("entrypoints.test", "foo")))
7875
self.assertEqual("foobar", load("entrypoints.test", "foo")())
7976

77+
with self.assertRaisesRegex(KeyError, "baz"):
78+
load("entrypoints.test", "baz")()
79+
8080
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
8181
def test_load_with_default(self, mock_md_eps: MagicMock) -> None:
8282
self.assertEqual("barbaz", load("entrypoints.test", "missing", barbaz)())
@@ -86,7 +86,7 @@ def test_load_with_default(self, mock_md_eps: MagicMock) -> None:
8686
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
8787
def test_load_group(self, mock_md_eps: MagicMock) -> None:
8888
eps = load_group("ep.grp.test")
89-
self.assertEqual(2, len(eps))
89+
self.assertEqual(2, len(eps), eps)
9090
self.assertEqual("foobar", eps["foo"]())
9191
self.assertEqual("barbaz", eps["bar"]())
9292

0 commit comments

Comments
 (0)