11import itertools
22
3- import h5py
43import pandas as pd
54import numpy as np
65
76from zipline .data .fx import DEFAULT_FX_RATE
8- from zipline .data .fx .hdf5 import HDF5FXRateReader , HDF5FXRateWriter
97
108from zipline .testing .predicates import assert_equal
119import zipline .testing .fixtures as zp_fixtures
@@ -59,33 +57,6 @@ def make_fx_rates(cls, fields, currencies, sessions):
5957 'tokyo_mid' : cls .tokyo_mid_rates ,
6058 }
6159
62- @classmethod
63- def get_expected_rate_scalar (cls , rate , quote , base , dt ):
64- """Get the expected FX rate for the given scalar coordinates.
65- """
66- if rate == DEFAULT_FX_RATE :
67- rate = cls .FX_RATES_DEFAULT_RATE
68-
69- col = cls .fx_rates [rate ][quote ][base ]
70- # PERF: We call this function a lot in this suite, and get_loc is
71- # surprisingly expensive, so optimizing it has a meaningful impact on
72- # overall suite performance. See test_fast_get_loc_ffilled_for
73- # assurance that this behaves the same as get_loc.
74- ix = fast_get_loc_ffilled (col .index .values , dt .asm8 )
75- return col .values [ix ]
76-
77- @classmethod
78- def get_expected_rates (cls , rate , quote , bases , dts ):
79- """Get an array of expected FX rates for the given indices.
80- """
81- out = np .empty ((len (dts ), len (bases )), dtype = 'float64' )
82-
83- for i , dt in enumerate (dts ):
84- for j , base in enumerate (bases ):
85- out [i , j ] = cls .get_expected_rate_scalar (rate , quote , base , dt )
86-
87- return out
88-
8960 @property
9061 def reader (self ):
9162 raise NotImplementedError ("Must be implemented by test suite." )
@@ -110,7 +81,7 @@ def test_scalar_lookup(self):
11081 if quote == base :
11182 assert_equal (result_scalar , 1.0 )
11283
113- expected = self .get_expected_rate_scalar (rate , quote , base , dt )
84+ expected = self .get_expected_fx_rate_scalar (rate , quote , base , dt )
11485 assert_equal (result_scalar , expected )
11586
11687 def test_vectorized_lookup (self ):
@@ -135,7 +106,7 @@ def test_vectorized_lookup(self):
135106 # ...And check that we get the expected result when querying
136107 # for those dates/currencies.
137108 result = self .reader .get_rates (rate , quote , bases , dts )
138- expected = self .get_expected_rates (rate , quote , bases , dts )
109+ expected = self .get_expected_fx_rates (rate , quote , bases , dts )
139110
140111 assert_equal (result , expected )
141112
@@ -205,47 +176,14 @@ class HDF5FXReaderTestCase(zp_fixtures.WithTmpDir,
205176 @classmethod
206177 def init_class_fixtures (cls ):
207178 super (HDF5FXReaderTestCase , cls ).init_class_fixtures ()
208-
209179 path = cls .tmpdir .getpath ('fx_rates.h5' )
210-
211- # Set by WithFXRates.
212- sessions = cls .fx_rates_sessions
213-
214- # Write in-memory data to h5 file.
215- with h5py .File (path , 'w' ) as h5_file :
216- writer = HDF5FXRateWriter (h5_file )
217- fx_data = ((rate , quote , quote_frame .values )
218- for rate , rate_dict in cls .fx_rates .items ()
219- for quote , quote_frame in rate_dict .items ())
220-
221- writer .write (
222- dts = sessions .values ,
223- currencies = np .array (cls .FX_RATES_CURRENCIES , dtype = 'S3' ),
224- data = fx_data ,
225- )
226-
227- h5_file = cls .enter_class_context (h5py .File (path , 'r' ))
228- cls .h5_fx_reader = HDF5FXRateReader (
229- h5_file ,
230- default_rate = cls .FX_RATES_DEFAULT_RATE ,
231- )
180+ cls .h5_fx_reader = cls .write_h5_fx_rates (path )
232181
233182 @property
234183 def reader (self ):
235184 return self .h5_fx_reader
236185
237186
238- def fast_get_loc_ffilled (dts , dt ):
239- """
240- Equivalent to dts.get_loc(dt, method='ffill'), but with reasonable
241- microperformance.
242- """
243- ix = dts .searchsorted (dt , side = 'right' ) - 1
244- if ix < 0 :
245- raise KeyError (dt )
246- return ix
247-
248-
249187class FastGetLocTestCase (zp_fixtures .ZiplineTestCase ):
250188
251189 def test_fast_get_loc_ffilled (self ):
@@ -258,12 +196,12 @@ def test_fast_get_loc_ffilled(self):
258196 ])
259197
260198 for dt in pd .date_range ('2014-01-02' , '2014-01-08' ):
261- result = fast_get_loc_ffilled (dts .values , dt .asm8 )
199+ result = zp_fixtures . fast_get_loc_ffilled (dts .values , dt .asm8 )
262200 expected = dts .get_loc (dt , method = 'ffill' )
263201 assert_equal (result , expected )
264202
265203 with self .assertRaises (KeyError ):
266204 dts .get_loc (pd .Timestamp ('2014-01-01' ), method = 'ffill' )
267205
268206 with self .assertRaises (KeyError ):
269- fast_get_loc_ffilled (dts , pd .Timestamp ('2014-01-01' ))
207+ zp_fixtures . fast_get_loc_ffilled (dts , pd .Timestamp ('2014-01-01' ))
0 commit comments