Skip to content

Commit 008d719

Browse files
author
Scott Sanderson
authored
Merge pull request #2640 from quantopian/fx-changes-for-estimates
MAINT: Clarify edge case handling for FX Rate Readers.
2 parents 8443934 + baeac17 commit 008d719

File tree

6 files changed

+144
-76
lines changed

6 files changed

+144
-76
lines changed

tests/data/test_fx.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,23 @@ def test_scalar_lookup(self):
6565
reader = self.reader
6666

6767
rates = self.FX_RATES_RATE_NAMES
68-
currencies = self.FX_RATES_CURRENCIES
69-
dates = pd.date_range(self.FX_RATES_START_DATE, self.FX_RATES_END_DATE)
70-
71-
cases = itertools.product(rates, currencies, currencies, dates)
68+
quotes = self.FX_RATES_CURRENCIES
69+
bases = self.FX_RATES_CURRENCIES + [None]
70+
dates = pd.date_range(
71+
self.FX_RATES_START_DATE - pd.Timedelta('1 day'),
72+
self.FX_RATES_END_DATE,
73+
)
74+
cases = itertools.product(rates, quotes, bases, dates)
7275

7376
for rate, quote, base, dt in cases:
7477
dts = pd.DatetimeIndex([dt], tz='UTC')
75-
bases = np.array([base])
78+
bases = np.array([base], dtype=object)
7679

7780
result = reader.get_rates(rate, quote, bases, dts)
7881
assert_equal(result.shape, (1, 1))
7982

8083
result_scalar = result[0, 0]
81-
if quote == base:
84+
if dt >= self.FX_RATES_START_DATE and quote == base:
8285
assert_equal(result_scalar, 1.0)
8386

8487
expected = self.get_expected_fx_rate_scalar(rate, quote, base, dt)
@@ -93,12 +96,16 @@ def test_scalar_lookup(self):
9396
def test_2d_lookup(self):
9497
rand = np.random.RandomState(42)
9598

96-
dates = pd.date_range(self.FX_RATES_START_DATE, self.FX_RATES_END_DATE)
99+
dates = pd.date_range(
100+
self.FX_RATES_START_DATE - pd.Timedelta('2 days'),
101+
self.FX_RATES_END_DATE
102+
)
97103
rates = self.FX_RATES_RATE_NAMES + [DEFAULT_FX_RATE]
98-
currencies = self.FX_RATES_CURRENCIES
104+
possible_quotes = self.FX_RATES_CURRENCIES
105+
possible_bases = self.FX_RATES_CURRENCIES + [None]
99106

100107
# For every combination of rate name and quote currency...
101-
for rate, quote in itertools.product(rates, currencies):
108+
for rate, quote in itertools.product(rates, possible_quotes):
102109

103110
# Choose N random distinct days...
104111
for ndays in 1, 2, 7, 20:
@@ -107,7 +114,10 @@ def test_2d_lookup(self):
107114

108115
# Choose M random possibly-non-distinct currencies...
109116
for nbases in 1, 2, 10, 200:
110-
bases = rand.choice(currencies, nbases, replace=True)
117+
bases = (
118+
rand.choice(possible_bases, nbases, replace=True)
119+
.astype(object)
120+
)
111121

112122
# ...And check that we get the expected result when querying
113123
# for those dates/currencies.
@@ -119,18 +129,25 @@ def test_2d_lookup(self):
119129
def test_columnar_lookup(self):
120130
rand = np.random.RandomState(42)
121131

122-
dates = pd.date_range(self.FX_RATES_START_DATE, self.FX_RATES_END_DATE)
132+
dates = pd.date_range(
133+
self.FX_RATES_START_DATE - pd.Timedelta('2 days'),
134+
self.FX_RATES_END_DATE,
135+
)
123136
rates = self.FX_RATES_RATE_NAMES + [DEFAULT_FX_RATE]
124-
currencies = self.FX_RATES_CURRENCIES
137+
possible_quotes = self.FX_RATES_CURRENCIES
138+
possible_bases = self.FX_RATES_CURRENCIES + [None]
125139
reader = self.reader
126140

127141
# For every combination of rate name and quote currency...
128-
for rate, quote in itertools.product(rates, currencies):
142+
for rate, quote in itertools.product(rates, possible_quotes):
129143
for N in 1, 2, 10, 200:
130144
# Choose N (date, base) pairs randomly with replacement.
131145
dts_raw = rand.choice(dates, N, replace=True)
132-
dts = pd.DatetimeIndex(dts_raw, tz='utc').sort_values()
133-
bases = rand.choice(currencies, N, replace=True)
146+
dts = pd.DatetimeIndex(dts_raw, tz='utc')
147+
bases = (
148+
rand.choice(possible_bases, N, replace=True)
149+
.astype(object)
150+
)
134151

135152
# ... And check that we get the expected result when querying
136153
# for those dates/currencies.
@@ -175,27 +192,50 @@ def test_load_everything(self):
175192
assert_equal(london_result, london_rates.values)
176193

177194
def test_read_before_start_date(self):
195+
# Reads from before the start of our data should emit NaN. We do this
196+
# because, for some Pipeline loaders, it's hard to put a lower bound on
197+
# input asof dates, so we end up making queries for asof_dates that
198+
# might be before the start of FX data. When that happens, we want to
199+
# emit NaN, but we don't want to fail.
178200
for bad_date in (self.FX_RATES_START_DATE - pd.Timedelta('1 day'),
179201
self.FX_RATES_START_DATE - pd.Timedelta('1000 days')):
180202

181203
for rate in self.FX_RATES_RATE_NAMES:
182204
quote = 'USD'
183205
bases = np.array(['CAD'], dtype=object)
184206
dts = pd.DatetimeIndex([bad_date])
185-
with self.assertRaises(ValueError):
186-
self.reader.get_rates(rate, quote, bases, dts)
207+
result = self.reader.get_rates(rate, quote, bases, dts)
208+
assert_equal(result.shape, (1, 1))
209+
assert_equal(np.nan, result[0, 0])
187210

188211
def test_read_after_end_date(self):
212+
# Reads from **after** the end of our data, on the other hand, should
213+
# fail. We can always upper bound the relevant asofs that we're
214+
# interested in, and having fx rates forward-fill past the end of data
215+
# is confusing and takes a while to debug.
189216
for bad_date in (self.FX_RATES_END_DATE + pd.Timedelta('1 day'),
190217
self.FX_RATES_END_DATE + pd.Timedelta('1000 days')):
191218

192219
for rate in self.FX_RATES_RATE_NAMES:
193220
quote = 'USD'
194221
bases = np.array(['CAD'], dtype=object)
195222
dts = pd.DatetimeIndex([bad_date])
223+
196224
with self.assertRaises(ValueError):
197225
self.reader.get_rates(rate, quote, bases, dts)
198226

227+
with self.assertRaises(ValueError):
228+
self.reader.get_rates_columnar(rate, quote, bases, dts)
229+
230+
def test_read_unknown_base(self):
231+
for rate in self.FX_RATES_RATE_NAMES:
232+
quote = 'USD'
233+
for unknown_base in 'XXX', None:
234+
bases = np.array([unknown_base], dtype=object)
235+
dts = pd.DatetimeIndex([self.FX_RATES_START_DATE])
236+
result = self.reader.get_rates(rate, quote, bases, dts)[0, 0]
237+
assert_equal(result, np.nan)
238+
199239

200240
class InMemoryFXReaderTestCase(_FXReaderTestCase):
201241

zipline/data/fx/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55

66
from zipline.utils.sentinel import sentinel
7+
from zipline.lib._factorize import factorize_strings
78

89
DEFAULT_FX_RATE = sentinel('DEFAULT_FX_RATE')
910

@@ -127,15 +128,21 @@ def get_rates_columnar(self, rate, quote, bases, dts):
127128
may appear multiple times.
128129
dts : np.DatetimeIndex
129130
Datetimes for which to load rates. The same value may appear
130-
multiple times, but datetimes must be sorted in ascending order and
131-
localized to UTC.
131+
multiple times. Datetimes do not need to be sorted.
132132
"""
133133
if len(bases) != len(dts):
134134
raise ValueError(
135135
"len(bases) ({}) != len(dts) ({})".format(len(bases), len(dts))
136136
)
137137

138-
unique_bases, bases_ix = np.unique(bases, return_inverse=True)
138+
bases_ix, unique_bases, _ = factorize_strings(
139+
bases,
140+
missing_value=None,
141+
# Only dts need to be sorted, not bases.
142+
sort=False,
143+
)
144+
# NOTE: np.unique returns unique_dts in sorted order, which is required
145+
# for calling get_rates.
139146
unique_dts, dts_ix = np.unique(dts.values, return_inverse=True)
140147
rates_2d = self.get_rates(
141148
rate,

zipline/data/fx/hdf5.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
from zipline.utils.numpy_utils import bytes_array_to_native_str_object_array
105105

106106
from .base import FXRateReader, DEFAULT_FX_RATE
107+
from .utils import check_dts, is_sorted_ascending
107108

108109
HDF5_FX_VERSION = 0
109110

@@ -189,7 +190,7 @@ def get_rates(self, rate, quote, bases, dts):
189190
if rate == DEFAULT_FX_RATE:
190191
rate = self._default_rate
191192

192-
self._check_dts(self.dts, dts)
193+
check_dts(self.dts, dts)
193194

194195
row_ixs = self.dts.searchsorted(dts, side='right') - 1
195196
col_ixs = self.currencies.get_indexer(bases)
@@ -204,46 +205,48 @@ def get_rates(self, rate, quote, bases, dts):
204205

205206
# OPTIMIZATION: Row indices correspond to dates, which must be in
206207
# sorted order. Rather than reading the entire dataset from h5, we can
207-
# read just the interval from min_row to max_row inclusive.
208+
# read just the interval from min_row to max_row inclusive
208209
#
209-
# We don't bother with a similar optimization for columns because in
210-
# expectation we're going to load most of the
211-
212-
# array, so it's easier to pull all columns and reindex in memory. For
213-
# rows, however, a quick and easy optimization is to pull just the
214-
# slice from min(row_ixs) to max(row_ixs).
215-
min_row = row_ixs[0]
216-
max_row = row_ixs[-1]
217-
rows = dataset[min_row:max_row + 1] # +1 to be inclusive of end
218-
219-
out = rows[row_ixs - min_row][:, col_ixs]
210+
# However, we also need to handle two important edge cases:
211+
#
212+
# 1. row_ixs contains -1 for dts before the start of self.dts.
213+
# 2. col_ixs contains -1 for any currencies we don't know about.
214+
#
215+
# If either of the above cases obtains, we want to return NaN for the
216+
# corresponding output locations.
220217

221-
# get_indexer returns -1 for failed lookups. Fill these in with NaN.
218+
# We handle (1) by reading raw data into a buffer with one extra
219+
# row. When we then apply the row index to permute the raw data into
220+
# the correct order, any rows with values of -1 will pull from the
221+
# extra row, which will always contain NaN>
222+
#
223+
# We handle (2) by overwriting columns with indices of -1 with NaN as a
224+
# postprocessing step.
225+
slice_begin = max(row_ixs[0], 0)
226+
slice_end = max(row_ixs[-1], 0) + 1 # +1 to be inclusive of end date.
227+
228+
# Allocate a buffer full of NaNs with one extra row/column. See
229+
# OPTIMIZATION notes above.
230+
buf = np.full(
231+
(slice_end - slice_begin + 1, len(self.currencies)),
232+
np.nan,
233+
)
234+
235+
# Read data into all but the last row/column of the buffer.
236+
dataset.read_direct(
237+
buf[:-1],
238+
np.s_[slice_begin:slice_end],
239+
)
240+
241+
# Permute the rows into place, pulling from the empty NaN locations for
242+
# row/column indices of -1.
243+
out = buf[:, col_ixs][row_ixs - slice_begin]
244+
245+
# Fill missing columns with NaN. See OPTIMIZATION notes above.
222246
out[:, col_ixs == -1] = np.nan
223247

224248
return out
225249

226-
def _check_dts(self, stored, requested):
227-
"""Validate that requested dates are in bounds for what we have stored.
228-
"""
229-
request_start, request_end = requested[[0, -1]]
230-
data_start, data_end = stored[[0, -1]]
231-
232-
if request_start < data_start:
233-
raise ValueError(
234-
"Requested fx rates starting at {}, but data starts at {}"
235-
.format(request_start, data_start)
236-
)
237-
238-
if request_end > data_end:
239-
raise ValueError(
240-
"Requested fx rates ending at {}, but data ends at {}"
241-
.format(request_end, data_end)
242-
)
243-
244-
if not is_sorted_ascending(requested):
245-
raise ValueError("Requested fx rates with non-ascending dts.")
246-
247250

248251
class HDF5FXRateWriter(object):
249252
"""Writer class for HDF5 files consumed by HDF5FXRateReader.
@@ -312,7 +315,3 @@ def _write_data_group(self, dts, currencies, data):
312315

313316
def _log_writing(self, *path):
314317
log.debug("Writing {}", '/'.join(path))
315-
316-
317-
def is_sorted_ascending(array):
318-
return (np.maximum.accumulate(array) <= array).all()

zipline/data/fx/in_memory.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Interface and definitions for foreign exchange rate readers.
22
"""
33
from interface import implements
4+
import numpy as np
45

56
from .base import FXRateReader, DEFAULT_FX_RATE
7+
from .utils import check_dts
68

79

810
class InMemoryFXRateReader(implements(FXRateReader)):
@@ -34,7 +36,7 @@ def get_rates(self, rate, quote, bases, dts):
3436

3537
df = self._data[rate][quote]
3638

37-
self._check_dts(df.index, dts)
39+
check_dts(df.index, dts)
3840

3941
# Get raw values out of the frame.
4042
#
@@ -51,22 +53,11 @@ def get_rates(self, rate, quote, bases, dts):
5153
values = df.values
5254
row_ixs = df.index.searchsorted(dts, side='right') - 1
5355
col_ixs = df.columns.get_indexer(bases)
54-
return values[row_ixs][:, col_ixs]
5556

56-
def _check_dts(self, stored, requested):
57-
"""Validate that requested dates are in bounds for what we have stored.
58-
"""
59-
request_start, request_end = requested[[0, -1]]
60-
data_start, data_end = stored[[0, -1]]
57+
out = values[:, col_ixs][row_ixs]
6158

62-
if request_start < data_start:
63-
raise ValueError(
64-
"Requested fx rates starting at {}, but data starts at {}"
65-
.format(request_start, data_start)
66-
)
59+
# Handle dates before start and unknown bases.
60+
out[row_ixs == -1] = np.nan
61+
out[:, col_ixs == -1] = np.nan
6762

68-
if request_end > data_end:
69-
raise ValueError(
70-
"Requested fx rates ending at {}, but data ends at {}"
71-
.format(request_end, data_end)
72-
)
63+
return out

zipline/data/fx/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
3+
4+
def check_dts(stored_dts, requested_dts):
5+
"""
6+
Validate that ``requested_dts`` are valid for querying from an FX reader
7+
that has data for ``stored_dts``.
8+
"""
9+
request_end = requested_dts[-1]
10+
data_end = stored_dts[-1]
11+
12+
if not is_sorted_ascending(requested_dts):
13+
raise ValueError("Requested fx rates with non-ascending dts.")
14+
15+
if request_end > data_end:
16+
raise ValueError(
17+
"Requested fx rates ending at {}, but data ends at {}"
18+
.format(request_end, data_end)
19+
)
20+
21+
22+
def is_sorted_ascending(array):
23+
return (np.maximum.accumulate(array) <= array).all()

zipline/testing/fixtures.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,10 +2196,18 @@ def write_h5_fx_rates(cls, path):
21962196
def get_expected_fx_rate_scalar(cls, rate, quote, base, dt):
21972197
"""Get the expected FX rate for the given scalar coordinates.
21982198
"""
2199+
if base is None:
2200+
return np.nan
2201+
21992202
if rate == DEFAULT_FX_RATE:
22002203
rate = cls.FX_RATES_DEFAULT_RATE
22012204

22022205
col = cls.fx_rates[rate][quote][base]
2206+
if dt < col.index[0]:
2207+
return np.nan
2208+
elif dt > col.index[-1]:
2209+
raise ValueError("dt={} > max dt={}".format(dt, col.index[-1]))
2210+
22032211
# PERF: We call this function a lot in some suites, and get_loc is
22042212
# surprisingly expensive, so optimizing it has a meaningful impact on
22052213
# overall suite performance. See test_fast_get_loc_ffilled_for

0 commit comments

Comments
 (0)