@@ -187,7 +187,6 @@ if __name__ == "__main__":
187
187
```
188
188
189
189
## Multi-value Input : Movielens
190
- ----------------------------------
191
190
192
191
The MovieLens data has been used for personalized tag recommendation,which contains 668, 953 tag applications of users
193
192
on movies. Here is a small fraction of data include sparse fields and a multivalent field.
@@ -275,7 +274,6 @@ if __name__ == "__main__":
275
274
```
276
275
277
276
## Multi-value Input : Movielens with feature hashing on the fly
278
- ----------------------------------
279
277
280
278
``` python
281
279
import numpy as np
@@ -300,7 +298,7 @@ if __name__ == "__main__":
300
298
max_len = max (genres_length)
301
299
302
300
# Notice : padding=`post`
303
- genres_list = pad_sequences(genres_list, maxlen = max_len, padding = ' post' , dtype = str , value = 0 )
301
+ genres_list = pad_sequences(genres_list, maxlen = max_len, padding = ' post' , dtype = object , value = 0 ).astype( str )
304
302
305
303
# 2.set hashing space for each sparse field and generate feature config for sequence feature
306
304
@@ -358,7 +356,7 @@ if __name__ == "__main__":
358
356
max_len = max (genres_length)
359
357
360
358
# Notice : padding=`post`
361
- genres_list = pad_sequences(genres_list, maxlen = max_len, padding = ' post' , dtype = str , value = 0 )
359
+ genres_list = pad_sequences(genres_list, maxlen = max_len, padding = ' post' , dtype = object , value = 0 ).astype( str )
362
360
363
361
# 2.set hashing space for each sparse field and generate feature config for sequence feature
364
362
@@ -521,11 +519,11 @@ if __name__ == "__main__":
521
519
The UCI census-income dataset is extracted from the 1994 census database. It contains 299,285 instances of demographic
522
520
information of American adults. There are 40 features in total. We construct a multi-task learning problem from this
523
521
dataset by setting some of the features as prediction targets :
522
+
524
523
- Task 1: Predict whether the income exceeds $50K;
525
- - Task 2: Predict whether this person’s marital status is never married.
524
+ - Task 2: Predict whether this person’s marital status is never married.
526
525
527
- This example shows how to use `` MMOE `` to solve a multi
528
- task learning problem. You can get the demo
526
+ This example shows how to use `` MMOE `` to solve a multi task learning problem. You can get the demo
529
527
data [ census-income.sample] ( https://github.com/shenweichen/DeepCTR/tree/master/examples/census-income.sample ) and run
530
528
the following codes.
531
529
@@ -572,29 +570,29 @@ if __name__ == "__main__":
572
570
data[feat] = lbe.fit_transform(data[feat])
573
571
574
572
fixlen_feature_columns = [SparseFeat(feat, data[feat].max() + 1 , embedding_dim = 4 ) for feat in sparse_features]
575
- + [DenseFeat(feat, 1 , ) for feat in dense_features]
573
+ + [DenseFeat(feat, 1 , ) for feat in dense_features]
576
574
577
575
dnn_feature_columns = fixlen_feature_columns
578
576
linear_feature_columns = fixlen_feature_columns
579
-
577
+
580
578
feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)
581
-
579
+
582
580
# 3.generate input data for model
583
-
581
+
584
582
train, test = train_test_split(data, test_size = 0.2 , random_state = 2020 )
585
583
train_model_input = {name: train[name] for name in feature_names}
586
584
test_model_input = {name: test[name] for name in feature_names}
587
-
585
+
588
586
# 4.Define Model,train,predict and evaluate
589
587
model = MMOE(dnn_feature_columns, tower_dnn_hidden_units = [], task_types = [' binary' , ' binary' ],
590
588
task_names = [' label_income' , ' label_marital' ])
591
589
model.compile(" adam" , loss = [" binary_crossentropy" , " binary_crossentropy" ],
592
590
metrics = [' binary_crossentropy' ], )
593
-
591
+
594
592
history = model.fit(train_model_input, [train[' label_income' ].values, train[' label_marital' ].values],
595
593
batch_size = 256 , epochs = 10 , verbose = 2 , validation_split = 0.2 )
596
594
pred_ans = model.predict(test_model_input, batch_size = 256 )
597
-
595
+
598
596
print (" test income AUC" , round (roc_auc_score(test[' label_income' ], pred_ans[0 ]), 4 ))
599
597
print (" test marital AUC" , round (roc_auc_score(test[' label_marital' ], pred_ans[1 ]), 4 ))
600
598
0 commit comments