|
159 | 159 | },
|
160 | 160 | "outputs": [],
|
161 | 161 | "source": [
|
162 |
| - "latent_dim = 64 \n", |
163 |
| - "\n", |
164 | 162 | "class Autoencoder(Model):\n",
|
165 |
| - " def __init__(self, latent_dim):\n", |
| 163 | + " def __init__(self, latent_dim, shape):\n", |
166 | 164 | " super(Autoencoder, self).__init__()\n",
|
167 |
| - " self.latent_dim = latent_dim \n", |
| 165 | + " self.latent_dim = latent_dim\n", |
| 166 | + " self.shape = shape\n", |
168 | 167 | " self.encoder = tf.keras.Sequential([\n",
|
169 | 168 | " layers.Flatten(),\n",
|
170 | 169 | " layers.Dense(latent_dim, activation='relu'),\n",
|
171 | 170 | " ])\n",
|
172 | 171 | " self.decoder = tf.keras.Sequential([\n",
|
173 |
| - " layers.Dense(784, activation='sigmoid'),\n", |
174 |
| - " layers.Reshape((28, 28))\n", |
| 172 | + " layers.Dense(tf.math.reduce_prod(shape), activation='sigmoid'),\n", |
| 173 | + " layers.Reshape(shape)\n", |
175 | 174 | " ])\n",
|
176 | 175 | "\n",
|
177 | 176 | " def call(self, x):\n",
|
178 | 177 | " encoded = self.encoder(x)\n",
|
179 | 178 | " decoded = self.decoder(encoded)\n",
|
180 | 179 | " return decoded\n",
|
181 |
| - " \n", |
182 |
| - "autoencoder = Autoencoder(latent_dim) " |
| 180 | + "\n", |
| 181 | + "\n", |
| 182 | + "shape = x_test.shape[1:]\n", |
| 183 | + "latent_dim = 64\n", |
| 184 | + "autoencoder = Autoencoder(latent_dim, shape)\n" |
183 | 185 | ]
|
184 | 186 | },
|
185 | 187 | {
|
|
0 commit comments