11package rule
22
33import (
4+ "fmt"
5+ "reflect"
6+
47 "github.com/src-d/gitbase"
58 errors "gopkg.in/src-d/go-errors.v1"
69 "gopkg.in/src-d/go-mysql-server.v0/sql"
@@ -50,7 +53,7 @@ func SquashJoins(
5053 return n , nil
5154 }
5255
53- return buildSquashedTable (t .tables , t .filters , t .columns , t .indexes )
56+ return buildSquashedTable (a , t .tables , t .filters , t .columns , t .indexes )
5457 })
5558
5659 if err != nil {
@@ -214,7 +217,13 @@ func rearrange(join *plan.InnerJoin, squashedTable *joinedTables) sql.Node {
214217
215218var errInvalidIteratorChain = errors .NewKind ("invalid iterator to chain with %s: %T" )
216219
220+ type unsquashableTable struct {
221+ table sql.Table
222+ filters []sql.Expression
223+ }
224+
217225func buildSquashedTable (
226+ a * analyzer.Analyzer ,
218227 tables []sql.Table ,
219228 filters []sql.Expression ,
220229 columns []string ,
@@ -229,9 +238,28 @@ func buildSquashedTable(
229238 index = idx
230239 }
231240
241+ tablesByName := make (map [string ]sql.Table )
242+ for _ , t := range tables {
243+ tablesByName [t .Name ()] = t
244+ }
245+
246+ var usedTables []string
247+ var unsquashable []unsquashableTable
248+ addUnsquashable := func (tableName string ) {
249+ var f []sql.Expression
250+ f , filters = filtersForTables (filters , usedTables ... )
251+ unsquashable = append (unsquashable , unsquashableTable {
252+ table : tablesByName [tableName ],
253+ filters : f ,
254+ })
255+ }
256+
232257 var iter gitbase.ChainableIter
258+ var squashedTables []string
259+
233260 var err error
234261 for _ , t := range tableNames {
262+ usedTables = append (usedTables , t )
235263 switch t {
236264 case gitbase .RepositoriesTableName :
237265 switch iter .(type ) {
@@ -245,6 +273,7 @@ func buildSquashedTable(
245273 if err != nil {
246274 return nil , err
247275 }
276+
248277 iter = gitbase .NewAllReposIter (f )
249278 default :
250279 return nil , errInvalidIteratorChain .New ("repositories" , iter )
@@ -274,9 +303,11 @@ func buildSquashedTable(
274303 if err != nil {
275304 return nil , err
276305 }
306+
277307 iter = gitbase .NewAllRemotesIter (f )
278308 default :
279- return nil , errInvalidIteratorChain .New ("remotes" , iter )
309+ addUnsquashable (gitbase .RemotesTableName )
310+ continue
280311 }
281312 case gitbase .ReferencesTableName :
282313 switch it := iter .(type ) {
@@ -323,7 +354,8 @@ func buildSquashedTable(
323354 iter = gitbase .NewIndexRefsIter (f , index )
324355 }
325356 default :
326- return nil , errInvalidIteratorChain .New ("refs" , iter )
357+ addUnsquashable (gitbase .ReferencesTableName )
358+ continue
327359 }
328360 case gitbase .RefCommitsTableName :
329361 switch it := iter .(type ) {
@@ -375,7 +407,8 @@ func buildSquashedTable(
375407 iter = gitbase .NewAllRefCommitsIter (f )
376408 }
377409 default :
378- return nil , errInvalidIteratorChain .New ("ref_commits" , iter )
410+ addUnsquashable (gitbase .RefCommitsTableName )
411+ continue
379412 }
380413 case gitbase .CommitsTableName :
381414 switch it := iter .(type ) {
@@ -435,7 +468,8 @@ func buildSquashedTable(
435468 iter = gitbase .NewAllCommitsIter (f , false )
436469 }
437470 default :
438- return nil , errInvalidIteratorChain .New ("commits" , iter )
471+ addUnsquashable (gitbase .CommitsTableName )
472+ continue
439473 }
440474 case gitbase .CommitTreesTableName :
441475 switch it := iter .(type ) {
@@ -517,7 +551,8 @@ func buildSquashedTable(
517551 iter = gitbase .NewAllCommitTreesIter (f )
518552 }
519553 default :
520- return nil , errInvalidIteratorChain .New ("commit_trees" , iter )
554+ addUnsquashable (gitbase .CommitTreesTableName )
555+ continue
521556 }
522557 case gitbase .CommitBlobsTableName :
523558 switch it := iter .(type ) {
@@ -580,7 +615,8 @@ func buildSquashedTable(
580615 iter = gitbase .NewAllCommitBlobsIter (f )
581616 }
582617 default :
583- return nil , errInvalidIteratorChain .New ("commit_blobs" , iter )
618+ addUnsquashable (gitbase .CommitBlobsTableName )
619+ continue
584620 }
585621 case gitbase .TreeEntriesTableName :
586622 switch it := iter .(type ) {
@@ -641,7 +677,8 @@ func buildSquashedTable(
641677 iter = gitbase .NewAllTreeEntriesIter (f )
642678 }
643679 default :
644- return nil , errInvalidIteratorChain .New ("tree_entries" , iter )
680+ addUnsquashable (gitbase .TreeEntriesTableName )
681+ continue
645682 }
646683 case gitbase .BlobsTableName :
647684 readContent := stringInSlice (columns , "blob_content" )
@@ -687,7 +724,8 @@ func buildSquashedTable(
687724
688725 iter = gitbase .NewTreeEntryBlobsIter (it , f , readContent )
689726 default :
690- return nil , errInvalidIteratorChain .New ("blobs" , iter )
727+ addUnsquashable (gitbase .BlobsTableName )
728+ continue
691729 }
692730 case gitbase .CommitFilesTableName :
693731 switch it := iter .(type ) {
@@ -734,7 +772,8 @@ func buildSquashedTable(
734772 iter = gitbase .NewAllCommitFilesIter (f )
735773 }
736774 default :
737- return nil , errInvalidIteratorChain .New ("commit_files" , iter )
775+ addUnsquashable (gitbase .CommitFilesTableName )
776+ continue
738777 }
739778 case gitbase .FilesTableName :
740779 readContent := stringInSlice (columns , "blob_content" )
@@ -754,9 +793,12 @@ func buildSquashedTable(
754793
755794 iter = gitbase .NewCommitFileFilesIter (it , f , readContent )
756795 default :
757- return nil , errInvalidIteratorChain .New ("files" , iter )
796+ addUnsquashable (gitbase .FilesTableName )
797+ continue
758798 }
759799 }
800+
801+ squashedTables = append (squashedTables , t )
760802 }
761803
762804 var originalSchema sql.Schema
@@ -771,17 +813,61 @@ func buildSquashedTable(
771813 indexedTables = []string {firstTable }
772814 }
773815
816+ var squashMapping = mapping
817+ if len (unsquashable ) > 0 {
818+ squashMapping = nil
819+ }
820+
821+ var nonSquashedFilters []sql.Expression
822+ for _ , t := range unsquashable {
823+ nonSquashedFilters = append (nonSquashedFilters , t .filters ... )
824+ }
825+ nonSquashedFilters = append (nonSquashedFilters , filters ... )
826+ squashedFilters := filterDiff (allFilters , nonSquashedFilters )
827+
774828 table := gitbase .NewSquashedTable (
775829 iter ,
776- mapping ,
777- allFilters ,
830+ squashMapping ,
831+ squashedFilters ,
778832 indexedTables ,
779- tableNames ... ,
833+ squashedTables ... ,
780834 )
781835 var node sql.Node = plan .NewResolvedTable (table )
782836
837+ if len (unsquashable ) > 0 {
838+ for _ , t := range unsquashable {
839+ var table sql.Node = plan .NewResolvedTable (t .table )
840+ if a .Parallelism > 1 {
841+ table = plan .NewExchange (a .Parallelism , table )
842+ }
843+
844+ if len (t .filters ) > 0 {
845+ f , err := fixFieldIndexes (
846+ expression .JoinAnd (t .filters ... ),
847+ append (t .table .Schema (), node .Schema ()... ),
848+ )
849+ if err != nil {
850+ return nil , err
851+ }
852+
853+ node = plan .NewInnerJoin (
854+ table ,
855+ node ,
856+ f ,
857+ )
858+ } else {
859+ node = plan .NewCrossJoin (table , node )
860+ }
861+ }
862+
863+ node , err = projectSchema (node , originalSchema )
864+ if err != nil {
865+ return nil , err
866+ }
867+ }
868+
783869 if len (filters ) > 0 {
784- f , err := fixFieldIndexes (expression .JoinAnd (filters ... ), iter . Schema () )
870+ f , err := fixFieldIndexes (expression .JoinAnd (filters ... ), originalSchema )
785871 if err != nil {
786872 return nil , err
787873 }
@@ -791,6 +877,39 @@ func buildSquashedTable(
791877 return node , nil
792878}
793879
880+ var errUnsquashableFieldNotFound = errors .NewKind ("unable to unsquash table, column %s.%s not found" )
881+
882+ // projectSchema wraps the node in a Project node that has the same schema as
883+ // the one provided.
884+ func projectSchema (node sql.Node , schema sql.Schema ) (sql.Node , error ) {
885+ if node .Schema ().Equals (schema ) {
886+ return node , nil
887+ }
888+
889+ var columnIndexes = make (map [string ]int )
890+ for i , col := range node .Schema () {
891+ columnIndexes [fmt .Sprintf ("%s.%s" , col .Source , col .Name )] = i + 1
892+ }
893+
894+ var project = make ([]sql.Expression , len (schema ))
895+ for i , col := range schema {
896+ idx := columnIndexes [fmt .Sprintf ("%s.%s" , col .Source , col .Name )]
897+ if idx <= 0 {
898+ return nil , errUnsquashableFieldNotFound .New (col .Source , col .Name )
899+ }
900+
901+ project [i ] = expression .NewGetFieldWithTable (
902+ idx - 1 ,
903+ col .Type ,
904+ col .Source ,
905+ col .Name ,
906+ col .Nullable ,
907+ )
908+ }
909+
910+ return plan .NewProject (project , node ), nil
911+ }
912+
794913// buildSchemaMapping returns a mapping to convert the actual schema into the
795914// original schema. If both schemas are equal, nil will be returned.
796915func buildSchemaMapping (original , actual sql.Schema ) []int {
@@ -1409,3 +1528,23 @@ func fixFieldIndexes(e sql.Expression, schema sql.Schema) (sql.Expression, error
14091528 return nil , analyzer .ErrColumnTableNotFound .New (gf .Table (), gf .Name ())
14101529 })
14111530}
1531+
1532+ func filterDiff (a , b []sql.Expression ) []sql.Expression {
1533+ var result []sql.Expression
1534+
1535+ for _ , fa := range a {
1536+ var found bool
1537+ for _ , fb := range b {
1538+ if reflect .DeepEqual (fa , fb ) {
1539+ found = true
1540+ break
1541+ }
1542+ }
1543+
1544+ if ! found {
1545+ result = append (result , fa )
1546+ }
1547+ }
1548+
1549+ return result
1550+ }
0 commit comments