Skip to content

Commit 007d8f4

Browse files
Merge pull request #2215 from beyarkay:patch-1
PiperOrigin-RevId: 556803453
2 parents 51a06aa + f1da823 commit 007d8f4

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

site/en/tutorials/generative/autoencoder.ipynb

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,29 @@
159159
},
160160
"outputs": [],
161161
"source": [
162-
"latent_dim = 64 \n",
163-
"\n",
164162
"class Autoencoder(Model):\n",
165-
" def __init__(self, latent_dim):\n",
163+
" def __init__(self, latent_dim, shape):\n",
166164
" super(Autoencoder, self).__init__()\n",
167-
" self.latent_dim = latent_dim \n",
165+
" self.latent_dim = latent_dim\n",
166+
" self.shape = shape\n",
168167
" self.encoder = tf.keras.Sequential([\n",
169168
" layers.Flatten(),\n",
170169
" layers.Dense(latent_dim, activation='relu'),\n",
171170
" ])\n",
172171
" self.decoder = tf.keras.Sequential([\n",
173-
" layers.Dense(784, activation='sigmoid'),\n",
174-
" layers.Reshape((28, 28))\n",
172+
" layers.Dense(tf.math.reduce_prod(shape), activation='sigmoid'),\n",
173+
" layers.Reshape(shape)\n",
175174
" ])\n",
176175
"\n",
177176
" def call(self, x):\n",
178177
" encoded = self.encoder(x)\n",
179178
" decoded = self.decoder(encoded)\n",
180179
" return decoded\n",
181-
" \n",
182-
"autoencoder = Autoencoder(latent_dim) "
180+
"\n",
181+
"\n",
182+
"shape = x_test.shape[1:]\n",
183+
"latent_dim = 64\n",
184+
"autoencoder = Autoencoder(latent_dim, shape)\n"
183185
]
184186
},
185187
{

0 commit comments

Comments
 (0)