@@ -91,8 +91,13 @@ func (s *spi) String() string {
9191 return fmt .Sprintf ("SPI(FWD: 0x%x, REV: 0x%x)" , uint32 (s .forward ), uint32 (s .reverse ))
9292}
9393
94+ type encrNode struct {
95+ spi []spi
96+ count int
97+ }
98+
9499type encrMap struct {
95- nodes map [netip.Addr ][] * spi
100+ nodes map [netip.Addr ]encrNode
96101 sync.Mutex
97102}
98103
@@ -105,7 +110,7 @@ func (e *encrMap) String() string {
105110 b .WriteString (k .String ())
106111 b .WriteString (":" )
107112 b .WriteString ("[" )
108- for _ , s := range v {
113+ for _ , s := range v . spi {
109114 b .WriteString (s .String ())
110115 b .WriteString ("," )
111116 }
@@ -126,10 +131,10 @@ func (d *driver) setupEncryption(remoteIP netip.Addr) error {
126131 }
127132 log .G (context .TODO ()).Debugf ("Programming encryption between %s and %s" , localIP , remoteIP )
128133
129- indices := make ([]* spi , 0 , len (keys ))
134+ indices := make ([]spi , 0 , len (keys ))
130135
131136 for i , k := range keys {
132- spis := & spi {buildSPI (advIP .AsSlice (), remoteIP .AsSlice (), k .tag ), buildSPI (remoteIP .AsSlice (), advIP .AsSlice (), k .tag )}
137+ spis := spi {buildSPI (advIP .AsSlice (), remoteIP .AsSlice (), k .tag ), buildSPI (remoteIP .AsSlice (), advIP .AsSlice (), k .tag )}
133138 dir := reverse
134139 if i == 0 {
135140 dir = bidir
@@ -149,7 +154,10 @@ func (d *driver) setupEncryption(remoteIP netip.Addr) error {
149154 }
150155
151156 d .secMap .Lock ()
152- d .secMap .nodes [remoteIP ] = indices
157+ node := d .secMap .nodes [remoteIP ]
158+ node .spi = indices
159+ node .count ++
160+ d .secMap .nodes [remoteIP ] = node
153161 d .secMap .Unlock ()
154162
155163 return nil
@@ -158,13 +166,20 @@ func (d *driver) setupEncryption(remoteIP netip.Addr) error {
158166func (d * driver ) removeEncryption (remoteIP netip.Addr ) error {
159167 log .G (context .TODO ()).Debugf ("removeEncryption(%s)" , remoteIP )
160168
161- d .secMap .Lock ()
162- indices , ok := d .secMap .nodes [remoteIP ]
163- d .secMap .Unlock ()
164- if ! ok {
169+ spi := func () []spi {
170+ d .secMap .Lock ()
171+ defer d .secMap .Unlock ()
172+ node := d .secMap .nodes [remoteIP ]
173+ if node .count == 1 {
174+ delete (d .secMap .nodes , remoteIP )
175+ return node .spi
176+ }
177+ node .count --
178+ d .secMap .nodes [remoteIP ] = node
165179 return nil
166- }
167- for i , idxs := range indices {
180+ }()
181+
182+ for i , idxs := range spi {
168183 dir := reverse
169184 if i == 0 {
170185 dir = bidir
@@ -263,7 +278,7 @@ func (d *driver) programInput(vni uint32, add bool) error {
263278 return nil
264279}
265280
266- func programSA (localIP , remoteIP net.IP , spi * spi , k * key , dir int , add bool ) (fSA * netlink.XfrmState , rSA * netlink.XfrmState , lastErr error ) {
281+ func programSA (localIP , remoteIP net.IP , spi spi , k * key , dir int , add bool ) (fSA * netlink.XfrmState , rSA * netlink.XfrmState , lastErr error ) {
267282 var (
268283 action = "Removing"
269284 xfrmProgram = ns .NlHandle ().XfrmStateDel
@@ -436,12 +451,12 @@ func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
436451 }
437452}
438453
439- func (d * driver ) secMapWalk (f func (netip.Addr , []* spi ) ([]* spi , bool )) error {
454+ func (d * driver ) secMapWalk (f func (netip.Addr , []spi ) ([]spi , bool )) error {
440455 d .secMap .Lock ()
441- for node , indices := range d .secMap .nodes {
442- idxs , stop := f (node , indices )
456+ for rIP , node := range d .secMap .nodes {
457+ idxs , stop := f (rIP , node . spi )
443458 if idxs != nil {
444- d .secMap .nodes [node ] = idxs
459+ d .secMap .nodes [rIP ] = encrNode { idxs , node . count }
445460 }
446461 if stop {
447462 break
@@ -457,7 +472,7 @@ func (d *driver) setKeys(keys []*key) error {
457472 // Accept the encryption keys and clear any stale encryption map
458473 d .Lock ()
459474 d .keys = keys
460- d .secMap = & encrMap {nodes : map [netip.Addr ][] * spi {}}
475+ d .secMap = & encrMap {nodes : map [netip.Addr ]encrNode {}}
461476 d .Unlock ()
462477 log .G (context .TODO ()).Debugf ("Initial encryption keys: %v" , keys )
463478 return nil
@@ -506,7 +521,7 @@ func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
506521 return types .InvalidParameterErrorf ("attempting to both make a key (index %d) primary and delete it" , priIdx )
507522 }
508523
509- d .secMapWalk (func (rIP netip.Addr , spis []* spi ) ([]* spi , bool ) {
524+ d .secMapWalk (func (rIP netip.Addr , spis []spi ) ([]spi , bool ) {
510525 return updateNodeKey (lIP .AsSlice (), aIP .AsSlice (), rIP .AsSlice (), spis , d .keys , newIdx , priIdx , delIdx ), false
511526 })
512527
@@ -534,15 +549,15 @@ func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
534549 *********************************************************/
535550
536551// Spis and keys are sorted in such away the one in position 0 is the primary
537- func updateNodeKey (lIP , aIP , rIP net.IP , idxs []* spi , curKeys []* key , newIdx , priIdx , delIdx int ) []* spi {
552+ func updateNodeKey (lIP , aIP , rIP net.IP , idxs []spi , curKeys []* key , newIdx , priIdx , delIdx int ) []spi {
538553 log .G (context .TODO ()).Debugf ("Updating keys for node: %s (%d,%d,%d)" , rIP , newIdx , priIdx , delIdx )
539554
540555 spis := idxs
541556 log .G (context .TODO ()).Debugf ("Current: %v" , spis )
542557
543558 // add new
544559 if newIdx != - 1 {
545- spis = append (spis , & spi {
560+ spis = append (spis , spi {
546561 forward : buildSPI (aIP , rIP , curKeys [newIdx ].tag ),
547562 reverse : buildSPI (rIP , aIP , curKeys [newIdx ].tag ),
548563 })
0 commit comments