Skip to content

Commit 964bed4

Browse files
authored
Apply the expand to the dataframe instead.
1 parent 158d2f8 commit 964bed4

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

site/en/tutorials/structured_data/preprocessing_layers.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,10 @@
295295
"outputs": [],
296296
"source": [
297297
"def df_to_dataset(dataframe, shuffle=True, batch_size=32):\n",
298-
" dataframe = dataframe.copy()\n",
299-
" labels = dataframe.pop('target')\n",
300-
" ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))\n",
298+
" df = dataframe.copy()\n",
299+
" labels = df.pop('target')\n",
300+
" df = {key: value[:,tf.newaxis] for key, value in dataframe.items()}\n",
301+
" ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))\n",
301302
" if shuffle:\n",
302303
" ds = ds.shuffle(buffer_size=len(dataframe))\n",
303304
" ds = ds.batch(batch_size)\n",
@@ -502,7 +503,7 @@
502503
"test_type_layer = get_category_encoding_layer(name='Type',\n",
503504
" dataset=train_ds,\n",
504505
" dtype='string')\n",
505-
"test_type_layer(tf.expand_dims(test_type_col, -1))"
506+
"test_type_layer(test_type_col)"
506507
]
507508
},
508509
{
@@ -527,7 +528,7 @@
527528
" dataset=train_ds,\n",
528529
" dtype='int64',\n",
529530
" max_tokens=5)\n",
530-
"test_age_layer(tf.expand_dims(test_age_col, -1))"
531+
"test_age_layer(test_age_col)"
531532
]
532533
},
533534
{

0 commit comments

Comments
 (0)