1616from interface import implements
1717from numpy import iinfo , uint32 , multiply
1818
19+ from zipline .currency import MISSING_CURRENCY_CODE
1920from zipline .data .fx import ExplodingFXRateReader
2021from zipline .lib .adjusted_array import AdjustedArray
22+ from zipline .utils .numpy_utils import (
23+ repeat_first_axis ,
24+ bytes_array_to_native_str_object_array ,
25+ )
2126
2227from .base import PipelineLoader
2328from .utils import shift_dates
29+ from ..data .equity_pricing import EquityPricing
2430
2531UINT32_MAX = iinfo (uint32 ).max
2632
@@ -81,9 +87,12 @@ def load_adjusted_array(self, domain, columns, dates, sids, mask):
8187 sessions = domain .all_sessions ()
8288 shifted_dates = shift_dates (sessions , dates [0 ], dates [- 1 ], shift = 1 )
8389
84- colnames = [c .name for c in columns ]
85- raw_arrays = self .raw_price_reader .load_raw_arrays (
86- colnames ,
90+ ohlcv_cols , currency_cols = self ._split_column_types (columns )
91+ del columns # From here on we should use ohlcv_cols or currency_cols.
92+ ohlcv_colnames = [c .name for c in ohlcv_cols ]
93+
94+ raw_ohlcv_arrays = self .raw_price_reader .load_raw_arrays (
95+ ohlcv_colnames ,
8796 shifted_dates [0 ],
8897 shifted_dates [- 1 ],
8998 sids ,
@@ -93,25 +102,40 @@ def load_adjusted_array(self, domain, columns, dates, sids, mask):
93102 # dates to load currency conversion rates to make them line up with
94103 # dates used to fetch prices.
95104 self ._inplace_currency_convert (
96- columns ,
97- raw_arrays ,
105+ ohlcv_cols ,
106+ raw_ohlcv_arrays ,
98107 shifted_dates ,
99108 sids ,
100109 )
101110
102111 adjustments = self .adjustments_reader .load_pricing_adjustments (
103- colnames ,
112+ ohlcv_colnames ,
104113 dates ,
105114 sids ,
106115 )
107116
108117 out = {}
109- for c , c_raw , c_adjs in zip (columns , raw_arrays , adjustments ):
118+ for c , c_raw , c_adjs in zip (ohlcv_cols , raw_ohlcv_arrays , adjustments ):
110119 out [c ] = AdjustedArray (
111120 c_raw .astype (c .dtype ),
112121 c_adjs ,
113122 c .missing_value ,
114123 )
124+
125+ for c in currency_cols :
126+ codes_1d = bytes_array_to_native_str_object_array (
127+ self .raw_price_reader .currency_codes (sids )
128+ )
129+ # XXX: Should this just be the contract of `currency_codes`?
130+ codes_1d [codes_1d == MISSING_CURRENCY_CODE ] = None
131+
132+ codes = repeat_first_axis (codes_1d , len (dates ))
133+ out [c ] = AdjustedArray (
134+ codes ,
135+ adjustments = {},
136+ missing_value = None ,
137+ )
138+
115139 return out
116140
117141 @property
@@ -169,6 +193,33 @@ def _inplace_currency_convert(self, columns, arrays, dates, sids):
169193 for arr in arrays :
170194 multiply (arr , rates , out = arr )
171195
196+ def _split_column_types (self , columns ):
197+ """Split out currency columns from OHLCV columns.
198+
199+ Parameters
200+ ----------
201+ columns : list[zipline.pipeline.data.BoundColumn]
202+ Columns to be loaded by ``load_adjusted_array``.
203+
204+ Returns
205+ -------
206+ ohlcv_columns : list[zipline.pipeline.data.BoundColumn]
207+ Price and volume columns from ``columns``.
208+ currency_columns : list[zipline.pipeline.data.BoundColumn]
209+ Currency code column from ``columns``, if present.
210+ """
211+ currency_name = EquityPricing .currency .name
212+
213+ ohlcv = []
214+ currency = []
215+ for c in columns :
216+ if c .name == currency_name :
217+ currency .append (c )
218+ else :
219+ ohlcv .append (c )
220+
221+ return ohlcv , currency
222+
172223
173224# Backwards compat alias.
174225USEquityPricingLoader = EquityPricingLoader
0 commit comments