Skip to content

Commit 75af377

Browse files
Merge pull request #129 from sccn/develop
Sync the branchs
2 parents df08554 + 9cb4abd commit 75af377

23 files changed

+718
-631
lines changed

.github/workflows/doc.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,4 @@ jobs:
6666
with:
6767
github_token: ${{ secrets.GITHUB_TOKEN }}
6868
publish_dir: ./docs/build/html
69+
cname: eegdash.org

eegdash/api.py

Lines changed: 58 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
import logging
21
import os
3-
import tempfile
42
from pathlib import Path
53
from typing import Any, Mapping
6-
from urllib.parse import urlsplit
74

85
import mne
96
import numpy as np
107
import xarray as xr
118
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
129
from dotenv import load_dotenv
1310
from joblib import Parallel, delayed
14-
from mne.utils import warn
1511
from mne_bids import find_matching_paths, get_bids_path_from_fname, read_raw_bids
1612
from pymongo import InsertOne, UpdateOne
17-
from s3fs import S3FileSystem
13+
from rich.console import Console
14+
from rich.panel import Panel
15+
from rich.text import Text
1816

1917
from braindecode.datasets import BaseConcatDataset
2018

19+
from . import downloader
2120
from .bids_eeg_metadata import (
2221
build_query_from_kwargs,
2322
load_eeg_attrs_from_bids_file,
@@ -33,10 +32,10 @@
3332
EEGBIDSDataset,
3433
EEGDashBaseDataset,
3534
)
35+
from .logging import logger
3636
from .mongodb import MongoConnectionManager
3737
from .paths import get_default_cache_dir
38-
39-
logger = logging.getLogger("eegdash")
38+
from .utils import _init_mongo_client
4039

4140

4241
class EEGDash:
@@ -74,19 +73,26 @@ def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
7473

7574
if self.is_public:
7675
DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
76+
if not DB_CONNECTION_STRING:
77+
try:
78+
_init_mongo_client()
79+
DB_CONNECTION_STRING = mne.utils.get_config("EEGDASH_DB_URI")
80+
except Exception:
81+
DB_CONNECTION_STRING = None
7782
else:
7883
load_dotenv()
7984
DB_CONNECTION_STRING = os.getenv("DB_CONNECTION_STRING")
8085

8186
# Use singleton to get MongoDB client, database, and collection
87+
if not DB_CONNECTION_STRING:
88+
raise RuntimeError(
89+
"No MongoDB connection string configured. Set MNE config 'EEGDASH_DB_URI' "
90+
"or environment variable 'DB_CONNECTION_STRING'."
91+
)
8292
self.__client, self.__db, self.__collection = MongoConnectionManager.get_client(
8393
DB_CONNECTION_STRING, is_staging
8494
)
8595

86-
self.filesystem = S3FileSystem(
87-
anon=True, client_kwargs={"region_name": "us-east-2"}
88-
)
89-
9096
def find(
9197
self, query: dict[str, Any] = None, /, **kwargs
9298
) -> list[Mapping[str, Any]]:
@@ -310,83 +316,6 @@ def _raise_if_conflicting_constraints(
310316
f"Conflicting constraints for '{key}': disjoint sets {r_val!r} and {k_val!r}"
311317
)
312318

313-
def load_eeg_data_from_s3(self, s3path: str) -> xr.DataArray:
314-
"""Load EEG data from an S3 URI into an ``xarray.DataArray``.
315-
316-
Preserves the original filename, downloads sidecar files when applicable
317-
(e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
318-
MNE's direct readers.
319-
320-
Parameters
321-
----------
322-
s3path : str
323-
An S3 URI (should start with "s3://").
324-
325-
Returns
326-
-------
327-
xr.DataArray
328-
EEG data with dimensions ``("channel", "time")``.
329-
330-
Raises
331-
------
332-
ValueError
333-
If the file extension is unsupported.
334-
335-
"""
336-
# choose a temp dir so sidecars can be colocated
337-
with tempfile.TemporaryDirectory() as tmpdir:
338-
# Derive local filenames from the S3 key to keep base name consistent
339-
s3_key = urlsplit(s3path).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
340-
basename = Path(s3_key).name
341-
ext = Path(basename).suffix.lower()
342-
local_main = Path(tmpdir) / basename
343-
344-
# Download main file
345-
with (
346-
self.filesystem.open(s3path, mode="rb") as fsrc,
347-
open(local_main, "wb") as fdst,
348-
):
349-
fdst.write(fsrc.read())
350-
351-
# Determine and fetch any required sidecars
352-
sidecars: list[str] = []
353-
if ext == ".set": # EEGLAB
354-
sidecars = [".fdt"]
355-
elif ext == ".vhdr": # BrainVision
356-
sidecars = [".vmrk", ".eeg", ".dat", ".raw"]
357-
358-
for sc_ext in sidecars:
359-
sc_key = s3_key[: -len(ext)] + sc_ext
360-
sc_uri = f"s3://{urlsplit(s3path).netloc}{sc_key}"
361-
try:
362-
# If sidecar exists, download next to the main file
363-
info = self.filesystem.info(sc_uri)
364-
if info:
365-
sc_local = Path(tmpdir) / Path(sc_key).name
366-
with (
367-
self.filesystem.open(sc_uri, mode="rb") as fsrc,
368-
open(sc_local, "wb") as fdst,
369-
):
370-
fdst.write(fsrc.read())
371-
except Exception:
372-
# Sidecar not present; skip silently
373-
pass
374-
375-
# Read using appropriate MNE reader
376-
raw = mne.io.read_raw(str(local_main), preload=True, verbose=False)
377-
378-
data = raw.get_data()
379-
fs = raw.info["sfreq"]
380-
max_time = data.shape[1] / fs
381-
time_steps = np.linspace(0, max_time, data.shape[1]).squeeze()
382-
channel_names = raw.ch_names
383-
384-
return xr.DataArray(
385-
data=data,
386-
dims=["channel", "time"],
387-
coords={"time": time_steps, "channel": channel_names},
388-
)
389-
390319
def load_eeg_data_from_bids_file(self, bids_file: str) -> xr.DataArray:
391320
"""Load EEG data from a local BIDS-formatted file.
392321
@@ -508,39 +437,13 @@ def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
508437
results = Parallel(
509438
n_jobs=-1 if len(sessions) > 1 else 1, prefer="threads", verbose=1
510439
)(
511-
delayed(self.load_eeg_data_from_s3)(self._get_s3path(session))
440+
delayed(downloader.load_eeg_from_s3)(
441+
downloader.get_s3path("s3://openneuro.org", session["bidspath"])
442+
)
512443
for session in sessions
513444
)
514445
return results
515446

516-
def _get_s3path(self, record: Mapping[str, Any] | str) -> str:
517-
"""Build an S3 URI from a DB record or a relative path.
518-
519-
Parameters
520-
----------
521-
record : dict or str
522-
Either a DB record containing a ``'bidspath'`` key, or a relative
523-
path string under the OpenNeuro bucket.
524-
525-
Returns
526-
-------
527-
str
528-
Fully qualified S3 URI.
529-
530-
Raises
531-
------
532-
ValueError
533-
If a mapping is provided but ``'bidspath'`` is missing.
534-
535-
"""
536-
if isinstance(record, str):
537-
rel = record
538-
else:
539-
rel = record.get("bidspath")
540-
if not rel:
541-
raise ValueError("Record missing 'bidspath' for S3 path resolution")
542-
return f"s3://openneuro.org/{rel}"
543-
544447
def _add_request(self, record: dict):
545448
"""Internal helper method to create a MongoDB insertion request for a record."""
546449
return InsertOne(record)
@@ -552,8 +455,11 @@ def add(self, record: dict):
552455
except ValueError as e:
553456
logger.error("Validation error for record: %s ", record["data_name"])
554457
logger.error(e)
555-
except:
556-
logger.error("Error adding record: %s ", record["data_name"])
458+
except Exception as exc:
459+
logger.error(
460+
"Error adding record: %s ", record.get("data_name", "<unknown>")
461+
)
462+
logger.debug("Add operation failed", exc_info=exc)
557463

558464
def _update_request(self, record: dict):
559465
"""Internal helper method to create a MongoDB update request for a record."""
@@ -572,8 +478,11 @@ def update(self, record: dict):
572478
self.__collection.update_one(
573479
{"data_name": record["data_name"]}, {"$set": record}
574480
)
575-
except: # silent failure
576-
logger.error("Error updating record: %s", record["data_name"])
481+
except Exception as exc: # log and continue
482+
logger.error(
483+
"Error updating record: %s", record.get("data_name", "<unknown>")
484+
)
485+
logger.debug("Update operation failed", exc_info=exc)
577486

578487
def exists(self, query: dict[str, Any]) -> bool:
579488
"""Alias for :meth:`exist` provided for API clarity."""
@@ -726,13 +635,15 @@ def __init__(
726635
self.records = records
727636
self.download = download
728637
self.n_jobs = n_jobs
729-
self.eeg_dash_instance = eeg_dash_instance or EEGDash()
638+
self.eeg_dash_instance = eeg_dash_instance
730639

731640
# Resolve a unified cache directory across code/tests/CI
732641
self.cache_dir = Path(cache_dir or get_default_cache_dir())
733642

734643
if not self.cache_dir.exists():
735-
warn(f"Cache directory does not exist, creating it: {self.cache_dir}")
644+
logger.warning(
645+
f"Cache directory does not exist, creating it: {self.cache_dir}"
646+
)
736647
self.cache_dir.mkdir(exist_ok=True, parents=True)
737648

738649
# Separate query kwargs from other kwargs passed to the BaseDataset constructor
@@ -772,21 +683,28 @@ def __init__(
772683
not _suppress_comp_warning
773684
and self.query["dataset"] in RELEASE_TO_OPENNEURO_DATASET_MAP.values()
774685
):
775-
warn(
776-
"If you are not participating in the competition, you can ignore this warning!"
777-
"\n\n"
778-
"EEG 2025 Competition Data Notice:\n"
779-
"---------------------------------\n"
780-
" You are loading the dataset that is used in the EEG 2025 Competition:\n"
781-
"IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` object directly.\n"
782-
"and it is not what you will use for the competition. Downsampling and filtering were applied to the data"
783-
"to allow more people to participate.\n"
784-
"\n"
785-
"If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n"
786-
"\n",
787-
UserWarning,
788-
module="eegdash",
686+
message_text = Text.from_markup(
687+
"[italic]This notice is only for users who are participating in the [link=https://eeg2025.github.io/]EEG 2025 Competition[/link].[/italic]\n\n"
688+
"[bold]EEG 2025 Competition Data Notice![/bold]\n"
689+
"You are loading one of the datasets that is used in competition, but via `EEGDashDataset`.\n\n"
690+
"[bold red]IMPORTANT[/bold red]: \n"
691+
"If you download data from `EEGDashDataset`, it is [u]NOT[/u] identical to the official competition data, which is accessed via `EEGChallengeDataset`. "
692+
"The competition data has been downsampled and filtered.\n\n"
693+
"[bold]If you are participating in the competition, you must use the `EEGChallengeDataset` object to ensure consistency.[/bold] \n\n"
694+
"If you are not participating in the competition, you can ignore this message."
789695
)
696+
warning_panel = Panel(
697+
message_text,
698+
title="[yellow]EEG 2025 Competition Data Notice[/yellow]",
699+
subtitle="[cyan]Source: EEGDashDataset[/cyan]",
700+
border_style="yellow",
701+
)
702+
703+
try:
704+
Console().print(warning_panel)
705+
except Exception:
706+
logger.warning(str(message_text))
707+
790708
if records is not None:
791709
self.records = records
792710
datasets = [
@@ -848,16 +766,15 @@ def __init__(
848766
)
849767
)
850768
elif self.query:
851-
# This is the DB query path that we are improving
769+
if self.eeg_dash_instance is None:
770+
self.eeg_dash_instance = EEGDash()
852771
datasets = self._find_datasets(
853772
query=build_query_from_kwargs(**self.query),
854773
description_fields=description_fields,
855774
base_dataset_kwargs=base_dataset_kwargs,
856775
)
857776
# We only need filesystem if we need to access S3
858-
self.filesystem = S3FileSystem(
859-
anon=True, client_kwargs={"region_name": "us-east-2"}
860-
)
777+
self.filesystem = downloader.get_s3_filesystem()
861778
else:
862779
raise ValueError(
863780
"You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."

0 commit comments

Comments
 (0)