@@ -19,6 +19,7 @@ import (
1919 "log"
2020 "os/exec"
2121 "reflect"
22+ "regexp"
2223 "sort"
2324 "strings"
2425 "sync"
@@ -30,29 +31,29 @@ const (
3031)
3132
3233var (
33- // Bool indicates that the corresponding attribute is a boolean
34- Bool = reflect .TypeOf (true )
35- // Int indicates that the corresponding attribute is an integer
36- Int = reflect .TypeOf (0 )
37- // Float indicates that the corresponding attribute is a float32
38- Float = reflect .TypeOf (float32 (0. ))
39- // String indicates the corresponding attribute is a string
40- String = reflect .TypeOf ("" )
41- // IntList indicates the corresponding attribute is a list of integers
42- IntList = reflect .TypeOf ([]int {})
43- // Unknown type indicates that the attribute type is dynamically determined.
44- Unknown = reflect .Type (nil )
34+ // boolType indicates that the corresponding attribute is a boolean
35+ boolType = reflect .TypeOf (true )
36+ // intType indicates that the corresponding attribute is an integer
37+ intType = reflect .TypeOf (0 )
38+ // floatType indicates that the corresponding attribute is a float32
39+ floatType = reflect .TypeOf (float32 (0. ))
40+ // stringType indicates the corresponding attribute is a string
41+ stringType = reflect .TypeOf ("" )
42+ // intListType indicates the corresponding attribute is a list of integers
43+ intListType = reflect .TypeOf ([]int {})
44+ // unknownType indicates that the attribute type is dynamically determined.
45+ unknownType = reflect .Type (nil )
4546)
4647
4748// Dictionary contains the description of all attributes
48- type Dictionary map [string ]* Description
49-
50- // Description describes a requirement for a particular attribute
51- type Description struct {
52- Type reflect.Type
53- Default interface {}
54- Doc string
55- Checker func (i interface {}) error
49+ type Dictionary map [string ]* description
50+
51+ // description describes a requirement for a particular attribute
52+ type description struct {
53+ typ reflect.Type
54+ defaultValue interface {}
55+ doc string
56+ checker func (i interface {}) error
5657}
5758
5859// Int declares an attribute of int-typed in Dictionary d.
@@ -74,25 +75,31 @@ func (d Dictionary) Int(name string, value interface{}, doc string, checker func
7475 }
7576 }
7677
77- d [name ] = & Description {
78- Type : Int ,
79- Default : value ,
80- Doc : doc ,
81- Checker : interfaceChecker ,
78+ d [name ] = & description {
79+ typ : intType ,
80+ defaultValue : value ,
81+ doc : doc ,
82+ checker : interfaceChecker ,
8283 }
8384 return d
8485}
8586
8687// Float declares an attribute of float32-typed in Dictionary d.
8788func (d Dictionary ) Float (name string , value interface {}, doc string , checker func (float32 ) error ) Dictionary {
8889 interfaceChecker := func (v interface {}) error {
90+ var fValue float32
8991 if floatValue , ok := v .(float32 ); ok {
90- if checker != nil {
91- return checker (floatValue )
92- }
93- return nil
92+ fValue = floatValue
93+ } else if intValue , ok := v .(int ); ok { // implicit type conversion from int to float
94+ fValue = float32 (intValue )
95+ } else {
96+ return fmt .Errorf ("attribute %s must be of type float, but got %T" , name , v )
97+ }
98+
99+ if checker != nil {
100+ return checker (fValue )
94101 }
95- return fmt . Errorf ( "attribute %s must be of type float, but got %T" , name , v )
102+ return nil
96103 }
97104
98105 if value != nil {
@@ -102,11 +109,20 @@ func (d Dictionary) Float(name string, value interface{}, doc string, checker fu
102109 }
103110 }
104111
105- d [name ] = & Description {
106- Type : Float ,
107- Default : value ,
108- Doc : doc ,
109- Checker : interfaceChecker ,
112+ var fInterfaceValue interface {}
113+ if value == nil {
114+ fInterfaceValue = nil
115+ } else if floatValue , ok := value .(float32 ); ok {
116+ fInterfaceValue = floatValue
117+ } else if intValue , ok := value .(int ); ok { // implicit type conversion from int to float
118+ fInterfaceValue = float32 (intValue )
119+ }
120+
121+ d [name ] = & description {
122+ typ : floatType ,
123+ defaultValue : fInterfaceValue ,
124+ doc : doc ,
125+ checker : interfaceChecker ,
110126 }
111127 return d
112128}
@@ -130,11 +146,11 @@ func (d Dictionary) Bool(name string, value interface{}, doc string, checker fun
130146 }
131147 }
132148
133- d [name ] = & Description {
134- Type : Bool ,
135- Default : value ,
136- Doc : doc ,
137- Checker : interfaceChecker ,
149+ d [name ] = & description {
150+ typ : boolType ,
151+ defaultValue : value ,
152+ doc : doc ,
153+ checker : interfaceChecker ,
138154 }
139155 return d
140156}
@@ -158,11 +174,11 @@ func (d Dictionary) String(name string, value interface{}, doc string, checker f
158174 }
159175 }
160176
161- d [name ] = & Description {
162- Type : String ,
163- Default : value ,
164- Doc : doc ,
165- Checker : interfaceChecker ,
177+ d [name ] = & description {
178+ typ : stringType ,
179+ defaultValue : value ,
180+ doc : doc ,
181+ checker : interfaceChecker ,
166182 }
167183 return d
168184}
@@ -186,11 +202,11 @@ func (d Dictionary) IntList(name string, value interface{}, doc string, checker
186202 }
187203 }
188204
189- d [name ] = & Description {
190- Type : IntList ,
191- Default : value ,
192- Doc : doc ,
193- Checker : interfaceChecker ,
205+ d [name ] = & description {
206+ typ : intListType ,
207+ defaultValue : value ,
208+ doc : doc ,
209+ checker : interfaceChecker ,
194210 }
195211 return d
196212}
@@ -204,28 +220,28 @@ func (d Dictionary) Unknown(name string, value interface{}, doc string, checker
204220 }
205221 }
206222
207- d [name ] = & Description {
208- Type : Unknown ,
209- Default : value ,
210- Doc : doc ,
211- Checker : checker ,
223+ d [name ] = & description {
224+ typ : unknownType ,
225+ defaultValue : value ,
226+ doc : doc ,
227+ checker : checker ,
212228 }
213229 return d
214230}
215231
216- // FillDefaults fills default values defined in Dictionary to attrs.
217- func (d Dictionary ) FillDefaults (attrs map [string ]interface {}) {
232+ // ExportDefaults exports default values defined in Dictionary to attrs.
233+ func (d Dictionary ) ExportDefaults (attrs map [string ]interface {}) {
218234 for k , v := range d {
219235 // Do not fill default value for unknown type, and with nil default values.
220- if v .Type == Unknown {
236+ if v .typ == unknownType {
221237 continue
222238 }
223- if v .Default == nil {
239+ if v .defaultValue == nil {
224240 continue
225241 }
226242 _ , ok := attrs [k ]
227243 if ! ok {
228- attrs [k ] = v .Default
244+ attrs [k ] = v .defaultValue
229245 }
230246 }
231247}
@@ -235,7 +251,7 @@ func (d Dictionary) FillDefaults(attrs map[string]interface{}) {
235251// 2. Customer checker
236252func (d Dictionary ) Validate (attrs map [string ]interface {}) error {
237253 for k , v := range attrs {
238- var desc * Description
254+ var desc * description
239255 desc , ok := d [k ]
240256 if ! ok {
241257 // Support attribute definition like "model.*" to match
@@ -254,15 +270,15 @@ func (d Dictionary) Validate(attrs map[string]interface{}) error {
254270 }
255271 }
256272
257- if desc .Type != Unknown && desc .Type != reflect .TypeOf (v ) {
273+ if desc .typ != unknownType && desc .typ != reflect .TypeOf (v ) {
258274 // Allow implicit conversion from int to float to ease typing
259- if ! (desc .Type == Float && reflect .TypeOf (v ) == Int ) {
260- return fmt .Errorf (errUnexpectedType , k , desc .Type , v )
275+ if ! (desc .typ == floatType && reflect .TypeOf (v ) == intType ) {
276+ return fmt .Errorf (errUnexpectedType , k , desc .typ , v )
261277 }
262278 }
263279
264- if desc .Checker != nil {
265- if err := desc .Checker (v ); err != nil {
280+ if desc .checker != nil {
281+ if err := desc .checker (v ); err != nil {
266282 return err
267283 }
268284 }
@@ -293,7 +309,7 @@ func (d Dictionary) GenerateTableInHTML() string {
293309 <td>%s</td>
294310</tr>`
295311 // NOTE(tony): if the doc string has multiple lines, need to replace \n with <br>
296- s := fmt .Sprintf (t , k , desc .Type , strings .Replace (desc .Doc , "\n " , `<br>` , - 1 ))
312+ s := fmt .Sprintf (t , k , desc .typ , strings .Replace (desc .doc , "\n " , `<br>` , - 1 ))
297313 l = append (l , s )
298314 }
299315
@@ -311,9 +327,40 @@ func (d Dictionary) Update(other Dictionary) Dictionary {
311327
312328// NewDictionaryFromModelDefinition create a new Dictionary according to pre-made estimators or XGBoost model types.
313329func NewDictionaryFromModelDefinition (estimator , prefix string ) Dictionary {
330+ isXGBoostModel := strings .HasPrefix (estimator , "xgboost" )
331+ re := regexp .MustCompile ("[^a-z]" )
332+
314333 var d = Dictionary {}
315334 for param , doc := range PremadeModelParamsDocs [estimator ] {
316- d [prefix + param ] = & Description {Unknown , nil , doc , nil }
335+ desc := & description {unknownType , nil , doc , nil }
336+ d [prefix + param ] = desc
337+
338+ if ! isXGBoostModel {
339+ continue
340+ }
341+
342+ // Fill typ field according to the model parameter doc
343+ // The doc would be like: "int Maximum tree depth for base learners"
344+ pieces := strings .SplitN (strings .TrimSpace (desc .doc ), " " , 2 )
345+ if len (pieces ) != 2 {
346+ continue
347+ }
348+
349+ maybeType := re .ReplaceAllString (pieces [0 ], "" )
350+ switch strings .ToLower (maybeType ) {
351+ case "float" :
352+ desc .typ = floatType
353+ desc .doc = pieces [1 ]
354+ case "int" :
355+ desc .typ = intType
356+ desc .doc = pieces [1 ]
357+ case "string" :
358+ desc .typ = stringType
359+ desc .doc = pieces [1 ]
360+ case "boolean" :
361+ desc .typ = boolType
362+ desc .doc = pieces [1 ]
363+ }
317364 }
318365 return d
319366}
0 commit comments