@@ -275,12 +275,7 @@ func (f UASTMode) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err erro
275
275
return nil , fmt .Errorf ("invalid uast mode %s" , m )
276
276
}
277
277
278
- u , err := getUAST (ctx , bytes , lang , "" , mode )
279
- if err != nil {
280
- return nil , err
281
- }
282
-
283
- return u , nil
278
+ return getUAST (ctx , bytes , lang , "" , mode )
284
279
}
285
280
286
281
// UASTXPath performs an XPath query over the given UAST nodes.
@@ -318,37 +313,20 @@ func (f *UASTXPath) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err er
318
313
return nil , nil
319
314
}
320
315
321
- left , err = sql . Array ( sql . Blob ). Convert (left )
316
+ nodes , err := nodesFromBlobArray (left )
322
317
if err != nil {
323
318
return nil , err
324
319
}
325
320
326
- arr := left .([]interface {})
327
- var nodes = make ([]* uast.Node , len (arr ))
328
- for i , n := range arr {
329
- node := uast .NewNode ()
330
- if err := node .Unmarshal (n .([]byte )); err != nil {
331
- return nil , err
332
- }
333
- nodes [i ] = node
334
- }
335
-
336
- right , err := f .Right .Eval (ctx , row )
321
+ xpath , err := exprToString (ctx , f .Right , row )
337
322
if err != nil {
338
323
return nil , err
339
324
}
340
325
341
- if right == nil {
326
+ if xpath == "" {
342
327
return nil , nil
343
328
}
344
329
345
- right , err = sql .Text .Convert (right )
346
- if err != nil {
347
- return nil , err
348
- }
349
-
350
- xpath := right .(string )
351
-
352
330
var result []interface {}
353
331
for _ , n := range nodes {
354
332
ns , err := tools .Filter (n , xpath )
@@ -368,6 +346,25 @@ func (f *UASTXPath) Eval(ctx *sql.Context, row sql.Row) (out interface{}, err er
368
346
return result , nil
369
347
}
370
348
349
+ func nodesFromBlobArray (data interface {}) ([]* uast.Node , error ) {
350
+ data , err := sql .Array (sql .Blob ).Convert (data )
351
+ if err != nil {
352
+ return nil , err
353
+ }
354
+
355
+ arr := data .([]interface {})
356
+ var nodes = make ([]* uast.Node , len (arr ))
357
+ for i , n := range arr {
358
+ node := uast .NewNode ()
359
+ if err := node .Unmarshal (n .([]byte )); err != nil {
360
+ return nil , err
361
+ }
362
+ nodes [i ] = node
363
+ }
364
+
365
+ return nodes , nil
366
+ }
367
+
371
368
func (f UASTXPath ) String () string {
372
369
return fmt .Sprintf ("uast_xpath(%s, %s)" , f .Left , f .Right )
373
370
}
@@ -475,3 +472,117 @@ func getUAST(
475
472
476
473
return result , nil
477
474
}
475
+
476
+ // UASTExtract extracts keys from an UAST.
477
+ type UASTExtract struct {
478
+ expression.BinaryExpression
479
+ }
480
+
481
+ // NewUASTExtract creates a new UASTExtract UDF.
482
+ func NewUASTExtract (uast , key sql.Expression ) sql.Expression {
483
+ return & UASTExtract {expression.BinaryExpression {Left : uast , Right : key }}
484
+ }
485
+
486
+ // String implements the fmt.Stringer interface.
487
+ func (u * UASTExtract ) String () string {
488
+ return fmt .Sprintf ("uast_extract(%s, %s)" , u .Left , u .Right )
489
+ }
490
+
491
+ // Type implements the sql.Expression interface.
492
+ func (u * UASTExtract ) Type () sql.Type {
493
+ return sql .Array (sql .Array (sql .Text ))
494
+ }
495
+
496
+ // Eval implements the sql.Expression interface.
497
+ func (u * UASTExtract ) Eval (ctx * sql.Context , row sql.Row ) (out interface {}, err error ) {
498
+ defer func () {
499
+ if r := recover (); r != nil {
500
+ err = fmt .Errorf ("uast: unknown error: %s" , r )
501
+ }
502
+ }()
503
+
504
+ span , ctx := ctx .Span ("gitbase.UASTExtract" )
505
+ defer span .Finish ()
506
+
507
+ left , err := u .Left .Eval (ctx , row )
508
+ if err != nil {
509
+ return nil , err
510
+ }
511
+
512
+ if left == nil {
513
+ return nil , nil
514
+ }
515
+
516
+ nodes , err := nodesFromBlobArray (left )
517
+ if err != nil {
518
+ return nil , err
519
+ }
520
+
521
+ key , err := exprToString (ctx , u .Right , row )
522
+ if err != nil {
523
+ return nil , err
524
+ }
525
+
526
+ if key == "" {
527
+ return nil , nil
528
+ }
529
+
530
+ extracted := make ([][]string , len (nodes ))
531
+ for i , n := range nodes {
532
+ extracted [i ] = extractInfo (n , key )
533
+ }
534
+
535
+ return extracted , nil
536
+ }
537
+
538
+ const (
539
+ keyType = "@type"
540
+ keyToken = "@token"
541
+ keyRoles = "@role"
542
+ keyStartPos = "@startpos"
543
+ keyEndPos = "@endpos"
544
+ )
545
+
546
+ func extractInfo (n * uast.Node , key string ) []string {
547
+
548
+ info := []string {}
549
+ switch key {
550
+ case keyType :
551
+ info = append (info , n .InternalType )
552
+ case keyToken :
553
+ info = append (info , n .Token )
554
+ case keyRoles :
555
+ roles := make ([]string , len (n .Roles ))
556
+ for i , rol := range n .Roles {
557
+ roles [i ] = rol .String ()
558
+ }
559
+
560
+ info = append (info , roles ... )
561
+ case keyStartPos :
562
+ info = append (info , n .StartPosition .String ())
563
+ case keyEndPos :
564
+ info = append (info , n .EndPosition .String ())
565
+ default :
566
+ v , ok := n .Properties [key ]
567
+ if ok {
568
+ info = append (info , v )
569
+ }
570
+ }
571
+
572
+ return info
573
+ }
574
+
575
+ // TransformUp implements the sql.Expression interface.
576
+ func (u * UASTExtract ) TransformUp (f sql.TransformExprFunc ) (sql.Expression , error ) {
577
+ left , err := u .Left .TransformUp (f )
578
+ if err != nil {
579
+ return nil , err
580
+ }
581
+
582
+ rigth , err := u .Right .TransformUp (f )
583
+ if err != nil {
584
+ return nil , err
585
+ }
586
+
587
+ return f (NewUASTExtract (left , rigth ))
588
+ }
0 commit comments