@@ -53,10 +53,13 @@ def _drop_input_dims(ds, input_dims, suffix='_input'):
53
53
out .coords [dim ] = newdim , ds [dim ].data , ds [dim ].attrs
54
54
return out
55
55
56
- def _maybe_stack_batch_dims (ds , input_dims , stacked_dim_name = 'sample' ):
56
+ def _maybe_stack_batch_dims (ds , input_dims , squeeze_batch_dim , stacked_dim_name = 'sample' ):
57
57
batch_dims = [d for d in ds .dims if d not in input_dims ]
58
58
if len (batch_dims ) < 2 :
59
- return ds .expand_dims (stacked_dim_name , 0 )
59
+ if (squeeze_batch_dim ):
60
+ return ds
61
+ else :
62
+ return ds .expand_dims (stacked_dim_name , 0 )
60
63
ds_stack = ds .stack (** {stacked_dim_name : batch_dims })
61
64
# ensure correct order
62
65
dim_order = (stacked_dim_name ,) + tuple (input_dims )
@@ -89,6 +92,10 @@ class BatchGenerator:
89
92
preload_batch : bool, optional
90
93
If ``True``, each batch will be loaded into memory before reshaping /
91
94
processing, triggering any dask arrays to be computed.
95
+ squeeze_batch_dim : bool, optional
96
+ If ``False", each batch's dataset will have a "batch" dimension of size 1
97
+ prepended to the array. This functionality is useful for interoperability
98
+ with Keras / Tensorflow.
92
99
93
100
Yields
94
101
------
@@ -104,6 +111,7 @@ def __init__(
104
111
batch_dims = {},
105
112
concat_input_dims = False ,
106
113
preload_batch = True ,
114
+ squeeze_batch_dim = True
107
115
):
108
116
109
117
self .ds = _as_xarray_dataset (ds )
@@ -113,6 +121,7 @@ def __init__(
113
121
self .batch_dims = OrderedDict (batch_dims )
114
122
self .concat_input_dims = concat_input_dims
115
123
self .preload_batch = preload_batch
124
+ self .squeeze_batch_dim = squeeze_batch_dim
116
125
117
126
def __iter__ (self ):
118
127
for ds_batch in self ._iterate_batch_dims (self .ds ):
@@ -131,11 +140,11 @@ def __iter__(self):
131
140
new_input_dims = [
132
141
dim + new_dim_suffix for dim in self .input_dims
133
142
]
134
- yield _maybe_stack_batch_dims (dsc , new_input_dims )
143
+ yield _maybe_stack_batch_dims (dsc , new_input_dims , self . squeeze_batch_dim )
135
144
else :
136
145
for ds_input in input_generator :
137
146
yield _maybe_stack_batch_dims (
138
- ds_input , list (self .input_dims )
147
+ ds_input , list (self .input_dims ), self . squeeze_batch_dim
139
148
)
140
149
141
150
def _iterate_batch_dims (self , ds ):
0 commit comments