@@ -51,7 +51,7 @@ def approx(
5151nps_tf .dtype64 = tf .float64
5252
5353
54- @pytest .fixture (params = [nps_torch , nps_tf ], scope = "module" )
54+ @pytest .fixture (params = [nps_tf , nps_torch ], scope = "module" )
5555def nps (request ):
5656 return request .param
5757
@@ -64,14 +64,17 @@ def generate_data(
6464 n_context = 5 ,
6565 n_target = 7 ,
6666 binary = False ,
67+ dtype = None ,
6768):
68- xc = B .randn (nps .dtype , batch_size , dim_x , n_context )
69- yc = B .randn (nps .dtype , batch_size , dim_y , n_context )
70- xt = B .randn (nps .dtype , batch_size , dim_x , n_target )
71- yt = B .randn (nps .dtype , batch_size , dim_y , n_target )
69+ if dtype is None :
70+ dtype = nps .dtype
71+ xc = B .randn (dtype , batch_size , dim_x , n_context )
72+ yc = B .randn (dtype , batch_size , dim_y , n_context )
73+ xt = B .randn (dtype , batch_size , dim_x , n_target )
74+ yt = B .randn (dtype , batch_size , dim_y , n_target )
7275 if binary :
73- yc = B .cast (nps . dtype , yc >= 0 )
74- yt = B .cast (nps . dtype , yt >= 0 )
76+ yc = B .cast (dtype , yc >= 0 )
77+ yt = B .cast (dtype , yt >= 0 )
7578 return xc , yc , xt , yt
7679
7780
0 commit comments