22
22
To build the dataset, run the following from directory containing this file:
23
23
$ tfds build.
24
24
"""
25
+
25
26
import re
27
+ from typing import Any , Iterable
26
28
27
- from typing import Any , Dict , Iterable , Tuple
29
+ from etils import epath
28
30
import numpy as np
29
31
import tensorflow_datasets .public_api as tfds
30
32
33
+
31
34
pd = tfds .core .lazy_imports .pandas
32
35
33
36
_HOMEPAGE = 'https://doi.org/10.6084/m9.figshare.c.978904.v5'
34
37
35
38
_ATOMREF_URL = 'https://figshare.com/ndownloader/files/3195395'
36
- _UNCHARACTERIZED_URL = 'https://springernature.figshare.com/ndownloader/files/3195404'
39
+ _UNCHARACTERIZED_URL = (
40
+ 'https://springernature.figshare.com/ndownloader/files/3195404'
41
+ )
37
42
_MOLECULES_URL = 'https://springernature.figshare.com/ndownloader/files/3195389'
38
43
39
44
_SIZE = 133_885
40
45
_CHARACTERIZED_SIZE = 130_831
41
46
42
47
_MAX_ATOMS = 29
43
48
_CHARGES = {'H' : 1 , 'C' : 6 , 'N' : 7 , 'O' : 8 , 'F' : 9 }
44
- _LABELS = ['tag' , 'index' , 'A' , 'B' , 'C' , 'mu' , 'alpha' , 'homo' , 'lumo' , 'gap' ,
45
- 'r2' , 'zpve' , 'U0' , 'U' , 'H' , 'G' , 'Cv' ]
49
+ _LABELS = [
50
+ 'tag' ,
51
+ 'index' ,
52
+ 'A' ,
53
+ 'B' ,
54
+ 'C' ,
55
+ 'mu' ,
56
+ 'alpha' ,
57
+ 'homo' ,
58
+ 'lumo' ,
59
+ 'gap' ,
60
+ 'r2' ,
61
+ 'zpve' ,
62
+ 'U0' ,
63
+ 'U' ,
64
+ 'H' ,
65
+ 'G' ,
66
+ 'Cv' ,
67
+ ]
46
68
# For each of these targets, we will add a second target with an
47
69
# _atomization suffix that has the thermo term subtracted.
48
70
_ATOMIZATION_TARGETS = ['U0' , 'U' , 'H' , 'G' ]
49
71
50
72
51
- def _process_molecule (atomref , fname ):
73
+ def _process_molecule (
74
+ atomref : dict [str , Any ], fname : epath .PathLike
75
+ ) -> dict [str , Any ]:
52
76
"""Read molecule data from file."""
53
- with open (fname , 'r' ) as f :
77
+ with epath . Path (fname ). open ( ) as f :
54
78
lines = f .readlines ()
55
79
num_atoms = int (lines [0 ].rstrip ())
56
80
frequencies = re .split (r'\s+' , lines [num_atoms + 2 ].rstrip ())
57
81
smiles = re .split (r'\s+' , lines [num_atoms + 3 ].rstrip ())
58
82
inchi = re .split (r'\s+' , lines [num_atoms + 4 ].rstrip ())
59
83
60
- labels = pd .read_table (fname ,
61
- skiprows = 1 ,
62
- nrows = 1 ,
63
- sep = r'\s+' ,
64
- names = _LABELS )
84
+ labels = pd .read_table (fname , skiprows = 1 , nrows = 1 , sep = r'\s+' , names = _LABELS )
65
85
66
- atoms = pd .read_table (fname ,
67
- skiprows = 2 ,
68
- nrows = num_atoms ,
69
- sep = r'\s+' ,
70
- names = ['Z' , 'x' , 'y' , 'z' , 'Mulliken_charge' ])
86
+ atoms = pd .read_table (
87
+ fname ,
88
+ skiprows = 2 ,
89
+ nrows = num_atoms ,
90
+ sep = r'\s+' ,
91
+ names = ['Z' , 'x' , 'y' , 'z' , 'Mulliken_charge' ],
92
+ )
71
93
72
94
# Correct exponential notation (6.8*^-6 -> 6.8e-6).
73
95
for key in ['x' , 'y' , 'z' , 'Mulliken_charge' ]:
74
96
if atoms [key ].values .dtype == 'object' :
75
97
# there are unrecognized numbers.
76
- atoms [key ].values [:] = np .array ([
77
- float (x .replace ('*^' , 'e' ))
78
- for i , x in enumerate (atoms [key ].values )])
79
-
80
- charges = np .pad ([_CHARGES [v ] for v in atoms ['Z' ].values ],
81
- (0 , _MAX_ATOMS - num_atoms ))
82
- positions = np .stack ([atoms ['x' ].values ,
83
- atoms ['y' ].values ,
84
- atoms ['z' ].values ], axis = - 1 ).astype (np .float32 )
98
+ atoms [key ].values [:] = np .array (
99
+ [float (x .replace ('*^' , 'e' )) for i , x in enumerate (atoms [key ].values )]
100
+ )
101
+
102
+ charges = np .pad (
103
+ [_CHARGES [v ] for v in atoms ['Z' ].values ], (0 , _MAX_ATOMS - num_atoms )
104
+ )
105
+ positions = np .stack (
106
+ [atoms ['x' ].values , atoms ['y' ].values , atoms ['z' ].values ], axis = - 1
107
+ ).astype (np .float32 )
85
108
positions = np .pad (positions , ((0 , _MAX_ATOMS - num_atoms ), (0 , 0 )))
86
109
87
110
mulliken_charges = atoms ['Mulliken_charge' ].values .astype (np .float32 )
88
111
mulliken_charges = np .pad (mulliken_charges , ((0 , _MAX_ATOMS - num_atoms )))
89
112
90
- example = {'num_atoms' : num_atoms ,
91
- 'charges' : charges ,
92
- 'Mulliken_charges' : mulliken_charges ,
93
- 'positions' : positions .astype (np .float32 ),
94
- 'frequencies' : frequencies ,
95
- 'SMILES' : smiles [0 ],
96
- 'SMILES_relaxed' : smiles [1 ],
97
- 'InChI' : inchi [0 ],
98
- 'InChI_relaxed' : inchi [1 ],
99
- ** {k : labels [k ].values [0 ] for k in _LABELS }}
113
+ example = {
114
+ 'num_atoms' : num_atoms ,
115
+ 'charges' : charges ,
116
+ 'Mulliken_charges' : mulliken_charges ,
117
+ 'positions' : positions .astype (np .float32 ),
118
+ 'frequencies' : frequencies ,
119
+ 'SMILES' : smiles [0 ],
120
+ 'SMILES_relaxed' : smiles [1 ],
121
+ 'InChI' : inchi [0 ],
122
+ 'InChI_relaxed' : inchi [1 ],
123
+ ** {k : labels [k ].values [0 ] for k in _LABELS },
124
+ }
100
125
101
126
# Create atomization targets by subtracting thermochemical energy of
102
127
# each atom.
@@ -113,8 +138,9 @@ def _process_molecule(atomref, fname):
113
138
def _get_valid_ids (uncharacterized ):
114
139
"""Get valid ids."""
115
140
# Original data files are 1-indexed.
116
- characterized_ids = np .array (sorted (set (range (1 , _SIZE + 1 )) -
117
- set (uncharacterized )))
141
+ characterized_ids = np .array (
142
+ sorted (set (range (1 , _SIZE + 1 )) - set (uncharacterized ))
143
+ )
118
144
assert len (characterized_ids ) == _CHARACTERIZED_SIZE
119
145
return characterized_ids
120
146
@@ -173,27 +199,32 @@ def _info(self) -> tfds.core.DatasetInfo:
173
199
)
174
200
175
201
def _split_generators (
176
- self , dl_manager : tfds .download .DownloadManager ) -> Dict [str , Any ]:
202
+ self , dl_manager : tfds .download .DownloadManager
203
+ ) -> dict [str , Any ]:
177
204
"""Returns SplitGenerators. See superclass method for details."""
178
205
atomref = pd .read_table (
179
206
dl_manager .download ({'atomref' : _ATOMREF_URL })['atomref' ],
180
207
skiprows = 5 ,
181
208
index_col = 'Z' ,
182
209
skipfooter = 1 ,
183
210
sep = r'\s+' ,
184
- names = ['Z' , 'zpve' , 'U0' , 'U' , 'H' , 'G' , 'Cv' ]).to_dict ()
211
+ names = ['Z' , 'zpve' , 'U0' , 'U' , 'H' , 'G' , 'Cv' ],
212
+ ).to_dict ()
185
213
186
214
uncharacterized = pd .read_table (
187
- dl_manager .download (
188
- {'uncharacterized' : _UNCHARACTERIZED_URL })['uncharacterized' ],
215
+ dl_manager .download ({'uncharacterized' : _UNCHARACTERIZED_URL })[
216
+ 'uncharacterized'
217
+ ],
189
218
skiprows = 9 ,
190
219
skipfooter = 1 ,
191
220
sep = r'\s+' ,
192
221
usecols = [0 ],
193
- names = ['index' ]).values [:, 0 ]
222
+ names = ['index' ],
223
+ ).values [:, 0 ]
194
224
195
225
molecules_dir = dl_manager .download_and_extract (
196
- {'dsgdb9nsd' : _MOLECULES_URL })['dsgdb9nsd' ]
226
+ {'dsgdb9nsd' : _MOLECULES_URL }
227
+ )['dsgdb9nsd' ]
197
228
198
229
valid_ids = _get_valid_ids (uncharacterized )
199
230
@@ -202,11 +233,13 @@ def _split_generators(
202
233
def _generate_examples (
203
234
self ,
204
235
split : np .ndarray ,
205
- atomref : Dict [str , Any ],
206
- molecules_dir : Any ) -> Iterable [Tuple [int , Dict [str , Any ]]]:
236
+ atomref : dict [str , Any ],
237
+ molecules_dir : epath .Path ,
238
+ ) -> Iterable [tuple [int , dict [str , Any ]]]:
207
239
"""Dataset generator. See superclass method for details."""
208
240
209
241
for i in split :
210
242
entry = _process_molecule (
211
- atomref , molecules_dir / f'dsgdb9nsd_{ i :06d} .xyz' )
243
+ atomref , molecules_dir / f'dsgdb9nsd_{ i :06d} .xyz'
244
+ )
212
245
yield int (i ), entry
0 commit comments