1
1
package rule
2
2
3
3
import (
4
+ "fmt"
5
+ "reflect"
6
+
4
7
"github.com/src-d/gitbase"
5
8
errors "gopkg.in/src-d/go-errors.v1"
6
9
"gopkg.in/src-d/go-mysql-server.v0/sql"
@@ -50,7 +53,7 @@ func SquashJoins(
50
53
return n , nil
51
54
}
52
55
53
- return buildSquashedTable (t .tables , t .filters , t .columns , t .indexes )
56
+ return buildSquashedTable (a , t .tables , t .filters , t .columns , t .indexes )
54
57
})
55
58
56
59
if err != nil {
@@ -214,7 +217,13 @@ func rearrange(join *plan.InnerJoin, squashedTable *joinedTables) sql.Node {
214
217
215
218
var errInvalidIteratorChain = errors .NewKind ("invalid iterator to chain with %s: %T" )
216
219
220
+ type unsquashableTable struct {
221
+ table sql.Table
222
+ filters []sql.Expression
223
+ }
224
+
217
225
func buildSquashedTable (
226
+ a * analyzer.Analyzer ,
218
227
tables []sql.Table ,
219
228
filters []sql.Expression ,
220
229
columns []string ,
@@ -229,9 +238,28 @@ func buildSquashedTable(
229
238
index = idx
230
239
}
231
240
241
+ tablesByName := make (map [string ]sql.Table )
242
+ for _ , t := range tables {
243
+ tablesByName [t .Name ()] = t
244
+ }
245
+
246
+ var usedTables []string
247
+ var unsquashable []unsquashableTable
248
+ addUnsquashable := func (tableName string ) {
249
+ var f []sql.Expression
250
+ f , filters = filtersForTables (filters , usedTables ... )
251
+ unsquashable = append (unsquashable , unsquashableTable {
252
+ table : tablesByName [tableName ],
253
+ filters : f ,
254
+ })
255
+ }
256
+
232
257
var iter gitbase.ChainableIter
258
+ var squashedTables []string
259
+
233
260
var err error
234
261
for _ , t := range tableNames {
262
+ usedTables = append (usedTables , t )
235
263
switch t {
236
264
case gitbase .RepositoriesTableName :
237
265
switch iter .(type ) {
@@ -245,6 +273,7 @@ func buildSquashedTable(
245
273
if err != nil {
246
274
return nil , err
247
275
}
276
+
248
277
iter = gitbase .NewAllReposIter (f )
249
278
default :
250
279
return nil , errInvalidIteratorChain .New ("repositories" , iter )
@@ -274,9 +303,11 @@ func buildSquashedTable(
274
303
if err != nil {
275
304
return nil , err
276
305
}
306
+
277
307
iter = gitbase .NewAllRemotesIter (f )
278
308
default :
279
- return nil , errInvalidIteratorChain .New ("remotes" , iter )
309
+ addUnsquashable (gitbase .RemotesTableName )
310
+ continue
280
311
}
281
312
case gitbase .ReferencesTableName :
282
313
switch it := iter .(type ) {
@@ -323,7 +354,8 @@ func buildSquashedTable(
323
354
iter = gitbase .NewIndexRefsIter (f , index )
324
355
}
325
356
default :
326
- return nil , errInvalidIteratorChain .New ("refs" , iter )
357
+ addUnsquashable (gitbase .ReferencesTableName )
358
+ continue
327
359
}
328
360
case gitbase .RefCommitsTableName :
329
361
switch it := iter .(type ) {
@@ -375,7 +407,8 @@ func buildSquashedTable(
375
407
iter = gitbase .NewAllRefCommitsIter (f )
376
408
}
377
409
default :
378
- return nil , errInvalidIteratorChain .New ("ref_commits" , iter )
410
+ addUnsquashable (gitbase .RefCommitsTableName )
411
+ continue
379
412
}
380
413
case gitbase .CommitsTableName :
381
414
switch it := iter .(type ) {
@@ -435,7 +468,8 @@ func buildSquashedTable(
435
468
iter = gitbase .NewAllCommitsIter (f , false )
436
469
}
437
470
default :
438
- return nil , errInvalidIteratorChain .New ("commits" , iter )
471
+ addUnsquashable (gitbase .CommitsTableName )
472
+ continue
439
473
}
440
474
case gitbase .CommitTreesTableName :
441
475
switch it := iter .(type ) {
@@ -517,7 +551,8 @@ func buildSquashedTable(
517
551
iter = gitbase .NewAllCommitTreesIter (f )
518
552
}
519
553
default :
520
- return nil , errInvalidIteratorChain .New ("commit_trees" , iter )
554
+ addUnsquashable (gitbase .CommitTreesTableName )
555
+ continue
521
556
}
522
557
case gitbase .CommitBlobsTableName :
523
558
switch it := iter .(type ) {
@@ -580,7 +615,8 @@ func buildSquashedTable(
580
615
iter = gitbase .NewAllCommitBlobsIter (f )
581
616
}
582
617
default :
583
- return nil , errInvalidIteratorChain .New ("commit_blobs" , iter )
618
+ addUnsquashable (gitbase .CommitBlobsTableName )
619
+ continue
584
620
}
585
621
case gitbase .TreeEntriesTableName :
586
622
switch it := iter .(type ) {
@@ -641,7 +677,8 @@ func buildSquashedTable(
641
677
iter = gitbase .NewAllTreeEntriesIter (f )
642
678
}
643
679
default :
644
- return nil , errInvalidIteratorChain .New ("tree_entries" , iter )
680
+ addUnsquashable (gitbase .TreeEntriesTableName )
681
+ continue
645
682
}
646
683
case gitbase .BlobsTableName :
647
684
readContent := stringInSlice (columns , "blob_content" )
@@ -687,7 +724,8 @@ func buildSquashedTable(
687
724
688
725
iter = gitbase .NewTreeEntryBlobsIter (it , f , readContent )
689
726
default :
690
- return nil , errInvalidIteratorChain .New ("blobs" , iter )
727
+ addUnsquashable (gitbase .BlobsTableName )
728
+ continue
691
729
}
692
730
case gitbase .CommitFilesTableName :
693
731
switch it := iter .(type ) {
@@ -734,7 +772,8 @@ func buildSquashedTable(
734
772
iter = gitbase .NewAllCommitFilesIter (f )
735
773
}
736
774
default :
737
- return nil , errInvalidIteratorChain .New ("commit_files" , iter )
775
+ addUnsquashable (gitbase .CommitFilesTableName )
776
+ continue
738
777
}
739
778
case gitbase .FilesTableName :
740
779
readContent := stringInSlice (columns , "blob_content" )
@@ -754,9 +793,12 @@ func buildSquashedTable(
754
793
755
794
iter = gitbase .NewCommitFileFilesIter (it , f , readContent )
756
795
default :
757
- return nil , errInvalidIteratorChain .New ("files" , iter )
796
+ addUnsquashable (gitbase .FilesTableName )
797
+ continue
758
798
}
759
799
}
800
+
801
+ squashedTables = append (squashedTables , t )
760
802
}
761
803
762
804
var originalSchema sql.Schema
@@ -771,17 +813,61 @@ func buildSquashedTable(
771
813
indexedTables = []string {firstTable }
772
814
}
773
815
816
+ var squashMapping = mapping
817
+ if len (unsquashable ) > 0 {
818
+ squashMapping = nil
819
+ }
820
+
821
+ var nonSquashedFilters []sql.Expression
822
+ for _ , t := range unsquashable {
823
+ nonSquashedFilters = append (nonSquashedFilters , t .filters ... )
824
+ }
825
+ nonSquashedFilters = append (nonSquashedFilters , filters ... )
826
+ squashedFilters := filterDiff (allFilters , nonSquashedFilters )
827
+
774
828
table := gitbase .NewSquashedTable (
775
829
iter ,
776
- mapping ,
777
- allFilters ,
830
+ squashMapping ,
831
+ squashedFilters ,
778
832
indexedTables ,
779
- tableNames ... ,
833
+ squashedTables ... ,
780
834
)
781
835
var node sql.Node = plan .NewResolvedTable (table )
782
836
837
+ if len (unsquashable ) > 0 {
838
+ for _ , t := range unsquashable {
839
+ var table sql.Node = plan .NewResolvedTable (t .table )
840
+ if a .Parallelism > 1 {
841
+ table = plan .NewExchange (a .Parallelism , table )
842
+ }
843
+
844
+ if len (t .filters ) > 0 {
845
+ f , err := fixFieldIndexes (
846
+ expression .JoinAnd (t .filters ... ),
847
+ append (t .table .Schema (), node .Schema ()... ),
848
+ )
849
+ if err != nil {
850
+ return nil , err
851
+ }
852
+
853
+ node = plan .NewInnerJoin (
854
+ table ,
855
+ node ,
856
+ f ,
857
+ )
858
+ } else {
859
+ node = plan .NewCrossJoin (table , node )
860
+ }
861
+ }
862
+
863
+ node , err = projectSchema (node , originalSchema )
864
+ if err != nil {
865
+ return nil , err
866
+ }
867
+ }
868
+
783
869
if len (filters ) > 0 {
784
- f , err := fixFieldIndexes (expression .JoinAnd (filters ... ), iter . Schema () )
870
+ f , err := fixFieldIndexes (expression .JoinAnd (filters ... ), originalSchema )
785
871
if err != nil {
786
872
return nil , err
787
873
}
@@ -791,6 +877,39 @@ func buildSquashedTable(
791
877
return node , nil
792
878
}
793
879
880
+ var errUnsquashableFieldNotFound = errors .NewKind ("unable to unsquash table, column %s.%s not found" )
881
+
882
+ // projectSchema wraps the node in a Project node that has the same schema as
883
+ // the one provided.
884
+ func projectSchema (node sql.Node , schema sql.Schema ) (sql.Node , error ) {
885
+ if node .Schema ().Equals (schema ) {
886
+ return node , nil
887
+ }
888
+
889
+ var columnIndexes = make (map [string ]int )
890
+ for i , col := range node .Schema () {
891
+ columnIndexes [fmt .Sprintf ("%s.%s" , col .Source , col .Name )] = i + 1
892
+ }
893
+
894
+ var project = make ([]sql.Expression , len (schema ))
895
+ for i , col := range schema {
896
+ idx := columnIndexes [fmt .Sprintf ("%s.%s" , col .Source , col .Name )]
897
+ if idx <= 0 {
898
+ return nil , errUnsquashableFieldNotFound .New (col .Source , col .Name )
899
+ }
900
+
901
+ project [i ] = expression .NewGetFieldWithTable (
902
+ idx - 1 ,
903
+ col .Type ,
904
+ col .Source ,
905
+ col .Name ,
906
+ col .Nullable ,
907
+ )
908
+ }
909
+
910
+ return plan .NewProject (project , node ), nil
911
+ }
912
+
794
913
// buildSchemaMapping returns a mapping to convert the actual schema into the
795
914
// original schema. If both schemas are equal, nil will be returned.
796
915
func buildSchemaMapping (original , actual sql.Schema ) []int {
@@ -1409,3 +1528,23 @@ func fixFieldIndexes(e sql.Expression, schema sql.Schema) (sql.Expression, error
1409
1528
return nil , analyzer .ErrColumnTableNotFound .New (gf .Table (), gf .Name ())
1410
1529
})
1411
1530
}
1531
+
1532
+ func filterDiff (a , b []sql.Expression ) []sql.Expression {
1533
+ var result []sql.Expression
1534
+
1535
+ for _ , fa := range a {
1536
+ var found bool
1537
+ for _ , fb := range b {
1538
+ if reflect .DeepEqual (fa , fb ) {
1539
+ found = true
1540
+ break
1541
+ }
1542
+ }
1543
+
1544
+ if ! found {
1545
+ result = append (result , fa )
1546
+ }
1547
+ }
1548
+
1549
+ return result
1550
+ }
0 commit comments