Skip to content

Commit 650c23c

Browse files
Joe Jevnikllllllllll
authored andcommitted
BUG: group exprdata by the expressions, not identity
1 parent 6975deb commit 650c23c

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

tests/pipeline/test_blaze.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,6 +2415,38 @@ def __str__(self):
24152415
" checkpoints='checkpoints', odo_kwargs={'a': 'b'})",
24162416
)
24172417

2418+
def test_exprdata_eq(self):
2419+
dshape = 'var * {sid: int64, asof_date: datetime, value: float64}'
2420+
base_expr = bz.symbol('base', dshape)
2421+
checkpoints_expr = bz.symbol('checkpoints', dshape)
2422+
2423+
odo_kwargs = {'a': 1, 'b': 2}
2424+
2425+
actual = ExprData(
2426+
expr=base_expr,
2427+
deltas=None,
2428+
checkpoints=checkpoints_expr,
2429+
odo_kwargs=odo_kwargs,
2430+
)
2431+
same = ExprData(
2432+
expr=base_expr,
2433+
deltas=None,
2434+
checkpoints=checkpoints_expr,
2435+
odo_kwargs=odo_kwargs,
2436+
)
2437+
self.assertEqual(actual, same)
2438+
2439+
different_obs = [
2440+
actual._replace(expr=bz.symbol('not base', dshape)),
2441+
actual._replace(expr=bz.symbol('not deltas', dshape)),
2442+
actual._replace(checkpoints=bz.symbol('not checkpoints', dshape)),
2443+
actual._replace(checkpoints=None),
2444+
actual._replace(odo_kwargs={k: ~v for k, v in odo_kwargs.items()}),
2445+
]
2446+
2447+
for different in different_obs:
2448+
self.assertNotEqual(actual, different)
2449+
24182450
def test_blaze_loader_lookup_failure(self):
24192451
class D(DataSet):
24202452
c = Column(dtype='float64')

zipline/pipeline/loaders/blaze/core.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -729,25 +729,42 @@ def __new__(cls,
729729
expr,
730730
deltas,
731731
checkpoints,
732-
odo_kwargs or {},
732+
frozenset((odo_kwargs or {}).items()),
733733
)
734734

735735
def __repr__(self):
736736
# If the expressions have _resources() then the repr will
737737
# drive computation so we take the str here.
738-
cls = type(self)
739-
return super(ExprData, cls).__repr__(cls(
738+
return repr(_expr_data_base(
740739
str(self.expr),
741740
str(self.deltas),
742741
str(self.checkpoints),
743-
self.odo_kwargs,
742+
dict(self.odo_kwargs),
744743
))
745744

745+
@staticmethod
746+
def _expr_eq(a, b):
747+
return a is b is None or a.isidentical(b)
748+
746749
def __hash__(self):
747-
return id(self)
750+
return super(ExprData, self).__hash__()
748751

749752
def __eq__(self, other):
750-
return self is other
753+
if not isinstance(other, ExprData):
754+
return NotImplemented
755+
756+
return (
757+
self._expr_eq(self.expr, other.expr) and
758+
self._expr_eq(self.deltas, other.deltas) and
759+
self._expr_eq(self.checkpoints, other.checkpoints) and
760+
self.odo_kwargs == other.odo_kwargs
761+
)
762+
763+
def __ne__(self, other):
764+
# note: ``tuple`` (inherited from ``namedtuple``) adds a ``__ne__``,
765+
# but we want to rely on ``__eq__` to use ``isidentical`` to compare
766+
# expressions
767+
return not (self == other)
751768

752769

753770
class BlazeLoader(object):
@@ -907,6 +924,7 @@ def _load_dataset(self,
907924
)
908925

909926
expr, deltas, checkpoints, odo_kwargs = expr_data
927+
odo_kwargs = dict(odo_kwargs)
910928

911929
have_sids = (first(columns).dataset.ndim == 2)
912930
added_query_fields = {AD_FIELD_NAME, TS_FIELD_NAME} | (

0 commit comments

Comments
 (0)