Skip to content

Commit aff896e

Browse files
Improve validate and fix some bugs
1 parent b3026b5 commit aff896e

File tree

2 files changed

+142
-130
lines changed

2 files changed

+142
-130
lines changed

bio2zarr/vcf.py

Lines changed: 140 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -37,69 +37,6 @@
3737
)
3838

3939

40-
def assert_all_missing_float(a):
41-
v = np.array(a, dtype=np.float32).view(np.int32)
42-
assert np.all(v == FLOAT32_MISSING_AS_INT32)
43-
44-
45-
def assert_prefix_integer_equal_1d(vcf_val, zarr_val):
46-
v = np.array(vcf_val, dtype=np.int32, ndmin=1)
47-
z = np.array(zarr_val, dtype=np.int32, ndmin=1)
48-
v[v == VCF_INT_MISSING] = -1
49-
v[v == VCF_INT_FILL] = -2
50-
k = v.shape[0]
51-
assert np.all(z[k:] == -2)
52-
nt.assert_array_equal(v, z[:k])
53-
54-
55-
def assert_prefix_integer_equal_2d(vcf_val, zarr_val):
56-
assert len(vcf_val.shape) == 2
57-
vcf_val[vcf_val == VCF_INT_MISSING] = -1
58-
vcf_val[vcf_val == VCF_INT_FILL] = -2
59-
if vcf_val.shape[1] == 1:
60-
nt.assert_array_equal(vcf_val[:, 0], zarr_val)
61-
else:
62-
k = vcf_val.shape[1]
63-
nt.assert_array_equal(vcf_val, zarr_val[:, :k])
64-
assert np.all(zarr_val[:, k:] == -2)
65-
66-
67-
# FIXME these are sort-of working. It's not clear that we're
68-
# handling the different dimensions and padded etc correctly.
69-
# Will need to hand-craft from examples to test
70-
def assert_prefix_float_equal_1d(vcf_val, zarr_val):
71-
v = np.array(vcf_val, dtype=np.float32, ndmin=1)
72-
# vi = v.view(np.int32)
73-
z = np.array(zarr_val, dtype=np.float32, ndmin=1)
74-
zi = z.view(np.int32)
75-
assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0
76-
k = v.shape[0]
77-
assert np.all(zi[k:] == FLOAT32_FILL_AS_INT32)
78-
# assert np.where(zi[:k] == FLOAT32_FILL_AS_INT32)
79-
nt.assert_array_almost_equal(v, z[:k])
80-
# nt.assert_array_equal(v, z[:k])
81-
82-
83-
def assert_prefix_float_equal_2d(vcf_val, zarr_val):
84-
assert len(vcf_val.shape) == 2
85-
if vcf_val.shape[1] == 1:
86-
vcf_val = vcf_val[:, 0]
87-
v = np.array(vcf_val, dtype=np.float32, ndmin=2)
88-
vi = v.view(np.int32)
89-
z = np.array(zarr_val, dtype=np.float32, ndmin=2)
90-
zi = z.view(np.int32)
91-
assert np.all((zi == FLOAT32_MISSING_AS_INT32) == (vi == FLOAT32_MISSING_AS_INT32))
92-
assert np.all((zi == FLOAT32_FILL_AS_INT32) == (vi == FLOAT32_FILL_AS_INT32))
93-
# print(vcf_val, zarr_val)
94-
# assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0
95-
k = v.shape[0]
96-
# print("k", k)
97-
assert np.all(zi[k:] == FLOAT32_FILL_AS_INT32)
98-
# assert np.where(zi[:k] == FLOAT32_FILL_AS_INT32)
99-
nt.assert_array_almost_equal(v, z[:k])
100-
# nt.assert_array_equal(v, z[:k])
101-
102-
10340
@dataclasses.dataclass
10441
class VcfFieldSummary:
10542
num_chunks: int = 0
@@ -177,13 +114,8 @@ def smallest_dtype(self):
177114
elif self.vcf_type == "Flag":
178115
ret = "bool"
179116
else:
180-
assert self.vcf_type == "String"
117+
assert self.vcf_type in ("String", "Character")
181118
ret = "str"
182-
# if s.max_number == 0:
183-
# ret = "str"
184-
# else:
185-
# ret = "O"
186-
# print("smallest dtype", self.name, self.vcf_type,":", ret)
187119
return ret
188120

189121

@@ -363,23 +295,27 @@ def sanitise_value_float_1d(buff, j, value):
363295
buff[j] = FLOAT32_MISSING
364296
else:
365297
value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
298+
# numpy will map None values to Nan, but we need a
299+
# specific NaN
300+
value[np.isnan(value)] = FLOAT32_MISSING
366301
value = drop_empty_second_dim(value)
367302
buff[j] = FLOAT32_FILL
368-
# TODO check for missing?
369303
buff[j, : value.shape[0]] = value
370304

371305

372306
def sanitise_value_float_2d(buff, j, value):
373307
if value is None:
374308
buff[j] = FLOAT32_MISSING
375309
else:
376-
value = np.array(value, dtype=buff.dtype, copy=False)
310+
# print("value = ", value)
311+
value = np.array(value, ndmin=2, dtype=buff.dtype, copy=False)
377312
buff[j] = FLOAT32_FILL
378-
# TODO check for missing?
379-
buff[j, :, : value.shape[0]] = value
313+
buff[j, :, : value.shape[1]] = value
380314

381315

382316
def sanitise_int_array(value, ndmin, dtype):
317+
if isinstance(value, tuple):
318+
value = [VCF_INT_MISSING if x is None else x for x in value]
383319
value = np.array(value, ndmin=ndmin, copy=False)
384320
value[value == VCF_INT_MISSING] = -1
385321
value[value == VCF_INT_FILL] = -2
@@ -497,7 +433,7 @@ def sanitiser_factory(self, shape):
497433
else:
498434
return sanitise_value_int_2d
499435
else:
500-
assert self.vcf_field.vcf_type == "String"
436+
assert self.vcf_field.vcf_type in ("String", "Character")
501437
if len(shape) == 1:
502438
return sanitise_value_string_scalar
503439
elif len(shape) == 2:
@@ -527,6 +463,8 @@ def update_bounds_float(summary, value, number_dim):
527463

528464
def update_bounds_integer(summary, value, number_dim):
529465
# print("update bounds int", summary, value)
466+
if isinstance(value, tuple):
467+
value = [VCF_INT_MISSING if x is None else x for x in value]
530468
value = np.array(value, dtype=np.int32, copy=False)
531469
# Mask out missing and fill values
532470
a = value[value >= MIN_INT_VALUE]
@@ -579,7 +517,6 @@ def __init__(self, column, partition_index, executor, futures, chunk_size=1):
579517
def _update_bounds(self, value):
580518
if value is not None:
581519
summary = self.column.vcf_field.summary
582-
# print("update", self.column.vcf_field.full_name, value)
583520
if self._summary_bounds_update is not None:
584521
self._summary_bounds_update(summary, value)
585522

@@ -1314,6 +1251,123 @@ def convert_vcf(
13141251
)
13151252

13161253

1254+
def assert_all_missing_float(a):
1255+
v = np.array(a, dtype=np.float32).view(np.int32)
1256+
nt.assert_equal(v, FLOAT32_MISSING_AS_INT32)
1257+
1258+
1259+
def assert_all_fill_float(a):
1260+
v = np.array(a, dtype=np.float32).view(np.int32)
1261+
nt.assert_equal(v, FLOAT32_FILL_AS_INT32)
1262+
1263+
1264+
def assert_all_missing_int(a):
1265+
v = np.array(a, dtype=int)
1266+
nt.assert_equal(v, -1)
1267+
1268+
1269+
def assert_all_fill_int(a):
1270+
v = np.array(a, dtype=int)
1271+
nt.assert_equal(v, -2)
1272+
1273+
1274+
def assert_all_missing_string(a):
1275+
nt.assert_equal(a, ".")
1276+
1277+
1278+
def assert_all_fill_string(a):
1279+
nt.assert_equal(a, "")
1280+
1281+
1282+
def assert_all_fill(zarr_val, vcf_type):
1283+
if vcf_type == "Integer":
1284+
assert_all_fill_int(zarr_val)
1285+
elif vcf_type in ("String", "Character"):
1286+
assert_all_fill_string(zarr_val)
1287+
elif vcf_type == "Float":
1288+
assert_all_fill_float(zarr_val)
1289+
else:
1290+
assert False
1291+
1292+
1293+
def assert_all_missing(zarr_val, vcf_type):
1294+
if vcf_type == "Integer":
1295+
assert_all_missing_int(zarr_val)
1296+
elif vcf_type in ("String", "Character"):
1297+
assert_all_missing_string(zarr_val)
1298+
elif vcf_type == "Flag":
1299+
assert zarr_val == False # noqa 712
1300+
elif vcf_type == "Float":
1301+
assert_all_missing_float(zarr_val)
1302+
else:
1303+
assert False
1304+
1305+
1306+
def assert_info_val_missing(zarr_val, vcf_type):
1307+
assert_all_missing(zarr_val, vcf_type)
1308+
1309+
1310+
def assert_format_val_missing(zarr_val, vcf_type):
1311+
assert_info_val_missing(zarr_val, vcf_type)
1312+
1313+
1314+
# Note: checking exact equality may prove problematic here
1315+
# but we should be deterministically storing what cyvcf2
1316+
# provides, which should compare equal.
1317+
1318+
1319+
def assert_info_val_equal(vcf_val, zarr_val, vcf_type):
1320+
assert vcf_val is not None
1321+
if not isinstance(vcf_val, tuple):
1322+
# Scalar
1323+
zarr_val = np.array(zarr_val, ndmin=1)
1324+
assert len(zarr_val.shape) == 1
1325+
assert vcf_val == zarr_val[0]
1326+
if len(zarr_val) > 1:
1327+
assert_all_fill(zarr_val[1:], vcf_type)
1328+
else:
1329+
vcf_missing_value_map = {
1330+
"Integer": -1,
1331+
"Float": FLOAT32_MISSING,
1332+
"String": ".",
1333+
"Character": ".",
1334+
}
1335+
v = [vcf_missing_value_map[vcf_type] if x is None else x for x in vcf_val]
1336+
missing = np.array([j for j, x in enumerate(vcf_val) if x is None], dtype=int)
1337+
a = np.array(v)
1338+
k = len(a)
1339+
# We are checking for int missing twice here, but it's necessary to have
1340+
# a separate check for floats because different NaNs compare equal
1341+
nt.assert_equal(a, zarr_val[:k])
1342+
assert_all_missing(zarr_val[missing], vcf_type)
1343+
if k < len(zarr_val):
1344+
assert_all_fill(zarr_val[k:], vcf_type)
1345+
1346+
1347+
def assert_format_val_equal(vcf_val, zarr_val, vcf_type):
1348+
assert vcf_val is not None
1349+
assert isinstance(vcf_val, np.ndarray)
1350+
1351+
assert vcf_val.shape[0] == zarr_val.shape[0]
1352+
if len(vcf_val.shape) == len(zarr_val.shape) + 1:
1353+
assert vcf_val.shape[-1] == 1
1354+
vcf_val = vcf_val[..., 0]
1355+
assert len(vcf_val.shape) <= 2
1356+
assert len(vcf_val.shape) == len(zarr_val.shape)
1357+
if len(vcf_val.shape) == 2:
1358+
k = vcf_val.shape[1]
1359+
if zarr_val.shape[1] != k:
1360+
assert_all_fill(zarr_val[:, k:], vcf_type)
1361+
zarr_val = zarr_val[:, :k]
1362+
assert vcf_val.shape == zarr_val.shape
1363+
if vcf_type == "Integer":
1364+
vcf_val[vcf_val == VCF_INT_MISSING] = INT_MISSING
1365+
elif vcf_type == "Float":
1366+
nt.assert_equal(vcf_val.view(np.int32), zarr_val.view(np.int32))
1367+
1368+
nt.assert_equal(vcf_val, zarr_val)
1369+
1370+
13171371
def validate(vcf_path, zarr_path, show_progress=False):
13181372
store = zarr.DirectoryStore(zarr_path)
13191373

@@ -1369,7 +1423,12 @@ def validate(vcf_path, zarr_path, show_progress=False):
13691423
# TODO FILTERS
13701424

13711425
if call_genotype is None:
1372-
assert row.genotype is None
1426+
val = None
1427+
try:
1428+
val = row.format("GT")
1429+
except KeyError:
1430+
pass
1431+
assert val is None
13731432
else:
13741433
gt = row.genotype.array()
13751434
gt_zarr = next(call_genotype)
@@ -1382,39 +1441,13 @@ def validate(vcf_path, zarr_path, show_progress=False):
13821441
# print(gt_vcf)
13831442
nt.assert_array_equal(gt_zarr, gt_vcf)
13841443

1385-
# TODO this is basically right, but the details about float padding
1386-
# need to be worked out in particular. Need to find examples of
1387-
# VCFs with Number=. Float fields.
13881444
for name, (vcf_type, zarr_iter) in info_fields.items():
1389-
vcf_val = None
1390-
try:
1391-
vcf_val = row.INFO[name]
1392-
except KeyError:
1393-
pass
1445+
vcf_val = row.INFO.get(name, None)
13941446
zarr_val = next(zarr_iter)
13951447
if vcf_val is None:
1396-
if vcf_type == "Integer":
1397-
assert np.all(zarr_val == -1)
1398-
elif vcf_type == "String":
1399-
assert np.all(zarr_val == ".")
1400-
elif vcf_type == "Flag":
1401-
assert zarr_val == False # noqa 712
1402-
elif vcf_type == "Float":
1403-
assert_all_missing_float(zarr_val)
1404-
else:
1405-
assert False
1448+
assert_info_val_missing(zarr_val, vcf_type)
14061449
else:
1407-
# print(name, vcf_type, vcf_val, zarr_val, sep="\t")
1408-
if vcf_type == "Integer":
1409-
assert_prefix_integer_equal_1d(vcf_val, zarr_val)
1410-
elif vcf_type == "Float":
1411-
assert_prefix_float_equal_1d(vcf_val, zarr_val)
1412-
elif vcf_type == "Flag":
1413-
assert zarr_val == True # noqa 712
1414-
elif vcf_type == "String":
1415-
assert np.all(zarr_val == vcf_val)
1416-
else:
1417-
assert False
1450+
assert_info_val_equal(vcf_val, zarr_val, vcf_type)
14181451

14191452
for name, (vcf_type, zarr_iter) in format_fields.items():
14201453
vcf_val = None
@@ -1424,27 +1457,6 @@ def validate(vcf_path, zarr_path, show_progress=False):
14241457
pass
14251458
zarr_val = next(zarr_iter)
14261459
if vcf_val is None:
1427-
if vcf_type == "Integer":
1428-
assert np.all(zarr_val == -1)
1429-
elif vcf_type == "Float":
1430-
assert_all_missing_float(zarr_val)
1431-
elif vcf_type == "String":
1432-
assert np.all(zarr_val == ".")
1433-
else:
1434-
print("vcf_val", vcf_type, name, vcf_val)
1435-
assert False
1460+
assert_format_val_missing(zarr_val, vcf_type)
14361461
else:
1437-
assert vcf_val.shape[0] == zarr_val.shape[0]
1438-
if vcf_type == "Integer":
1439-
assert_prefix_integer_equal_2d(vcf_val, zarr_val)
1440-
elif vcf_type == "Float":
1441-
assert_prefix_float_equal_2d(vcf_val, zarr_val)
1442-
elif vcf_type == "String":
1443-
nt.assert_array_equal(vcf_val, zarr_val)
1444-
1445-
# assert_prefix_string_equal_2d(vcf_val, zarr_val)
1446-
else:
1447-
print(name)
1448-
print(vcf_val)
1449-
print(zarr_val)
1450-
assert False
1462+
assert_format_val_equal(vcf_val, zarr_val, vcf_type)

tests/test_vcf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,11 @@ def test_full_pipeline(self, ds, tmp_path, worker_processes):
304304
[
305305
"sample.vcf.gz",
306306
"sample_no_genotypes.vcf.gz",
307-
# "info_field_type_combos.vcf.gz",
307+
"info_field_type_combos.vcf.gz",
308308
],
309309
)
310310
def test_by_validating(name, tmp_path):
311311
path = f"tests/data/vcf/{name}"
312312
out = tmp_path / "test.zarr"
313-
vcf.convert_vcf([path], out)
313+
vcf.convert_vcf([path], out, worker_processes=0)
314314
vcf.validate(path, out)

0 commit comments

Comments
 (0)