diff --git a/.circleci/config.yml b/.circleci/config.yml index cd1217da..25c6840c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -109,7 +109,7 @@ jobs: export TEMPLATEFLOW_USE_DATALAD=on python -m pytest \ --junit-xml=~/tests/datalad.xml --cov templateflow --doctest-modules \ - templateflow/api.py + templateflow/client.py coverage run --append -m templateflow.cli config coverage run --append -m templateflow.cli ls MNI152NLin2009cAsym --suffix T1w diff --git a/docs/api.rst b/docs/api.rst index 4e1ca2d8..90f17651 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5,5 +5,6 @@ Information on specific functions, classes, and methods. .. toctree:: api/templateflow.cli + api/templateflow.client api/templateflow.api api/templateflow.conf diff --git a/docs/environment.yml b/docs/environment.yml index 5d0f69ff..89a46f85 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -202,6 +202,7 @@ dependencies: - nibabel==3.2.2 - nipreps-versions==1.0.3 - pandas==1.4.2 + - platformdirs - pybids==0.15.2 - sqlalchemy==1.3.24 - hatchling diff --git a/pyproject.toml b/pyproject.toml index 6d06522f..1f749bb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ license = {file = "LICENSE"} requires-python = ">=3.9" dependencies = [ "acres >= 0.5.0", + "platformdirs >= 4.0", "pybids >= 0.15.2", "requests", "tqdm", diff --git a/templateflow/__init__.py b/templateflow/__init__.py index ee150506..f407eaf8 100644 --- a/templateflow/__init__.py +++ b/templateflow/__init__.py @@ -40,12 +40,14 @@ del PackageNotFoundError from templateflow import api +from templateflow.client import TemplateFlowClient from templateflow.conf import update __all__ = [ '__copyright__', '__packagename__', '__version__', + 'TemplateFlowClient', 'api', 'update', ] diff --git a/templateflow/api.py b/templateflow/api.py index a193e2e6..3b0be03a 100644 --- a/templateflow/api.py +++ b/templateflow/api.py @@ -20,398 +20,41 @@ # # https://www.nipreps.org/community/licensing/ # -"""TemplateFlow's Python Client.""" +"""TemplateFlow's Python Client. -import sys -from json import loads -from pathlib import Path +``templateflow.api`` provides a global, high-level interface to query the TemplateFlow archive. -from bids.layout import Query +There are two methods to initialize a client: -from templateflow.conf import ( - TF_GET_TIMEOUT, - TF_LAYOUT, - TF_S3_ROOT, - TF_USE_DATALAD, - requires_layout, -) + >>> from templateflow import api as client -_layout_dir = tuple(item for item in dir(TF_LAYOUT) if item.startswith('get_')) + >>> from templateflow import TemplateFlowClient + >>> client = TemplateFlowClient() +The latter method allows additional configuration for the client, +while ``templateflow.api`` is only configurable through environment variables. -@requires_layout -def ls(template, **kwargs): - """ - List files pertaining to one or more templates. +.. autofunction:: get - Parameters - ---------- - template : str - A template identifier (e.g., ``MNI152NLin2009cAsym``). +.. autofunction:: ls - Keyword Arguments - ----------------- - resolution: int or None - Index to an specific spatial resolution of the template. - suffix : str or None - BIDS suffix - atlas : str or None - Name of a particular atlas - hemi : str or None - Hemisphere - space : str or None - Space template is mapped to - density : str or None - Surface density - desc : str or None - Description field +.. autofunction:: templates - Examples - -------- - >>> ls('MNI152Lin', resolution=1, suffix='T1w', desc=None) # doctest: +ELLIPSIS - [PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz')] +.. autofunction:: get_metadata - >>> ls('MNI152Lin', resolution=2, suffix='T1w', desc=None) # doctest: +ELLIPSIS - [PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz')] +.. autofunction:: get_citations +""" - >>> ls('MNI152Lin', suffix='T1w', desc=None) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - [PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz'), - PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz')] +from .client import TemplateFlowClient +from .conf import _cache - >>> ls('fsLR', space=None, hemi='L', - ... density='32k', suffix='sphere') # doctest: +ELLIPSIS - [PosixPath('.../tpl-fsLR_hemi-L_den-32k_sphere.surf.gii')] +_client = TemplateFlowClient(cache=_cache) - >>> ls('fsLR', space='madeup') - [] - - """ - # Normalize extensions to always have leading dot - if 'extension' in kwargs: - kwargs['extension'] = _normalize_ext(kwargs['extension']) - - return [ - Path(p) - for p in TF_LAYOUT.get( - template=Query.ANY if template is None else template, return_type='file', **kwargs - ) - ] - - -@requires_layout -def get(template, raise_empty=False, **kwargs): - """ - Pull files pertaining to one or more templates down. - - Parameters - ---------- - template : str - A template identifier (e.g., ``MNI152NLin2009cAsym``). - raise_empty : bool, optional - Raise exception if no files were matched - - Keyword Arguments - ----------------- - resolution: int or None - Index to an specific spatial resolution of the template. - suffix : str or None - BIDS suffix - atlas : str or None - Name of a particular atlas - hemi : str or None - Hemisphere - space : str or None - Space template is mapped to - density : str or None - Surface density - desc : str or None - Description field - - Examples - -------- - >>> str(get('MNI152Lin', resolution=1, suffix='T1w', desc=None)) # doctest: +ELLIPSIS - '.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz' - - >>> str(get('MNI152Lin', resolution=2, suffix='T1w', desc=None)) # doctest: +ELLIPSIS - '.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz' - - >>> [str(p) for p in get( - ... 'MNI152Lin', suffix='T1w', desc=None)] # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - ['.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz', - '.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz'] - - >>> str(get('fsLR', space=None, hemi='L', - ... density='32k', suffix='sphere')) # doctest: +ELLIPSIS - '.../tpl-fsLR_hemi-L_den-32k_sphere.surf.gii' - - >>> get('fsLR', space='madeup') - [] - - >>> get('fsLR', raise_empty=True, space='madeup') # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - Exception: - ... - - """ - # List files available - out_file = ls(template, **kwargs) - - if raise_empty and not out_file: - raise Exception('No results found') - - # Truncate possible S3 error files from previous attempts - _truncate_s3_errors(out_file) - - # Try DataLad first - dl_missing = [p for p in out_file if not p.is_file()] - if TF_USE_DATALAD and dl_missing: - for filepath in dl_missing: - _datalad_get(filepath) - dl_missing.remove(filepath) - - # Fall-back to S3 if some files are still missing - s3_missing = [p for p in out_file if p.is_file() and p.stat().st_size == 0] - for filepath in s3_missing + dl_missing: - _s3_get(filepath) - - not_fetched = [str(p) for p in out_file if not p.is_file() or p.stat().st_size == 0] - - if not_fetched: - msg = 'Could not fetch template files: {}.'.format(', '.join(not_fetched)) - if dl_missing and not TF_USE_DATALAD: - msg += f"""\ -The $TEMPLATEFLOW_HOME folder {TF_LAYOUT.root} seems to contain an initiated DataLad \ -dataset, but the environment variable $TEMPLATEFLOW_USE_DATALAD is not \ -set or set to one of (false, off, 0). Please set $TEMPLATEFLOW_USE_DATALAD \ -on (possible values: true, on, 1).""" - - if s3_missing and TF_USE_DATALAD: - msg += f"""\ -The $TEMPLATEFLOW_HOME folder {TF_LAYOUT.root} seems to contain an plain \ -dataset, but the environment variable $TEMPLATEFLOW_USE_DATALAD is \ -set to one of (true, on, 1). Please set $TEMPLATEFLOW_USE_DATALAD \ -off (possible values: false, off, 0).""" - - raise RuntimeError(msg) - - if len(out_file) == 1: - return out_file[0] - return out_file - - -@requires_layout -def templates(**kwargs): - """ - Return a list of available templates. - - Keyword Arguments - ----------------- - resolution: int or None - Index to an specific spatial resolution of the template. - suffix : str or None - BIDS suffix - atlas : str - Name of a particular atlas - desc : str - Description field - - Examples - -------- - >>> base = ['MNI152Lin', 'MNI152NLin2009cAsym', 'NKI', 'OASIS30ANTs'] - >>> tpls = templates() - >>> all([t in tpls for t in base]) - True - - >>> sorted(set(base).intersection(templates(suffix='PD'))) - ['MNI152Lin', 'MNI152NLin2009cAsym'] - - """ - return sorted(TF_LAYOUT.get_templates(**kwargs)) - - -@requires_layout -def get_metadata(template): - """ - Fetch one file from one template. - - Parameters - ---------- - template : str - A template identifier (e.g., ``MNI152NLin2009cAsym``). - - Examples - -------- - >>> get_metadata('MNI152Lin')['Name'] - 'Linear ICBM Average Brain (ICBM152) Stereotaxic Registration Model' - - """ - tf_home = Path(TF_LAYOUT.root) - filepath = tf_home / (f'tpl-{template}') / 'template_description.json' - - # Ensure that template is installed and file is available - if not filepath.is_file(): - _datalad_get(filepath) - return loads(filepath.read_text()) - - -def get_citations(template, bibtex=False): - """ - Fetch template citations - - Parameters - ---------- - template : :obj:`str` - A template identifier (e.g., ``MNI152NLin2009cAsym``). - bibtex : :obj:`bool`, optional - Generate citations in BibTeX format. - - """ - data = get_metadata(template) - refs = data.get('ReferencesAndLinks', []) - if isinstance(refs, dict): - refs = list(refs.values()) - - if not bibtex: - return refs - - return [_to_bibtex(ref, template, idx).rstrip() for idx, ref in enumerate(refs, 1)] - - -@requires_layout -def __getattr__(key: str): - key = key.replace('ls_', 'get_') - if ( - key.startswith('get_') - and key not in ('get_metadata', 'get_citations') - and key not in _layout_dir - ): - return TF_LAYOUT.__getattr__(key) - - # Spit out default message if we get this far - raise AttributeError(f"module '{__name__}' has no attribute '{key}'") - - -def _datalad_get(filepath): - if not filepath: - return - - from datalad import api - from datalad.support.exceptions import IncompleteResultsError +def __getattr__(name: str): + if name == 'TF_LAYOUT': + return _cache.layout try: - api.get(filepath, dataset=str(TF_LAYOUT.root)) - except IncompleteResultsError as exc: - if exc.failed[0]['message'] == 'path not associated with any dataset': - from .conf import TF_GITHUB_SOURCE - - api.install(path=TF_LAYOUT.root, source=TF_GITHUB_SOURCE, recursive=True) - api.get(filepath, dataset=str(TF_LAYOUT.root)) - else: - raise - - -def _s3_get(filepath): - from sys import stderr - from urllib.parse import quote - - import requests - from tqdm import tqdm - - path = quote(filepath.relative_to(TF_LAYOUT.root).as_posix()) - url = f'{TF_S3_ROOT}/{path}' - - print(f'Downloading {url}', file=stderr) - # Streaming, so we can iterate over the response. - r = requests.get(url, stream=True, timeout=TF_GET_TIMEOUT) - if r.status_code != 200: - raise RuntimeError(f'Failed to download {url} with status code {r.status_code}') - - # Total size in bytes. - total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 - wrote = 0 - if not filepath.is_file(): - filepath.unlink() - - with filepath.open('wb') as f: - with tqdm(total=total_size, unit='B', unit_scale=True) as t: - for data in r.iter_content(block_size): - wrote = wrote + len(data) - f.write(data) - t.update(len(data)) - - if total_size != 0 and wrote != total_size: - raise RuntimeError('ERROR, something went wrong') - - -def _to_bibtex(doi, template, idx): - if 'doi.org' not in doi: - return doi - - # Is a DOI URL - import requests - - response = requests.post( - doi, - headers={'Accept': 'application/x-bibtex; charset=utf-8'}, - timeout=TF_GET_TIMEOUT, - ) - if not response.ok: - print( - f'Failed to convert DOI <{doi}> to bibtex, returning URL.', - file=sys.stderr, - ) - return doi - - # doi.org may not honor requested charset, to safeguard force a bytestream with - # response.content, then decode into UTF-8. - bibtex = response.content.decode() - - # doi.org / crossref may still point to the no longer preferred proxy service - return bibtex.replace('http://dx.doi.org/', 'https://doi.org/') - - -def _normalize_ext(value): - """ - Normalize extensions to have a leading dot. - - Examples - -------- - >>> _normalize_ext(".nii.gz") - '.nii.gz' - >>> _normalize_ext("nii.gz") - '.nii.gz' - >>> _normalize_ext(("nii", ".nii.gz")) - ['.nii', '.nii.gz'] - >>> _normalize_ext(("", ".nii.gz")) - ['', '.nii.gz'] - >>> _normalize_ext((None, ".nii.gz")) - [None, '.nii.gz'] - >>> _normalize_ext([]) - [] - - """ - - if not value: - return value - - if isinstance(value, str): - return f'{"" if value.startswith(".") else "."}{value}' - return [_normalize_ext(v) for v in value] - - -def _truncate_s3_errors(filepaths): - """ - Truncate XML error bodies saved by previous versions of TemplateFlow. - - Parameters - ---------- - filepaths : list of Path - List of file paths to check and truncate if necessary. - """ - for filepath in filepaths: - if filepath.is_file(follow_symlinks=False) and 0 < filepath.stat().st_size < 1024: - with open(filepath, 'rb') as f: - content = f.read(100) - if content.startswith(b'' in content: - filepath.write_bytes(b'') # Truncate file to zero bytes + return getattr(_client, name) + except AttributeError: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") from None diff --git a/templateflow/cli.py b/templateflow/cli.py index 6df312d6..3d479f78 100644 --- a/templateflow/cli.py +++ b/templateflow/cli.py @@ -28,13 +28,13 @@ from pathlib import Path import click +from acres import Loader as _Loader from click.decorators import FC, Option, _param_memo -from templateflow import __package__, api -from acres import Loader as _Loader -from templateflow.conf import TF_AUTOUPDATE, TF_HOME, TF_USE_DATALAD +from templateflow.client import TemplateFlowClient +from templateflow.conf import _cache -load_data = _Loader(__package__) +load_data = _Loader(__spec__.parent) ENTITY_SHORTHANDS = { # 'template': ('--tpl', '-t'), @@ -48,7 +48,13 @@ 'segmentation': ('--seg',), } ENTITY_EXCLUDE = {'template', 'description'} -TEMPLATE_LIST = api.get_templates() + +CLIENT = TemplateFlowClient(cache=_cache) +CACHE = CLIENT.cache +CONFIG = CACHE.config +CACHE.ensure() + +TEMPLATE_LIST = [d.name[4:] for d in CONFIG.root.iterdir() if d.name.startswith('tpl-')] def _nulls(s): @@ -86,30 +92,30 @@ def config(): """Print-out configuration.""" click.echo(f"""Current TemplateFlow settings: - TEMPLATEFLOW_HOME={TF_HOME} - TEMPLATEFLOW_USE_DATALAD={'on' if TF_USE_DATALAD else 'off'} - TEMPLATEFLOW_AUTOUPDATE={'on' if TF_AUTOUPDATE else 'off'} + TEMPLATEFLOW_HOME={CONFIG.root} + TEMPLATEFLOW_USE_DATALAD={'on' if CONFIG.use_datalad else 'off'} + TEMPLATEFLOW_AUTOUPDATE={'on' if CONFIG.autoupdate else 'off'} """) @main.command() def wipe(): """Wipe out a local S3 (direct-download) TemplateFlow Archive.""" - click.echo(f'This will wipe out all data downloaded into {TF_HOME}.') + click.echo(f'This will wipe out all data downloaded into {CONFIG.root}.') if click.confirm('Do you want to continue?'): value = click.prompt( - f'Please write the path of your local archive ({TF_HOME})', + f'Please write the path of your local archive ({CONFIG.root})', default='(abort)', show_default=False, ) - if value.strip() == str(TF_HOME): + if value.strip() == str(CONFIG.root): from templateflow.conf import wipe wipe() - click.echo(f'{TF_HOME} was wiped out.') + click.echo(f'{CONFIG.root} was wiped out.') return - click.echo(f'Aborted! {TF_HOME} WAS NOT wiped out.') + click.echo(f'Aborted! {CONFIG.root} WAS NOT wiped out.') @main.command() @@ -120,7 +126,7 @@ def update(local, overwrite): from templateflow.conf import update as _update click.echo( - f'Successfully updated local TemplateFlow Archive: {TF_HOME}.' + f'Successfully updated local TemplateFlow Archive: {CONFIG.root}.' if _update(local=local, overwrite=overwrite) else 'TemplateFlow Archive not updated.' ) @@ -132,7 +138,7 @@ def update(local, overwrite): def ls(template, **kwargs): """List the assets corresponding to template and optional filters.""" entities = {k: _nulls(v) for k, v in kwargs.items() if v != ''} - click.echo('\n'.join(f'{match}' for match in api.ls(template, **entities))) + click.echo('\n'.join(f'{match}' for match in CLIENT.ls(template, **entities))) @main.command() @@ -141,7 +147,7 @@ def ls(template, **kwargs): def get(template, **kwargs): """Fetch the assets corresponding to template and optional filters.""" entities = {k: _nulls(v) for k, v in kwargs.items() if v != ''} - paths = api.get(template, **entities) + paths = CLIENT.get(template, **entities) filenames = [str(paths)] if isinstance(paths, Path) else [str(file) for file in paths] click.echo('\n'.join(filenames)) diff --git a/templateflow/client.py b/templateflow/client.py new file mode 100644 index 00000000..e0fc33ee --- /dev/null +++ b/templateflow/client.py @@ -0,0 +1,479 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2025 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""TemplateFlow's Python Client.""" + +from __future__ import annotations + +import os +import sys +from json import loads +from pathlib import Path + +from templateflow.conf.cache import CacheConfig, TemplateFlowCache + + +class TemplateFlowClient: + """TemplateFlow client for querying and retrieving template files. + + If instantiated without arguments, uses the default cache, which is + located at a platform-dependent location (e.g., ``$HOME/.cache/templateflow`` on + most Unix-like systems), or at the location specified by the ``TEMPLATEFLOW_HOME`` + environment variable: + + >>> client = TemplateFlowClient() + >>> client + + + To select a custom cache location, provide the ``root`` argument: + + >>> client = TemplateFlowClient(root='/path/to/templateflow_cache') + + Additional configuration options can be provided as keyword arguments. + + Parameters + ---------- + root: :class:`os.PathLike` or :class:`str`, optional + Path to the root of the TemplateFlow cache (will be created if it does not exist). + + Keyword Arguments + ----------------- + use_datalad: :class:`bool`, optional + Whether to use DataLad for managing the cache. Defaults to ``False`` or + the value of the ``TEMPLATEFLOW_USE_DATALAD`` environment variable + (1/True/on/yes to enable, 0/False/off/no to disable). + autoupdate: :class:`bool`, optional + Whether to automatically update the cache on first load. + Defaults to ``True`` or the value of the ``TEMPLATEFLOW_AUTOUPDATE`` + environment variable (1/True/on/yes to enable, 0/False/off/no to disable). + timeout: :class:`float`, optional + Timeout in seconds for network operations. Default is ``10.0`` seconds. + origin: :class:`str`, optional + Git repository URL for DataLad installations. Default is + . + s3_root: :class:`str`, optional + Base URL for S3 downloads. Default is . + cache: :class:`TemplateFlowCache`, optional + A pre-configured TemplateFlowCache instance. If provided, `root` and other + configuration keyword arguments cannot be used. + """ + def __init__( + self, + root: os.PathLike[str] | str | None = None, + *, + cache: TemplateFlowCache | None = None, + **config_kwargs, + ): + if cache is None: + if root: + config_kwargs['root'] = Path(root) + cache = TemplateFlowCache(CacheConfig(**config_kwargs)) + elif root or config_kwargs: + raise ValueError( + 'If `cache` is provided, `root` and other config kwargs cannot be used.' + ) + self.cache = cache + + def __repr__(self) -> str: + cache_type = 'DataLad' if self.cache.config.use_datalad else 'S3' + return f'<{self.__class__.__name__}[{cache_type}] cache="{self.cache.config.root}">' + + def __getattr__(self, name: str): + name = name.replace('ls_', 'get_') + try: + if name.startswith('get_') and name not in dir(self.cache.layout): + return getattr(self.cache.layout, name) + except AttributeError: + pass + msg = f"'{self.__class__.__name__}' object has no attribute '{name}'" + raise AttributeError(msg) from None + + def ls(self, template, **kwargs) -> list[Path]: + """ + List files pertaining to one or more templates. + + Parameters + ---------- + template : str + A template identifier (e.g., ``MNI152NLin2009cAsym``). + + Keyword Arguments + ----------------- + resolution: int or None + Index to an specific spatial resolution of the template. + suffix : str or None + BIDS suffix + atlas : str or None + Name of a particular atlas + hemi : str or None + Hemisphere + space : str or None + Space template is mapped to + density : str or None + Surface density + desc : str or None + Description field + + Examples + -------- + + .. testsetup:: + + >>> client = TemplateFlowClient() + + >>> client.ls('MNI152Lin', resolution=1, suffix='T1w', desc=None) + [PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz')] + + >>> client.ls('MNI152Lin', resolution=2, suffix='T1w', desc=None) + [PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz')] + + >>> client.ls('MNI152Lin', suffix='T1w', desc=None) + [PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz'), + PosixPath('.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz')] + + >>> client.ls('fsLR', space=None, hemi='L', density='32k', suffix='sphere') + [PosixPath('.../tpl-fsLR_hemi-L_den-32k_sphere.surf.gii')] + + >>> client.ls('fsLR', space='madeup') + [] + + """ + from bids.layout import Query + + # Normalize extensions to always have leading dot + if 'extension' in kwargs: + kwargs['extension'] = _normalize_ext(kwargs['extension']) + + return [ + Path(p) + for p in self.cache.layout.get( + template=Query.ANY if template is None else template, return_type='file', **kwargs + ) + ] + + def get(self, template, raise_empty=False, **kwargs) -> list[Path]: + """ + Pull files pertaining to one or more templates down. + + Parameters + ---------- + template : str + A template identifier (e.g., ``MNI152NLin2009cAsym``). + raise_empty : bool, optional + Raise exception if no files were matched + + Keyword Arguments + ----------------- + resolution: int or None + Index to an specific spatial resolution of the template. + suffix : str or None + BIDS suffix + atlas : str or None + Name of a particular atlas + hemi : str or None + Hemisphere + space : str or None + Space template is mapped to + density : str or None + Surface density + desc : str or None + Description field + + Examples + -------- + + .. testsetup:: + + >>> client = TemplateFlowClient() + + >>> str(client.get('MNI152Lin', resolution=1, suffix='T1w', desc=None)) + '.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz' + + >>> str(client.get('MNI152Lin', resolution=2, suffix='T1w', desc=None)) + '.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz' + + >>> [str(p) for p in client.get('MNI152Lin', suffix='T1w', desc=None)] + ['.../tpl-MNI152Lin/tpl-MNI152Lin_res-01_T1w.nii.gz', + '.../tpl-MNI152Lin/tpl-MNI152Lin_res-02_T1w.nii.gz'] + + >>> str(client.get('fsLR', space=None, hemi='L', density='32k', suffix='sphere')) + '.../tpl-fsLR_hemi-L_den-32k_sphere.surf.gii' + + >>> client.get('fsLR', space='madeup') + [] + + >>> client.get('fsLR', raise_empty=True, space='madeup') + Traceback (most recent call last): + Exception: + ... + + """ + # List files available + out_file = self.ls(template, **kwargs) + + if raise_empty and not out_file: + raise Exception('No results found') + + # Truncate possible S3 error files from previous attempts + _truncate_s3_errors(out_file) + + # Try DataLad first + dl_missing = [p for p in out_file if not p.is_file()] + if self.cache.config.use_datalad and dl_missing: + for filepath in dl_missing: + _datalad_get(self.cache.config, filepath) + dl_missing.remove(filepath) + + # Fall-back to S3 if some files are still missing + s3_missing = [p for p in out_file if p.is_file() and p.stat().st_size == 0] + for filepath in s3_missing + dl_missing: + _s3_get(self.cache.config, filepath) + + not_fetched = [str(p) for p in out_file if not p.is_file() or p.stat().st_size == 0] + + if not_fetched: + msg = 'Could not fetch template files: {}.'.format(', '.join(not_fetched)) + if dl_missing and not self.cache.config.use_datalad: + msg += f"""\ + The $TEMPLATEFLOW_HOME folder {self.cache.config.root} seems to contain an initiated DataLad \ + dataset, but the environment variable $TEMPLATEFLOW_USE_DATALAD is not \ + set or set to one of (false, off, 0). Please set $TEMPLATEFLOW_USE_DATALAD \ + on (possible values: true, on, 1).""" + + if s3_missing and self.cache.config.use_datalad: + msg += f"""\ + The $TEMPLATEFLOW_HOME folder {self.cache.layout.root} seems to contain an plain \ + dataset, but the environment variable $TEMPLATEFLOW_USE_DATALAD is \ + set to one of (true, on, 1). Please set $TEMPLATEFLOW_USE_DATALAD \ + off (possible values: false, off, 0).""" + + raise RuntimeError(msg) + + if len(out_file) == 1: + return out_file[0] + return out_file + + def templates(self, **kwargs) -> list[str]: + """ + Return a list of available templates. + + Keyword Arguments + ----------------- + resolution: int or None + Index to an specific spatial resolution of the template. + suffix : str or None + BIDS suffix + atlas : str + Name of a particular atlas + desc : str + Description field + + Examples + -------- + + .. testsetup:: + + >>> client = TemplateFlowClient() + + >>> base = ['MNI152Lin', 'MNI152NLin2009cAsym', 'NKI', 'OASIS30ANTs'] + >>> tpls = client.templates() + >>> all([t in tpls for t in base]) + True + + >>> sorted(set(base).intersection(client.templates(suffix='PD'))) + ['MNI152Lin', 'MNI152NLin2009cAsym'] + """ + return sorted(self.get_templates(**kwargs)) + + def get_metadata(self, template) -> dict[str, str]: + """ + Fetch one file from one template. + + Parameters + ---------- + template : str + A template identifier (e.g., ``MNI152NLin2009cAsym``). + + Examples + -------- + + .. testsetup:: + + >>> client = TemplateFlowClient() + + >>> client.get_metadata('MNI152Lin')['Name'] + 'Linear ICBM Average Brain (ICBM152) Stereotaxic Registration Model' + + """ + tf_home = Path(self.cache.layout.root) + filepath = tf_home / (f'tpl-{template}') / 'template_description.json' + + # Ensure that template is installed and file is available + if not filepath.is_file(): + _datalad_get(filepath) + return loads(filepath.read_text()) + + def get_citations(self, template, bibtex=False) -> list[str]: + """ + Fetch template citations + + Parameters + ---------- + template : :obj:`str` + A template identifier (e.g., ``MNI152NLin2009cAsym``). + bibtex : :obj:`bool`, optional + Generate citations in BibTeX format. + + """ + data = self.get_metadata(template) + refs = data.get('ReferencesAndLinks', []) + if isinstance(refs, dict): + refs = list(refs.values()) + + if not bibtex: + return refs + + return [_to_bibtex(ref, template, self.cache.config.timeout).rstrip() for ref in refs] + + +def _datalad_get(config: CacheConfig, filepath: Path) -> None: + if not filepath: + return + + from datalad import api + from datalad.support.exceptions import IncompleteResultsError + + try: + api.get(filepath, dataset=config.root) + except IncompleteResultsError as exc: + if exc.failed[0]['message'] == 'path not associated with any dataset': + api.install(path=config.root, source=config.origin, recursive=True) + api.get(filepath, dataset=config.root) + else: + raise + + +def _s3_get(config: CacheConfig, filepath: Path) -> None: + from sys import stderr + from urllib.parse import quote + + import requests + from tqdm import tqdm + + path = quote(filepath.relative_to(config.root).as_posix()) + url = f'{config.s3_root}/{path}' + + print(f'Downloading {url}', file=stderr) + # Streaming, so we can iterate over the response. + r = requests.get(url, stream=True, timeout=config.timeout) + if r.status_code != 200: + raise RuntimeError(f'Failed to download {url} with status code {r.status_code}') + + # Total size in bytes. + total_size = int(r.headers.get('content-length', 0)) + block_size = 1024 + wrote = 0 + if not filepath.is_file(): + filepath.unlink() + + with filepath.open('wb') as f: + with tqdm(total=total_size, unit='B', unit_scale=True) as t: + for data in r.iter_content(block_size): + wrote = wrote + len(data) + f.write(data) + t.update(len(data)) + + if total_size != 0 and wrote != total_size: + raise RuntimeError('ERROR, something went wrong') + + +def _to_bibtex(doi: str, template: str, timeout: float) -> str: + if 'doi.org' not in doi: + return doi + + # Is a DOI URL + import requests + + response = requests.post( + doi, + headers={'Accept': 'application/x-bibtex; charset=utf-8'}, + timeout=timeout, + ) + if not response.ok: + print( + f'Failed to convert DOI <{doi}> to bibtex, returning URL.', + file=sys.stderr, + ) + return doi + + # doi.org may not honor requested charset, to safeguard force a bytestream with + # response.content, then decode into UTF-8. + bibtex = response.content.decode() + + # doi.org / crossref may still point to the no longer preferred proxy service + return bibtex.replace('http://dx.doi.org/', 'https://doi.org/') + + +def _normalize_ext(value): + """ + Normalize extensions to have a leading dot. + + Examples + -------- + >>> _normalize_ext(".nii.gz") + '.nii.gz' + >>> _normalize_ext("nii.gz") + '.nii.gz' + >>> _normalize_ext(("nii", ".nii.gz")) + ['.nii', '.nii.gz'] + >>> _normalize_ext(("", ".nii.gz")) + ['', '.nii.gz'] + >>> _normalize_ext((None, ".nii.gz")) + [None, '.nii.gz'] + >>> _normalize_ext([]) + [] + + """ + + if not value: + return value + + if isinstance(value, str): + return f'{"" if value.startswith(".") else "."}{value}' + return [_normalize_ext(v) for v in value] + + +def _truncate_s3_errors(filepaths): + """ + Truncate XML error bodies saved by previous versions of TemplateFlow. + + Parameters + ---------- + filepaths : list of Path + List of file paths to check and truncate if necessary. + """ + for filepath in filepaths: + if filepath.is_file(follow_symlinks=False) and 0 < filepath.stat().st_size < 1024: + with open(filepath, 'rb') as f: + content = f.read(100) + if content.startswith(b'' in content: + filepath.write_bytes(b'') # Truncate file to zero bytes diff --git a/templateflow/conf/__init__.py b/templateflow/conf/__init__.py index ab045240..ab7c24f2 100644 --- a/templateflow/conf/__init__.py +++ b/templateflow/conf/__init__.py @@ -1,81 +1,47 @@ """Configuration and settings.""" -import re -from contextlib import suppress from functools import wraps -from os import getenv -from pathlib import Path from warnings import warn from acres import Loader -load_data = Loader(__spec__.name) +from templateflow.conf.cache import CacheConfig, TemplateFlowCache +load_data = Loader(__spec__.name) -def _env_to_bool(envvar: str, default: bool) -> bool: - """Check for environment variable switches and convert to booleans.""" - switches = { - 'on': {'true', 'on', '1', 'yes', 'y'}, - 'off': {'false', 'off', '0', 'no', 'n'}, - } - - val = getenv(envvar, default) - if isinstance(val, str): - if val.lower() in switches['on']: - return True - elif val.lower() in switches['off']: - return False - else: - # TODO: Create templateflow logger - print( - f'{envvar} is set to unknown value <{val}>. ' - f'Falling back to default value <{default}>' - ) - return default - return bool(val) - - -TF_DEFAULT_HOME = Path.home() / '.cache' / 'templateflow' -TF_HOME = Path(getenv('TEMPLATEFLOW_HOME', str(TF_DEFAULT_HOME))).absolute() -TF_GITHUB_SOURCE = 'https://github.com/templateflow/templateflow.git' -TF_S3_ROOT = 'https://templateflow.s3.amazonaws.com' -TF_USE_DATALAD = _env_to_bool('TEMPLATEFLOW_USE_DATALAD', False) -TF_AUTOUPDATE = _env_to_bool('TEMPLATEFLOW_AUTOUPDATE', True) -TF_CACHED = True -TF_GET_TIMEOUT = 10 - -if TF_USE_DATALAD: - try: - from datalad.api import install - except ImportError: - warn('DataLad is not installed ➔ disabled.', stacklevel=2) - TF_USE_DATALAD = False - -if not TF_USE_DATALAD: - from templateflow.conf._s3 import update as _update_s3 - - -def _init_cache(): - global TF_CACHED - - if not TF_HOME.exists() or not list(TF_HOME.iterdir()): - TF_CACHED = False - warn( - f"""\ -TemplateFlow: repository not found at <{TF_HOME}>. Populating a new TemplateFlow stub. +_cache = TemplateFlowCache(config=CacheConfig()) + + +def __getattr__(name: str): + if name == 'TF_HOME': + return _cache.config.root + elif name == 'TF_GITHUB_SOURCE': + return _cache.config.origin + elif name == 'TF_S3_ROOT': + return _cache.config.s3_root + elif name == 'TF_USE_DATALAD': + return _cache.config.use_datalad + elif name == 'TF_AUTOUPDATE': + return _cache.config.autoupdate + elif name == 'TF_CACHED': + return _cache.precached + elif name == 'TF_GET_TIMEOUT': + return _cache.config.timeout + elif name == 'TF_LAYOUT': + return _cache.layout + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +if not _cache.precached: + warn( + f"""\ +TemplateFlow: repository not found at <{_cache.config.root}>. Populating a new TemplateFlow stub. If the path reported above is not the desired location for TemplateFlow, \ please set the TEMPLATEFLOW_HOME environment variable.""", - ResourceWarning, - stacklevel=2, - ) - if TF_USE_DATALAD: - TF_HOME.parent.mkdir(exist_ok=True, parents=True) - install(path=str(TF_HOME), source=TF_GITHUB_SOURCE, recursive=True) - else: - _update_s3(TF_HOME, local=True, overwrite=TF_AUTOUPDATE, silent=True) - - -_init_cache() + ResourceWarning, + stacklevel=2, + ) + _cache.ensure() def requires_layout(func): @@ -83,9 +49,7 @@ def requires_layout(func): @wraps(func) def wrapper(*args, **kwargs): - from templateflow.conf import TF_LAYOUT - - if TF_LAYOUT is None: + if _cache.layout is None: from bids import __version__ raise RuntimeError(f'A layout with PyBIDS <{__version__}> could not be initiated') @@ -94,101 +58,17 @@ def wrapper(*args, **kwargs): return wrapper -def update(local=False, overwrite=True, silent=False): - """Update an existing DataLad or S3 home.""" - if TF_USE_DATALAD: - success = _update_datalad() - else: - from ._s3 import update as _update_s3 - - success = _update_s3(TF_HOME, local=local, overwrite=overwrite, silent=silent) - - # update Layout only if necessary - if success and TF_LAYOUT is not None: - init_layout() - # ensure the api uses the updated layout - import importlib - - from .. import api - - importlib.reload(api) - return success - - -def wipe(): - """Clear the cache if functioning in S3 mode.""" - - if TF_USE_DATALAD: - print('TemplateFlow is configured in DataLad mode, wipe() has no effect') - return - - import importlib - from shutil import rmtree - - from templateflow import api - - def _onerror(func, path, excinfo): - from pathlib import Path - - if Path(path).exists(): - print(f'Warning: could not delete <{path}>, please clear the cache manually.') - - rmtree(TF_HOME, onerror=_onerror) - _init_cache() - - importlib.reload(api) +update = _cache.update +wipe = _cache.wipe def setup_home(force=False): """Initialize/update TF's home if necessary.""" - if not force and not TF_CACHED: + if not force and not _cache.precached: print( f"""\ -TemplateFlow was not cached (TEMPLATEFLOW_HOME={TF_HOME}), \ +TemplateFlow was not cached (TEMPLATEFLOW_HOME={_cache.config.root}), \ a fresh initialization was done.""" ) return False - return update(local=True, overwrite=False) - - -def _update_datalad(): - from datalad.api import update - - print('Updating TEMPLATEFLOW_HOME using DataLad ...') - try: - update(dataset=str(TF_HOME), recursive=True, merge=True) - except Exception as e: # noqa: BLE001 - warn( - f"Error updating TemplateFlow's home directory (using DataLad): {e}", - stacklevel=2, - ) - return False - return True - - -TF_LAYOUT = None - - -def init_layout(): - from bids.layout.index import BIDSLayoutIndexer - - from templateflow.conf.bids import Layout - - global TF_LAYOUT - TF_LAYOUT = Layout( - TF_HOME, - validate=False, - config='templateflow', - indexer=BIDSLayoutIndexer( - validate=False, - ignore=( - re.compile(r'scripts/'), - re.compile(r'/\.'), - re.compile(r'^\.'), - ), - ), - ) - - -with suppress(ImportError): - init_layout() + return _cache.update(local=True, overwrite=False) diff --git a/templateflow/conf/_s3.py b/templateflow/conf/_s3.py index eda04909..9be194b6 100644 --- a/templateflow/conf/_s3.py +++ b/templateflow/conf/_s3.py @@ -25,34 +25,35 @@ from pathlib import Path from tempfile import mkstemp -from templateflow.conf import TF_GET_TIMEOUT, load_data +from acres import Loader + +load_data = Loader(__spec__.parent) TF_SKEL_URL = ( 'https://raw.githubusercontent.com/templateflow/python-client/' '{release}/templateflow/conf/templateflow-skel.{ext}' ).format -TF_SKEL_PATH = load_data('templateflow-skel.zip') -TF_SKEL_MD5 = load_data.readable('templateflow-skel.md5').read_text() -def update(dest, local=True, overwrite=True, silent=False): +def update(dest, local=True, overwrite=True, silent=False, *, timeout: int): """Update an S3-backed TEMPLATEFLOW_HOME repository.""" - skel_file = Path((_get_skeleton_file() if not local else None) or TF_SKEL_PATH) + skel_zip = load_data('templateflow-skel.zip') + skel_file = Path((_get_skeleton_file(timeout) if not local else None) or skel_zip) retval = _update_skeleton(skel_file, dest, overwrite=overwrite, silent=silent) - if skel_file != TF_SKEL_PATH: + if skel_file != skel_zip: skel_file.unlink() return retval -def _get_skeleton_file(): +def _get_skeleton_file(timeout: int): import requests try: r = requests.get( TF_SKEL_URL(release='master', ext='md5'), allow_redirects=True, - timeout=TF_GET_TIMEOUT, + timeout=timeout, ) except requests.exceptions.ConnectionError: return @@ -60,11 +61,12 @@ def _get_skeleton_file(): if not r.ok: return - if r.content.decode().split()[0] != TF_SKEL_MD5: + md5 = load_data.readable('templateflow-skel.md5').read_bytes() + if r.content != md5: r = requests.get( TF_SKEL_URL(release='master', ext='zip'), allow_redirects=True, - timeout=TF_GET_TIMEOUT, + timeout=timeout, ) if r.ok: from os import close diff --git a/templateflow/conf/bids.py b/templateflow/conf/bids.py index a143e4de..46108669 100644 --- a/templateflow/conf/bids.py +++ b/templateflow/conf/bids.py @@ -22,9 +22,10 @@ # """Extending pyBIDS for querying TemplateFlow.""" +from acres import Loader from bids.layout import BIDSLayout, add_config_paths -from templateflow.conf import load_data +load_data = Loader(__spec__.parent) add_config_paths(templateflow=load_data('config.json')) diff --git a/templateflow/conf/cache.py b/templateflow/conf/cache.py new file mode 100644 index 00000000..62e6ea1b --- /dev/null +++ b/templateflow/conf/cache.py @@ -0,0 +1,178 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2025 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import cache, cached_property +from pathlib import Path +from warnings import warn + +from templateflow.conf.env import env_to_bool, get_templateflow_home + +TYPE_CHECKING = False +if TYPE_CHECKING: + from bids.layout import BIDSLayout + + +# The first CacheConfig is initialized during import, so we need a higher +# level of indirection for warnings to point to the user code. +# After that, we will set the stack level to point to the CacheConfig() caller. +STACKLEVEL = 6 + + +@cache +def _have_datalad() -> bool: + import importlib.util + + return importlib.util.find_spec('datalad') is not None + + +@dataclass +class CacheConfig: + root: Path = field(default_factory=get_templateflow_home) + origin: str = field(default='https://github.com/templateflow/templateflow.git') + s3_root: str = field(default='https://templateflow.s3.amazonaws.com') + use_datalad: bool = field(default_factory=env_to_bool('TEMPLATEFLOW_USE_DATALAD', False)) + autoupdate: bool = field(default_factory=env_to_bool('TEMPLATEFLOW_AUTOUPDATE', True)) + timeout: int = field(default=10) + + def __post_init__(self) -> None: + global STACKLEVEL + if self.use_datalad and not _have_datalad(): + self.use_datalad = False + warn('DataLad is not installed ➔ disabled.', stacklevel=STACKLEVEL) + STACKLEVEL = 3 + + +@dataclass +class S3Manager: + s3_root: str + + def install(self, path: Path, overwrite: bool, timeout: int) -> None: + from ._s3 import update + + update(path, local=True, overwrite=overwrite, silent=True, timeout=timeout) + + def update(self, path: Path, local: bool, overwrite: bool, silent: bool, timeout: int) -> bool: + from ._s3 import update as _update_s3 + + return _update_s3(path, local=local, overwrite=overwrite, silent=silent, timeout=timeout) + + def wipe(self, path: Path) -> None: + from shutil import rmtree + + def _onerror(func, path, excinfo): + from pathlib import Path + + if Path(path).exists(): + print(f'Warning: could not delete <{path}>, please clear the cache manually.') + + rmtree(path, onerror=_onerror) + + +@dataclass +class DataladManager: + source: str + + def install(self, path: Path, overwrite: bool, timeout: int) -> None: + from datalad.api import install + + install(path=path, source=self.source, recursive=True) + + def update(self, path: Path, local: bool, overwrite: bool, silent: bool, timeout: int) -> bool: + from datalad.api import update + + print('Updating TEMPLATEFLOW_HOME using DataLad ...') + try: + update(dataset=path, recursive=True, merge=True) + except Exception as e: # noqa: BLE001 + warn( + f"Error updating TemplateFlow's home directory (using DataLad): {e}", + stacklevel=2, + ) + return False + return True + + def wipe(self, path: Path) -> None: + print('TemplateFlow is configured in DataLad mode, wipe() has no effect') + + +@dataclass +class TemplateFlowCache: + config: CacheConfig + precached: bool = field(init=False) + manager: DataladManager | S3Manager = field(init=False) + + def __post_init__(self) -> None: + self.manager = ( + DataladManager(self.config.origin) + if self.config.use_datalad + else S3Manager(self.config.s3_root) + ) + # cache.cached checks live, precached stores state at init + self.precached = self.cached + + @property + def cached(self) -> bool: + return self.config.root.is_dir() and any(self.config.root.iterdir()) + + @cached_property + def layout(self) -> BIDSLayout: + import re + + from bids.layout.index import BIDSLayoutIndexer + + from .bids import Layout + + self.ensure() + return Layout( + self.config.root, + validate=False, + config='templateflow', + indexer=BIDSLayoutIndexer( + validate=False, + ignore=(re.compile(r'scripts/'), re.compile(r'/\.'), re.compile(r'^\.')), + ), + ) + + def ensure(self) -> None: + if not self.cached: + self.manager.install( + self.config.root, overwrite=self.config.autoupdate, timeout=self.config.timeout + ) + + def update(self, local: bool = False, overwrite: bool = True, silent: bool = False) -> bool: + if self.manager.update( + self.config.root, + local=local, + overwrite=overwrite, + silent=silent, + timeout=self.config.timeout, + ): + self.__dict__.pop('layout', None) # Uncache property + return True + return False + + def wipe(self) -> None: + self.__dict__.pop('layout', None) # Uncache property + self.manager.wipe(self.config.root) diff --git a/templateflow/conf/env.py b/templateflow/conf/env.py new file mode 100644 index 00000000..3623e56d --- /dev/null +++ b/templateflow/conf/env.py @@ -0,0 +1,59 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2025 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +import os +from functools import partial +from pathlib import Path +from typing import Callable + +from platformdirs import user_cache_dir + + +def _env_to_bool(envvar: str, default: bool) -> bool: + """Check for environment variable switches and convert to booleans.""" + switches = { + 'on': {'true', 'on', '1', 'yes', 'y'}, + 'off': {'false', 'off', '0', 'no', 'n'}, + } + + val = os.getenv(envvar, default) + if isinstance(val, str): + if val.lower() in switches['on']: + return True + elif val.lower() in switches['off']: + return False + else: + # TODO: Create templateflow logger + print( + f'{envvar} is set to unknown value <{val}>. ' + f'Falling back to default value <{default}>' + ) + return default + return bool(val) + + +def get_templateflow_home() -> Path: + return Path(os.getenv('TEMPLATEFLOW_HOME', user_cache_dir('templateflow'))).absolute() + + +def env_to_bool(envvar: str, default: bool) -> Callable[[], bool]: + return partial(_env_to_bool, envvar, default) diff --git a/templateflow/tests/test_conf.py b/templateflow/tests/test_conf.py index c30ea877..59aeb86c 100644 --- a/templateflow/tests/test_conf.py +++ b/templateflow/tests/test_conf.py @@ -23,12 +23,15 @@ """Tests the config module.""" from importlib import reload +from importlib.util import find_spec from shutil import rmtree import pytest from templateflow import conf as tfc +have_datalad = find_spec('datalad') is not None + def _find_message(lines, msg, reverse=True): if isinstance(lines, str): @@ -61,22 +64,18 @@ def test_conf_init(monkeypatch, tmp_path, use_datalad): def test_setup_home(monkeypatch, tmp_path, capsys, use_datalad): """Check the correct functioning of the installation hook.""" - if use_datalad == 'on': - # ImportError if not installed - pass - home = (tmp_path / f'setup-home-{use_datalad}').absolute() monkeypatch.setenv('TEMPLATEFLOW_USE_DATALAD', use_datalad) monkeypatch.setenv('TEMPLATEFLOW_HOME', str(home)) - use_post = tfc._env_to_bool('TEMPLATEFLOW_USE_DATALAD', False) + use_post = tfc.env._env_to_bool('TEMPLATEFLOW_USE_DATALAD', False) assert use_post is (use_datalad == 'on') with capsys.disabled(): reload(tfc) # Ensure mocks are up-to-date - assert tfc.TF_USE_DATALAD is (use_datalad == 'on') + assert tfc.TF_USE_DATALAD is (use_datalad == 'on' and have_datalad) assert str(tfc.TF_HOME) == str(home) # First execution, the S3 stub is created (or datalad install) assert tfc.TF_CACHED is False @@ -92,11 +91,11 @@ def test_setup_home(monkeypatch, tmp_path, capsys, use_datalad): out = capsys.readouterr()[0] assert _find_message(out, 'TemplateFlow was not cached') is False - if use_datalad == 'on': + if use_datalad == 'on' and have_datalad: assert _find_message(out, 'Updating TEMPLATEFLOW_HOME using DataLad') assert updated is True - elif use_datalad == 'off': + else: # At this point, S3 should be up-to-date assert updated is False assert _find_message(out, 'TEMPLATEFLOW_HOME directory (S3 type) was up-to-date.') @@ -114,11 +113,11 @@ def test_setup_home(monkeypatch, tmp_path, capsys, use_datalad): out = capsys.readouterr()[0] assert not _find_message(out, 'TemplateFlow was not cached') - if use_datalad == 'on': + if use_datalad == 'on' and have_datalad: assert _find_message(out, 'Updating TEMPLATEFLOW_HOME using DataLad') assert updated is True - elif use_datalad == 'off': + else: # At this point, S3 should be up-to-date assert updated is False assert _find_message(out, 'TEMPLATEFLOW_HOME directory (S3 type) was up-to-date.') @@ -156,7 +155,7 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): # noqa: return oldimport(name, globals=globals, locals=locals, fromlist=fromlist, level=level) with monkeypatch.context() as m: - m.setattr(tfc, 'TF_LAYOUT', None) + m.setattr(tfc._cache, 'layout', None) with pytest.raises(RuntimeError): myfunc() diff --git a/templateflow/tests/test_s3.py b/templateflow/tests/test_s3.py index 8c904e23..7b0288ed 100644 --- a/templateflow/tests/test_s3.py +++ b/templateflow/tests/test_s3.py @@ -28,6 +28,8 @@ import pytest import requests +import templateflow +import templateflow.conf._s3 from templateflow import api as tf from templateflow import conf as tfc @@ -37,60 +39,56 @@ def test_get_skel_file(tmp_path, monkeypatch): """Exercise the skeleton file generation.""" - home = (tmp_path / 's3-skel-file').resolve() - monkeypatch.setenv('TEMPLATEFLOW_USE_DATALAD', 'off') - monkeypatch.setenv('TEMPLATEFLOW_HOME', str(home)) + md5content = b'anything' - # First execution, the S3 stub is created (or datalad install) - reload(tfc) + def mock_get(*args, **kwargs): + class MockResponse: + status_code = 200 + ok = True + content = md5content + + return MockResponse() + + monkeypatch.setattr(requests, 'get', mock_get) - local_md5 = tfc._s3.TF_SKEL_MD5 - monkeypatch.setattr(tfc._s3, 'TF_SKEL_MD5', 'invent') - new_skel = tfc._s3._get_skeleton_file() + # Mismatching the local MD5 causes an update + new_skel = tfc._s3._get_skeleton_file(timeout=10) assert new_skel is not None assert Path(new_skel).exists() - assert Path(new_skel).stat().st_size > 0 + assert Path(new_skel).read_bytes() == b'anything' - latest_md5 = ( - requests.get( - tfc._s3.TF_SKEL_URL(release='master', ext='md5', allow_redirects=True), timeout=10 - ) - .content.decode() - .split()[0] - ) - monkeypatch.setattr(tfc._s3, 'TF_SKEL_MD5', latest_md5) - assert tfc._s3._get_skeleton_file() is None + md5content = tfc._s3.load_data.readable('templateflow-skel.md5').read_bytes() + # Matching the local MD5 skips the update + assert tfc._s3._get_skeleton_file(timeout=10) is None - monkeypatch.setattr(tfc._s3, 'TF_SKEL_MD5', local_md5) + # Bad URL fails to update monkeypatch.setattr(tfc._s3, 'TF_SKEL_URL', 'http://weird/{release}/{ext}'.format) - assert tfc._s3._get_skeleton_file() is None + assert tfc._s3._get_skeleton_file(timeout=10) is None monkeypatch.setattr( tfc._s3, 'TF_SKEL_URL', tfc._s3.TF_SKEL_URL(release='{release}', ext='{ext}z').format ) - assert tfc._s3._get_skeleton_file() is None + assert tfc._s3._get_skeleton_file(timeout=10) is None def test_update_s3(tmp_path, monkeypatch): """Exercise updating the S3 skeleton.""" newhome = (tmp_path / 's3-update').resolve() - monkeypatch.setenv('TEMPLATEFLOW_USE_DATALAD', 'off') - monkeypatch.setenv('TEMPLATEFLOW_HOME', str(newhome)) - assert tfc._s3.update(newhome) - assert not tfc._s3.update(newhome, overwrite=False) + assert tfc._s3.update(newhome, timeout=10) + assert not tfc._s3.update(newhome, overwrite=False, timeout=10) for p in (newhome / 'tpl-MNI152NLin6Sym').glob('*.nii.gz'): p.unlink() - assert tfc._s3.update(newhome, overwrite=False) + assert tfc._s3.update(newhome, overwrite=False, timeout=10) # This should cover the remote zip file fetching - monkeypatch.setattr(tfc._s3, 'TF_SKEL_MD5', 'invent') - assert tfc._s3.update(newhome, local=False) - assert not tfc._s3.update(newhome, local=False, overwrite=False) + # monkeypatch.setattr(tfc._s3, 'TF_SKEL_MD5', 'invent') + assert tfc._s3.update(newhome, local=False, timeout=10) + assert not tfc._s3.update(newhome, local=False, overwrite=False, timeout=10) for p in (newhome / 'tpl-MNI152NLin6Sym').glob('*.nii.gz'): p.unlink() - assert tfc._s3.update(newhome, local=False, overwrite=False) + assert tfc._s3.update(newhome, local=False, overwrite=False, timeout=10) def mock_get(*args, **kwargs): @@ -108,28 +106,20 @@ def test_s3_400_error(monkeypatch): monkeypatch.setattr(requests, 'get', mock_get) with pytest.raises(RuntimeError, match=r'Failed to download .* code 400'): - tf._s3_get( + templateflow.client._s3_get( + tfc._cache.config, Path(tfc.TF_LAYOUT.root) - / 'tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz' + / 'tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-02_T1w.nii.gz', ) def test_bad_skeleton(tmp_path, monkeypatch): newhome = (tmp_path / 's3-update').resolve() - monkeypatch.setattr(tfc, 'TF_USE_DATALAD', False) - monkeypatch.setattr(tfc, 'TF_HOME', newhome) - monkeypatch.setattr(tfc, 'TF_LAYOUT', None) - - tfc._init_cache() - tfc.init_layout() - - assert tfc.TF_LAYOUT is not None - assert tfc.TF_LAYOUT.root == str(newhome) + client = templateflow.client.TemplateFlowClient(root=newhome, use_datalad=False) - # Instead of reloading - monkeypatch.setattr(tf, 'TF_LAYOUT', tfc.TF_LAYOUT) + assert client.cache.layout.root == str(newhome) - paths = tf.ls('MNI152NLin2009cAsym', resolution='02', suffix='T1w', desc=None) + paths = client.ls('MNI152NLin2009cAsym', resolution='02', suffix='T1w', desc=None) assert paths path = Path(paths[0]) assert path.read_bytes() == b'' @@ -138,14 +128,14 @@ def test_bad_skeleton(tmp_path, monkeypatch): path.write_bytes(error_file.read_bytes()) # Test directly before testing through API paths - tf._truncate_s3_errors(paths) + templateflow.client._truncate_s3_errors(paths) assert path.read_bytes() == b'' path.write_bytes(error_file.read_bytes()) monkeypatch.setattr(requests, 'get', mock_get) with pytest.raises(RuntimeError): - tf.get('MNI152NLin2009cAsym', resolution='02', suffix='T1w', desc=None) + client.get('MNI152NLin2009cAsym', resolution='02', suffix='T1w', desc=None) # Running get clears bad files before attempting to download assert path.read_bytes() == b''