Skip to content

Commit a505f5a

Browse files
Merge pull request #468 from jeromekelleher/finalise-dataset-import
Remove requirement for date_field
2 parents 6a18655 + 3c27b5f commit a505f5a

File tree

6 files changed

+64
-61
lines changed

6 files changed

+64
-61
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ jobs:
8080
8181
- name: Validate
8282
run: |
83-
sc2ts validate -v testrun/dataset.zarr testrun/results/test/test_2020-02-02.ts
83+
sc2ts validate -v --date-field=date testrun/dataset.zarr testrun/results/test/test_2020-02-02.ts
8484
8585
- name: Info
8686
run: |

sc2ts/cli.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,20 @@ def import_alignments(dataset, fastas, initialise, progress, verbose):
161161
@click.command()
162162
@click.argument("dataset", type=click.Path(dir_okay=True, file_okay=False))
163163
@click.argument("metadata", type=click.Path(dir_okay=False, file_okay=True))
164+
@click.option(
165+
"--field-descriptions",
166+
type=click.File(mode="r"),
167+
default=None,
168+
help="JSON formatted file of field descriptions",
169+
)
164170
@click.option(
165171
"--viridian",
166172
is_flag=True,
167173
help="Do some preprocessing appropriate for the Viridian metadata "
168174
"(Available at https://figshare.com/ndownloader/files/49694808)",
169175
)
170176
@verbose
171-
def import_metadata(dataset, metadata, viridian, verbose):
177+
def import_metadata(dataset, metadata, field_descriptions, viridian, verbose):
172178
"""
173179
Import a CSV/TSV metadata file into the dataset.
174180
"""
@@ -178,48 +184,14 @@ def import_metadata(dataset, metadata, viridian, verbose):
178184
if viridian:
179185
dtype = {"Artic_primer_version": str}
180186
df_in = pd.read_csv(metadata, sep="\t", dtype=dtype)
181-
date_field = "date"
182187
index_field = "Run"
183188
if viridian:
184189
df_in = sc2ts.massage_viridian_metadata(df_in)
185190
df = df_in.set_index(index_field)
186-
sc2ts.Dataset.add_metadata(dataset, df)
187-
188-
189-
@click.command()
190-
@click.argument("in_dataset", type=click.Path(dir_okay=True, file_okay=False))
191-
@click.argument("out_dataset", type=click.Path(dir_okay=True, file_okay=False))
192-
@click.option(
193-
"--date-field", default="date", help="The metadata field to use for dates"
194-
)
195-
@click.option(
196-
"-a",
197-
"--additional-field",
198-
default=[],
199-
help="Additional fields to sort by",
200-
multiple=True,
201-
)
202-
@chunk_cache_size
203-
@progress
204-
@verbose
205-
def reorder_dataset(
206-
in_dataset,
207-
out_dataset,
208-
chunk_cache_size,
209-
date_field,
210-
additional_field,
211-
progress,
212-
verbose,
213-
):
214-
"""
215-
Create a copy of the specified dataset where the samples are reordered by
216-
date (and optionally other fields).
217-
"""
218-
setup_logging(verbose)
219-
ds = sc2ts.Dataset(
220-
in_dataset, chunk_cache_size=chunk_cache_size, date_field=date_field
221-
)
222-
ds.reorder(out_dataset, show_progress=progress, additional_fields=additional_field)
191+
d = {}
192+
if field_descriptions is not None:
193+
d = json.load(field_descriptions)
194+
sc2ts.Dataset.add_metadata(dataset, df, field_descriptions=d)
223195

224196

225197
@click.command()
@@ -415,6 +387,11 @@ def infer(config_file, start, stop, force):
415387
@dataset
416388
@click.argument("ts_file")
417389
@deletions_as_missing
390+
@click.option(
391+
"--date-field",
392+
default=None,
393+
help="Specify date field to use. Required for metadata.",
394+
)
418395
@click.option(
419396
"--genotypes/--no-genotypes",
420397
default=True,
@@ -440,6 +417,7 @@ def infer(config_file, start, stop, force):
440417
def validate(
441418
dataset,
442419
ts_file,
420+
date_field,
443421
deletions_as_missing,
444422
genotypes,
445423
metadata,
@@ -453,7 +431,9 @@ def validate(
453431
setup_logging(verbose)
454432

455433
ts = tszip.load(ts_file)
456-
ds = sc2ts.Dataset(dataset, chunk_cache_size=chunk_cache_size)
434+
ds = sc2ts.Dataset(
435+
dataset, date_field=date_field, chunk_cache_size=chunk_cache_size
436+
)
457437
if genotypes:
458438
sc2ts.validate_genotypes(ts, ds, deletions_as_missing, show_progress=True)
459439
if metadata:
@@ -564,7 +544,6 @@ def cli():
564544

565545
cli.add_command(import_alignments)
566546
cli.add_command(import_metadata)
567-
cli.add_command(reorder_dataset)
568547

569548
cli.add_command(info_dataset)
570549
cli.add_command(info_matches)

sc2ts/dataset.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ def __len__(self):
9999
class CachedMetadataMapping(collections.abc.Mapping):
100100
def __init__(self, root, sample_id_map, date_field, chunk_cache_size):
101101
self.sample_id_map = sample_id_map
102-
self.sample_date = root[f"sample_{date_field}"][:].astype(str)
103-
self.sample_date_array = root[f"sample_{date_field}"]
104102
self.sample_id = root["sample_id"][:].astype(str)
105103
self.sample_id_array = root["sample_id"]
106104
# Mapping of field name to Zarr array
@@ -114,6 +112,10 @@ def __init__(self, root, sample_id_map, date_field, chunk_cache_size):
114112
self.chunk_cache = {}
115113

116114
logger.debug(f"Got {self.num_fields} metadata fields")
115+
self.date_field = date_field
116+
if date_field is not None:
117+
self.sample_date = root[f"sample_{date_field}"][:].astype(str)
118+
self.sample_date_array = root[f"sample_{date_field}"]
117119

118120
@property
119121
def num_fields(self):
@@ -145,6 +147,8 @@ def get_metadata(self, j):
145147
d[key] = bool(d[key])
146148
else:
147149
d[key] = str(d[key])
150+
if self.date_field is None:
151+
raise ValueError("No date field set, cannot get metadata items")
148152
# For compatibility in the short term:
149153
d["date"] = self.sample_date[j]
150154
d["strain"] = self.sample_id[j]
@@ -178,7 +182,7 @@ class Variant:
178182

179183
class Dataset(collections.abc.Mapping):
180184

181-
def __init__(self, path, chunk_cache_size=1, date_field="date", skip_metadata=False):
185+
def __init__(self, path, chunk_cache_size=1, date_field=None):
182186
logger.info(f"Loading dateset @{path} using {date_field} as date field")
183187
self.date_field = date_field
184188
self.path = pathlib.Path(path)
@@ -196,13 +200,12 @@ def __init__(self, path, chunk_cache_size=1, date_field="date", skip_metadata=Fa
196200
self.haplotypes = CachedHaplotypeMapping(
197201
self.root, self.sample_id_map, chunk_cache_size
198202
)
199-
if not skip_metadata:
200-
self.metadata = CachedMetadataMapping(
201-
self.root,
202-
self.sample_id_map,
203-
date_field,
204-
chunk_cache_size=chunk_cache_size,
205-
)
203+
self.metadata = CachedMetadataMapping(
204+
self.root,
205+
self.sample_id_map,
206+
date_field,
207+
chunk_cache_size=chunk_cache_size,
208+
)
206209

207210
def __getitem__(self, key):
208211
return self.root[key]
@@ -432,7 +435,7 @@ def append_alignments(path, alignments):
432435
zarr.consolidate_metadata(store)
433436

434437
@staticmethod
435-
def add_metadata(path, df):
438+
def add_metadata(path, df, field_descriptions=dict()):
436439
"""
437440
Add metadata from the specified dataframe, indexed by sample ID.
438441
Each column will be added as a new array with prefix "sample_"
@@ -467,6 +470,8 @@ def add_metadata(path, df):
467470
overwrite=True,
468471
)
469472
z.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
473+
z.attrs["description"] = field_descriptions.get(colname, "")
474+
470475
z[:] = data
471476
logger.info(f"Wrote metadata array {z.name}")
472477

tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def fx_alignments_fasta(fx_data_cache):
3131

3232
@pytest.fixture
3333
def fx_alignments_mafft_fasta(fx_data_cache):
34-
# This is bgzipped so we can access directly
3534
cache_path = fx_data_cache / "alignments-mafft.fasta"
3635
if not cache_path.exists():
3736
with gzip.open("tests/data/alignments-mafft.fasta.gz") as src:
@@ -105,7 +104,7 @@ def fx_dataset(tmp_path, fx_data_cache, fx_alignments_fasta, fx_metadata_df):
105104
)
106105
sc2ts.Dataset.add_metadata(fs_path, fx_metadata_df)
107106
sc2ts.Dataset.create_zip(fs_path, cache_path)
108-
return sc2ts.Dataset(cache_path)
107+
return sc2ts.Dataset(cache_path, date_field="date")
109108

110109

111110
@pytest.fixture

tests/test_cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,17 @@ def test_suite_data(self, tmp_path, fx_metadata_tsv, fx_alignments_fasta):
5252
catch_exceptions=False,
5353
)
5454
assert result.exit_code == 0
55+
fields_path = tmp_path / "fields.json"
56+
with open(fields_path, "w") as f:
57+
f.write(json.dumps({"NO SUCH": "A", "Viridian_pangolin": "PANGO"}))
5558

5659
result = runner.invoke(
5760
cli.cli,
58-
f"import-metadata {ds_path} {fx_metadata_tsv} ",
61+
f"import-metadata {ds_path} {fx_metadata_tsv} --field-descriptions={fields_path}",
5962
catch_exceptions=False,
6063
)
64+
ds = sc2ts.Dataset(ds_path)
65+
assert ds.metadata.fields["Viridian_pangolin"].attrs["description"] == "PANGO"
6166

6267
def test_viridian_metadata(
6368
self, tmp_path, fx_raw_viridian_metadata_tsv, fx_alignments_fasta
@@ -379,7 +384,7 @@ def test_date(self, tmp_path, fx_ts_map, fx_dataset, date):
379384
runner = ct.CliRunner(mix_stderr=False)
380385
result = runner.invoke(
381386
cli.cli,
382-
f"validate {fx_dataset.path} {ts_path} ",
387+
f"validate {fx_dataset.path} {ts_path} --date-field=date",
383388
catch_exceptions=False,
384389
)
385390
assert result.exit_code == 0

tests/test_dataset.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def test_add_metadata(self, tmp_path, fx_encoded_alignments, fx_metadata_df):
135135
path = tmp_path / "dataset.vcz"
136136
ds = sc2ts.Dataset.new(path)
137137
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments)
138-
sc2ts.Dataset.add_metadata(path, fx_metadata_df)
138+
field_descriptions = {col: col.upper() for col in fx_metadata_df}
139+
sc2ts.Dataset.add_metadata(
140+
path, fx_metadata_df, field_descriptions=field_descriptions
141+
)
139142

140143
sg_ds = sgkit.load_dataset(path)
141144
assert dict(sg_ds.sizes) == {
@@ -147,7 +150,9 @@ def test_add_metadata(self, tmp_path, fx_encoded_alignments, fx_metadata_df):
147150
}
148151
df = fx_metadata_df.loc[sg_ds["sample_id"].values]
149152
for col in fx_metadata_df:
150-
nt.assert_array_equal(df[col], sg_ds[f"sample_{col}"])
153+
x = sg_ds[f"sample_{col}"]
154+
nt.assert_array_equal(df[col], x)
155+
assert x.attrs["description"] == field_descriptions[col]
151156

152157
def test_create_zip(self, tmp_path, fx_encoded_alignments, fx_metadata_df):
153158

@@ -283,7 +288,7 @@ def test_import(self, tmp_path, fx_encoded_alignments_mafft):
283288
path = tmp_path / "dataset.vcz"
284289
sc2ts.Dataset.new(path)
285290
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments_mafft)
286-
ds = sc2ts.Dataset(path, skip_metadata=True)
291+
ds = sc2ts.Dataset(path)
287292
assert len(ds.haplotypes) == 19
288293
for k, v in fx_encoded_alignments_mafft.items():
289294
h = ds.haplotypes[k]
@@ -363,6 +368,12 @@ def test_known(self, fx_dataset):
363368
assert d["Genbank_N"] == -1
364369
assert d["Viridian_pangolin"] == "A"
365370

371+
def test_known_no_date_field(self, fx_dataset):
372+
ds = sc2ts.Dataset(fx_dataset.path)
373+
374+
with pytest.raises(ValueError, match="No date field set"):
375+
ds.metadata["SRR11772659"]
376+
366377
@pytest.mark.parametrize(
367378
["chunk_size", "cache_size"],
368379
[
@@ -382,7 +393,7 @@ def test_chunk_size_cache_size(
382393
sc2ts.Dataset.new(path, samples_chunk_size=chunk_size)
383394
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments)
384395
sc2ts.Dataset.add_metadata(path, fx_metadata_df)
385-
ds = sc2ts.Dataset(path, chunk_cache_size=cache_size)
396+
ds = sc2ts.Dataset(path, chunk_cache_size=cache_size, date_field="date")
386397
for strain in fx_encoded_alignments.keys():
387398
row = fx_metadata_df.loc[strain]
388399
d1 = ds.metadata[strain]
@@ -406,6 +417,10 @@ def test_as_dataframe(self, fx_dataset, fx_metadata_df):
406417
data2 = df2[col]
407418
nt.assert_array_equal(data1.to_numpy(), data2.to_numpy())
408419

420+
def test_metadata_field_descriptions(self, fx_dataset):
421+
for array in fx_dataset.metadata.fields.values():
422+
assert array.attrs["description"] == ""
423+
409424

410425
class TestEncodeAlignment:
411426
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)