44import subprocess
55import urllib .request
66from pathlib import Path
7+ from typing import Dict , Any
78
89import rich_click as click
910from rich .console import Console
1213from ..utils .logging import get_logger
1314from ..utils .validators import VotuDerepError
1415
16+ # ---------------------------------------------------------------------------
17+ # Dataset registry
18+ # ---------------------------------------------------------------------------
19+ # Each dataset is a mapping of arbitrary item keys -> {"url": ..., "path": ...}
20+ # - "url": full remote URL
21+ # - "path": relative output path under the chosen --outdir
22+ #
23+ # Add new datasets by inserting additional top-level keys into DATASETS.
24+ # Avoid special-casing; if paired reads exist, list R1 and R2 explicitly.
25+ DATASETS : Dict [str , Dict [str , Dict [str , str ]]] = {
26+ "virome" : {
27+ # Assembly
28+ "assembly" : {
29+ "url" : "https://zenodo.org/api/records/10650983/files/illumina_sample_pool_megahit.fa.gz/content" ,
30+ "path" : "human_gut_assembly.fa.gz" ,
31+ },
32+ # Reads: explicit R1/R2 entries (no implicit pairing logic)
33+ "ERR6797443_R1" : {
34+ "url" : "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/003/ERR6797443/ERR6797443_1.fastq.gz" ,
35+ "path" : "reads/ERR6797443_R1.fastq.gz" ,
36+ },
37+ "ERR6797443_R2" : {
38+ "url" : "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/003/ERR6797443/ERR6797443_2.fastq.gz" ,
39+ "path" : "reads/ERR6797443_R2.fastq.gz" ,
40+ },
41+ "ERR6797444_R1" : {
42+ "url" : "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/004/ERR6797444/ERR6797444_1.fastq.gz" ,
43+ "path" : "reads/ERR6797444_R1.fastq.gz" ,
44+ },
45+ "ERR6797444_R2" : {
46+ "url" : "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/004/ERR6797444/ERR6797444_2.fastq.gz" ,
47+ "path" : "reads/ERR6797444_R2.fastq.gz" ,
48+ },
49+ "ERR6797445_R1" : {
50+ "url" : "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/005/ERR6797445/ERR6797445_1.fastq.gz" ,
51+ "path" : "reads/ERR6797445_R1.fastq.gz" ,
52+ },
53+ "ERR6797445_R2" : {
54+ "url" : "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/005/ERR6797445/ERR6797445_2.fastq.gz" ,
55+ "path" : "reads/ERR6797445_R2.fastq.gz" ,
56+ },
57+ }
58+ }
59+
1560console = Console (stderr = True )
1661logger = get_logger (__name__ )
1762
1863
1964def download_file (url : str , output_path : str , description : str = "Downloading" ):
20- """Download a file from a URL with progress indication."""
65+ """Download a file from a URL with progress indication (urllib)."""
66+ output_file = Path (output_path )
67+ temp_file = Path (str (output_path ) + ".downloading" )
68+
69+ # Check if final file already exists (complete download)
70+ if output_file .exists ():
71+ logger .info (f"File already exists, skipping download: { output_path } " )
72+ console .print (
73+ f"[yellow]✓ Skipping { os .path .basename (output_path )} (already exists)[/yellow]"
74+ )
75+ return
76+
77+ # Clean up any partial downloads
78+ if temp_file .exists ():
79+ logger .info (f"Removing partial download: { temp_file } " )
80+ temp_file .unlink ()
81+
2182 try :
2283 with Progress (
2384 SpinnerColumn (),
@@ -32,17 +93,39 @@ def reporthook(block_num, block_size, total_size):
3293 percent = min (100 , (block_num * block_size * 100 ) / total_size )
3394 progress .update (task , completed = percent )
3495
35- urllib .request .urlretrieve (url , output_path , reporthook )
96+ # Download to temporary file
97+ urllib .request .urlretrieve (url , str (temp_file ), reporthook )
3698 progress .update (task , completed = 100 )
3799
100+ # Rename to final filename on success
101+ temp_file .rename (output_file )
38102 logger .info (f"Downloaded: { output_path } " )
39103
40104 except Exception as e :
105+ # Clean up partial download on failure
106+ if temp_file .exists ():
107+ temp_file .unlink ()
41108 raise VotuDerepError (f"Failed to download { url } : { e } " )
42109
43110
44111def run_curl (url : str , output_path : str , description : str = "Downloading" ):
45112 """Download a file using curl with progress indication."""
113+ output_file = Path (output_path )
114+ temp_file = Path (str (output_path ) + ".downloading" )
115+
116+ # Check if final file already exists (complete download)
117+ if output_file .exists ():
118+ logger .info (f"File already exists, skipping download: { output_path } " )
119+ console .print (
120+ f"[yellow]✓ Skipping { os .path .basename (output_path )} (already exists)[/yellow]"
121+ )
122+ return
123+
124+ # Clean up any partial downloads
125+ if temp_file .exists ():
126+ logger .info (f"Removing partial download: { temp_file } " )
127+ temp_file .unlink ()
128+
46129 try :
47130 with Progress (
48131 SpinnerColumn (),
@@ -52,14 +135,20 @@ def run_curl(url: str, output_path: str, description: str = "Downloading"):
52135 ) as progress :
53136 task = progress .add_task (f"{ description } { os .path .basename (output_path )} " )
54137
55- cmd = ["curl" , "-L" , url , "-o" , output_path ]
138+ # Download to temporary file
139+ cmd = ["curl" , "-L" , url , "-o" , str (temp_file )]
56140 subprocess .run (cmd , capture_output = True , text = True , check = True )
57141
58142 progress .update (task , completed = 100 )
59143
144+ # Rename to final filename on success
145+ temp_file .rename (output_file )
60146 logger .info (f"Downloaded: { output_path } " )
61147
62148 except subprocess .CalledProcessError as e :
149+ # Clean up partial download on failure
150+ if temp_file .exists ():
151+ temp_file .unlink ()
63152 raise VotuDerepError (f"Failed to download { url } : { e .stderr } " )
64153 except FileNotFoundError :
65154 raise VotuDerepError ("curl command not found. Please install curl." )
@@ -73,51 +162,60 @@ def run_curl(url: str, output_path: str, description: str = "Downloading"):
73162 show_default = True ,
74163 help = "Where to put the output files" ,
75164)
165+ @click .option (
166+ "-n" ,
167+ "--name" ,
168+ "dataset_name" ,
169+ default = "virome" ,
170+ show_default = True ,
171+ help = "Dataset name to download (registered in DATASETS)" ,
172+ )
76173@click .pass_context
77- def trainingdata (ctx , outdir : str ):
174+ def trainingdata (ctx , outdir : str , dataset_name : str ):
78175 """
79176 Download training dataset from the internet.
80177
81- Downloads viral assembly and sequencing reads for training purposes.
178+ Uses a registry (DATASETS) of named datasets, each containing a set of
179+ {url, path} items. Adds new datasets by extending the DATASETS dict.
82180 """
83181 verbose = ctx .obj .get ("verbose" , False )
84182
85183 if verbose :
86184 console .print (f"[blue]Output directory:[/blue] { outdir } " )
185+ console .print (f"[blue]Dataset:[/blue] { dataset_name } " )
87186
88- # Create output directory structure
89- outdir_path = Path (outdir )
90- reads_dir = outdir_path / "reads"
187+ # Resolve dataset
188+ dataset : Dict [str , Any ] = DATASETS .get (dataset_name )
189+ if not dataset :
190+ available = ", " .join (sorted (DATASETS .keys ()))
191+ raise VotuDerepError (
192+ f"Unknown dataset '{ dataset_name } '. Available datasets: { available } "
193+ )
91194
195+ # Create output directory
196+ outdir_path = Path (outdir )
92197 try :
93- reads_dir .mkdir (parents = True , exist_ok = True )
94- logger .info (f"Created directory structure : { reads_dir } " )
198+ outdir_path .mkdir (parents = True , exist_ok = True )
199+ logger .info (f"Ensured base output directory exists : { outdir_path } " )
95200
96201 console .print ("[bold green]Downloading training dataset...[/bold green]" )
97202
98- # Download assembly
99- assembly_url = "https://zenodo.org/api/records/10650983/files/illumina_sample_pool_megahit.fa.gz/content"
100- assembly_path = outdir_path / "human_gut_assembly.fa.gz"
101-
102- console .print ("\n [blue]Downloading assembly...[/blue]" )
103- download_file (assembly_url , str (assembly_path ), "Downloading assembly" )
104-
105- # Download reads
106- console .print ("\n [blue]Downloading sequencing reads...[/blue]" )
107- ebi_base = "ftp://ftp.sra.ebi.ac.uk/vol1/fastq"
108-
109- reads_to_download = [
110- ("ERR6797445" , "ERR679/005/ERR6797445" ),
111- ("ERR6797444" , "ERR679/004/ERR6797444" ),
112- ("ERR6797443" , "ERR679/003/ERR6797443" ),
113- ]
114-
115- for sample_id , path_suffix in reads_to_download :
116- for read_num in ["1" , "2" ]:
117- url = f"{ ebi_base } /{ path_suffix } /{ sample_id } _{ read_num } .fastq.gz"
118- output_file = reads_dir / f"{ sample_id } _R{ read_num } .fastq.gz"
119-
120- run_curl (url , str (output_file ), f"Downloading { sample_id } _R{ read_num } " )
203+ # Download each item in the dataset
204+ for key , entry in dataset .items ():
205+ url = entry ["url" ]
206+ rel_path = entry ["path" ]
207+ dest_path = outdir_path / rel_path
208+
209+ # Ensure parent directory exists for this entry
210+ dest_path .parent .mkdir (parents = True , exist_ok = True )
211+
212+ # Choose downloader: use curl for ftp URLs; urllib for http(s)
213+ is_ftp = url .lower ().startswith ("ftp://" )
214+ console .print (f"\n [blue]Downloading { key } ...[/blue]" )
215+ if is_ftp :
216+ run_curl (url , str (dest_path ), f"Downloading { key } " )
217+ else :
218+ download_file (url , str (dest_path ), f"Downloading { key } " )
121219
122220 console .print ("\n [bold green]✓ Training dataset downloaded successfully![/bold green]" )
123221 console .print (f"[blue]Files saved to:[/blue] { outdir_path .absolute ()} " )
0 commit comments