Skip to content

Commit 11f28f3

Browse files
authored
Add doc string for models in sqlflow_models (#1717)
* Add doc string for models in sqlflow_models * Fix unit tests * Wait until tensorflow imported to fix unit test * Wait longer
1 parent 4493962 commit 11f28f3

File tree

5 files changed

+29
-3
lines changed

5 files changed

+29
-3
lines changed

cmd/repl/repl_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func testMainFastFail(t *testing.T, interactive bool) {
4949

5050
done := make(chan error)
5151
go func() { done <- cmd.Wait() }()
52-
timeout := time.After(2 * time.Second) // 2s are enough for **fast** fail
52+
timeout := time.After(4 * time.Second) // 4s are enough for **fast** fail
5353

5454
select {
5555
case <-timeout:
@@ -137,7 +137,7 @@ func TestComplete(t *testing.T) {
137137

138138
p.InsertText(`RAIN `, false, true)
139139
c = s.completer(*p.Document())
140-
a.Equal(11, len(c))
140+
a.Equal(18, len(c))
141141

142142
p.InsertText(`DNN`, false, true)
143143
c = s.completer(*p.Document())

pkg/sql/codegen/attribute/attribute.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ package attribute
1616
import (
1717
"encoding/json"
1818
"fmt"
19+
"log"
20+
"os/exec"
1921
"reflect"
2022
"sort"
2123
"strings"
@@ -158,6 +160,18 @@ func NewDictionaryFromModelDefinition(estimator, prefix string) Dictionary {
158160
// PremadeModelParamsDocs stores parameters and documents of all known models
159161
var PremadeModelParamsDocs map[string]map[string]string
160162

163+
// ExtractDocString extracts parameter documents from python doc strings
164+
func ExtractDocString(module ...string) {
165+
cmd := exec.Command("python", "-uc", fmt.Sprintf("__import__('extract_docstring').print_param_doc('%s')", strings.Join(module, "', '")))
166+
output, e := cmd.CombinedOutput()
167+
if e != nil {
168+
log.Println("ExtractDocString failed: ", e, string(output))
169+
}
170+
if e := json.Unmarshal(output, &PremadeModelParamsDocs); e != nil {
171+
log.Println("ExtractDocString failed:", e, string(output))
172+
}
173+
}
174+
161175
func removeUnnecessaryParams() {
162176
// The following parameters of canned estimators are already supported in the COLUMN clause.
163177
for _, v := range PremadeModelParamsDocs {
@@ -171,5 +185,6 @@ func init() {
171185
if err := json.Unmarshal([]byte(ModelParameterJSON), &PremadeModelParamsDocs); err != nil {
172186
panic(err) // assertion
173187
}
188+
ExtractDocString("sqlflow_models")
174189
removeUnnecessaryParams()
175190
}

pkg/sql/codegen/attribute/attribute_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func TestDictionaryValidate(t *testing.T) {
4343
func TestPremadeModelParamsDocs(t *testing.T) {
4444
a := assert.New(t)
4545

46-
a.Equal(11, len(PremadeModelParamsDocs))
46+
a.Equal(18, len(PremadeModelParamsDocs))
4747
a.Equal(len(PremadeModelParamsDocs["DNNClassifier"]), 12)
4848
a.NotContains(PremadeModelParamsDocs["DNNClassifier"], "feature_columns")
4949
a.Contains(PremadeModelParamsDocs["DNNClassifier"], "optimizer")

python/extract_docstring.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ def parse_ctor_args(f, prefix=''):
7474
[' '.join(doc.split()).replace("`", "'") for doc in total[2::2]]))
7575

7676

77+
def print_param_doc(*modules):
78+
param_doc = {} # { "class_names": {"parameters": "splitted docstrings"} }
79+
for module in modules:
80+
models = filter(lambda m: inspect.isclass(m[1]),
81+
inspect.getmembers(__import__(module)))
82+
for name, cls in models:
83+
param_doc[f'{module}.{name}'] = parse_ctor_args(cls, ':param')
84+
print(json.dumps(param_doc))
85+
86+
7787
if __name__ == "__main__":
7888
param_doc = {} # { "class_names": {"parameters": "splitted docstrings"} }
7989

scripts/test/ipython.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ DATASOURCE="mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
4545
export PYTHONPATH=$GOPATH/src/sqlflow.org/sqlflow/python
4646

4747
sqlflowserver &
48+
sleep 10
4849
# e2e test for standard SQL
4950
SQLFLOW_DATASOURCE=${DATASOURCE} SQLFLOW_SERVER=localhost:50051 ipython python/test_magic.py
5051
# TODO(yi): Re-enable the end-to-end test of Ant XGBoost after accelerating Travis CI.

0 commit comments

Comments
 (0)