Skip to content

Commit b52885b

Browse files
Johannes GasteigerThe TensorFlow Datasets Authors
authored andcommitted
Fix opening QM9 files
PiperOrigin-RevId: 619201213
1 parent c8ee73c commit b52885b

File tree

1 file changed

+79
-46
lines changed

1 file changed

+79
-46
lines changed

tensorflow_datasets/datasets/qm9/qm9_dataset_builder.py

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,81 +22,106 @@
2222
To build the dataset, run the following from directory containing this file:
2323
$ tfds build.
2424
"""
25+
2526
import re
27+
from typing import Any, Iterable
2628

27-
from typing import Any, Dict, Iterable, Tuple
29+
from etils import epath
2830
import numpy as np
2931
import tensorflow_datasets.public_api as tfds
3032

33+
3134
pd = tfds.core.lazy_imports.pandas
3235

3336
_HOMEPAGE = 'https://doi.org/10.6084/m9.figshare.c.978904.v5'
3437

3538
_ATOMREF_URL = 'https://figshare.com/ndownloader/files/3195395'
36-
_UNCHARACTERIZED_URL = 'https://springernature.figshare.com/ndownloader/files/3195404'
39+
_UNCHARACTERIZED_URL = (
40+
'https://springernature.figshare.com/ndownloader/files/3195404'
41+
)
3742
_MOLECULES_URL = 'https://springernature.figshare.com/ndownloader/files/3195389'
3843

3944
_SIZE = 133_885
4045
_CHARACTERIZED_SIZE = 130_831
4146

4247
_MAX_ATOMS = 29
4348
_CHARGES = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9}
44-
_LABELS = ['tag', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap',
45-
'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
49+
_LABELS = [
50+
'tag',
51+
'index',
52+
'A',
53+
'B',
54+
'C',
55+
'mu',
56+
'alpha',
57+
'homo',
58+
'lumo',
59+
'gap',
60+
'r2',
61+
'zpve',
62+
'U0',
63+
'U',
64+
'H',
65+
'G',
66+
'Cv',
67+
]
4668
# For each of these targets, we will add a second target with an
4769
# _atomization suffix that has the thermo term subtracted.
4870
_ATOMIZATION_TARGETS = ['U0', 'U', 'H', 'G']
4971

5072

51-
def _process_molecule(atomref, fname):
73+
def _process_molecule(
74+
atomref: dict[str, Any], fname: epath.PathLike
75+
) -> dict[str, Any]:
5276
"""Read molecule data from file."""
53-
with open(fname, 'r') as f:
77+
with epath.Path(fname).open() as f:
5478
lines = f.readlines()
5579
num_atoms = int(lines[0].rstrip())
5680
frequencies = re.split(r'\s+', lines[num_atoms + 2].rstrip())
5781
smiles = re.split(r'\s+', lines[num_atoms + 3].rstrip())
5882
inchi = re.split(r'\s+', lines[num_atoms + 4].rstrip())
5983

60-
labels = pd.read_table(fname,
61-
skiprows=1,
62-
nrows=1,
63-
sep=r'\s+',
64-
names=_LABELS)
84+
labels = pd.read_table(fname, skiprows=1, nrows=1, sep=r'\s+', names=_LABELS)
6585

66-
atoms = pd.read_table(fname,
67-
skiprows=2,
68-
nrows=num_atoms,
69-
sep=r'\s+',
70-
names=['Z', 'x', 'y', 'z', 'Mulliken_charge'])
86+
atoms = pd.read_table(
87+
fname,
88+
skiprows=2,
89+
nrows=num_atoms,
90+
sep=r'\s+',
91+
names=['Z', 'x', 'y', 'z', 'Mulliken_charge'],
92+
)
7193

7294
# Correct exponential notation (6.8*^-6 -> 6.8e-6).
7395
for key in ['x', 'y', 'z', 'Mulliken_charge']:
7496
if atoms[key].values.dtype == 'object':
7597
# there are unrecognized numbers.
76-
atoms[key].values[:] = np.array([
77-
float(x.replace('*^', 'e'))
78-
for i, x in enumerate(atoms[key].values)])
79-
80-
charges = np.pad([_CHARGES[v] for v in atoms['Z'].values],
81-
(0, _MAX_ATOMS - num_atoms))
82-
positions = np.stack([atoms['x'].values,
83-
atoms['y'].values,
84-
atoms['z'].values], axis=-1).astype(np.float32)
98+
atoms[key].values[:] = np.array(
99+
[float(x.replace('*^', 'e')) for i, x in enumerate(atoms[key].values)]
100+
)
101+
102+
charges = np.pad(
103+
[_CHARGES[v] for v in atoms['Z'].values], (0, _MAX_ATOMS - num_atoms)
104+
)
105+
positions = np.stack(
106+
[atoms['x'].values, atoms['y'].values, atoms['z'].values], axis=-1
107+
).astype(np.float32)
85108
positions = np.pad(positions, ((0, _MAX_ATOMS - num_atoms), (0, 0)))
86109

87110
mulliken_charges = atoms['Mulliken_charge'].values.astype(np.float32)
88111
mulliken_charges = np.pad(mulliken_charges, ((0, _MAX_ATOMS - num_atoms)))
89112

90-
example = {'num_atoms': num_atoms,
91-
'charges': charges,
92-
'Mulliken_charges': mulliken_charges,
93-
'positions': positions.astype(np.float32),
94-
'frequencies': frequencies,
95-
'SMILES': smiles[0],
96-
'SMILES_relaxed': smiles[1],
97-
'InChI': inchi[0],
98-
'InChI_relaxed': inchi[1],
99-
**{k: labels[k].values[0] for k in _LABELS}}
113+
example = {
114+
'num_atoms': num_atoms,
115+
'charges': charges,
116+
'Mulliken_charges': mulliken_charges,
117+
'positions': positions.astype(np.float32),
118+
'frequencies': frequencies,
119+
'SMILES': smiles[0],
120+
'SMILES_relaxed': smiles[1],
121+
'InChI': inchi[0],
122+
'InChI_relaxed': inchi[1],
123+
**{k: labels[k].values[0] for k in _LABELS},
124+
}
100125

101126
# Create atomization targets by subtracting thermochemical energy of
102127
# each atom.
@@ -113,8 +138,9 @@ def _process_molecule(atomref, fname):
113138
def _get_valid_ids(uncharacterized):
114139
"""Get valid ids."""
115140
# Original data files are 1-indexed.
116-
characterized_ids = np.array(sorted(set(range(1, _SIZE + 1)) -
117-
set(uncharacterized)))
141+
characterized_ids = np.array(
142+
sorted(set(range(1, _SIZE + 1)) - set(uncharacterized))
143+
)
118144
assert len(characterized_ids) == _CHARACTERIZED_SIZE
119145
return characterized_ids
120146

@@ -173,27 +199,32 @@ def _info(self) -> tfds.core.DatasetInfo:
173199
)
174200

175201
def _split_generators(
176-
self, dl_manager: tfds.download.DownloadManager) -> Dict[str, Any]:
202+
self, dl_manager: tfds.download.DownloadManager
203+
) -> dict[str, Any]:
177204
"""Returns SplitGenerators. See superclass method for details."""
178205
atomref = pd.read_table(
179206
dl_manager.download({'atomref': _ATOMREF_URL})['atomref'],
180207
skiprows=5,
181208
index_col='Z',
182209
skipfooter=1,
183210
sep=r'\s+',
184-
names=['Z', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']).to_dict()
211+
names=['Z', 'zpve', 'U0', 'U', 'H', 'G', 'Cv'],
212+
).to_dict()
185213

186214
uncharacterized = pd.read_table(
187-
dl_manager.download(
188-
{'uncharacterized': _UNCHARACTERIZED_URL})['uncharacterized'],
215+
dl_manager.download({'uncharacterized': _UNCHARACTERIZED_URL})[
216+
'uncharacterized'
217+
],
189218
skiprows=9,
190219
skipfooter=1,
191220
sep=r'\s+',
192221
usecols=[0],
193-
names=['index']).values[:, 0]
222+
names=['index'],
223+
).values[:, 0]
194224

195225
molecules_dir = dl_manager.download_and_extract(
196-
{'dsgdb9nsd': _MOLECULES_URL})['dsgdb9nsd']
226+
{'dsgdb9nsd': _MOLECULES_URL}
227+
)['dsgdb9nsd']
197228

198229
valid_ids = _get_valid_ids(uncharacterized)
199230

@@ -202,11 +233,13 @@ def _split_generators(
202233
def _generate_examples(
203234
self,
204235
split: np.ndarray,
205-
atomref: Dict[str, Any],
206-
molecules_dir: Any) -> Iterable[Tuple[int, Dict[str, Any]]]:
236+
atomref: dict[str, Any],
237+
molecules_dir: epath.Path,
238+
) -> Iterable[tuple[int, dict[str, Any]]]:
207239
"""Dataset generator. See superclass method for details."""
208240

209241
for i in split:
210242
entry = _process_molecule(
211-
atomref, molecules_dir / f'dsgdb9nsd_{i:06d}.xyz')
243+
atomref, molecules_dir / f'dsgdb9nsd_{i:06d}.xyz'
244+
)
212245
yield int(i), entry

0 commit comments

Comments
 (0)