Skip to content

Commit e35d473

Browse files
Fix validation problems
1 parent f23fda1 commit e35d473

File tree

2 files changed

+83
-42
lines changed

2 files changed

+83
-42
lines changed

bio2zarr/vcf.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def smallest_dtype(self):
116116
elif self.vcf_type == "Flag":
117117
ret = "bool"
118118
elif self.vcf_type == "Character":
119-
ret = "S1"
119+
ret = "U1"
120120
else:
121121
assert self.vcf_type == "String"
122122
ret = "O"
@@ -1393,7 +1393,7 @@ def assert_all_fill(zarr_val, vcf_type):
13931393
assert_all_fill_string(zarr_val)
13941394
elif vcf_type == "Float":
13951395
assert_all_fill_float(zarr_val)
1396-
else:
1396+
else: # pragma: no cover
13971397
assert False
13981398

13991399

@@ -1406,7 +1406,7 @@ def assert_all_missing(zarr_val, vcf_type):
14061406
assert zarr_val == False # noqa 712
14071407
elif vcf_type == "Float":
14081408
assert_all_missing_float(zarr_val)
1409-
else:
1409+
else: # pragma: no cover
14101410
assert False
14111411

14121412

@@ -1425,19 +1425,20 @@ def assert_format_val_missing(zarr_val, vcf_type):
14251425

14261426
def assert_info_val_equal(vcf_val, zarr_val, vcf_type):
14271427
assert vcf_val is not None
1428-
if not isinstance(vcf_val, tuple):
1429-
# Scalar
1430-
zarr_val = np.array(zarr_val, ndmin=1)
1431-
assert len(zarr_val.shape) == 1
1432-
assert vcf_val == zarr_val[0]
1433-
if len(zarr_val) > 1:
1434-
assert_all_fill(zarr_val[1:], vcf_type)
1435-
else:
1428+
if vcf_type in ("String", "Character"):
1429+
split = list(vcf_val.split(","))
1430+
k = len(split)
1431+
if k == 1:
1432+
# Scalar
1433+
assert vcf_val == zarr_val
1434+
else:
1435+
nt.assert_equal(split, zarr_val[:k])
1436+
assert_all_fill(zarr_val[k:], vcf_type)
1437+
1438+
elif isinstance(vcf_val, tuple):
14361439
vcf_missing_value_map = {
14371440
"Integer": -1,
14381441
"Float": FLOAT32_MISSING,
1439-
"String": ".",
1440-
"Character": ".",
14411442
}
14421443
v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val]
14431444
missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int)
@@ -1449,31 +1450,50 @@ def assert_info_val_equal(vcf_val, zarr_val, vcf_type):
14491450
assert_all_missing(zarr_val[missing], vcf_type)
14501451
if k < len(zarr_val):
14511452
assert_all_fill(zarr_val[k:], vcf_type)
1453+
else:
1454+
# Scalar
1455+
zarr_val = np.array(zarr_val, ndmin=1)
1456+
assert len(zarr_val.shape) == 1
1457+
assert vcf_val == zarr_val[0]
1458+
if len(zarr_val) > 1:
1459+
assert_all_fill(zarr_val[1:], vcf_type)
14521460

14531461

14541462
def assert_format_val_equal(vcf_val, zarr_val, vcf_type):
14551463
assert vcf_val is not None
14561464
assert isinstance(vcf_val, np.ndarray)
1457-
1458-
assert vcf_val.shape[0] == zarr_val.shape[0]
1459-
if len(vcf_val.shape) == len(zarr_val.shape) + 1:
1460-
assert vcf_val.shape[-1] == 1
1461-
vcf_val = vcf_val[..., 0]
1462-
assert len(vcf_val.shape) <= 2
1463-
assert len(vcf_val.shape) == len(zarr_val.shape)
1464-
if len(vcf_val.shape) == 2:
1465-
k = vcf_val.shape[1]
1466-
if zarr_val.shape[1] != k:
1467-
assert_all_fill(zarr_val[:, k:], vcf_type)
1468-
zarr_val = zarr_val[:, :k]
1469-
assert vcf_val.shape == zarr_val.shape
1470-
if vcf_type == "Integer":
1471-
vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING
1472-
vcf_val[vcf_val == VCF_INT_FILL] = INT_FILL
1473-
elif vcf_type == "Float":
1474-
nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32))
1475-
1476-
nt.assert_equal(vcf_val, zarr_val)
1465+
if vcf_type in ("String", "Character"):
1466+
assert len(vcf_val) == len(zarr_val)
1467+
for v, z in zip(vcf_val, zarr_val):
1468+
split = list(v.split(","))
1469+
# Note: deliberately duplicating logic here between this and the
1470+
# INFO col above to make sure all combinations are covered by tests
1471+
k = len(split)
1472+
if k == 1:
1473+
assert v == z
1474+
else:
1475+
nt.assert_equal(split, z[:k])
1476+
assert_all_fill(z[k:], vcf_type)
1477+
else:
1478+
assert vcf_val.shape[0] == zarr_val.shape[0]
1479+
if len(vcf_val.shape) == len(zarr_val.shape) + 1:
1480+
assert vcf_val.shape[-1] == 1
1481+
vcf_val = vcf_val[..., 0]
1482+
assert len(vcf_val.shape) <= 2
1483+
assert len(vcf_val.shape) == len(zarr_val.shape)
1484+
if len(vcf_val.shape) == 2:
1485+
k = vcf_val.shape[1]
1486+
if zarr_val.shape[1] != k:
1487+
assert_all_fill(zarr_val[:, k:], vcf_type)
1488+
zarr_val = zarr_val[:, :k]
1489+
assert vcf_val.shape == zarr_val.shape
1490+
if vcf_type == "Integer":
1491+
vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING
1492+
vcf_val[vcf_val == VCF_INT_FILL] = INT_FILL
1493+
elif vcf_type == "Float":
1494+
nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32))
1495+
1496+
nt.assert_equal(vcf_val, zarr_val)
14771497

14781498

14791499
def validate(vcf_path, zarr_path, show_progress=False):
@@ -1541,12 +1561,8 @@ def validate(vcf_path, zarr_path, show_progress=False):
15411561
gt = row.genotype.array()
15421562
gt_zarr = next(call_genotype)
15431563
gt_vcf = gt[:, :-1]
1544-
# NOTE weirdly cyvcf2 seems to remap genotypes automatically
1564+
# NOTE cyvcf2 remaps genotypes automatically
15451565
# into the same missing/pad encoding that sgkit uses.
1546-
# if np.any(gt_zarr < 0):
1547-
# print("MISSING")
1548-
# print(gt_zarr)
1549-
# print(gt_vcf)
15501566
nt.assert_array_equal(gt_zarr, gt_vcf)
15511567

15521568
for name, (vcf_type, zarr_iter) in info_fields.items():

tests/test_vcf_examples.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -654,13 +654,38 @@ def ds(self, tmp_path_factory):
654654
return sg.load_dataset(out)
655655

656656
def test_info_string1(self, ds):
657-
print(repr(ds["variant_IS1"].values))
657+
values = ds["variant_IS1"].values
658+
non_missing = values[values != "."]
659+
nt.assert_array_equal(non_missing, ["bc"])
660+
661+
def test_info_char1(self, ds):
662+
values = ds["variant_IC1"].values
663+
non_missing = values[values != "."]
664+
nt.assert_array_equal(non_missing, "f")
658665

659666
def test_info_string2(self, ds):
660-
print(repr(ds["variant_IS2"].values))
667+
values = ds["variant_IS2"].values
668+
missing = np.all(values == ".", axis=1)
669+
non_missing_rows = values[~missing]
670+
nt.assert_array_equal(
671+
non_missing_rows, [["hij", "d"], [".", "d"], ["hij", "."]]
672+
)
661673

662-
def test_format_string2(self, ds):
663-
print(repr(ds["call_FS2"].values))
674+
# FIXME can't figure out how to do the row masking properly here
675+
# def test_format_string1(self, ds):
676+
# values = ds["call_FS1"].values
677+
# missing = np.all(values == ".", axis=1)
678+
# non_missing_rows = values[~missing]
679+
# print(non_missing_rows)
680+
# # nt.assert_array_equal(non_missing_rows, [["bc"], ["."]])
681+
682+
# def test_format_string2(self, ds):
683+
# values = ds["call_FS2"].values
684+
# missing = np.all(values == ".", axis=1)
685+
# non_missing_rows = values[~missing]
686+
# non_missing = [v for v in pcvcf["FORMAT/FS2"].values if v is not None]
687+
# nt.assert_array_equal(non_missing[0], [["bc", "op"], [".", "op"]])
688+
# nt.assert_array_equal(non_missing[1], [["bc", "."], [".", "."]])
664689

665690

666691
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)