Skip to content

Commit c1e2f9e

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Use sequential keys to make keys distribute better
PiperOrigin-RevId: 703494729
1 parent d2ad852 commit c1e2f9e

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

tensorflow_datasets/datasets/smart_buildings/smart_buildings_dataset_builder.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
# The years in the dataset.
3232
YEARS = [19, 20, 21, 22, 23, 24]
3333

34+
_MAX_EXAMPLES_PER_DAY = 24 * 12 # 288 per day
35+
3436
_REWARD_RESPONSES = [
3537
'agentRewardValue',
3638
'productivityReward',
@@ -196,18 +198,20 @@ def _generate_examples(self, path: epath.Path, year: int, pipeline):
196198

197199
return (
198200
pipeline
199-
| f'CreateDates_{year}' >> beam.Create(all_dates)
201+
| f'CreateDates_{year}' >> beam.Create(enumerate(all_dates))
200202
| f'ProcessDate_{year}' >> beam.FlatMap(process_date, path=path)
201203
| f'Reshuffle_{year}' >> beam.Reshuffle()
202204
)
203205

204206

205207
def process_date(
206-
start_time: pd.Timestamp,
208+
day_index_and_start_time: tuple[int, pd.Timestamp],
207209
path: epath.Path,
208210
) -> Iterable[tuple[int, dict[str, Any]]]:
209211
"""Process a single date."""
212+
day_index, start_time = day_index_and_start_time
210213
end_time = start_time + pd.Timedelta(hours=23)
214+
key_offset = day_index * _MAX_EXAMPLES_PER_DAY
211215

212216
reader = controller_reader.ProtoReader(path)
213217
observation_responses = reader.read_observation_responses(
@@ -230,6 +234,12 @@ def process_date(
230234
key=lambda o: to_ns_timestamp(o.start_timestamp),
231235
)
232236

237+
if len(observation_responses) > _MAX_EXAMPLES_PER_DAY:
238+
raise ValueError(
239+
f'Too many observation responses for date {start_time}: '
240+
f'{len(observation_responses)} > {_MAX_EXAMPLES_PER_DAY}'
241+
)
242+
233243
for i in range(len(observation_responses)):
234244
observation_response = json_format.MessageToDict(observation_responses[i])
235245
action_response = json_format.MessageToDict(action_responses[i])
@@ -256,7 +266,7 @@ def process_date(
256266
reward_response[val] = -1 # sentinal value
257267

258268
beam.metrics.Metrics.counter(f'date_{start_time}', 'example_count').inc()
259-
key = int(f'{start_time.toordinal()}{i:05d}')
269+
key = key_offset + i
260270
yield key, {
261271
'observation': observation_response,
262272
'action': action_response,

0 commit comments

Comments
 (0)