|
18 | 18 | import matplotlib.pyplot as plt
|
19 | 19 | import numpy as np
|
20 | 20 | import tensorflow as tf
|
21 |
| -import tensorflow_probability as tfp |
| 21 | + |
| 22 | + |
| 23 | +def source_dataset(source, batch_size, seed, dataset_size=None): |
| 24 | + """Returns a `tf.data.Dataset` of samples from `source`.""" |
| 25 | + dataset = tf.data.Dataset.random(seed=seed) |
| 26 | + if dataset_size is not None: |
| 27 | + # This rounds up to multiple of batch size. |
| 28 | + batches = (dataset_size - 1) // batch_size + 1 |
| 29 | + dataset = dataset.take(batches) |
| 30 | + return dataset.map( |
| 31 | + lambda seed: source.sample(batch_size, seed=tf.bitcast(seed, tf.int32))) |
22 | 32 |
|
23 | 33 |
|
24 | 34 | class CompressionModel(tf.keras.Model, metaclass=abc.ABCMeta):
|
@@ -131,22 +141,15 @@ def compile(self, **kwargs):
|
131 | 141 | self.distortion = tf.keras.metrics.Mean(name="distortion")
|
132 | 142 | self.grad_rms = tf.keras.metrics.Mean(name="gradient RMS")
|
133 | 143 |
|
134 |
| - def fit(self, batch_size, validation_size, validation_batch_size, **kwargs): |
135 |
| - train_data = tf.data.Dataset.from_tensors([]) |
136 |
| - train_data = train_data.repeat() |
137 |
| - train_data = train_data.map( |
138 |
| - lambda _: self.source.sample(batch_size), |
139 |
| - ) |
140 |
| - |
141 |
| - seed = tfp.util.SeedStream(528374623, "compression_model_fit") |
142 |
| - # This rounds up to multiple of batch size. |
143 |
| - validation_batches = (validation_size - 1) // validation_batch_size + 1 |
144 |
| - validation_data = tf.data.Dataset.from_tensors([]) |
145 |
| - validation_data = validation_data.repeat(validation_batches) |
146 |
| - validation_data = validation_data.map( |
147 |
| - lambda _: self.source.sample(validation_batch_size, seed=seed), |
148 |
| - ) |
149 |
| - |
| 144 | + def fit(self, batch_size, validation_size, validation_batch_size, |
| 145 | + train_size=None, train_seed=None, validation_seed=82913749, **kwargs): |
| 146 | + train_data = source_dataset( |
| 147 | + self.source, batch_size, train_seed, dataset_size=train_size) |
| 148 | + if train_size is not None: |
| 149 | + train_data = train_data.repeat() |
| 150 | + validation_data = source_dataset( |
| 151 | + self.source, validation_batch_size, validation_seed, |
| 152 | + dataset_size=validation_size) |
150 | 153 | super().fit(
|
151 | 154 | train_data,
|
152 | 155 | validation_data=validation_data,
|
|
0 commit comments