Skip to content

Commit c0aa78f

Browse files
committed
v0.4.1
1 parent e4c3c3d commit c0aa78f

File tree

1 file changed

+131
-33
lines changed

1 file changed

+131
-33
lines changed

src/votuderep/commands/trainingdata.py

Lines changed: 131 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import subprocess
55
import urllib.request
66
from pathlib import Path
7+
from typing import Dict, Any
78

89
import rich_click as click
910
from rich.console import Console
@@ -12,12 +13,72 @@
1213
from ..utils.logging import get_logger
1314
from ..utils.validators import VotuDerepError
1415

16+
# ---------------------------------------------------------------------------
17+
# Dataset registry
18+
# ---------------------------------------------------------------------------
19+
# Each dataset is a mapping of arbitrary item keys -> {"url": ..., "path": ...}
20+
# - "url": full remote URL
21+
# - "path": relative output path under the chosen --outdir
22+
#
23+
# Add new datasets by inserting additional top-level keys into DATASETS.
24+
# Avoid special-casing; if paired reads exist, list R1 and R2 explicitly.
25+
DATASETS: Dict[str, Dict[str, Dict[str, str]]] = {
26+
"virome": {
27+
# Assembly
28+
"assembly": {
29+
"url": "https://zenodo.org/api/records/10650983/files/illumina_sample_pool_megahit.fa.gz/content",
30+
"path": "human_gut_assembly.fa.gz",
31+
},
32+
# Reads: explicit R1/R2 entries (no implicit pairing logic)
33+
"ERR6797443_R1": {
34+
"url": "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/003/ERR6797443/ERR6797443_1.fastq.gz",
35+
"path": "reads/ERR6797443_R1.fastq.gz",
36+
},
37+
"ERR6797443_R2": {
38+
"url": "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/003/ERR6797443/ERR6797443_2.fastq.gz",
39+
"path": "reads/ERR6797443_R2.fastq.gz",
40+
},
41+
"ERR6797444_R1": {
42+
"url": "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/004/ERR6797444/ERR6797444_1.fastq.gz",
43+
"path": "reads/ERR6797444_R1.fastq.gz",
44+
},
45+
"ERR6797444_R2": {
46+
"url": "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/004/ERR6797444/ERR6797444_2.fastq.gz",
47+
"path": "reads/ERR6797444_R2.fastq.gz",
48+
},
49+
"ERR6797445_R1": {
50+
"url": "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/005/ERR6797445/ERR6797445_1.fastq.gz",
51+
"path": "reads/ERR6797445_R1.fastq.gz",
52+
},
53+
"ERR6797445_R2": {
54+
"url": "ftp://ftp.sra.ebi.ac.uk/vol1/fastq/ERR679/005/ERR6797445/ERR6797445_2.fastq.gz",
55+
"path": "reads/ERR6797445_R2.fastq.gz",
56+
},
57+
}
58+
}
59+
1560
console = Console(stderr=True)
1661
logger = get_logger(__name__)
1762

1863

1964
def download_file(url: str, output_path: str, description: str = "Downloading"):
20-
"""Download a file from a URL with progress indication."""
65+
"""Download a file from a URL with progress indication (urllib)."""
66+
output_file = Path(output_path)
67+
temp_file = Path(str(output_path) + ".downloading")
68+
69+
# Check if final file already exists (complete download)
70+
if output_file.exists():
71+
logger.info(f"File already exists, skipping download: {output_path}")
72+
console.print(
73+
f"[yellow]✓ Skipping {os.path.basename(output_path)} (already exists)[/yellow]"
74+
)
75+
return
76+
77+
# Clean up any partial downloads
78+
if temp_file.exists():
79+
logger.info(f"Removing partial download: {temp_file}")
80+
temp_file.unlink()
81+
2182
try:
2283
with Progress(
2384
SpinnerColumn(),
@@ -32,17 +93,39 @@ def reporthook(block_num, block_size, total_size):
3293
percent = min(100, (block_num * block_size * 100) / total_size)
3394
progress.update(task, completed=percent)
3495

35-
urllib.request.urlretrieve(url, output_path, reporthook)
96+
# Download to temporary file
97+
urllib.request.urlretrieve(url, str(temp_file), reporthook)
3698
progress.update(task, completed=100)
3799

100+
# Rename to final filename on success
101+
temp_file.rename(output_file)
38102
logger.info(f"Downloaded: {output_path}")
39103

40104
except Exception as e:
105+
# Clean up partial download on failure
106+
if temp_file.exists():
107+
temp_file.unlink()
41108
raise VotuDerepError(f"Failed to download {url}: {e}")
42109

43110

44111
def run_curl(url: str, output_path: str, description: str = "Downloading"):
45112
"""Download a file using curl with progress indication."""
113+
output_file = Path(output_path)
114+
temp_file = Path(str(output_path) + ".downloading")
115+
116+
# Check if final file already exists (complete download)
117+
if output_file.exists():
118+
logger.info(f"File already exists, skipping download: {output_path}")
119+
console.print(
120+
f"[yellow]✓ Skipping {os.path.basename(output_path)} (already exists)[/yellow]"
121+
)
122+
return
123+
124+
# Clean up any partial downloads
125+
if temp_file.exists():
126+
logger.info(f"Removing partial download: {temp_file}")
127+
temp_file.unlink()
128+
46129
try:
47130
with Progress(
48131
SpinnerColumn(),
@@ -52,14 +135,20 @@ def run_curl(url: str, output_path: str, description: str = "Downloading"):
52135
) as progress:
53136
task = progress.add_task(f"{description} {os.path.basename(output_path)}")
54137

55-
cmd = ["curl", "-L", url, "-o", output_path]
138+
# Download to temporary file
139+
cmd = ["curl", "-L", url, "-o", str(temp_file)]
56140
subprocess.run(cmd, capture_output=True, text=True, check=True)
57141

58142
progress.update(task, completed=100)
59143

144+
# Rename to final filename on success
145+
temp_file.rename(output_file)
60146
logger.info(f"Downloaded: {output_path}")
61147

62148
except subprocess.CalledProcessError as e:
149+
# Clean up partial download on failure
150+
if temp_file.exists():
151+
temp_file.unlink()
63152
raise VotuDerepError(f"Failed to download {url}: {e.stderr}")
64153
except FileNotFoundError:
65154
raise VotuDerepError("curl command not found. Please install curl.")
@@ -73,51 +162,60 @@ def run_curl(url: str, output_path: str, description: str = "Downloading"):
73162
show_default=True,
74163
help="Where to put the output files",
75164
)
165+
@click.option(
166+
"-n",
167+
"--name",
168+
"dataset_name",
169+
default="virome",
170+
show_default=True,
171+
help="Dataset name to download (registered in DATASETS)",
172+
)
76173
@click.pass_context
77-
def trainingdata(ctx, outdir: str):
174+
def trainingdata(ctx, outdir: str, dataset_name: str):
78175
"""
79176
Download training dataset from the internet.
80177
81-
Downloads viral assembly and sequencing reads for training purposes.
178+
Uses a registry (DATASETS) of named datasets, each containing a set of
179+
{url, path} items. Adds new datasets by extending the DATASETS dict.
82180
"""
83181
verbose = ctx.obj.get("verbose", False)
84182

85183
if verbose:
86184
console.print(f"[blue]Output directory:[/blue] {outdir}")
185+
console.print(f"[blue]Dataset:[/blue] {dataset_name}")
87186

88-
# Create output directory structure
89-
outdir_path = Path(outdir)
90-
reads_dir = outdir_path / "reads"
187+
# Resolve dataset
188+
dataset: Dict[str, Any] = DATASETS.get(dataset_name)
189+
if not dataset:
190+
available = ", ".join(sorted(DATASETS.keys()))
191+
raise VotuDerepError(
192+
f"Unknown dataset '{dataset_name}'. Available datasets: {available}"
193+
)
91194

195+
# Create output directory
196+
outdir_path = Path(outdir)
92197
try:
93-
reads_dir.mkdir(parents=True, exist_ok=True)
94-
logger.info(f"Created directory structure: {reads_dir}")
198+
outdir_path.mkdir(parents=True, exist_ok=True)
199+
logger.info(f"Ensured base output directory exists: {outdir_path}")
95200

96201
console.print("[bold green]Downloading training dataset...[/bold green]")
97202

98-
# Download assembly
99-
assembly_url = "https://zenodo.org/api/records/10650983/files/illumina_sample_pool_megahit.fa.gz/content"
100-
assembly_path = outdir_path / "human_gut_assembly.fa.gz"
101-
102-
console.print("\n[blue]Downloading assembly...[/blue]")
103-
download_file(assembly_url, str(assembly_path), "Downloading assembly")
104-
105-
# Download reads
106-
console.print("\n[blue]Downloading sequencing reads...[/blue]")
107-
ebi_base = "ftp://ftp.sra.ebi.ac.uk/vol1/fastq"
108-
109-
reads_to_download = [
110-
("ERR6797445", "ERR679/005/ERR6797445"),
111-
("ERR6797444", "ERR679/004/ERR6797444"),
112-
("ERR6797443", "ERR679/003/ERR6797443"),
113-
]
114-
115-
for sample_id, path_suffix in reads_to_download:
116-
for read_num in ["1", "2"]:
117-
url = f"{ebi_base}/{path_suffix}/{sample_id}_{read_num}.fastq.gz"
118-
output_file = reads_dir / f"{sample_id}_R{read_num}.fastq.gz"
119-
120-
run_curl(url, str(output_file), f"Downloading {sample_id}_R{read_num}")
203+
# Download each item in the dataset
204+
for key, entry in dataset.items():
205+
url = entry["url"]
206+
rel_path = entry["path"]
207+
dest_path = outdir_path / rel_path
208+
209+
# Ensure parent directory exists for this entry
210+
dest_path.parent.mkdir(parents=True, exist_ok=True)
211+
212+
# Choose downloader: use curl for ftp URLs; urllib for http(s)
213+
is_ftp = url.lower().startswith("ftp://")
214+
console.print(f"\n[blue]Downloading {key}...[/blue]")
215+
if is_ftp:
216+
run_curl(url, str(dest_path), f"Downloading {key}")
217+
else:
218+
download_file(url, str(dest_path), f"Downloading {key}")
121219

122220
console.print("\n[bold green]✓ Training dataset downloaded successfully![/bold green]")
123221
console.print(f"[blue]Files saved to:[/blue] {outdir_path.absolute()}")

0 commit comments

Comments
 (0)