104104from zipline .utils .numpy_utils import bytes_array_to_native_str_object_array
105105
106106from .base import FXRateReader , DEFAULT_FX_RATE
107+ from .utils import check_dts , is_sorted_ascending
107108
108109HDF5_FX_VERSION = 0
109110
@@ -189,10 +190,7 @@ def get_rates(self, rate, quote, bases, dts):
189190 if rate == DEFAULT_FX_RATE :
190191 rate = self ._default_rate
191192
192- # TODO: Commenting this _check_dts out for now to bypass the
193- # estimates loader date bounds issue. Will need to address
194- # this before finalizing anything.
195- # self._check_dts(self.dts, dts)
193+ check_dts (self .dts , dts )
196194
197195 row_ixs = self .dts .searchsorted (dts , side = 'right' ) - 1
198196 col_ixs = self .currencies .get_indexer (bases )
@@ -207,51 +205,48 @@ def get_rates(self, rate, quote, bases, dts):
207205
208206 # OPTIMIZATION: Row indices correspond to dates, which must be in
209207 # sorted order. Rather than reading the entire dataset from h5, we can
210- # read just the interval from min_row to max_row inclusive.
208+ # read just the interval from min_row to max_row inclusive
211209 #
212- # We don't bother with a similar optimization for columns because in
213- # expectation we're going to load most of the
214-
215- # array, so it's easier to pull all columns and reindex in memory. For
216- # rows, however, a quick and easy optimization is to pull just the
217- # slice from min(row_ixs) to max(row_ixs).
218- min_row = max (row_ixs [0 ], 0 )
219- max_row = row_ixs [- 1 ]
220- rows = dataset [min_row :max_row + 1 ] # +1 to be inclusive of end
221-
222- out = rows [row_ixs - min_row ][:, col_ixs ]
210+ # However, we also need to handle two important edge cases:
211+ #
212+ # 1. row_ixs contains -1 for dts before the start of self.dts.
213+ # 2. col_ixs contains -1 for any currencies we don't know about.
214+ #
215+ # If either of the above cases obtains, we want to return NaN for the
216+ # corresponding output locations.
223217
224- # get_indexer returns -1 for failed lookups. Fill these in with NaN.
218+ # We handle (1) by reading raw data into a buffer with one extra
219+ # row. When we then apply the row index to permute the raw data into
220+ # the correct order, any rows with values of -1 will pull from the
221+ # extra row, which will always contain NaN>
222+ #
223+ # We handle (2) by overwriting columns with indices of -1 with NaN as a
224+ # postprocessing step.
225+ slice_begin = max (row_ixs [0 ], 0 )
226+ slice_end = max (row_ixs [- 1 ], 0 ) + 1 # +1 to be inclusive of end date.
227+
228+ # Allocate a buffer full of NaNs with one extra row/column. See
229+ # OPTIMIZATION notes above.
230+ buf = np .full (
231+ (slice_end - slice_begin + 1 , len (self .currencies )),
232+ np .nan ,
233+ )
234+
235+ # Read data into all but the last row/column of the buffer.
236+ dataset .read_direct (
237+ buf [:- 1 ],
238+ np .s_ [slice_begin :slice_end ],
239+ )
240+
241+ # Permute the rows into place, pulling from the empty NaN locations for
242+ # row/column indices of -1.
243+ out = buf [:, col_ixs ][row_ixs - slice_begin ]
244+
245+ # Fill missing columns with NaN. See OPTIMIZATION notes above.
225246 out [:, col_ixs == - 1 ] = np .nan
226247
227- # TODO: searchsorted also gives -1 for failed lookups. However, these
228- # failed lookups arise due to the estimates date bounds bug that we
229- # have not yet addressed, so this is a temporary fix.
230- out [row_ixs == - 1 , :] = np .nan
231-
232248 return out
233249
234- def _check_dts (self , stored , requested ):
235- """Validate that requested dates are in bounds for what we have stored.
236- """
237- request_start , request_end = requested [[0 , - 1 ]]
238- data_start , data_end = stored [[0 , - 1 ]]
239-
240- if request_start < data_start :
241- raise ValueError (
242- "Requested fx rates starting at {}, but data starts at {}"
243- .format (request_start , data_start )
244- )
245-
246- if request_end > data_end :
247- raise ValueError (
248- "Requested fx rates ending at {}, but data ends at {}"
249- .format (request_end , data_end )
250- )
251-
252- if not is_sorted_ascending (requested ):
253- raise ValueError ("Requested fx rates with non-ascending dts." )
254-
255250
256251class HDF5FXRateWriter (object ):
257252 """Writer class for HDF5 files consumed by HDF5FXRateReader.
@@ -320,7 +315,3 @@ def _write_data_group(self, dts, currencies, data):
320315
321316 def _log_writing (self , * path ):
322317 log .debug ("Writing {}" , '/' .join (path ))
323-
324-
325- def is_sorted_ascending (array ):
326- return (np .maximum .accumulate (array ) <= array ).all ()
0 commit comments