diff --git a/tests/pipeline/test_dispatcher.py b/tests/pipeline/test_dispatcher.py new file mode 100644 index 0000000000..6937dfd5a6 --- /dev/null +++ b/tests/pipeline/test_dispatcher.py @@ -0,0 +1,93 @@ +from zipline.pipeline.data import ( + Column, + DataSet, + BoundColumn, + USEquityPricing, +) +from zipline.pipeline.dispatcher import PipelineDispatcher +from zipline.pipeline.loaders.base import PipelineLoader +from zipline.pipeline.sentinels import NotSpecified +from zipline.testing import ZiplineTestCase +from zipline.testing.predicates import ( + assert_raises_str, + assert_equal, +) +from zipline.utils.numpy_utils import float64_dtype + + +class FakeDataSet(DataSet): + test_col = Column(float64_dtype) + + +class FakeColumn(BoundColumn): + pass + + +class FakePipelineLoader(PipelineLoader): + + def load_adjusted_array(self, columns, dates, assets, mask): + pass + + +class UnrelatedType(object): + pass + + +class PipelineDispatcherTestCase(ZiplineTestCase): + + def test_load_not_registered(self): + fake_col_instance = FakeColumn( + float64_dtype, + NotSpecified, + FakeDataSet, + 'test', + None, + {}, + ) + fake_pl_instance = FakePipelineLoader() + pipeline_dispatcher = PipelineDispatcher( + {fake_col_instance: fake_pl_instance} + ) + + expected_dict = {fake_col_instance: fake_pl_instance} + assert_equal(pipeline_dispatcher._column_loaders, expected_dict) + + msg = "No pipeline loader registered for %s" % USEquityPricing.close + with assert_raises_str(LookupError, msg): + pipeline_dispatcher(USEquityPricing.close) + + def test_register_unrelated_type(self): + fake_pl_instance = FakePipelineLoader() + + msg = "%s is neither a BoundColumn nor a DataSet" % UnrelatedType + with assert_raises_str(TypeError, msg): + PipelineDispatcher( + {UnrelatedType: fake_pl_instance} + ) + + def test_normal_ops(self): + fake_loader_instance = FakePipelineLoader() + fake_col_instance = FakeColumn( + float64_dtype, + NotSpecified, + FakeDataSet, + 'test', + None, + {}, + ) + pipeline_dispatcher = PipelineDispatcher({ + fake_col_instance: fake_loader_instance, + FakeDataSet: fake_loader_instance + }) + + expected_dict = { + fake_col_instance: fake_loader_instance, + FakeDataSet.test_col: fake_loader_instance, + } + assert_equal(pipeline_dispatcher._column_loaders, expected_dict) + assert_equal( + pipeline_dispatcher(fake_col_instance), fake_loader_instance + ) + assert_equal( + pipeline_dispatcher(FakeDataSet.test_col), fake_loader_instance + ) diff --git a/zipline/pipeline/__init__.py b/zipline/pipeline/__init__.py index a169256bb9..5ed5b2418c 100644 --- a/zipline/pipeline/__init__.py +++ b/zipline/pipeline/__init__.py @@ -9,6 +9,7 @@ from .graph import ExecutionPlan, TermGraph from .pipeline import Pipeline from .loaders import USEquityPricingLoader +from .dispatcher import PipelineDispatcher def engine_from_files(daily_bar_path, @@ -60,4 +61,5 @@ def engine_from_files(daily_bar_path, 'SimplePipelineEngine', 'Term', 'TermGraph', + 'PipelineDispatcher', ) diff --git a/zipline/pipeline/dispatcher.py b/zipline/pipeline/dispatcher.py new file mode 100644 index 0000000000..dbd95c45b3 --- /dev/null +++ b/zipline/pipeline/dispatcher.py @@ -0,0 +1,28 @@ +from zipline.pipeline.data import BoundColumn, DataSet + + +class PipelineDispatcher(object): + """Helper class for building a dispatching function for a PipelineLoader. + + Parameters + ---------- + loaders : dict[BoundColumn or DataSet -> PipelineLoader] + Map from columns or datasets to pipeline loader for those objects. + """ + def __init__(self, loaders): + self._column_loaders = {} + for data, pl in loaders.items(): + if isinstance(data, BoundColumn): + self._column_loaders[data] = pl + elif issubclass(data, DataSet): + for c in data.columns: + self._column_loaders[c] = pl + else: + raise TypeError("%s is neither a BoundColumn " + "nor a DataSet" % data) + + def __call__(self, column): + if column in self._column_loaders: + return self._column_loaders[column] + else: + raise LookupError("No pipeline loader registered for %s" % column) diff --git a/zipline/utils/run_algo.py b/zipline/utils/run_algo.py index 80ece34b78..8f11ce47cc 100644 --- a/zipline/utils/run_algo.py +++ b/zipline/utils/run_algo.py @@ -2,7 +2,6 @@ import os import sys import warnings - try: from pygments import highlight from pygments.lexers import PythonLexer @@ -18,6 +17,7 @@ from zipline.data.loader import load_market_data from zipline.data.data_portal import DataPortal from zipline.finance import metrics +from zipline.pipeline import PipelineDispatcher from zipline.finance.trading import SimulationParameters from zipline.pipeline.data import USEquityPricing from zipline.pipeline.loaders import USEquityPricingLoader @@ -72,6 +72,7 @@ def _run(handle_data, metrics_set, local_namespace, environ, + pipeline_dispatcher, blotter, benchmark_returns): """Run a backtest for the given algorithm. @@ -155,16 +156,14 @@ def _run(handle_data, adjustment_reader=bundle_data.adjustment_reader, ) - pipeline_loader = USEquityPricingLoader( - bundle_data.equity_daily_bar_reader, - bundle_data.adjustment_reader, - ) - - def choose_loader(column): - if column in USEquityPricing.columns: - return pipeline_loader - raise ValueError( - "No PipelineLoader registered for column %s." % column + if pipeline_dispatcher is None: + # create the default dispatcher + pipeline_loader = USEquityPricingLoader( + bundle_data.equity_daily_bar_reader, + bundle_data.adjustment_reader, + ) + pipeline_dispatcher = PipelineDispatcher( + {USEquityPricing: pipeline_loader} ) if isinstance(metrics_set, six.string_types): @@ -181,8 +180,8 @@ def choose_loader(column): perf = TradingAlgorithm( namespace=namespace, + get_pipeline_loader=pipeline_dispatcher, data_portal=data, - get_pipeline_loader=choose_loader, trading_calendar=trading_calendar, sim_params=SimulationParameters( start_session=start, @@ -225,7 +224,7 @@ def load_extensions(default, extensions, strict, environ, reload=False): ---------- default : bool Load the default exension (~/.zipline/extension.py)? - extension : iterable[str] + extensions : iterable[str] The paths to the extensions to load. If the path ends in ``.py`` it is treated as a script and executed. If it does not end in ``.py`` it is treated as a module to be imported. @@ -285,6 +284,7 @@ def run_algorithm(start, extensions=(), strict_extensions=True, environ=os.environ, + pipeline_dispatcher=None, blotter='default'): """ Run a trading algorithm. @@ -338,6 +338,9 @@ def run_algorithm(start, environ : mapping[str -> str], optional The os environment to use. Many extensions use this to get parameters. This defaults to ``os.environ``. + pipeline_dispatcher : PipelineDispatcher, optional + The pipeline dispatcher to use, which should contains any column-to- + loader associations necessary to run the trading algorithm blotter : str or zipline.finance.blotter.Blotter, optional Blotter to use with this algorithm. If passed as a string, we look for a blotter construction function registered with @@ -376,6 +379,7 @@ def run_algorithm(start, metrics_set=metrics_set, local_namespace=False, environ=environ, + pipeline_dispatcher=pipeline_dispatcher, blotter=blotter, benchmark_returns=benchmark_returns, )