Skip to content

Commit 6be86f2

Browse files
authored
Update attribute definition by named type methods and hide attribute.Description (#2504)
* hide Description * polish * polish * update
1 parent 4097352 commit 6be86f2

File tree

10 files changed

+260
-301
lines changed

10 files changed

+260
-301
lines changed

doc/model_parameter.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ INTO sqlflow_models.my_xgb_regression_model;
7979
<tr>
8080
<td>max_bin</td>
8181
<td>%!s(<nil>)</td>
82-
<td>used if tree_method is set to hist, Maximum number of discrete bins to bucket continuous features.</td>
82+
<td>Only used if tree_method is set to hist, Maximum number of discrete bins to bucket continuous features.</td>
8383
</tr>
8484
<tr>
8585
<td>max_delta_step</td>
@@ -153,7 +153,7 @@ INTO sqlflow_models.my_xgb_regression_model;
153153
</tr>
154154
<tr>
155155
<td>silent</td>
156-
<td>%!s(<nil>)</td>
156+
<td>bool</td>
157157
<td>Whether to print messages while running boosting. Deprecated. Use verbosity instead.</td>
158158
</tr>
159159
<tr>

pkg/attribute/attribute.go

Lines changed: 115 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"log"
2020
"os/exec"
2121
"reflect"
22+
"regexp"
2223
"sort"
2324
"strings"
2425
"sync"
@@ -30,29 +31,29 @@ const (
3031
)
3132

3233
var (
33-
// Bool indicates that the corresponding attribute is a boolean
34-
Bool = reflect.TypeOf(true)
35-
// Int indicates that the corresponding attribute is an integer
36-
Int = reflect.TypeOf(0)
37-
// Float indicates that the corresponding attribute is a float32
38-
Float = reflect.TypeOf(float32(0.))
39-
// String indicates the corresponding attribute is a string
40-
String = reflect.TypeOf("")
41-
// IntList indicates the corresponding attribute is a list of integers
42-
IntList = reflect.TypeOf([]int{})
43-
// Unknown type indicates that the attribute type is dynamically determined.
44-
Unknown = reflect.Type(nil)
34+
// boolType indicates that the corresponding attribute is a boolean
35+
boolType = reflect.TypeOf(true)
36+
// intType indicates that the corresponding attribute is an integer
37+
intType = reflect.TypeOf(0)
38+
// floatType indicates that the corresponding attribute is a float32
39+
floatType = reflect.TypeOf(float32(0.))
40+
// stringType indicates the corresponding attribute is a string
41+
stringType = reflect.TypeOf("")
42+
// intListType indicates the corresponding attribute is a list of integers
43+
intListType = reflect.TypeOf([]int{})
44+
// unknownType indicates that the attribute type is dynamically determined.
45+
unknownType = reflect.Type(nil)
4546
)
4647

4748
// Dictionary contains the description of all attributes
48-
type Dictionary map[string]*Description
49-
50-
// Description describes a requirement for a particular attribute
51-
type Description struct {
52-
Type reflect.Type
53-
Default interface{}
54-
Doc string
55-
Checker func(i interface{}) error
49+
type Dictionary map[string]*description
50+
51+
// description describes a requirement for a particular attribute
52+
type description struct {
53+
typ reflect.Type
54+
defaultValue interface{}
55+
doc string
56+
checker func(i interface{}) error
5657
}
5758

5859
// Int declares an attribute of int-typed in Dictionary d.
@@ -74,25 +75,31 @@ func (d Dictionary) Int(name string, value interface{}, doc string, checker func
7475
}
7576
}
7677

77-
d[name] = &Description{
78-
Type: Int,
79-
Default: value,
80-
Doc: doc,
81-
Checker: interfaceChecker,
78+
d[name] = &description{
79+
typ: intType,
80+
defaultValue: value,
81+
doc: doc,
82+
checker: interfaceChecker,
8283
}
8384
return d
8485
}
8586

8687
// Float declares an attribute of float32-typed in Dictionary d.
8788
func (d Dictionary) Float(name string, value interface{}, doc string, checker func(float32) error) Dictionary {
8889
interfaceChecker := func(v interface{}) error {
90+
var fValue float32
8991
if floatValue, ok := v.(float32); ok {
90-
if checker != nil {
91-
return checker(floatValue)
92-
}
93-
return nil
92+
fValue = floatValue
93+
} else if intValue, ok := v.(int); ok { // implicit type conversion from int to float
94+
fValue = float32(intValue)
95+
} else {
96+
return fmt.Errorf("attribute %s must be of type float, but got %T", name, v)
97+
}
98+
99+
if checker != nil {
100+
return checker(fValue)
94101
}
95-
return fmt.Errorf("attribute %s must be of type float, but got %T", name, v)
102+
return nil
96103
}
97104

98105
if value != nil {
@@ -102,11 +109,20 @@ func (d Dictionary) Float(name string, value interface{}, doc string, checker fu
102109
}
103110
}
104111

105-
d[name] = &Description{
106-
Type: Float,
107-
Default: value,
108-
Doc: doc,
109-
Checker: interfaceChecker,
112+
var fInterfaceValue interface{}
113+
if value == nil {
114+
fInterfaceValue = nil
115+
} else if floatValue, ok := value.(float32); ok {
116+
fInterfaceValue = floatValue
117+
} else if intValue, ok := value.(int); ok { // implicit type conversion from int to float
118+
fInterfaceValue = float32(intValue)
119+
}
120+
121+
d[name] = &description{
122+
typ: floatType,
123+
defaultValue: fInterfaceValue,
124+
doc: doc,
125+
checker: interfaceChecker,
110126
}
111127
return d
112128
}
@@ -130,11 +146,11 @@ func (d Dictionary) Bool(name string, value interface{}, doc string, checker fun
130146
}
131147
}
132148

133-
d[name] = &Description{
134-
Type: Bool,
135-
Default: value,
136-
Doc: doc,
137-
Checker: interfaceChecker,
149+
d[name] = &description{
150+
typ: boolType,
151+
defaultValue: value,
152+
doc: doc,
153+
checker: interfaceChecker,
138154
}
139155
return d
140156
}
@@ -158,11 +174,11 @@ func (d Dictionary) String(name string, value interface{}, doc string, checker f
158174
}
159175
}
160176

161-
d[name] = &Description{
162-
Type: String,
163-
Default: value,
164-
Doc: doc,
165-
Checker: interfaceChecker,
177+
d[name] = &description{
178+
typ: stringType,
179+
defaultValue: value,
180+
doc: doc,
181+
checker: interfaceChecker,
166182
}
167183
return d
168184
}
@@ -186,11 +202,11 @@ func (d Dictionary) IntList(name string, value interface{}, doc string, checker
186202
}
187203
}
188204

189-
d[name] = &Description{
190-
Type: IntList,
191-
Default: value,
192-
Doc: doc,
193-
Checker: interfaceChecker,
205+
d[name] = &description{
206+
typ: intListType,
207+
defaultValue: value,
208+
doc: doc,
209+
checker: interfaceChecker,
194210
}
195211
return d
196212
}
@@ -204,28 +220,28 @@ func (d Dictionary) Unknown(name string, value interface{}, doc string, checker
204220
}
205221
}
206222

207-
d[name] = &Description{
208-
Type: Unknown,
209-
Default: value,
210-
Doc: doc,
211-
Checker: checker,
223+
d[name] = &description{
224+
typ: unknownType,
225+
defaultValue: value,
226+
doc: doc,
227+
checker: checker,
212228
}
213229
return d
214230
}
215231

216-
// FillDefaults fills default values defined in Dictionary to attrs.
217-
func (d Dictionary) FillDefaults(attrs map[string]interface{}) {
232+
// ExportDefaults exports default values defined in Dictionary to attrs.
233+
func (d Dictionary) ExportDefaults(attrs map[string]interface{}) {
218234
for k, v := range d {
219235
// Do not fill default value for unknown type, and with nil default values.
220-
if v.Type == Unknown {
236+
if v.typ == unknownType {
221237
continue
222238
}
223-
if v.Default == nil {
239+
if v.defaultValue == nil {
224240
continue
225241
}
226242
_, ok := attrs[k]
227243
if !ok {
228-
attrs[k] = v.Default
244+
attrs[k] = v.defaultValue
229245
}
230246
}
231247
}
@@ -235,7 +251,7 @@ func (d Dictionary) FillDefaults(attrs map[string]interface{}) {
235251
// 2. Customer checker
236252
func (d Dictionary) Validate(attrs map[string]interface{}) error {
237253
for k, v := range attrs {
238-
var desc *Description
254+
var desc *description
239255
desc, ok := d[k]
240256
if !ok {
241257
// Support attribute definition like "model.*" to match
@@ -254,15 +270,15 @@ func (d Dictionary) Validate(attrs map[string]interface{}) error {
254270
}
255271
}
256272

257-
if desc.Type != Unknown && desc.Type != reflect.TypeOf(v) {
273+
if desc.typ != unknownType && desc.typ != reflect.TypeOf(v) {
258274
// Allow implicit conversion from int to float to ease typing
259-
if !(desc.Type == Float && reflect.TypeOf(v) == Int) {
260-
return fmt.Errorf(errUnexpectedType, k, desc.Type, v)
275+
if !(desc.typ == floatType && reflect.TypeOf(v) == intType) {
276+
return fmt.Errorf(errUnexpectedType, k, desc.typ, v)
261277
}
262278
}
263279

264-
if desc.Checker != nil {
265-
if err := desc.Checker(v); err != nil {
280+
if desc.checker != nil {
281+
if err := desc.checker(v); err != nil {
266282
return err
267283
}
268284
}
@@ -293,7 +309,7 @@ func (d Dictionary) GenerateTableInHTML() string {
293309
<td>%s</td>
294310
</tr>`
295311
// NOTE(tony): if the doc string has multiple lines, need to replace \n with <br>
296-
s := fmt.Sprintf(t, k, desc.Type, strings.Replace(desc.Doc, "\n", `<br>`, -1))
312+
s := fmt.Sprintf(t, k, desc.typ, strings.Replace(desc.doc, "\n", `<br>`, -1))
297313
l = append(l, s)
298314
}
299315

@@ -311,9 +327,40 @@ func (d Dictionary) Update(other Dictionary) Dictionary {
311327

312328
// NewDictionaryFromModelDefinition create a new Dictionary according to pre-made estimators or XGBoost model types.
313329
func NewDictionaryFromModelDefinition(estimator, prefix string) Dictionary {
330+
isXGBoostModel := strings.HasPrefix(estimator, "xgboost")
331+
re := regexp.MustCompile("[^a-z]")
332+
314333
var d = Dictionary{}
315334
for param, doc := range PremadeModelParamsDocs[estimator] {
316-
d[prefix+param] = &Description{Unknown, nil, doc, nil}
335+
desc := &description{unknownType, nil, doc, nil}
336+
d[prefix+param] = desc
337+
338+
if !isXGBoostModel {
339+
continue
340+
}
341+
342+
// Fill typ field according to the model parameter doc
343+
// The doc would be like: "int Maximum tree depth for base learners"
344+
pieces := strings.SplitN(strings.TrimSpace(desc.doc), " ", 2)
345+
if len(pieces) != 2 {
346+
continue
347+
}
348+
349+
maybeType := re.ReplaceAllString(pieces[0], "")
350+
switch strings.ToLower(maybeType) {
351+
case "float":
352+
desc.typ = floatType
353+
desc.doc = pieces[1]
354+
case "int":
355+
desc.typ = intType
356+
desc.doc = pieces[1]
357+
case "string":
358+
desc.typ = stringType
359+
desc.doc = pieces[1]
360+
case "boolean":
361+
desc.typ = boolType
362+
desc.doc = pieces[1]
363+
}
317364
}
318365
return d
319366
}

pkg/attribute/attribute_test.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,13 @@ func TestDictionaryNamedTypeChecker(t *testing.T) {
123123
func TestDictionaryValidate(t *testing.T) {
124124
a := assert.New(t)
125125

126-
checker := func(i interface{}) error {
127-
ii, ok := i.(int)
128-
if !ok {
129-
return fmt.Errorf("%T %v should of type integer", i, i)
130-
}
131-
if ii < 0 {
126+
checker := func(i int) error {
127+
if i < 0 {
132128
return fmt.Errorf("some error")
133129
}
134130
return nil
135131
}
136-
tb := Dictionary{"a": {Int, 1, "attribute a", checker}, "b": {Float, 1, "attribute b", nil}}
132+
tb := Dictionary{}.Int("a", 1, "attribute a", checker).Float("b", float32(1), "attribute b", nil)
137133
a.NoError(tb.Validate(map[string]interface{}{"a": 1}))
138134
a.EqualError(tb.Validate(map[string]interface{}{"a": -1}), "some error")
139135
a.EqualError(tb.Validate(map[string]interface{}{"_a": -1}), fmt.Sprintf(errUnsupportedAttribute, "_a"))
@@ -165,7 +161,7 @@ func TestParamsDocs(t *testing.T) {
165161
func TestNewAndUpdateDictionary(t *testing.T) {
166162
a := assert.New(t)
167163

168-
commonAttrs := Dictionary{"a": {Int, 1, "attribute a", nil}}
164+
commonAttrs := Dictionary{}.Int("a", 1, "attribute a", nil)
169165
specificAttrs := NewDictionaryFromModelDefinition("DNNClassifier", "model.")
170166
a.Equal(len(specificAttrs), 12)
171167
a.Equal(len(specificAttrs.Update(specificAttrs)), 12)
@@ -180,12 +176,12 @@ func TestNewAndUpdateDictionary(t *testing.T) {
180176

181177
func TestDictionary_GenerateTableInHTML(t *testing.T) {
182178
a := assert.New(t)
183-
tb := Dictionary{
184-
"a": {Int, 1, `this is a
179+
tb := Dictionary{}.
180+
Int("a", 1, `this is a
185181
multiple line
186-
doc string.`, nil},
187-
"世界": {String, "", `42`, nil},
188-
}
182+
doc string.`, nil).
183+
String("世界", "", `42`, nil)
184+
189185
expected := `<table>
190186
<tr>
191187
<td>Name</td>

0 commit comments

Comments
 (0)