Skip to content

Commit 55ba808

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds option to train with finite dataset size.
PiperOrigin-RevId: 449469629 Change-Id: I50eeb854ce9337ab53f806ca1603d7ecb499abb3
1 parent 10cbd03 commit 55ba808

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

models/toy_sources/compression_model.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,17 @@
1818
import matplotlib.pyplot as plt
1919
import numpy as np
2020
import tensorflow as tf
21-
import tensorflow_probability as tfp
21+
22+
23+
def source_dataset(source, batch_size, seed, dataset_size=None):
24+
"""Returns a `tf.data.Dataset` of samples from `source`."""
25+
dataset = tf.data.Dataset.random(seed=seed)
26+
if dataset_size is not None:
27+
# This rounds up to multiple of batch size.
28+
batches = (dataset_size - 1) // batch_size + 1
29+
dataset = dataset.take(batches)
30+
return dataset.map(
31+
lambda seed: source.sample(batch_size, seed=tf.bitcast(seed, tf.int32)))
2232

2333

2434
class CompressionModel(tf.keras.Model, metaclass=abc.ABCMeta):
@@ -131,22 +141,15 @@ def compile(self, **kwargs):
131141
self.distortion = tf.keras.metrics.Mean(name="distortion")
132142
self.grad_rms = tf.keras.metrics.Mean(name="gradient RMS")
133143

134-
def fit(self, batch_size, validation_size, validation_batch_size, **kwargs):
135-
train_data = tf.data.Dataset.from_tensors([])
136-
train_data = train_data.repeat()
137-
train_data = train_data.map(
138-
lambda _: self.source.sample(batch_size),
139-
)
140-
141-
seed = tfp.util.SeedStream(528374623, "compression_model_fit")
142-
# This rounds up to multiple of batch size.
143-
validation_batches = (validation_size - 1) // validation_batch_size + 1
144-
validation_data = tf.data.Dataset.from_tensors([])
145-
validation_data = validation_data.repeat(validation_batches)
146-
validation_data = validation_data.map(
147-
lambda _: self.source.sample(validation_batch_size, seed=seed),
148-
)
149-
144+
def fit(self, batch_size, validation_size, validation_batch_size,
145+
train_size=None, train_seed=None, validation_seed=82913749, **kwargs):
146+
train_data = source_dataset(
147+
self.source, batch_size, train_seed, dataset_size=train_size)
148+
if train_size is not None:
149+
train_data = train_data.repeat()
150+
validation_data = source_dataset(
151+
self.source, validation_batch_size, validation_seed,
152+
dataset_size=validation_size)
150153
super().fit(
151154
train_data,
152155
validation_data=validation_data,

0 commit comments

Comments
 (0)