Skip to content

Commit 660b4e8

Browse files
Joe Jevnikllllllllll
authored andcommitted
BUG: ExprData lazily constructs empty odo_kwargs dict to preserve equality in more cases
1 parent 013cd1b commit 660b4e8

File tree

2 files changed

+78
-43
lines changed

2 files changed

+78
-43
lines changed

tests/pipeline/test_blaze.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,16 +2411,17 @@ def __str__(self):
24112411
checkpoints=BadRepr('checkpoints'),
24122412
odo_kwargs={'a': 'b'},
24132413
)),
2414-
"ExprData(expr='expr', deltas='deltas',"
2415-
" checkpoints='checkpoints', odo_kwargs={'a': 'b'})",
2414+
"ExprData(expr=expr, deltas=deltas,"
2415+
" checkpoints=checkpoints, odo_kwargs={'a': 'b'})",
24162416
)
24172417

24182418
def test_exprdata_eq(self):
24192419
dshape = 'var * {sid: int64, asof_date: datetime, value: float64}'
24202420
base_expr = bz.symbol('base', dshape)
24212421
checkpoints_expr = bz.symbol('checkpoints', dshape)
24222422

2423-
odo_kwargs = {'a': 1, 'b': 2}
2423+
# use a nested dict to emulate real call sites
2424+
odo_kwargs = {'a': {'c': 1, 'd': 2}, 'b': {'e': 3, 'f': 4}}
24242425

24252426
actual = ExprData(
24262427
expr=base_expr,
@@ -2435,18 +2436,35 @@ def test_exprdata_eq(self):
24352436
odo_kwargs=odo_kwargs,
24362437
)
24372438
self.assertEqual(actual, same)
2439+
self.assertEqual(hash(actual), hash(same))
24382440

24392441
different_obs = [
2440-
actual._replace(expr=bz.symbol('not base', dshape)),
2441-
actual._replace(expr=bz.symbol('not deltas', dshape)),
2442-
actual._replace(checkpoints=bz.symbol('not checkpoints', dshape)),
2443-
actual._replace(checkpoints=None),
2444-
actual._replace(odo_kwargs={k: ~v for k, v in odo_kwargs.items()}),
2442+
actual.replace(expr=bz.symbol('not base', dshape)),
2443+
actual.replace(expr=bz.symbol('not deltas', dshape)),
2444+
actual.replace(checkpoints=bz.symbol('not checkpoints', dshape)),
2445+
actual.replace(checkpoints=None),
2446+
actual.replace(odo_kwargs={
2447+
# invert the leaf values
2448+
ok: {ik: ~iv for ik, iv in ov.items()}
2449+
for ok, ov in odo_kwargs.items()
2450+
}),
24452451
]
24462452

24472453
for different in different_obs:
24482454
self.assertNotEqual(actual, different)
24492455

2456+
actual_with_none_odo_kwargs = actual.replace(odo_kwargs=None)
2457+
same_with_none_odo_kwargs = same.replace(odo_kwargs=None)
2458+
2459+
self.assertEqual(
2460+
actual_with_none_odo_kwargs,
2461+
same_with_none_odo_kwargs,
2462+
)
2463+
self.assertEqual(
2464+
hash(actual_with_none_odo_kwargs),
2465+
hash(same_with_none_odo_kwargs),
2466+
)
2467+
24502468
def test_blaze_loader_lookup_failure(self):
24512469
class D(DataSet):
24522470
c = Column(dtype='float64')

zipline/pipeline/loaders/blaze/core.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@
138138
from __future__ import division, absolute_import
139139

140140
from abc import ABCMeta, abstractproperty
141-
from collections import namedtuple
142141
from functools import partial
143142
from itertools import count
144143
import warnings
@@ -699,12 +698,7 @@ def from_blaze(expr,
699698
getdataset = op.attrgetter('dataset')
700699

701700

702-
_expr_data_base = namedtuple(
703-
'ExprData', 'expr deltas checkpoints odo_kwargs'
704-
)
705-
706-
707-
class ExprData(_expr_data_base):
701+
class ExprData(object):
708702
"""A pair of expressions and data resources. The expressions will be
709703
computed using the resources as the starting scope.
710704
@@ -719,37 +713,66 @@ class ExprData(_expr_data_base):
719713
odo_kwargs : dict, optional
720714
The keyword arguments to forward to the odo calls internally.
721715
"""
722-
def __new__(cls,
723-
expr,
724-
deltas=None,
725-
checkpoints=None,
726-
odo_kwargs=None):
727-
return super(ExprData, cls).__new__(
728-
cls,
729-
expr,
730-
deltas,
731-
checkpoints,
732-
odo_kwargs or {},
733-
)
716+
def __init__(self,
717+
expr,
718+
deltas=None,
719+
checkpoints=None,
720+
odo_kwargs=None):
721+
self.expr = expr
722+
self.deltas = deltas
723+
self.checkpoints = checkpoints
724+
self._odo_kwargs = odo_kwargs
725+
726+
def replace(self, **kwargs):
727+
base_kwargs = {
728+
'expr': self.expr,
729+
'deltas': self.deltas,
730+
'checkpoints': self.checkpoints,
731+
'odo_kwargs': self._odo_kwargs,
732+
}
733+
invalid_kwargs = set(kwargs) - set(base_kwargs)
734+
if invalid_kwargs:
735+
raise TypeError('invalid param(s): %s' % sorted(invalid_kwargs))
736+
737+
base_kwargs.update(kwargs)
738+
return type(self)(**base_kwargs)
739+
740+
def __iter__(self):
741+
yield self.expr
742+
yield self.deltas
743+
yield self.checkpoints
744+
yield self.odo_kwargs
745+
746+
@property
747+
def odo_kwargs(self):
748+
out = self._odo_kwargs
749+
if out is None:
750+
out = {}
751+
return out
734752

735753
def __repr__(self):
736754
# If the expressions have _resources() then the repr will
737755
# drive computation so we take the str here.
738-
return repr(_expr_data_base(
739-
str(self.expr),
740-
str(self.deltas),
741-
str(self.checkpoints),
742-
self.odo_kwargs,
743-
))
756+
return (
757+
'ExprData(expr=%s, deltas=%s, checkpoints=%s, odo_kwargs=%r)' % (
758+
self.expr,
759+
self.deltas,
760+
self.checkpoints,
761+
self.odo_kwargs,
762+
)
763+
)
744764

745765
@staticmethod
746766
def _expr_eq(a, b):
747767
return a is b is None or a.isidentical(b)
748768

749769
def __hash__(self):
750-
return hash(
751-
(self.expr, self.deltas, self.checkpoints, id(self.odo_kwargs))
752-
)
770+
return hash((
771+
self.expr,
772+
self.deltas,
773+
self.checkpoints,
774+
id(self._odo_kwargs),
775+
))
753776

754777
def __eq__(self, other):
755778
if not isinstance(other, ExprData):
@@ -759,15 +782,9 @@ def __eq__(self, other):
759782
self._expr_eq(self.expr, other.expr) and
760783
self._expr_eq(self.deltas, other.deltas) and
761784
self._expr_eq(self.checkpoints, other.checkpoints) and
762-
self.odo_kwargs is other.odo_kwargs
785+
self._odo_kwargs is other._odo_kwargs
763786
)
764787

765-
def __ne__(self, other):
766-
# note: ``tuple`` (inherited from ``namedtuple``) adds a ``__ne__``,
767-
# but we want to rely on ``__eq__` to use ``isidentical`` to compare
768-
# expressions
769-
return not (self == other)
770-
771788

772789
class BlazeLoader(object):
773790
"""A PipelineLoader for datasets constructed with ``from_blaze``.

0 commit comments

Comments
 (0)