Skip to content

Commit d419142

Browse files
author
Scott Sanderson
committed
MAINT: Move expected rate lookup functions into fixture for reuse.
1 parent 2c54c35 commit d419142

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

tests/data/test_fx.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -57,33 +57,6 @@ def make_fx_rates(cls, fields, currencies, sessions):
5757
'tokyo_mid': cls.tokyo_mid_rates,
5858
}
5959

60-
@classmethod
61-
def get_expected_rate_scalar(cls, rate, quote, base, dt):
62-
"""Get the expected FX rate for the given scalar coordinates.
63-
"""
64-
if rate == DEFAULT_FX_RATE:
65-
rate = cls.FX_RATES_DEFAULT_RATE
66-
67-
col = cls.fx_rates[rate][quote][base]
68-
# PERF: We call this function a lot in this suite, and get_loc is
69-
# surprisingly expensive, so optimizing it has a meaningful impact on
70-
# overall suite performance. See test_fast_get_loc_ffilled_for
71-
# assurance that this behaves the same as get_loc.
72-
ix = fast_get_loc_ffilled(col.index.values, dt.asm8)
73-
return col.values[ix]
74-
75-
@classmethod
76-
def get_expected_rates(cls, rate, quote, bases, dts):
77-
"""Get an array of expected FX rates for the given indices.
78-
"""
79-
out = np.empty((len(dts), len(bases)), dtype='float64')
80-
81-
for i, dt in enumerate(dts):
82-
for j, base in enumerate(bases):
83-
out[i, j] = cls.get_expected_rate_scalar(rate, quote, base, dt)
84-
85-
return out
86-
8760
@property
8861
def reader(self):
8962
raise NotImplementedError("Must be implemented by test suite.")
@@ -108,7 +81,7 @@ def test_scalar_lookup(self):
10881
if quote == base:
10982
assert_equal(result_scalar, 1.0)
11083

111-
expected = self.get_expected_rate_scalar(rate, quote, base, dt)
84+
expected = self.get_expected_fx_rate_scalar(rate, quote, base, dt)
11285
assert_equal(result_scalar, expected)
11386

11487
def test_vectorized_lookup(self):
@@ -133,7 +106,7 @@ def test_vectorized_lookup(self):
133106
# ...And check that we get the expected result when querying
134107
# for those dates/currencies.
135108
result = self.reader.get_rates(rate, quote, bases, dts)
136-
expected = self.get_expected_rates(rate, quote, bases, dts)
109+
expected = self.get_expected_fx_rates(rate, quote, bases, dts)
137110

138111
assert_equal(result, expected)
139112

@@ -211,17 +184,6 @@ def reader(self):
211184
return self.h5_fx_reader
212185

213186

214-
def fast_get_loc_ffilled(dts, dt):
215-
"""
216-
Equivalent to dts.get_loc(dt, method='ffill'), but with reasonable
217-
microperformance.
218-
"""
219-
ix = dts.searchsorted(dt, side='right') - 1
220-
if ix < 0:
221-
raise KeyError(dt)
222-
return ix
223-
224-
225187
class FastGetLocTestCase(zp_fixtures.ZiplineTestCase):
226188

227189
def test_fast_get_loc_ffilled(self):
@@ -234,12 +196,12 @@ def test_fast_get_loc_ffilled(self):
234196
])
235197

236198
for dt in pd.date_range('2014-01-02', '2014-01-08'):
237-
result = fast_get_loc_ffilled(dts.values, dt.asm8)
199+
result = zp_fixtures.fast_get_loc_ffilled(dts.values, dt.asm8)
238200
expected = dts.get_loc(dt, method='ffill')
239201
assert_equal(result, expected)
240202

241203
with self.assertRaises(KeyError):
242204
dts.get_loc(pd.Timestamp('2014-01-01'), method='ffill')
243205

244206
with self.assertRaises(KeyError):
245-
fast_get_loc_ffilled(dts, pd.Timestamp('2014-01-01'))
207+
zp_fixtures.fast_get_loc_ffilled(dts, pd.Timestamp('2014-01-01'))

zipline/testing/fixtures.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from zipline.algorithm import TradingAlgorithm
2121
from zipline.assets import Equity, Future
2222
from zipline.assets.continuous_futures import CHAIN_PREDICATES
23+
from zipline.data.fx import DEFAULT_FX_RATE
2324
from zipline.finance.asset_restrictions import NoRestrictions
2425
from zipline.utils.memoize import classlazyval
2526
from zipline.pipeline import SimplePipelineEngine
@@ -2191,3 +2192,43 @@ def write_h5_fx_rates(cls, path):
21912192
h5_file,
21922193
default_rate=cls.FX_RATES_DEFAULT_RATE,
21932194
)
2195+
2196+
@classmethod
2197+
def get_expected_fx_rate_scalar(cls, rate, quote, base, dt):
2198+
"""Get the expected FX rate for the given scalar coordinates.
2199+
"""
2200+
if rate == DEFAULT_FX_RATE:
2201+
rate = cls.FX_RATES_DEFAULT_RATE
2202+
2203+
col = cls.fx_rates[rate][quote][base]
2204+
# PERF: We call this function a lot in some suites, and get_loc is
2205+
# surprisingly expensive, so optimizing it has a meaningful impact on
2206+
# overall suite performance. See test_fast_get_loc_ffilled_for
2207+
# assurance that this behaves the same as get_loc.
2208+
ix = fast_get_loc_ffilled(col.index.values, dt.asm8)
2209+
return col.values[ix]
2210+
2211+
@classmethod
2212+
def get_expected_fx_rates(cls, rate, quote, bases, dts):
2213+
"""Get an array of expected FX rates for the given indices.
2214+
"""
2215+
out = np.empty((len(dts), len(bases)), dtype='float64')
2216+
2217+
for i, dt in enumerate(dts):
2218+
for j, base in enumerate(bases):
2219+
out[i, j] = cls.get_expected_fx_rate_scalar(
2220+
rate, quote, base, dt,
2221+
)
2222+
2223+
return out
2224+
2225+
2226+
def fast_get_loc_ffilled(dts, dt):
2227+
"""
2228+
Equivalent to dts.get_loc(dt, method='ffill'), but with reasonable
2229+
microperformance.
2230+
"""
2231+
ix = dts.searchsorted(dt, side='right') - 1
2232+
if ix < 0:
2233+
raise KeyError(dt)
2234+
return ix

0 commit comments

Comments
 (0)