Skip to content

Commit 1d47d66

Browse files
authored
Merge pull request #131 from mgxd/fix/zip-extract-race-condition
FIX: Avoid directory clobber during zip extraction
2 parents cb9566f + a9f7f5b commit 1d47d66

File tree

5 files changed

+76
-32
lines changed

5 files changed

+76
-32
lines changed

templateflow/__init__.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,8 @@
3737
del version
3838
del PackageNotFoundError
3939

40-
import os
41-
42-
from . import api
43-
from .conf import TF_USE_DATALAD, update
44-
45-
if not TF_USE_DATALAD and os.getenv('TEMPLATEFLOW_AUTOUPDATE', '1') not in (
46-
'false',
47-
'off',
48-
'0',
49-
'no',
50-
'n',
51-
):
52-
# trigger skeleton autoupdate
53-
update(local=True, overwrite=False, silent=True)
40+
from templateflow import api
41+
from templateflow.conf import update
5442

5543
__all__ = [
5644
'__copyright__',

templateflow/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from templateflow import __package__, api
3333
from templateflow._loader import Loader as _Loader
34-
from templateflow.conf import TF_HOME, TF_USE_DATALAD
34+
from templateflow.conf import TF_HOME, TF_USE_DATALAD, TF_AUTOUPDATE
3535

3636
load_data = _Loader(__package__)
3737

@@ -91,6 +91,7 @@ def config():
9191
9292
TEMPLATEFLOW_HOME={TF_HOME}
9393
TEMPLATEFLOW_USE_DATALAD={'on' if TF_USE_DATALAD else 'off'}
94+
TEMPLATEFLOW_AUTOUPDATE={'on' if TF_AUTOUPDATE else 'off'}
9495
""")
9596

9697

templateflow/conf/__init__.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,36 @@
1010

1111
load_data = Loader(__package__)
1212

13+
14+
def _env_to_bool(envvar: str, default: bool) -> bool:
15+
"""Check for environment variable switches and convert to booleans."""
16+
switches = {
17+
'on': {'true', 'on', '1', 'yes', 'y'},
18+
'off': {'false', 'off', '0', 'no', 'n'},
19+
}
20+
21+
val = getenv(envvar, default)
22+
if isinstance(val, str):
23+
if val.lower() in switches['on']:
24+
return True
25+
elif val.lower() in switches['off']:
26+
return False
27+
else:
28+
# TODO: Create templateflow logger
29+
print(
30+
f'{envvar} is set to unknown value <{val}>. '
31+
f'Falling back to default value <{default}>'
32+
)
33+
return default
34+
return bool(val)
35+
36+
1337
TF_DEFAULT_HOME = Path.home() / '.cache' / 'templateflow'
1438
TF_HOME = Path(getenv('TEMPLATEFLOW_HOME', str(TF_DEFAULT_HOME)))
1539
TF_GITHUB_SOURCE = 'https://github.com/templateflow/templateflow.git'
1640
TF_S3_ROOT = 'https://templateflow.s3.amazonaws.com'
17-
TF_USE_DATALAD = getenv('TEMPLATEFLOW_USE_DATALAD', 'false').lower() in (
18-
'true',
19-
'on',
20-
'1',
21-
'yes',
22-
'y',
23-
)
41+
TF_USE_DATALAD = _env_to_bool('TEMPLATEFLOW_USE_DATALAD', False)
42+
TF_AUTOUPDATE = _env_to_bool('TEMPLATEFLOW_AUTOUPDATE', True)
2443
TF_CACHED = True
2544
TF_GET_TIMEOUT = 10
2645

@@ -50,7 +69,7 @@ def _init_cache():
5069
if not TF_USE_DATALAD:
5170
from ._s3 import update as _update_s3
5271

53-
_update_s3(TF_HOME, local=True, overwrite=True)
72+
_update_s3(TF_HOME, local=True, overwrite=TF_AUTOUPDATE, silent=True)
5473

5574

5675
_init_cache()

templateflow/conf/_s3.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,32 @@ def _update_skeleton(skel_file, dest, overwrite=True, silent=False):
8080
dest = Path(dest)
8181
dest.mkdir(exist_ok=True, parents=True)
8282
with ZipFile(skel_file, 'r') as zipref:
83+
allfiles = sorted(zipref.namelist())
84+
8385
if overwrite:
84-
zipref.extractall(str(dest))
85-
return True
86+
newfiles = allfiles
87+
else:
88+
current_files = [s.relative_to(dest) for s in dest.glob('**/*')]
89+
existing = sorted({'%s/' % s.parent for s in current_files}) + [
90+
str(s) for s in current_files
91+
]
92+
newfiles = sorted(set(allfiles) - set(existing))
8693

87-
allfiles = zipref.namelist()
88-
current_files = [s.relative_to(dest) for s in dest.glob('**/*')]
89-
existing = sorted({'%s/' % s.parent for s in current_files}) + [
90-
str(s) for s in current_files
91-
]
92-
newfiles = sorted(set(allfiles) - set(existing))
9394
if newfiles:
9495
if not silent:
9596
print(
9697
'Updating TEMPLATEFLOW_HOME using S3. Adding:\n%s'
9798
% '\n'.join(newfiles)
9899
)
99-
zipref.extractall(str(dest), members=newfiles)
100+
for fl in newfiles:
101+
localpath = dest / fl
102+
if localpath.exists():
103+
continue
104+
try:
105+
zipref.extract(fl, path=dest)
106+
except FileExistsError:
107+
# If there is a conflict, do not clobber
108+
pass
100109
return True
101110
if not silent:
102111
print('TEMPLATEFLOW_HOME directory (S3 type) was up-to-date.')
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
from concurrent.futures import ProcessPoolExecutor
3+
4+
import pytest
5+
6+
CPUs = os.cpu_count() or 1
7+
8+
9+
def _update():
10+
from templateflow.conf import update
11+
12+
update(local=False, overwrite=True, silent=True)
13+
return True
14+
15+
16+
@pytest.mark.skipif(CPUs < 2, reason='At least 2 CPUs are required')
17+
def test_multi_proc_update(tmp_path, monkeypatch):
18+
tf_home = tmp_path / 'tf_home'
19+
monkeypatch.setenv('TEMPLATEFLOW_HOME', str(tf_home))
20+
21+
futs = []
22+
with ProcessPoolExecutor(max_workers=2) as executor:
23+
for _ in range(2):
24+
futs.append(executor.submit(_update))
25+
26+
for fut in futs:
27+
assert fut.result()

0 commit comments

Comments
 (0)