@@ -85,6 +85,11 @@ func cse(f *Func) {
8585 pNum ++
8686 }
8787
88+ // Keep a table to remap memory operand of any memory user which does not have a memory result (such as a regular load),
89+ // to some dominating memory operation, skipping the memory defs that do not alias with it.
90+ memTable := f .Cache .allocInt32Slice (f .NumValues ())
91+ defer f .Cache .freeInt32Slice (memTable )
92+
8893 // Split equivalence classes at points where they have
8994 // non-equivalent arguments. Repeat until we can't find any
9095 // more splits.
@@ -108,12 +113,23 @@ func cse(f *Func) {
108113
109114 // Sort by eq class of arguments.
110115 slices .SortFunc (e , func (v , w * Value ) int {
116+ _ , idxMem , _ , _ := isMemUser (v )
111117 for i , a := range v .Args {
112- b := w .Args [i ]
113- if valueEqClass [a .ID ] < valueEqClass [b .ID ] {
118+ var aId , bId ID
119+ if i != idxMem {
120+ b := w .Args [i ]
121+ aId = a .ID
122+ bId = b .ID
123+ } else {
124+ // A memory user's mem argument may be remapped to allow matching
125+ // identical load-like instructions across disjoint stores.
126+ aId , _ = getEffectiveMemoryArg (memTable , v )
127+ bId , _ = getEffectiveMemoryArg (memTable , w )
128+ }
129+ if valueEqClass [aId ] < valueEqClass [bId ] {
114130 return - 1
115131 }
116- if valueEqClass [a . ID ] > valueEqClass [b . ID ] {
132+ if valueEqClass [aId ] > valueEqClass [bId ] {
117133 return + 1
118134 }
119135 }
@@ -126,12 +142,23 @@ func cse(f *Func) {
126142 v , w := e [j - 1 ], e [j ]
127143 // Note: commutative args already correctly ordered by byArgClass.
128144 eqArgs := true
145+ _ , idxMem , _ , _ := isMemUser (v )
129146 for k , a := range v .Args {
130147 if v .Op == OpLocalAddr && k == 1 {
131148 continue
132149 }
133- b := w .Args [k ]
134- if valueEqClass [a .ID ] != valueEqClass [b .ID ] {
150+ var aId , bId ID
151+ if k != idxMem {
152+ b := w .Args [k ]
153+ aId = a .ID
154+ bId = b .ID
155+ } else {
156+ // A memory user's mem argument may be remapped to allow matching
157+ // identical load-like instructions across disjoint stores.
158+ aId , _ = getEffectiveMemoryArg (memTable , v )
159+ bId , _ = getEffectiveMemoryArg (memTable , w )
160+ }
161+ if valueEqClass [aId ] != valueEqClass [bId ] {
135162 eqArgs = false
136163 break
137164 }
@@ -180,10 +207,19 @@ func cse(f *Func) {
180207 defer f .Cache .freeValueSlice (rewrite )
181208 for _ , e := range partition {
182209 slices .SortFunc (e , func (v , w * Value ) int {
183- c := cmp .Compare (sdom .domorder (v .Block ), sdom .domorder (w .Block ))
184- if c != 0 {
210+ if c := cmp .Compare (sdom .domorder (v .Block ), sdom .domorder (w .Block )); c != 0 {
185211 return c
186212 }
213+ if _ , _ , _ , ok := isMemUser (v ); ok {
214+ // Additional ordering among the memory users within one block: prefer the earliest
215+ // possible value among the set of equivalent values, that is the one with the lowest
216+ // skip count (lowest number of memory defs skipped until their common def).
217+ _ , vSkips := getEffectiveMemoryArg (memTable , v )
218+ _ , wSkips := getEffectiveMemoryArg (memTable , w )
219+ if c := cmp .Compare (vSkips , wSkips ); c != 0 {
220+ return c
221+ }
222+ }
187223 if v .Op == OpLocalAddr {
188224 // compare the memory args for OpLocalAddrs in the same block
189225 vm := v .Args [1 ]
@@ -254,7 +290,7 @@ func cse(f *Func) {
254290 for _ , v := range b .Values {
255291 for i , w := range v .Args {
256292 if x := rewrite [w .ID ]; x != nil {
257- if w .Pos .IsStmt () == src .PosIsStmt {
293+ if w .Pos .IsStmt () == src .PosIsStmt && w . Op != OpNilCheck {
258294 // about to lose a statement marker, w
259295 // w is an input to v; if they're in the same block
260296 // and the same line, v is a good-enough new statement boundary.
@@ -420,3 +456,82 @@ func cmpVal(v, w *Value, auxIDs auxmap) types.Cmp {
420456
421457 return types .CMPeq
422458}
459+
460+ // Query if the given instruction only uses "memory" argument and we may try to skip some memory "defs" if they do not alias with its address.
461+ // Return index of pointer argument, index of "memory" argument, the access width and true on such instructions, otherwise return (-1, -1, 0, false).
462+ func isMemUser (v * Value ) (int , int , int64 , bool ) {
463+ switch v .Op {
464+ case OpLoad :
465+ return 0 , 1 , v .Type .Size (), true
466+ case OpNilCheck :
467+ return 0 , 1 , 0 , true
468+ default :
469+ return - 1 , - 1 , 0 , false
470+ }
471+ }
472+
473+ // Query if the given "memory"-defining instruction's memory destination can be analyzed for aliasing with a memory "user" instructions.
474+ // Return index of pointer argument, index of "memory" argument, the access width and true on such instructions, otherwise return (-1, -1, 0, false).
475+ func isMemDef (v * Value ) (int , int , int64 , bool ) {
476+ switch v .Op {
477+ case OpStore :
478+ return 0 , 2 , auxToType (v .Aux ).Size (), true
479+ default :
480+ return - 1 , - 1 , 0 , false
481+ }
482+ }
483+
484+ // Mem table keeps memTableSkipBits lower bits to store the number of skips of "memory" operand
485+ // and the rest to store the ID of the destination "memory"-producing instruction.
486+ const memTableSkipBits = 8
487+
488+ // The maximum ID value we are able to store in the memTable, otherwise fall back to v.ID
489+ const maxId = ID (1 << (31 - memTableSkipBits )) - 1
490+
491+ // Return the first possibly-aliased store along the memory chain starting at v's memory argument and the number of not-aliased stores skipped.
492+ func getEffectiveMemoryArg (memTable []int32 , v * Value ) (ID , uint32 ) {
493+ if code := uint32 (memTable [v .ID ]); code != 0 {
494+ return ID (code >> memTableSkipBits ), code & ((1 << memTableSkipBits ) - 1 )
495+ }
496+ if idxPtr , idxMem , width , ok := isMemUser (v ); ok {
497+ // TODO: We could early return some predefined value if width==0
498+ memId := v .Args [idxMem ].ID
499+ if memId > maxId {
500+ return memId , 0
501+ }
502+ mem , skips := skipDisjointMemDefs (v , idxPtr , idxMem , width )
503+ if mem .ID <= maxId {
504+ memId = mem .ID
505+ } else {
506+ skips = 0 // avoid the skip
507+ }
508+ memTable [v .ID ] = int32 (memId << memTableSkipBits ) | int32 (skips )
509+ return memId , skips
510+ } else {
511+ v .Block .Func .Fatalf ("expected memory user instruction: %v" , v .LongString ())
512+ }
513+ return 0 , 0
514+ }
515+
516+ // Find a memory def that's not trivially disjoint with the user instruction, count the number
517+ // of "skips" along the path. Return the corresponding memory def's value and the number of skips.
518+ func skipDisjointMemDefs (user * Value , idxUserPtr , idxUserMem int , useWidth int64 ) (* Value , uint32 ) {
519+ usePtr , mem := user .Args [idxUserPtr ], user .Args [idxUserMem ]
520+ const maxSkips = (1 << memTableSkipBits ) - 1
521+ var skips uint32
522+ for skips = 0 ; skips < maxSkips ; skips ++ {
523+ if idxPtr , idxMem , width , ok := isMemDef (mem ); ok {
524+ if mem .Args [idxMem ].Uses > 50 {
525+ // Skipping a memory def with a lot of uses may potentially increase register pressure.
526+ break
527+ }
528+ defPtr := mem .Args [idxPtr ]
529+ if disjoint (defPtr , width , usePtr , useWidth ) {
530+ mem = mem .Args [idxMem ]
531+ continue
532+ }
533+ }
534+ break
535+ }
536+ return mem , skips
537+ }
0 commit comments