Skip to content

Commit 135211f

Browse files
committed
gitbase, function: add mode parameter to uast function
Signed-off-by: Javi Fontan <[email protected]>
1 parent 1e942c0 commit 135211f

File tree

3 files changed

+80
-8
lines changed

3 files changed

+80
-8
lines changed

internal/function/uast.go

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/sirupsen/logrus"
88
"github.com/src-d/gitbase"
9+
bblfsh "gopkg.in/bblfsh/client-go.v2"
910
"gopkg.in/bblfsh/client-go.v2/tools"
1011
"gopkg.in/bblfsh/sdk.v1/protocol"
1112
"gopkg.in/bblfsh/sdk.v1/uast"
@@ -25,11 +26,12 @@ type UAST struct {
2526
Blob sql.Expression
2627
Lang sql.Expression
2728
XPath sql.Expression
29+
Mode sql.Expression
2830
}
2931

3032
// NewUAST creates a new UAST UDF.
3133
func NewUAST(args ...sql.Expression) (sql.Expression, error) {
32-
var blob, lang, xpath sql.Expression
34+
var blob, lang, xpath, mode sql.Expression
3335
switch len(args) {
3436
case 1:
3537
blob = args[0]
@@ -40,10 +42,15 @@ func NewUAST(args ...sql.Expression) (sql.Expression, error) {
4042
blob = args[0]
4143
lang = args[1]
4244
xpath = args[2]
45+
case 4:
46+
blob = args[0]
47+
lang = args[1]
48+
xpath = args[2]
49+
mode = args[3]
4350
default:
44-
return nil, sql.ErrInvalidArgumentNumber.New("1, 2 or 3", len(args))
51+
return nil, sql.ErrInvalidArgumentNumber.New("1, 2, 3 or 4", len(args))
4552
}
46-
return &UAST{blob, lang, xpath}, nil
53+
return &UAST{blob, lang, xpath, mode}, nil
4754
}
4855

4956
// IsNullable implements the Expression interface.
@@ -79,7 +86,7 @@ func (f UAST) Children() []sql.Expression {
7986

8087
// TransformUp implements the Expression interface.
8188
func (f UAST) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
82-
var lang, xpath sql.Expression
89+
var lang, xpath, mode sql.Expression
8390
blob, err := f.Blob.TransformUp(fn)
8491
if err != nil {
8592
return nil, err
@@ -99,10 +106,21 @@ func (f UAST) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
99106
}
100107
}
101108

102-
return fn(&UAST{Blob: blob, Lang: lang, XPath: xpath})
109+
if f.Mode != nil {
110+
mode, err = f.Mode.TransformUp(fn)
111+
if err != nil {
112+
return nil, err
113+
}
114+
}
115+
116+
return fn(&UAST{Blob: blob, Lang: lang, XPath: xpath, Mode: mode})
103117
}
104118

105119
func (f UAST) String() string {
120+
if f.Lang != nil && f.XPath != nil && f.Mode != nil {
121+
return fmt.Sprintf("uast(%s, %s, %s, %s)", f.Blob, f.Lang, f.XPath, f.Mode)
122+
}
123+
106124
if f.Lang != nil && f.XPath != nil {
107125
return fmt.Sprintf("uast(%s, %s, %s)", f.Blob, f.Lang, f.XPath)
108126
}
@@ -184,6 +202,39 @@ func (f UAST) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err error) {
184202
xpath = x.(string)
185203
}
186204

205+
modeSet := false
206+
var mode bblfsh.Mode
207+
if f.Mode != nil {
208+
x, err := f.Mode.Eval(ctx, row)
209+
if err != nil {
210+
return nil, err
211+
}
212+
213+
if x == nil {
214+
return nil, nil
215+
}
216+
217+
x, err = sql.Text.Convert(x)
218+
if err != nil {
219+
return nil, err
220+
}
221+
222+
m := x.(string)
223+
224+
switch m {
225+
case "semantic":
226+
mode = bblfsh.Semantic
227+
case "annotated":
228+
mode = bblfsh.Annotated
229+
case "native":
230+
mode = bblfsh.Native
231+
default:
232+
return nil, fmt.Errorf("invalid uast mode %s", m)
233+
}
234+
235+
modeSet = true
236+
}
237+
187238
client, err := session.BblfshClient()
188239
if err != nil {
189240
return nil, err
@@ -202,13 +253,19 @@ func (f UAST) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err error) {
202253
}
203254
}
204255

205-
resp, err := client.Parse(ctx, lang, bytes)
256+
var resp *protocol.ParseResponse
257+
if modeSet {
258+
resp, err = client.ParseWithMode(ctx, mode, lang, bytes)
259+
} else {
260+
resp, err = client.Parse(ctx, lang, bytes)
261+
}
262+
206263
if err != nil {
207264
logrus.Warn(ErrParseBlob.New(err))
208265
return nil, nil
209266
}
210267

211-
if resp.Status != protocol.Ok {
268+
if len(resp.Errors) > 0 {
212269
logrus.Warn(ErrParseBlob.New(strings.Join(resp.Errors, "\n")))
213270
}
214271

internal/function/uast_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/src-d/gitbase"
88
"github.com/stretchr/testify/require"
9+
bblfsh "gopkg.in/bblfsh/client-go.v2"
910
"gopkg.in/bblfsh/client-go.v2/tools"
1011
"gopkg.in/bblfsh/sdk.v1/protocol"
1112
"gopkg.in/bblfsh/sdk.v1/uast"
@@ -155,7 +156,7 @@ func bblfshFixtures(t *testing.T, ctx *sql.Context) (uast []interface{}, filtere
155156
client, err := ctx.Session.(*gitbase.Session).BblfshClient()
156157
require.NoError(t, err)
157158

158-
resp, err := client.Parse(context.Background(), "python", []byte(testCode))
159+
resp, err := client.Parse(context.Background(), bblfsh.Semantic, "python", []byte(testCode))
159160
require.NoError(t, err)
160161
require.Equal(t, protocol.Ok, resp.Status, "errors: %v", resp.Errors)
161162
testUAST, err := resp.UAST.Marshal()

session.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ func (c *BblfshClient) Parse(
137137
DoWithContext(ctx)
138138
}
139139

140+
// ParseWithMode the given content with the given language.
141+
func (c *BblfshClient) ParseWithMode(
142+
ctx context.Context,
143+
mode bblfsh.Mode,
144+
lang string,
145+
content []byte,
146+
) (*protocol.ParseResponse, error) {
147+
return c.NewParseRequest().
148+
Mode(mode).
149+
Language(lang).
150+
Content(string(content)).
151+
DoWithContext(ctx)
152+
}
153+
140154
// BblfshClient returns a BblfshClient.
141155
func (s *Session) BblfshClient() (*BblfshClient, error) {
142156
s.bblfshMu.Lock()

0 commit comments

Comments
 (0)