Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, path, add_zero_base=True):
def __getitem__(self, key):
x = self.reader[key]
h = np.array(x).astype(str)
h = np.char.upper(h)
if self.add_zero_base:
return np.append(["X"], h)
return h
Expand Down
12 changes: 8 additions & 4 deletions sc2ts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class Variant:

class Dataset(collections.abc.Mapping):

def __init__(self, path, chunk_cache_size=1, date_field="date"):
def __init__(self, path, chunk_cache_size=1, date_field="date", skip_metadata=False):
logger.info(f"Loading dateset @{path} using {date_field} as date field")
self.date_field = date_field
self.path = pathlib.Path(path)
Expand All @@ -196,9 +196,13 @@ def __init__(self, path, chunk_cache_size=1, date_field="date"):
self.haplotypes = CachedHaplotypeMapping(
self.root, self.sample_id_map, chunk_cache_size
)
self.metadata = CachedMetadataMapping(
self.root, self.sample_id_map, date_field, chunk_cache_size=chunk_cache_size
)
if not skip_metadata:
self.metadata = CachedMetadataMapping(
self.root,
self.sample_id_map,
date_field,
chunk_cache_size=chunk_cache_size,
)

def __getitem__(self, key):
return self.root[key]
Expand Down
27 changes: 24 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def fx_alignments_fasta(fx_data_cache):
return cache_path


@pytest.fixture
def fx_alignments_mafft_fasta(fx_data_cache):
# This is bgzipped so we can access directly
cache_path = fx_data_cache / "alignments-mafft.fasta"
if not cache_path.exists():
with gzip.open("tests/data/alignments-mafft.fasta.gz") as src:
with open(cache_path, "wb") as dest:
shutil.copyfileobj(src, dest)
return cache_path


def encoded_alignments(path):
fr = sc2ts.FastaReader(path)
alignments = {}
Expand All @@ -37,6 +48,11 @@ def encoded_alignments(path):
return alignments


@pytest.fixture
def fx_encoded_alignments_mafft(fx_alignments_mafft_fasta):
return encoded_alignments(fx_alignments_mafft_fasta)


@pytest.fixture
def fx_encoded_alignments(fx_alignments_fasta):
return encoded_alignments(fx_alignments_fasta)
Expand Down Expand Up @@ -258,7 +274,9 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path):
date = "2020-03-02"
left = start + 3 + 1
right = end - 3 + 1
ds = sc2ts.tmp_dataset(tmp_path / "tmp.zarr", {f"recombinant_{left}:{right}": a}, date=date)
ds = sc2ts.tmp_dataset(
tmp_path / "tmp.zarr", {f"recombinant_{left}:{right}": a}, date=date
)
rts = sc2ts.extend(
dataset=ds.path,
base_ts=ts_path,
Expand All @@ -267,6 +285,7 @@ def recombinant_example_2(tmp_path, fx_ts_map, fx_dataset, ds_path):
)
return rts


def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path):
# Pick a distinct strain to be the root of our three new haplotypes added
# on the first day.
Expand All @@ -286,7 +305,7 @@ def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path):
mid_a = a.copy()
mid_start = 15_000
mid_end = 15_009
mid_a[mid_start: mid_end] = 1 # "C"
mid_a[mid_start:mid_end] = 1 # "C"

a = mid_a.copy()
a[start : start + 3] = left_a[start : start + 3]
Expand Down Expand Up @@ -315,7 +334,7 @@ def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path):
mut = ts.mutation(mut_id)
assert mut.derived_state == "G"
assert ts.sites_position[mut.site] == start + j + 1

for j, mut_id in enumerate(np.where(ts.mutations_node == mid_node)[0]):
mut = ts.mutation(mut_id)
assert mut.derived_state == "C"
Expand Down Expand Up @@ -345,6 +364,7 @@ def recombinant_example_3(tmp_path, fx_ts_map, fx_dataset, ds_path):
assert rts.num_samples == ts.num_samples + 1
return rts


@pytest.fixture
def fx_recombinant_example_1(tmp_path, fx_data_cache, fx_ts_map, fx_dataset):
cache_path = fx_data_cache / "recombinant_ex1.ts"
Expand All @@ -366,6 +386,7 @@ def fx_recombinant_example_2(tmp_path, fx_data_cache, fx_ts_map, fx_dataset):
ts.dump(cache_path)
return tskit.load(cache_path)


@pytest.fixture
def fx_recombinant_example_3(tmp_path, fx_data_cache, fx_ts_map, fx_dataset):
cache_path = fx_data_cache / "recombinant_ex3.ts"
Expand Down
Binary file added tests/data/alignments-mafft.fasta.gz
Binary file not shown.
16 changes: 16 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,22 @@ def test_date_field(self, fx_dataset):
assert len(diffs) == 6


class TestMafftAlignments:

def test_import(self, tmp_path, fx_encoded_alignments_mafft):
path = tmp_path / "dataset.vcz"
sc2ts.Dataset.new(path)
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments_mafft)
ds = sc2ts.Dataset(path, skip_metadata=True)
assert len(ds.haplotypes) == 19
for k, v in fx_encoded_alignments_mafft.items():
h = ds.haplotypes[k]
nt.assert_array_equal(v, h)
# The flanks are marked as deletions
assert h[0] == 4
assert h[-1] == 4


class TestDatasetAlignments:

def test_fetch_known(self, fx_dataset):
Expand Down
Loading