-
Notifications
You must be signed in to change notification settings - Fork 4.9k
ENH: Add extension point for dataset-loader associations #2246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
7d41d4c
274a762
450b7c1
65f582f
cdbf1eb
aac863e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| from zipline.pipeline import USEquityPricingLoader | ||
| from zipline.pipeline.data import ( | ||
| Column, | ||
| DataSet, | ||
| BoundColumn, | ||
| USEquityPricing, | ||
| ) | ||
| from zipline.pipeline.dispatcher import ( | ||
| PipelineDispatcher, | ||
| clear_all_associations, | ||
| ) | ||
| from zipline.pipeline.loaders.base import PipelineLoader | ||
| from zipline.pipeline.sentinels import NotSpecified | ||
| from zipline.testing import ZiplineTestCase | ||
| from zipline.testing.fixtures import WithAdjustmentReader | ||
| from zipline.testing.predicates import ( | ||
| assert_raises_str, | ||
| assert_equal, | ||
| ) | ||
| from zipline.utils.numpy_utils import float64_dtype | ||
|
|
||
|
|
||
| class FakeDataSet(DataSet): | ||
| test_col = Column(float64_dtype) | ||
|
|
||
|
|
||
| class FakeColumn(BoundColumn): | ||
| pass | ||
|
|
||
|
|
||
| class FakePipelineLoader(PipelineLoader): | ||
|
|
||
| def load_adjusted_array(self, columns, dates, assets, mask): | ||
| pass | ||
|
|
||
|
|
||
| class UnrelatedType(object): | ||
| pass | ||
|
|
||
|
|
||
| class PipelineDispatcherTestCase(WithAdjustmentReader, ZiplineTestCase): | ||
|
|
||
| @classmethod | ||
| def init_class_fixtures(cls): | ||
| super(PipelineDispatcherTestCase, cls).init_class_fixtures() | ||
| cls.default_pipeline_loader = USEquityPricingLoader( | ||
| cls.bcolz_equity_daily_bar_reader, | ||
| cls.adjustment_reader, | ||
| ) | ||
|
|
||
| cls.add_class_callback(clear_all_associations) | ||
|
|
||
| def test_load_not_registered(self): | ||
| fake_col_instance = FakeColumn( | ||
| float64_dtype, | ||
| NotSpecified, | ||
| FakeDataSet, | ||
| 'test', | ||
| None, | ||
| {}, | ||
| ) | ||
| fake_pl_instance = FakePipelineLoader() | ||
| pipeline_dispatcher = PipelineDispatcher( | ||
| column_loaders={fake_col_instance: fake_pl_instance} | ||
| ) | ||
|
|
||
| expected_dict = {fake_col_instance: fake_pl_instance} | ||
| assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
|
||
| msg = "No pipeline loader registered for %s" % USEquityPricing.close | ||
| with assert_raises_str(LookupError, msg): | ||
| pipeline_dispatcher(USEquityPricing.close) | ||
|
|
||
| def test_register_unrelated_type(self): | ||
| pipeline_dispatcher = PipelineDispatcher() | ||
| fake_pl_instance = FakePipelineLoader() | ||
|
|
||
| msg = "Data provided is neither a BoundColumn nor a DataSet" | ||
| with assert_raises_str(TypeError, msg): | ||
| pipeline_dispatcher.register(UnrelatedType, fake_pl_instance) | ||
|
|
||
| def test_passive_registration(self): | ||
| pipeline_dispatcher = PipelineDispatcher() | ||
| assert_equal(pipeline_dispatcher.column_loaders, {}) | ||
|
|
||
| # imitate user registering a custom pipeline loader first | ||
| custom_loader = FakePipelineLoader() | ||
| pipeline_dispatcher.register(USEquityPricing.close, custom_loader) | ||
| expected_dict = {USEquityPricing.close: custom_loader} | ||
| assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
|
||
| # now check that trying to register something else won't change it | ||
| pipeline_dispatcher.register( | ||
| USEquityPricing.close, self.default_pipeline_loader | ||
| ) | ||
| assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
|
||
| def test_normal_ops(self): | ||
| fake_loader_instance = FakePipelineLoader() | ||
| fake_col_instance = FakeColumn( | ||
| float64_dtype, | ||
| NotSpecified, | ||
| FakeDataSet, | ||
| 'test', | ||
| None, | ||
| {}, | ||
| ) | ||
| pipeline_dispatcher = PipelineDispatcher( | ||
| column_loaders={ | ||
| fake_col_instance: fake_loader_instance | ||
| }, | ||
| dataset_loaders={ | ||
| FakeDataSet: fake_loader_instance | ||
| } | ||
| ) | ||
| expected_dict = { | ||
| fake_col_instance: fake_loader_instance, | ||
| FakeDataSet.test_col: fake_loader_instance, | ||
| } | ||
| assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
|
||
| pipeline_dispatcher.register( | ||
| USEquityPricing.close, fake_loader_instance | ||
| ) | ||
| expected_dict = { | ||
| fake_col_instance: fake_loader_instance, | ||
| FakeDataSet.test_col: fake_loader_instance, | ||
| USEquityPricing.close: fake_loader_instance, | ||
| } | ||
| assert_equal(pipeline_dispatcher.column_loaders, expected_dict) | ||
|
|
||
| assert_equal( | ||
| pipeline_dispatcher(fake_col_instance), fake_loader_instance | ||
| ) | ||
| assert_equal( | ||
| pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance | ||
| ) | ||
| assert_equal( | ||
| pipeline_dispatcher(USEquityPricing.close), fake_loader_instance | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| from zipline.pipeline.data import BoundColumn, DataSet | ||
| from zipline.pipeline.loaders.base import PipelineLoader | ||
| from zipline.utils.compat import mappingproxy | ||
|
|
||
|
|
||
| class PipelineDispatcher(object): | ||
| """Helper class for building a dispatching function for a PipelineLoader. | ||
| Parameters | ||
| ---------- | ||
| column_loaders : dict[BoundColumn -> PipelineLoader] | ||
| Map from columns to pipeline loader for those columns. | ||
| dataset_loaders : dict[DataSet -> PipelineLoader] | ||
| Map from datasets to pipeline loader for those datasets. | ||
| """ | ||
| def __init__(self, column_loaders=None, dataset_loaders=None): | ||
|
||
| self._column_loaders = column_loaders if column_loaders \ | ||
| is not None else {} | ||
| self.column_loaders = mappingproxy(self._column_loaders) | ||
| if dataset_loaders is not None: | ||
| for dataset, pl in dataset_loaders.items(): | ||
| self.register(dataset, pl) | ||
|
||
|
|
||
| def __call__(self, column): | ||
| if column in self._column_loaders: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's generally more idiomatic in python to write this as: c.f. https://blogs.msdn.microsoft.com/pythonengineering/2016/06/29/idiomatic-python-eafp-versus-lbyl/, for example |
||
| return self._column_loaders[column] | ||
| else: | ||
| raise LookupError("No pipeline loader registered for %s" % column) | ||
|
|
||
| def register(self, data, pl): | ||
| """Register a given PipelineLoader to a column or columns of a dataset | ||
| Parameters | ||
| ---------- | ||
| data : BoundColumn or DataSet | ||
| The column or dataset for which to register the PipelineLoader | ||
| pl : PipelineLoader | ||
| The PipelineLoader to register for the column or dataset columns | ||
| """ | ||
| assert isinstance(pl, PipelineLoader) | ||
|
||
|
|
||
| # make it so that in either case nothing will happen if the column is | ||
| # already registered, allowing users to register their own loaders | ||
| # early on in extensions | ||
| if isinstance(data, BoundColumn): | ||
| if data not in self._column_loaders: | ||
|
||
| self._column_loaders[data] = pl | ||
| elif issubclass(data, DataSet): | ||
| for c in data.columns: | ||
| if c not in self._column_loaders: | ||
| self._column_loaders[c] = pl | ||
| else: | ||
| raise TypeError("Data provided is neither a BoundColumn " | ||
| "nor a DataSet") | ||
|
|
||
| def clear(self): | ||
| """Unregisters all dataset-loader associations""" | ||
| self._column_loaders.clear() | ||
|
|
||
|
|
||
| global_pipeline_dispatcher = PipelineDispatcher() | ||
|
||
| register_pipeline_loader = global_pipeline_dispatcher.register | ||
| clear_all_associations = global_pipeline_dispatcher.clear | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import warnings | ||
|
|
||
| import click | ||
|
|
||
| try: | ||
| from pygments import highlight | ||
| from pygments.lexers import PythonLexer | ||
|
|
@@ -21,6 +22,10 @@ | |
| from zipline.data.data_portal import DataPortal | ||
| from zipline.finance import metrics | ||
| from zipline.finance.trading import TradingEnvironment | ||
| from zipline.pipeline import ( | ||
| register_pipeline_loader, | ||
| global_pipeline_dispatcher, | ||
| ) | ||
| from zipline.pipeline.data import USEquityPricing | ||
| from zipline.pipeline.loaders import USEquityPricingLoader | ||
| from zipline.utils.factory import create_simulation_parameters | ||
|
|
@@ -166,12 +171,13 @@ def _run(handle_data, | |
| bundle_data.adjustment_reader, | ||
| ) | ||
|
|
||
| # we register our default loader last, after any loaders from users | ||
| # have been registered via extensions | ||
| register_pipeline_loader(USEquityPricing, pipeline_loader) | ||
|
|
||
| def choose_loader(column): | ||
|
||
| if column in USEquityPricing.columns: | ||
| return pipeline_loader | ||
| raise ValueError( | ||
| "No PipelineLoader registered for column %s." % column | ||
| ) | ||
| return global_pipeline_dispatcher(column) | ||
|
|
||
| else: | ||
| env = TradingEnvironment(environ=environ) | ||
| choose_loader = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked carefully at these tests yet, but in general, it's not expected that anyone should ever construct BoundColumn instances explicitly. The usual way to get a bound column is to do: