Skip to content

Commit 09fb188

Browse files
author
Scott Sanderson
committed
ENH: Add currency column to EquityPricing.
1 parent d419142 commit 09fb188

File tree

4 files changed

+82
-8
lines changed

4 files changed

+82
-8
lines changed

tests/utils/test_numpy_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from toolz import curry
2020
from toolz.curried.operator import ne
2121

22+
from zipline.testing.predicates import assert_equal
2223
from zipline.utils.functional import mapall as lazy_mapall
2324
from zipline.utils.numpy_utils import (
25+
bytes_array_to_native_str_object_array,
2426
is_float,
2527
is_int,
2628
is_datetime,
@@ -92,3 +94,13 @@ def test_is_datetime(self):
9294

9395
for bad_value in everything_but(datetime, CASES):
9496
self.assertFalse(is_datetime(bad_value))
97+
98+
99+
class ArrayUtilsTestCase(TestCase):
100+
101+
def test_bytes_array_to_native_str_object_array(self):
102+
a = array([b'abc', b'def'], dtype='S3')
103+
result = bytes_array_to_native_str_object_array(a)
104+
expected = array(['abc', 'def'], dtype=object)
105+
106+
assert_equal(result, expected)

zipline/pipeline/data/equity_pricing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Dataset representing OHLCV data.
33
"""
4-
from zipline.utils.numpy_utils import float64_dtype
4+
from zipline.utils.numpy_utils import float64_dtype, categorical_dtype
55

66
from ..domain import US_EQUITIES
77
from .dataset import Column, DataSet
@@ -17,6 +17,7 @@ class EquityPricing(DataSet):
1717
low = Column(float64_dtype)
1818
close = Column(float64_dtype)
1919
volume = Column(float64_dtype)
20+
currency = Column(categorical_dtype)
2021

2122

2223
# Backwards compat alias.

zipline/pipeline/loaders/equity_pricing_loader.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@
1616
from interface import implements
1717
from numpy import iinfo, uint32, multiply
1818

19+
from zipline.currency import MISSING_CURRENCY_CODE
1920
from zipline.data.fx import ExplodingFXRateReader
2021
from zipline.lib.adjusted_array import AdjustedArray
22+
from zipline.utils.numpy_utils import (
23+
repeat_first_axis,
24+
bytes_array_to_native_str_object_array,
25+
)
2126

2227
from .base import PipelineLoader
2328
from .utils import shift_dates
29+
from ..data.equity_pricing import EquityPricing
2430

2531
UINT32_MAX = iinfo(uint32).max
2632

@@ -81,9 +87,12 @@ def load_adjusted_array(self, domain, columns, dates, sids, mask):
8187
sessions = domain.all_sessions()
8288
shifted_dates = shift_dates(sessions, dates[0], dates[-1], shift=1)
8389

84-
colnames = [c.name for c in columns]
85-
raw_arrays = self.raw_price_reader.load_raw_arrays(
86-
colnames,
90+
ohlcv_cols, currency_cols = self._split_column_types(columns)
91+
del columns # From here on we should use ohlcv_cols or currency_cols.
92+
ohlcv_colnames = [c.name for c in ohlcv_cols]
93+
94+
raw_ohlcv_arrays = self.raw_price_reader.load_raw_arrays(
95+
ohlcv_colnames,
8796
shifted_dates[0],
8897
shifted_dates[-1],
8998
sids,
@@ -93,25 +102,40 @@ def load_adjusted_array(self, domain, columns, dates, sids, mask):
93102
# dates to load currency conversion rates to make them line up with
94103
# dates used to fetch prices.
95104
self._inplace_currency_convert(
96-
columns,
97-
raw_arrays,
105+
ohlcv_cols,
106+
raw_ohlcv_arrays,
98107
shifted_dates,
99108
sids,
100109
)
101110

102111
adjustments = self.adjustments_reader.load_pricing_adjustments(
103-
colnames,
112+
ohlcv_colnames,
104113
dates,
105114
sids,
106115
)
107116

108117
out = {}
109-
for c, c_raw, c_adjs in zip(columns, raw_arrays, adjustments):
118+
for c, c_raw, c_adjs in zip(ohlcv_cols, raw_ohlcv_arrays, adjustments):
110119
out[c] = AdjustedArray(
111120
c_raw.astype(c.dtype),
112121
c_adjs,
113122
c.missing_value,
114123
)
124+
125+
for c in currency_cols:
126+
codes_1d = bytes_array_to_native_str_object_array(
127+
self.raw_price_reader.currency_codes(sids)
128+
)
129+
# XXX: Should this just be the contract of `currency_codes`?
130+
codes_1d[codes_1d == MISSING_CURRENCY_CODE] = None
131+
132+
codes = repeat_first_axis(codes_1d, len(dates))
133+
out[c] = AdjustedArray(
134+
codes,
135+
adjustments={},
136+
missing_value=None,
137+
)
138+
115139
return out
116140

117141
@property
@@ -169,6 +193,33 @@ def _inplace_currency_convert(self, columns, arrays, dates, sids):
169193
for arr in arrays:
170194
multiply(arr, rates, out=arr)
171195

196+
def _split_column_types(self, columns):
197+
"""Split out currency columns from OHLCV columns.
198+
199+
Parameters
200+
----------
201+
columns : list[zipline.pipeline.data.BoundColumn]
202+
Columns to be loaded by ``load_adjusted_array``.
203+
204+
Returns
205+
-------
206+
ohlcv_columns : list[zipline.pipeline.data.BoundColumn]
207+
Price and volume columns from ``columns``.
208+
currency_columns : list[zipline.pipeline.data.BoundColumn]
209+
Currency code column from ``columns``, if present.
210+
"""
211+
currency_name = EquityPricing.currency.name
212+
213+
ohlcv = []
214+
currency = []
215+
for c in columns:
216+
if c.name == currency_name:
217+
currency.append(c)
218+
else:
219+
ohlcv.append(c)
220+
221+
return ohlcv, currency
222+
172223

173224
# Backwards compat alias.
174225
USEquityPricingLoader = EquityPricingLoader

zipline/utils/numpy_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
filterwarnings,
1010
)
1111

12+
import six
1213
import numpy as np
1314
from numpy import (
1415
array_equal,
@@ -502,3 +503,12 @@ def compare_datetime_arrays(x, y):
502503
"""
503504

504505
return array_equal(x.view('int64'), y.view('int64'))
506+
507+
508+
def bytes_array_to_native_str_object_array(a):
509+
"""Convert an array of dtype S to an object array containing `str`.
510+
"""
511+
if six.PY2:
512+
return a.astype(object)
513+
else:
514+
return a.astype(str).astype(object)

0 commit comments

Comments
 (0)