Skip to content

Commit 25510b8

Browse files
author
Andrew Daniels
authored
Merge pull request #2420 from quantopian/adjusted-array-map
ENH: Add AdjustedArray.map_labels
2 parents 650c23c + ffa1023 commit 25510b8

File tree

3 files changed

+65
-7
lines changed

3 files changed

+65
-7
lines changed

tests/pipeline/test_adjusted_array.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,41 @@ def test_inspect(self):
707707
)
708708
got = adj_array.inspect()
709709
self.assertEqual(expected, got)
710+
711+
def test_update_labels(self):
712+
data = array([
713+
['aaa', 'bbb', 'ccc'],
714+
['ddd', 'eee', 'fff'],
715+
['ggg', 'hhh', 'iii'],
716+
['jjj', 'kkk', 'lll'],
717+
['mmm', 'nnn', 'ooo'],
718+
])
719+
label_array = LabelArray(data, missing_value='')
720+
721+
adj_array = AdjustedArray(
722+
data=label_array,
723+
adjustments={4: [ObjectOverwrite(2, 3, 0, 0, 'ppp')]},
724+
missing_value='',
725+
)
726+
727+
expected_data = array([
728+
['aaa-foo', 'bbb-foo', 'ccc-foo'],
729+
['ddd-foo', 'eee-foo', 'fff-foo'],
730+
['ggg-foo', 'hhh-foo', 'iii-foo'],
731+
['jjj-foo', 'kkk-foo', 'lll-foo'],
732+
['mmm-foo', 'nnn-foo', 'ooo-foo'],
733+
])
734+
expected_label_array = LabelArray(expected_data, missing_value='')
735+
736+
expected_adj_array = AdjustedArray(
737+
data=expected_label_array,
738+
adjustments={4: [ObjectOverwrite(2, 3, 0, 0, 'ppp-foo')]},
739+
missing_value='',
740+
)
741+
742+
adj_array.update_labels(lambda x: x + '-foo')
743+
744+
# Check that the mapped AdjustedArray has the expected baseline
745+
# values and adjustment values.
746+
check_arrays(adj_array.data, expected_adj_array.data)
747+
self.assertEqual(adj_array.adjustments, expected_adj_array.adjustments)

zipline/lib/adjusted_array.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
uint32,
1414
uint8,
1515
)
16+
from six import iteritems
1617
from zipline.errors import (
1718
WindowLengthNotPositive,
1819
WindowLengthTooLong,
@@ -164,7 +165,7 @@ def __init__(self, data, adjustments, missing_value):
164165
self.adjustments = adjustments
165166
self.missing_value = missing_value
166167

167-
@lazyval
168+
@property
168169
def data(self):
169170
"""
170171
The data stored in this array.
@@ -237,6 +238,25 @@ def inspect(self):
237238
adjustments=self.adjustments,
238239
)
239240

241+
def update_labels(self, func):
242+
"""
243+
Map a function over baseline and adjustment values in place.
244+
245+
Note that the baseline data values must be a LabelArray.
246+
"""
247+
if not isinstance(self.data, LabelArray):
248+
raise TypeError(
249+
'update_labels only supported if data is of type LabelArray.'
250+
)
251+
252+
# Map the baseline values.
253+
self._data = self._data.map(func)
254+
255+
# Map each of the adjustments.
256+
for _, row_adjustments in iteritems(self.adjustments):
257+
for adjustment in row_adjustments:
258+
adjustment.value = func(adjustment.value)
259+
240260

241261
def ensure_adjusted_array(ndarray_or_adjusted_array, missing_value):
242262
if isinstance(ndarray_or_adjusted_array, AdjustedArray):

zipline/lib/adjustment.pxd

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ cdef class Float64Adjustment(Adjustment):
4242
"""
4343
Base class for adjustments that operate on Float64 data.
4444
"""
45-
cdef readonly np.float64_t value
45+
cdef public np.float64_t value
4646

4747

4848
cdef class Float64Multiply(Float64Adjustment):
@@ -146,7 +146,7 @@ cdef class Float641DArrayOverwrite(ArrayAdjustment):
146146
[ 4., 16., 17., 18., 19.],
147147
[ 20., 21., 22., 23., 24.]])
148148
"""
149-
cdef readonly np.float64_t[:] values
149+
cdef public np.float64_t[:] values
150150
cpdef mutate(self, np.float64_t[:, :] data)
151151

152152

@@ -184,7 +184,7 @@ cdef class Datetime641DArrayOverwrite(ArrayAdjustment):
184184
[False, False, False],
185185
[False, True, True]], dtype=bool)
186186
"""
187-
cdef readonly np.int64_t[:] values
187+
cdef public np.int64_t[:] values
188188
cpdef mutate(self, np.int64_t[:, :] data)
189189

190190

@@ -226,7 +226,7 @@ cdef class _Int64Adjustment(Adjustment):
226226
This is private because we never actually operate on integers as data, but
227227
we use integer arrays to represent datetime and timedelta data.
228228
"""
229-
cdef readonly np.int64_t value
229+
cdef public np.int64_t value
230230

231231

232232
cdef class Int64Overwrite(_Int64Adjustment):
@@ -312,7 +312,7 @@ cdef class _ObjectAdjustment(Adjustment):
312312
We use only this for categorical data, where our data buffer is an array of
313313
indices into an array of unique Python string objects.
314314
"""
315-
cdef readonly object value
315+
cdef public object value
316316

317317

318318
cdef class ObjectOverwrite(_ObjectAdjustment):
@@ -330,7 +330,7 @@ cdef class BooleanAdjustment(Adjustment):
330330
instead we work with uint8 values everywhere, and we do validation/coercion
331331
at API boundaries.
332332
"""
333-
cdef readonly np.uint8_t value
333+
cdef public np.uint8_t value
334334

335335

336336
cdef class BooleanOverwrite(BooleanAdjustment):

0 commit comments

Comments
 (0)