Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/zipline/pipeline/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def _add_to_graph(self, term, parents):
self.graph.add_node(term)

for dependency in term.dependencies:
self._add_to_graph(dependency, parents)
self.graph.add_edge(dependency, term)
if dependency in self.graph:
self.graph.add_edge(dependency, term)
else:
self._add_to_graph(dependency, parents)
self.graph.add_edge(dependency, term)

parents.remove(term)

Expand Down Expand Up @@ -287,6 +290,7 @@ def __init__(self, domain, terms, start_date, end_date, min_extra_rows=0):
self.domain = domain

sessions = domain.sessions()
self.node_dict = dict(self.graph.nodes())
for term in terms.values():
self.set_extra_rows(
term,
Expand All @@ -295,13 +299,17 @@ def __init__(self, domain, terms, start_date, end_date, min_extra_rows=0):
end_date,
min_extra_rows=min_extra_rows,
)
delattr(self, "node_dict")

self._assert_all_loadable_terms_specialized_to(domain)

def set_extra_rows(self, term, all_dates, start_date, end_date, min_extra_rows):
# Specialize any loadable terms before adding extra rows.
term = maybe_specialize(term, self.domain)

if self._has_been_here_before_with_more_min_extra_rows(term, min_extra_rows):
return

# A term can require that additional extra rows beyond the minimum be
# computed. This is most often used with downsampled terms, which need
# to ensure that the first date is a computation date.
Expand All @@ -320,7 +328,7 @@ def set_extra_rows(self, term, all_dates, start_date, end_date, min_extra_rows):
)
)

self._ensure_extra_rows(term, extra_rows_for_term)
self._ensure_extra_rows(term, extra_rows_for_term, min_extra_rows)

for dependency, additional_extra_rows in term.dependencies.items():
self.set_extra_rows(
Expand Down Expand Up @@ -450,12 +458,20 @@ def extra_rows(self):

return {term: self.graph.nodes[term]["extra_rows"] for term in self.graph.nodes}

def _ensure_extra_rows(self, term, N):
def _ensure_extra_rows(self, term, N, min_extra_rows):
"""
Ensure that we're going to compute at least N extra rows of `term`.
"""
attrs = dict(self.graph.nodes())[term]
attrs = self.node_dict[term]
attrs["extra_rows"] = max(N, attrs.get("extra_rows", 0))
attrs["min_extra_rows"] = max(min_extra_rows, attrs.get("min_extra_rows", 0))

def _has_been_here_before_with_more_min_extra_rows(self, term, minimum_extra_rows):
"""
Check if the term has been visited before with the same or greater number of minimum extra rows.
"""
attrs = self.node_dict[term]
return attrs.get("min_extra_rows", -1) >= minimum_extra_rows

def mask_and_dates_for_term(self, term, root_mask_term, workspace, all_dates):
"""
Expand Down
Loading