Skip to content

Commit 3ea5916

Browse files
author
Scott Sanderson
authored
Merge pull request #2609 from quantopian/currency-fixes
Currency Improvements
2 parents 5e61943 + 09fb188 commit 3ea5916

File tree

14 files changed

+306
-164
lines changed

14 files changed

+306
-164
lines changed

tests/data/test_daily_bars.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from toolz import merge
3737
from trading_calendars import get_calendar
3838

39+
from zipline.currency import MISSING_CURRENCY_CODE
3940
from zipline.data.bar_reader import (
4041
NoDataAfterDate,
4142
NoDataBeforeDate,
@@ -59,7 +60,7 @@
5960
expected_bar_values_2d,
6061
make_bar_data,
6162
)
62-
from zipline.testing import seconds_to_timestamp
63+
from zipline.testing import seconds_to_timestamp, powerset
6364
from zipline.testing.fixtures import (
6465
WithAssetFinder,
6566
WithBcolzEquityDailyBarReader,
@@ -522,13 +523,45 @@ def test_get_last_traded_dt(self):
522523
)
523524

524525
def test_listing_currency(self):
525-
assets = np.array(list(self.assets))
526-
# TODO: Test loading codes for missing assets.
527-
results = self.daily_bar_reader.currency_codes(assets)
528-
expected = self.make_equity_daily_bar_currency_codes(
529-
self.DAILY_BARS_TEST_QUERY_COUNTRY_CODE, assets,
526+
# Test loading on all assets.
527+
all_assets = np.array(list(self.assets))
528+
all_results = self.daily_bar_reader.currency_codes(all_assets)
529+
all_expected = self.make_equity_daily_bar_currency_codes(
530+
self.DAILY_BARS_TEST_QUERY_COUNTRY_CODE, all_assets,
530531
).values
531-
assert_equal(results, expected)
532+
assert_equal(all_results, all_expected)
533+
534+
# Check all possible subsets of assets.
535+
for indices in map(list, powerset(range(len(all_assets)))):
536+
# Empty queries aren't currently supported.
537+
if not indices:
538+
continue
539+
assets = all_assets[indices]
540+
results = self.daily_bar_reader.currency_codes(assets)
541+
expected = all_expected[indices]
542+
543+
assert_equal(results, expected)
544+
545+
def test_listing_currency_for_nonexistent_asset(self):
546+
reader = self.daily_bar_reader
547+
548+
valid_sid = max(self.assets)
549+
valid_currency = reader.currency_codes(np.array([valid_sid]))[0]
550+
invalid_sids = [-1, -2]
551+
552+
# XXX: We currently require at least one valid sid here, because the
553+
# MultiCountryDailyBarReader needs one valid sid to be able to dispatch
554+
# to a child reader. We could probably make that work, but there are no
555+
# real-world cases where we expect to get all-invalid currency queries,
556+
# so it's unclear whether we should do work to explicitly support such
557+
# queries.
558+
mixed = np.array(invalid_sids + [valid_sid])
559+
result = self.daily_bar_reader.currency_codes(mixed)
560+
expected = np.array(
561+
[MISSING_CURRENCY_CODE] * 2 + [valid_currency],
562+
dtype='S3'
563+
)
564+
assert_equal(result, expected)
532565

533566

534567
class BcolzDailyBarTestCase(WithBcolzEquityDailyBarReader, _DailyBarsTestCase):

tests/data/test_fx.py

Lines changed: 5 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import itertools
22

3-
import h5py
43
import pandas as pd
54
import numpy as np
65

76
from zipline.data.fx import DEFAULT_FX_RATE
8-
from zipline.data.fx.hdf5 import HDF5FXRateReader, HDF5FXRateWriter
97

108
from zipline.testing.predicates import assert_equal
119
import zipline.testing.fixtures as zp_fixtures
@@ -59,33 +57,6 @@ def make_fx_rates(cls, fields, currencies, sessions):
5957
'tokyo_mid': cls.tokyo_mid_rates,
6058
}
6159

62-
@classmethod
63-
def get_expected_rate_scalar(cls, rate, quote, base, dt):
64-
"""Get the expected FX rate for the given scalar coordinates.
65-
"""
66-
if rate == DEFAULT_FX_RATE:
67-
rate = cls.FX_RATES_DEFAULT_RATE
68-
69-
col = cls.fx_rates[rate][quote][base]
70-
# PERF: We call this function a lot in this suite, and get_loc is
71-
# surprisingly expensive, so optimizing it has a meaningful impact on
72-
# overall suite performance. See test_fast_get_loc_ffilled_for
73-
# assurance that this behaves the same as get_loc.
74-
ix = fast_get_loc_ffilled(col.index.values, dt.asm8)
75-
return col.values[ix]
76-
77-
@classmethod
78-
def get_expected_rates(cls, rate, quote, bases, dts):
79-
"""Get an array of expected FX rates for the given indices.
80-
"""
81-
out = np.empty((len(dts), len(bases)), dtype='float64')
82-
83-
for i, dt in enumerate(dts):
84-
for j, base in enumerate(bases):
85-
out[i, j] = cls.get_expected_rate_scalar(rate, quote, base, dt)
86-
87-
return out
88-
8960
@property
9061
def reader(self):
9162
raise NotImplementedError("Must be implemented by test suite.")
@@ -110,7 +81,7 @@ def test_scalar_lookup(self):
11081
if quote == base:
11182
assert_equal(result_scalar, 1.0)
11283

113-
expected = self.get_expected_rate_scalar(rate, quote, base, dt)
84+
expected = self.get_expected_fx_rate_scalar(rate, quote, base, dt)
11485
assert_equal(result_scalar, expected)
11586

11687
def test_vectorized_lookup(self):
@@ -135,7 +106,7 @@ def test_vectorized_lookup(self):
135106
# ...And check that we get the expected result when querying
136107
# for those dates/currencies.
137108
result = self.reader.get_rates(rate, quote, bases, dts)
138-
expected = self.get_expected_rates(rate, quote, bases, dts)
109+
expected = self.get_expected_fx_rates(rate, quote, bases, dts)
139110

140111
assert_equal(result, expected)
141112

@@ -205,47 +176,14 @@ class HDF5FXReaderTestCase(zp_fixtures.WithTmpDir,
205176
@classmethod
206177
def init_class_fixtures(cls):
207178
super(HDF5FXReaderTestCase, cls).init_class_fixtures()
208-
209179
path = cls.tmpdir.getpath('fx_rates.h5')
210-
211-
# Set by WithFXRates.
212-
sessions = cls.fx_rates_sessions
213-
214-
# Write in-memory data to h5 file.
215-
with h5py.File(path, 'w') as h5_file:
216-
writer = HDF5FXRateWriter(h5_file)
217-
fx_data = ((rate, quote, quote_frame.values)
218-
for rate, rate_dict in cls.fx_rates.items()
219-
for quote, quote_frame in rate_dict.items())
220-
221-
writer.write(
222-
dts=sessions.values,
223-
currencies=np.array(cls.FX_RATES_CURRENCIES, dtype='S3'),
224-
data=fx_data,
225-
)
226-
227-
h5_file = cls.enter_class_context(h5py.File(path, 'r'))
228-
cls.h5_fx_reader = HDF5FXRateReader(
229-
h5_file,
230-
default_rate=cls.FX_RATES_DEFAULT_RATE,
231-
)
180+
cls.h5_fx_reader = cls.write_h5_fx_rates(path)
232181

233182
@property
234183
def reader(self):
235184
return self.h5_fx_reader
236185

237186

238-
def fast_get_loc_ffilled(dts, dt):
239-
"""
240-
Equivalent to dts.get_loc(dt, method='ffill'), but with reasonable
241-
microperformance.
242-
"""
243-
ix = dts.searchsorted(dt, side='right') - 1
244-
if ix < 0:
245-
raise KeyError(dt)
246-
return ix
247-
248-
249187
class FastGetLocTestCase(zp_fixtures.ZiplineTestCase):
250188

251189
def test_fast_get_loc_ffilled(self):
@@ -258,12 +196,12 @@ def test_fast_get_loc_ffilled(self):
258196
])
259197

260198
for dt in pd.date_range('2014-01-02', '2014-01-08'):
261-
result = fast_get_loc_ffilled(dts.values, dt.asm8)
199+
result = zp_fixtures.fast_get_loc_ffilled(dts.values, dt.asm8)
262200
expected = dts.get_loc(dt, method='ffill')
263201
assert_equal(result, expected)
264202

265203
with self.assertRaises(KeyError):
266204
dts.get_loc(pd.Timestamp('2014-01-01'), method='ffill')
267205

268206
with self.assertRaises(KeyError):
269-
fast_get_loc_ffilled(dts, pd.Timestamp('2014-01-01'))
207+
zp_fixtures.fast_get_loc_ffilled(dts, pd.Timestamp('2014-01-01'))

tests/utils/test_numpy_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from toolz import curry
2020
from toolz.curried.operator import ne
2121

22+
from zipline.testing.predicates import assert_equal
2223
from zipline.utils.functional import mapall as lazy_mapall
2324
from zipline.utils.numpy_utils import (
25+
bytes_array_to_native_str_object_array,
2426
is_float,
2527
is_int,
2628
is_datetime,
@@ -92,3 +94,13 @@ def test_is_datetime(self):
9294

9395
for bad_value in everything_but(datetime, CASES):
9496
self.assertFalse(is_datetime(bad_value))
97+
98+
99+
class ArrayUtilsTestCase(TestCase):
100+
101+
def test_bytes_array_to_native_str_object_array(self):
102+
a = array([b'abc', b'def'], dtype='S3')
103+
result = bytes_array_to_native_str_object_array(a)
104+
expected = array(['abc', 'def'], dtype=object)
105+
106+
assert_equal(result, expected)

zipline/currency.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,11 @@
1-
from functools import partial, total_ordering
2-
1+
from functools import total_ordering
32
from iso4217 import Currency as ISO4217Currency
43

5-
import numpy as np
6-
74
_ALL_CURRENCIES = {}
85

96

10-
def strs_to_sids(strs, category_num):
11-
"""TODO: Improve this.
12-
"""
13-
out = np.full(len(strs), category_num << 50, dtype='i8')
14-
casted_buffer = np.ndarray(
15-
shape=out.shape,
16-
dtype='S6',
17-
buffer=out,
18-
strides=out.strides,
19-
)
20-
casted_buffer[:] = np.array(strs, dtype='S6')
21-
return out
22-
23-
24-
def str_to_sid(str_, category_num):
25-
return strs_to_sids([str_], category_num)[0]
26-
27-
28-
iso_currency_to_sid = partial(str_to_sid, category_num=3)
7+
# Special sentinel used to represent unknown or missing currencies.
8+
MISSING_CURRENCY_CODE = 'XXX'
299

3010

3111
@total_ordering
@@ -48,15 +28,20 @@ def __new__(cls, code):
4828
try:
4929
return _ALL_CURRENCIES[code]
5030
except KeyError:
51-
try:
52-
iso_currency = ISO4217Currency(code)
53-
except ValueError:
54-
raise ValueError(
55-
"{!r} is not a valid currency code.".format(code)
56-
)
31+
# This isn't a real
32+
if code == MISSING_CURRENCY_CODE:
33+
name = "NO CURRENCY"
34+
else:
35+
try:
36+
name = ISO4217Currency(code).currency_name
37+
except ValueError:
38+
raise ValueError(
39+
"{!r} is not a valid currency code.".format(code)
40+
)
41+
5742
obj = _ALL_CURRENCIES[code] = super(Currency, cls).__new__(cls)
58-
obj._currency = iso_currency
59-
obj._sid = iso_currency_to_sid(iso_currency.value)
43+
obj._code = code
44+
obj._name = name
6045
return obj
6146

6247
@property
@@ -67,7 +52,7 @@ def code(self):
6752
-------
6853
code : str
6954
"""
70-
return self._currency.value
55+
return self._code
7156

7257
@property
7358
def name(self):
@@ -77,13 +62,7 @@ def name(self):
7762
-------
7863
name : str
7964
"""
80-
return self._currency.currency_name
81-
82-
@property
83-
def sid(self):
84-
"""Unique integer identifier for this currency.
85-
"""
86-
return self._sid
65+
return self._name
8766

8867
def __eq__(self, other):
8968
if type(self) != type(other):

zipline/data/bcolz_daily_bars.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from toolz import compose
3535
from trading_calendars import get_calendar
3636

37+
from zipline.currency import MISSING_CURRENCY_CODE
3738
from zipline.data.session_bars import CurrencyAwareSessionBarReader
3839
from zipline.data.bar_reader import (
3940
NoDataAfterDate,
@@ -706,5 +707,14 @@ def get_value(self, sid, dt, field):
706707
return price
707708

708709
def currency_codes(self, sids):
709-
# TODO: Better handling for this.
710-
return np.full(len(sids), b'USD', dtype='S3')
710+
# XXX: This is pretty inefficient. This reader doesn't really support
711+
# country codes, so we always either return USD or
712+
# MISSING_CURRENCY_CODE if we don't know about the sid at all.
713+
first_rows = self._first_rows
714+
out = []
715+
for sid in sids:
716+
if sid in first_rows:
717+
out.append('USD')
718+
else:
719+
out.append(MISSING_CURRENCY_CODE)
720+
return np.array(out, dtype='S3')

0 commit comments

Comments
 (0)