1- import logging
21import os
3- import tempfile
42from pathlib import Path
53from typing import Any , Mapping
6- from urllib .parse import urlsplit
74
85import mne
96import numpy as np
107import xarray as xr
118from docstring_inheritance import NumpyDocstringInheritanceInitMeta
129from dotenv import load_dotenv
1310from joblib import Parallel , delayed
14- from mne .utils import warn
1511from mne_bids import find_matching_paths , get_bids_path_from_fname , read_raw_bids
1612from pymongo import InsertOne , UpdateOne
17- from s3fs import S3FileSystem
13+ from rich .console import Console
14+ from rich .panel import Panel
15+ from rich .text import Text
1816
1917from braindecode .datasets import BaseConcatDataset
2018
19+ from . import downloader
2120from .bids_eeg_metadata import (
2221 build_query_from_kwargs ,
2322 load_eeg_attrs_from_bids_file ,
3332 EEGBIDSDataset ,
3433 EEGDashBaseDataset ,
3534)
35+ from .logging import logger
3636from .mongodb import MongoConnectionManager
3737from .paths import get_default_cache_dir
38-
39- logger = logging .getLogger ("eegdash" )
38+ from .utils import _init_mongo_client
4039
4140
4241class EEGDash :
@@ -74,19 +73,26 @@ def __init__(self, *, is_public: bool = True, is_staging: bool = False) -> None:
7473
7574 if self .is_public :
7675 DB_CONNECTION_STRING = mne .utils .get_config ("EEGDASH_DB_URI" )
76+ if not DB_CONNECTION_STRING :
77+ try :
78+ _init_mongo_client ()
79+ DB_CONNECTION_STRING = mne .utils .get_config ("EEGDASH_DB_URI" )
80+ except Exception :
81+ DB_CONNECTION_STRING = None
7782 else :
7883 load_dotenv ()
7984 DB_CONNECTION_STRING = os .getenv ("DB_CONNECTION_STRING" )
8085
8186 # Use singleton to get MongoDB client, database, and collection
87+ if not DB_CONNECTION_STRING :
88+ raise RuntimeError (
89+ "No MongoDB connection string configured. Set MNE config 'EEGDASH_DB_URI' "
90+ "or environment variable 'DB_CONNECTION_STRING'."
91+ )
8292 self .__client , self .__db , self .__collection = MongoConnectionManager .get_client (
8393 DB_CONNECTION_STRING , is_staging
8494 )
8595
86- self .filesystem = S3FileSystem (
87- anon = True , client_kwargs = {"region_name" : "us-east-2" }
88- )
89-
9096 def find (
9197 self , query : dict [str , Any ] = None , / , ** kwargs
9298 ) -> list [Mapping [str , Any ]]:
@@ -310,83 +316,6 @@ def _raise_if_conflicting_constraints(
310316 f"Conflicting constraints for '{ key } ': disjoint sets { r_val !r} and { k_val !r} "
311317 )
312318
313- def load_eeg_data_from_s3 (self , s3path : str ) -> xr .DataArray :
314- """Load EEG data from an S3 URI into an ``xarray.DataArray``.
315-
316- Preserves the original filename, downloads sidecar files when applicable
317- (e.g., ``.fdt`` for EEGLAB, ``.vmrk``/``.eeg`` for BrainVision), and uses
318- MNE's direct readers.
319-
320- Parameters
321- ----------
322- s3path : str
323- An S3 URI (should start with "s3://").
324-
325- Returns
326- -------
327- xr.DataArray
328- EEG data with dimensions ``("channel", "time")``.
329-
330- Raises
331- ------
332- ValueError
333- If the file extension is unsupported.
334-
335- """
336- # choose a temp dir so sidecars can be colocated
337- with tempfile .TemporaryDirectory () as tmpdir :
338- # Derive local filenames from the S3 key to keep base name consistent
339- s3_key = urlsplit (s3path ).path # e.g., "/dsXXXX/sub-.../..._eeg.set"
340- basename = Path (s3_key ).name
341- ext = Path (basename ).suffix .lower ()
342- local_main = Path (tmpdir ) / basename
343-
344- # Download main file
345- with (
346- self .filesystem .open (s3path , mode = "rb" ) as fsrc ,
347- open (local_main , "wb" ) as fdst ,
348- ):
349- fdst .write (fsrc .read ())
350-
351- # Determine and fetch any required sidecars
352- sidecars : list [str ] = []
353- if ext == ".set" : # EEGLAB
354- sidecars = [".fdt" ]
355- elif ext == ".vhdr" : # BrainVision
356- sidecars = [".vmrk" , ".eeg" , ".dat" , ".raw" ]
357-
358- for sc_ext in sidecars :
359- sc_key = s3_key [: - len (ext )] + sc_ext
360- sc_uri = f"s3://{ urlsplit (s3path ).netloc } { sc_key } "
361- try :
362- # If sidecar exists, download next to the main file
363- info = self .filesystem .info (sc_uri )
364- if info :
365- sc_local = Path (tmpdir ) / Path (sc_key ).name
366- with (
367- self .filesystem .open (sc_uri , mode = "rb" ) as fsrc ,
368- open (sc_local , "wb" ) as fdst ,
369- ):
370- fdst .write (fsrc .read ())
371- except Exception :
372- # Sidecar not present; skip silently
373- pass
374-
375- # Read using appropriate MNE reader
376- raw = mne .io .read_raw (str (local_main ), preload = True , verbose = False )
377-
378- data = raw .get_data ()
379- fs = raw .info ["sfreq" ]
380- max_time = data .shape [1 ] / fs
381- time_steps = np .linspace (0 , max_time , data .shape [1 ]).squeeze ()
382- channel_names = raw .ch_names
383-
384- return xr .DataArray (
385- data = data ,
386- dims = ["channel" , "time" ],
387- coords = {"time" : time_steps , "channel" : channel_names },
388- )
389-
390319 def load_eeg_data_from_bids_file (self , bids_file : str ) -> xr .DataArray :
391320 """Load EEG data from a local BIDS-formatted file.
392321
@@ -508,39 +437,13 @@ def get(self, query: dict[str, Any]) -> list[xr.DataArray]:
508437 results = Parallel (
509438 n_jobs = - 1 if len (sessions ) > 1 else 1 , prefer = "threads" , verbose = 1
510439 )(
511- delayed (self .load_eeg_data_from_s3 )(self ._get_s3path (session ))
440+ delayed (downloader .load_eeg_from_s3 )(
441+ downloader .get_s3path ("s3://openneuro.org" , session ["bidspath" ])
442+ )
512443 for session in sessions
513444 )
514445 return results
515446
516- def _get_s3path (self , record : Mapping [str , Any ] | str ) -> str :
517- """Build an S3 URI from a DB record or a relative path.
518-
519- Parameters
520- ----------
521- record : dict or str
522- Either a DB record containing a ``'bidspath'`` key, or a relative
523- path string under the OpenNeuro bucket.
524-
525- Returns
526- -------
527- str
528- Fully qualified S3 URI.
529-
530- Raises
531- ------
532- ValueError
533- If a mapping is provided but ``'bidspath'`` is missing.
534-
535- """
536- if isinstance (record , str ):
537- rel = record
538- else :
539- rel = record .get ("bidspath" )
540- if not rel :
541- raise ValueError ("Record missing 'bidspath' for S3 path resolution" )
542- return f"s3://openneuro.org/{ rel } "
543-
544447 def _add_request (self , record : dict ):
545448 """Internal helper method to create a MongoDB insertion request for a record."""
546449 return InsertOne (record )
@@ -552,8 +455,11 @@ def add(self, record: dict):
552455 except ValueError as e :
553456 logger .error ("Validation error for record: %s " , record ["data_name" ])
554457 logger .error (e )
555- except :
556- logger .error ("Error adding record: %s " , record ["data_name" ])
458+ except Exception as exc :
459+ logger .error (
460+ "Error adding record: %s " , record .get ("data_name" , "<unknown>" )
461+ )
462+ logger .debug ("Add operation failed" , exc_info = exc )
557463
558464 def _update_request (self , record : dict ):
559465 """Internal helper method to create a MongoDB update request for a record."""
@@ -572,8 +478,11 @@ def update(self, record: dict):
572478 self .__collection .update_one (
573479 {"data_name" : record ["data_name" ]}, {"$set" : record }
574480 )
575- except : # silent failure
576- logger .error ("Error updating record: %s" , record ["data_name" ])
481+ except Exception as exc : # log and continue
482+ logger .error (
483+ "Error updating record: %s" , record .get ("data_name" , "<unknown>" )
484+ )
485+ logger .debug ("Update operation failed" , exc_info = exc )
577486
578487 def exists (self , query : dict [str , Any ]) -> bool :
579488 """Alias for :meth:`exist` provided for API clarity."""
@@ -726,13 +635,15 @@ def __init__(
726635 self .records = records
727636 self .download = download
728637 self .n_jobs = n_jobs
729- self .eeg_dash_instance = eeg_dash_instance or EEGDash ()
638+ self .eeg_dash_instance = eeg_dash_instance
730639
731640 # Resolve a unified cache directory across code/tests/CI
732641 self .cache_dir = Path (cache_dir or get_default_cache_dir ())
733642
734643 if not self .cache_dir .exists ():
735- warn (f"Cache directory does not exist, creating it: { self .cache_dir } " )
644+ logger .warning (
645+ f"Cache directory does not exist, creating it: { self .cache_dir } "
646+ )
736647 self .cache_dir .mkdir (exist_ok = True , parents = True )
737648
738649 # Separate query kwargs from other kwargs passed to the BaseDataset constructor
@@ -772,21 +683,28 @@ def __init__(
772683 not _suppress_comp_warning
773684 and self .query ["dataset" ] in RELEASE_TO_OPENNEURO_DATASET_MAP .values ()
774685 ):
775- warn (
776- "If you are not participating in the competition, you can ignore this warning!"
777- "\n \n "
778- "EEG 2025 Competition Data Notice:\n "
779- "---------------------------------\n "
780- " You are loading the dataset that is used in the EEG 2025 Competition:\n "
781- "IMPORTANT: The data accessed via `EEGDashDataset` is NOT identical to what you get from `EEGChallengeDataset` object directly.\n "
782- "and it is not what you will use for the competition. Downsampling and filtering were applied to the data"
783- "to allow more people to participate.\n "
784- "\n "
785- "If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.\n "
786- "\n " ,
787- UserWarning ,
788- module = "eegdash" ,
686+ message_text = Text .from_markup (
687+ "[italic]This notice is only for users who are participating in the [link=https://eeg2025.github.io/]EEG 2025 Competition[/link].[/italic]\n \n "
688+ "[bold]EEG 2025 Competition Data Notice![/bold]\n "
689+ "You are loading one of the datasets that is used in competition, but via `EEGDashDataset`.\n \n "
690+ "[bold red]IMPORTANT[/bold red]: \n "
691+ "If you download data from `EEGDashDataset`, it is [u]NOT[/u] identical to the official competition data, which is accessed via `EEGChallengeDataset`. "
692+ "The competition data has been downsampled and filtered.\n \n "
693+ "[bold]If you are participating in the competition, you must use the `EEGChallengeDataset` object to ensure consistency.[/bold] \n \n "
694+ "If you are not participating in the competition, you can ignore this message."
789695 )
696+ warning_panel = Panel (
697+ message_text ,
698+ title = "[yellow]EEG 2025 Competition Data Notice[/yellow]" ,
699+ subtitle = "[cyan]Source: EEGDashDataset[/cyan]" ,
700+ border_style = "yellow" ,
701+ )
702+
703+ try :
704+ Console ().print (warning_panel )
705+ except Exception :
706+ logger .warning (str (message_text ))
707+
790708 if records is not None :
791709 self .records = records
792710 datasets = [
@@ -848,16 +766,15 @@ def __init__(
848766 )
849767 )
850768 elif self .query :
851- # This is the DB query path that we are improving
769+ if self .eeg_dash_instance is None :
770+ self .eeg_dash_instance = EEGDash ()
852771 datasets = self ._find_datasets (
853772 query = build_query_from_kwargs (** self .query ),
854773 description_fields = description_fields ,
855774 base_dataset_kwargs = base_dataset_kwargs ,
856775 )
857776 # We only need filesystem if we need to access S3
858- self .filesystem = S3FileSystem (
859- anon = True , client_kwargs = {"region_name" : "us-east-2" }
860- )
777+ self .filesystem = downloader .get_s3_filesystem ()
861778 else :
862779 raise ValueError (
863780 "You must provide either 'records', a 'data_dir', or a query/keyword arguments for filtering."
0 commit comments