@@ -25,8 +25,8 @@ import (
2525
2626const featureDerivationRows = 1000
2727
28- // FeatureColumnMap is a mapping from column name to FeatureColumn struct
29- type FeatureColumnMap map [string ]columns.FeatureColumn
28+ // FeatureColumnMap is like: target -> key -> FeatureColumn
29+ type FeatureColumnMap map [string ]map [ string ] columns.FeatureColumn
3030
3131// ColumnSpecMap is a mappign from column name to ColumnSpec struct
3232type ColumnSpecMap map [string ]* columns.ColumnSpec
@@ -35,9 +35,10 @@ type ColumnSpecMap map[string]*columns.ColumnSpec
3535// NOTE that the target is not important for analyzing feature derivation.
3636func makeFeatureColumnMap (parsedFeatureColumns map [string ][]columns.FeatureColumn ) FeatureColumnMap {
3737 fcMap := make (FeatureColumnMap )
38- for _ , fcList := range parsedFeatureColumns {
38+ for target , fcList := range parsedFeatureColumns {
39+ fcMap [target ] = make (map [string ]columns.FeatureColumn )
3940 for _ , fc := range fcList {
40- fcMap [fc .GetKey ()] = fc
41+ fcMap [target ][ fc .GetKey ()] = fc
4142 }
4243 }
4344 return fcMap
@@ -245,44 +246,58 @@ func InferFeatureColumns(slct *standardSelect,
245246
246247 // 1. Infer omited category_id_column for embedding_columns
247248 // 2. Add derivated feature column.
248- for slctKey := range selectFieldTypeMap {
249- if fc , ok := fcMap [slctKey ]; ok {
250- if fc .GetColumnType () == columns .ColumnTypeEmbedding {
251- if fc .(* columns.EmbeddingColumn ).CategoryColumn == nil {
252- cs , ok := csMap [fc .GetKey ()]
253- if ! ok {
254- return nil , nil , fmt .Errorf ("column not found or infered: %s" , fc .GetKey ())
249+ //
250+ // need to store FeatureColumn under it's target in case of
251+ // the same column used for different target, e.g.
252+ // COLUMN EMBEDDING(c1) for deep
253+ // EMBEDDING(c2) for deep
254+ // EMBEDDING(c1) for wide
255+ for target := range parsedFeatureColumns {
256+ for slctKey := range selectFieldTypeMap {
257+ fcTargetMap , ok := fcMap [target ]
258+ if ! ok {
259+ // create map for current target
260+ fcMap [target ] = make (map [string ]columns.FeatureColumn )
261+ fcTargetMap = fcMap [target ]
262+ }
263+ if fc , ok := fcTargetMap [slctKey ]; ok {
264+ if fc .GetColumnType () == columns .ColumnTypeEmbedding {
265+ if fc .(* columns.EmbeddingColumn ).CategoryColumn == nil {
266+ cs , ok := csMap [fc .GetKey ()]
267+ if ! ok {
268+ return nil , nil , fmt .Errorf ("column not found or infered: %s" , fc .GetKey ())
269+ }
270+ // FIXME(typhoonzero): when to use sequence_category_id_column?
271+ fc .(* columns.EmbeddingColumn ).CategoryColumn = & columns.CategoryIDColumn {
272+ Key : cs .ColumnName ,
273+ BucketSize : cs .Shape [0 ],
274+ Delimiter : cs .Delimiter ,
275+ Dtype : cs .DType ,
276+ }
277+ }
278+ }
279+ } else {
280+ cs , ok := csMap [slctKey ]
281+ if ! ok {
282+ return nil , nil , fmt .Errorf ("column not found or infered: %s" , slctKey )
283+ }
284+ if cs .DType != "string" {
285+ fcMap [target ][slctKey ] = & columns.NumericColumn {
286+ Key : cs .ColumnName ,
287+ Shape : cs .Shape ,
288+ Dtype : cs .DType ,
289+ Delimiter : cs .Delimiter ,
255290 }
256- // FIXME(typhoonzero): when to use sequence_category_id_column?
257- fc .(* columns.EmbeddingColumn ).CategoryColumn = & columns.CategoryIDColumn {
291+ } else {
292+ // FIXME(typhoonzero): need full test case for string numeric columns
293+ fcMap [target ][slctKey ] = & columns.CategoryIDColumn {
258294 Key : cs .ColumnName ,
259- BucketSize : cs .Shape [ 0 ] ,
295+ BucketSize : len ( cs .Vocabulary ) ,
260296 Delimiter : cs .Delimiter ,
261297 Dtype : cs .DType ,
262298 }
263299 }
264300 }
265- } else {
266- cs , ok := csMap [slctKey ]
267- if ! ok {
268- return nil , nil , fmt .Errorf ("column not found or infered: %s" , slctKey )
269- }
270- if cs .DType != "string" {
271- fcMap [slctKey ] = & columns.NumericColumn {
272- Key : cs .ColumnName ,
273- Shape : cs .Shape ,
274- Dtype : cs .DType ,
275- Delimiter : cs .Delimiter ,
276- }
277- } else {
278- // FIXME(typhoonzero): need full test case for string numeric columns
279- fcMap [slctKey ] = & columns.CategoryIDColumn {
280- Key : cs .ColumnName ,
281- BucketSize : len (cs .Vocabulary ),
282- Delimiter : cs .Delimiter ,
283- Dtype : cs .DType ,
284- }
285- }
286301 }
287302 }
288303
0 commit comments