Skip to content

Commit 2d95638

Browse files
authored
Merge pull request kubernetes#73313 from pivotal-k8s/csi-drivers-list
Refactor csiDriversStore
2 parents ae45068 + 84c4662 commit 2d95638

File tree

6 files changed

+167
-53
lines changed

6 files changed

+167
-53
lines changed

pkg/volume/csi/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ go_library(
66
"csi_attacher.go",
77
"csi_block.go",
88
"csi_client.go",
9+
"csi_drivers_store.go",
910
"csi_mounter.go",
1011
"csi_plugin.go",
1112
"csi_util.go",
@@ -44,6 +45,7 @@ go_test(
4445
"csi_attacher_test.go",
4546
"csi_block_test.go",
4647
"csi_client_test.go",
48+
"csi_drivers_store_test.go",
4749
"csi_mounter_test.go",
4850
"csi_plugin_test.go",
4951
],

pkg/volume/csi/csi_client.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,12 @@ func newCsiDriverClient(driverName csiDriverName) (*csiDriverClient, error) {
147147
addr := fmt.Sprintf(csiAddrTemplate, driverName)
148148
requiresV0Client := true
149149
if utilfeature.DefaultFeatureGate.Enabled(features.KubeletPluginsWatcher) {
150-
var existingDriver csiDriver
151-
driverExists := false
152-
func() {
153-
csiDrivers.RLock()
154-
defer csiDrivers.RUnlock()
155-
existingDriver, driverExists = csiDrivers.driversMap[string(driverName)]
156-
}()
157-
150+
existingDriver, driverExists := csiDrivers.Get(string(driverName))
158151
if !driverExists {
159152
return nil, fmt.Errorf("driver name %s not found in the list of registered CSI drivers", driverName)
160153
}
161154

162-
addr = existingDriver.driverEndpoint
155+
addr = existingDriver.endpoint
163156
requiresV0Client = versionRequiresV0Client(existingDriver.highestSupportedVersion)
164157
}
165158

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
Copyright 2019 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package csi
18+
19+
import (
20+
"sync"
21+
22+
utilversion "k8s.io/apimachinery/pkg/util/version"
23+
)
24+
25+
// Driver is a description of a CSI Driver, defined by an enpoint and the
26+
// highest CSI version supported
27+
type Driver struct {
28+
endpoint string
29+
highestSupportedVersion *utilversion.Version
30+
}
31+
32+
// DriversStore holds a list of CSI Drivers
33+
type DriversStore struct {
34+
store
35+
sync.RWMutex
36+
}
37+
38+
type store map[string]Driver
39+
40+
// Get lets you retrieve a CSI Driver by name.
41+
// This method is protected by a mutex.
42+
func (s *DriversStore) Get(driverName string) (Driver, bool) {
43+
s.RLock()
44+
defer s.RUnlock()
45+
46+
driver, ok := s.store[driverName]
47+
return driver, ok
48+
}
49+
50+
// Set lets you save a CSI Driver to the list and give it a specific name.
51+
// This method is protected by a mutex.
52+
func (s *DriversStore) Set(driverName string, driver Driver) {
53+
s.Lock()
54+
defer s.Unlock()
55+
56+
if s.store == nil {
57+
s.store = store{}
58+
}
59+
60+
s.store[driverName] = driver
61+
}
62+
63+
// Delete lets you delete a CSI Driver by name.
64+
// This method is protected by a mutex.
65+
func (s *DriversStore) Delete(driverName string) {
66+
s.Lock()
67+
defer s.Unlock()
68+
69+
delete(s.store, driverName)
70+
}
71+
72+
// Clear deletes all entries in the store.
73+
// This methiod is protected by a mutex.
74+
func (s *DriversStore) Clear() {
75+
s.Lock()
76+
defer s.Unlock()
77+
78+
s.store = store{}
79+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
Copyright 2019 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package csi_test
18+
19+
import (
20+
"reflect"
21+
"testing"
22+
23+
"k8s.io/kubernetes/pkg/volume/csi"
24+
)
25+
26+
func TestDriversStore(t *testing.T) {
27+
store := &csi.DriversStore{}
28+
someDriver := csi.Driver{}
29+
30+
expectAbsent(t, store, "does-not-exist")
31+
32+
store.Set("some-driver", someDriver)
33+
expectPresent(t, store, "some-driver", someDriver)
34+
35+
store.Delete("some-driver")
36+
expectAbsent(t, store, "some-driver")
37+
38+
store.Set("some-driver", someDriver)
39+
40+
store.Clear()
41+
expectAbsent(t, store, "some-driver")
42+
}
43+
44+
func expectPresent(t *testing.T, store *csi.DriversStore, name string, expected csi.Driver) {
45+
t.Helper()
46+
47+
retrieved, ok := store.Get(name)
48+
49+
if !ok {
50+
t.Fatalf("expected driver '%s' to exist", name)
51+
}
52+
53+
if !reflect.DeepEqual(retrieved, expected) {
54+
t.Fatalf("expected driver '%s' to be equal to %v", name, expected)
55+
}
56+
}
57+
58+
func expectAbsent(t *testing.T, store *csi.DriversStore, name string) {
59+
t.Helper()
60+
61+
if _, ok := store.Get(name); ok {
62+
t.Fatalf("expected driver '%s' not to exist in store", name)
63+
}
64+
}

pkg/volume/csi/csi_plugin.go

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"path"
2424
"sort"
2525
"strings"
26-
"sync"
2726
"time"
2827

2928
"context"
@@ -84,25 +83,14 @@ func ProbeVolumePlugins() []volume.VolumePlugin {
8483
// volume.VolumePlugin methods
8584
var _ volume.VolumePlugin = &csiPlugin{}
8685

87-
type csiDriver struct {
88-
driverName string
89-
driverEndpoint string
90-
highestSupportedVersion *utilversion.Version
91-
}
92-
93-
type csiDriversStore struct {
94-
driversMap map[string]csiDriver
95-
sync.RWMutex
96-
}
97-
9886
// RegistrationHandler is the handler which is fed to the pluginwatcher API.
9987
type RegistrationHandler struct {
10088
}
10189

10290
// TODO (verult) consider using a struct instead of global variables
10391
// csiDrivers map keep track of all registered CSI drivers on the node and their
10492
// corresponding sockets
105-
var csiDrivers csiDriversStore
93+
var csiDrivers = &DriversStore{}
10694

10795
var nim nodeinfomanager.Interface
10896

@@ -141,17 +129,12 @@ func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string,
141129
return err
142130
}
143131

144-
func() {
145-
// Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key
146-
// all other CSI components will be able to get the actual socket of CSI drivers by its name.
147-
148-
// It's not necessary to lock the entire RegistrationCallback() function because only the CSI
149-
// client depends on this driver map, and the CSI client does not depend on node information
150-
// updated in the rest of the function.
151-
csiDrivers.Lock()
152-
defer csiDrivers.Unlock()
153-
csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint, highestSupportedVersion: highestSupportedVersion}
154-
}()
132+
// Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key
133+
// all other CSI components will be able to get the actual socket of CSI drivers by its name.
134+
csiDrivers.Set(pluginName, Driver{
135+
endpoint: endpoint,
136+
highestSupportedVersion: highestSupportedVersion,
137+
})
155138

156139
// Get node info from the driver.
157140
csi, err := newCsiDriverClient(csiDriverName(pluginName))
@@ -200,15 +183,7 @@ func (h *RegistrationHandler) validateVersions(callerName, pluginName string, en
200183
return nil, err
201184
}
202185

203-
// Check for existing drivers with the same name
204-
var existingDriver csiDriver
205-
driverExists := false
206-
func() {
207-
csiDrivers.RLock()
208-
defer csiDrivers.RUnlock()
209-
existingDriver, driverExists = csiDrivers.driversMap[pluginName]
210-
}()
211-
186+
existingDriver, driverExists := csiDrivers.Get(pluginName)
212187
if driverExists {
213188
if !existingDriver.highestSupportedVersion.LessThan(newDriverHighestVersion) {
214189
err := fmt.Errorf("%s for CSI driver %q failed. Another driver with the same name is already registered with a higher supported version: %q", callerName, pluginName, existingDriver.highestSupportedVersion)
@@ -245,8 +220,7 @@ func (p *csiPlugin) Init(host volume.VolumeHost) error {
245220
}
246221
}
247222

248-
// Initializing csiDrivers map and label management channels
249-
csiDrivers = csiDriversStore{driversMap: map[string]csiDriver{}}
223+
// Initializing the label management channels
250224
nim = nodeinfomanager.NewNodeInfoManager(host.GetNodeName(), host)
251225

252226
// TODO(#70514) Init CSINodeInfo object if the CRD exists and create Driver
@@ -657,11 +631,7 @@ func (p *csiPlugin) getPublishContext(client clientset.Interface, handle, driver
657631
}
658632

659633
func unregisterDriver(driverName string) error {
660-
func() {
661-
csiDrivers.Lock()
662-
defer csiDrivers.Unlock()
663-
delete(csiDrivers.driversMap, driverName)
664-
}()
634+
csiDrivers.Delete(driverName)
665635

666636
if err := nim.UninstallCSIDriver(driverName); err != nil {
667637
klog.Errorf("Error uninstalling CSI driver: %v", err)

pkg/volume/csi/csi_plugin_test.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,16 @@ func makeTestPV(name string, sizeGig int, driverName, volID string) *api.Persist
105105
}
106106

107107
func registerFakePlugin(pluginName, endpoint string, versions []string, t *testing.T) {
108-
csiDrivers = csiDriversStore{driversMap: map[string]csiDriver{}}
109108
highestSupportedVersions, err := highestSupportedVersion(versions)
110109
if err != nil {
111110
t.Fatalf("unexpected error parsing versions (%v) for pluginName % q endpoint %q: %#v", versions, pluginName, endpoint, err)
112111
}
113112

114-
csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint, highestSupportedVersion: highestSupportedVersions}
113+
csiDrivers.Clear()
114+
csiDrivers.Set(pluginName, Driver{
115+
endpoint: endpoint,
116+
highestSupportedVersion: highestSupportedVersions,
117+
})
115118
}
116119

117120
func TestPluginGetPluginName(t *testing.T) {
@@ -839,13 +842,16 @@ func TestValidatePluginExistingDriver(t *testing.T) {
839842

840843
for _, tc := range testCases {
841844
// Arrange & Act
842-
csiDrivers = csiDriversStore{driversMap: map[string]csiDriver{}}
843845
highestSupportedVersions1, err := highestSupportedVersion(tc.versions1)
844846
if err != nil {
845847
t.Fatalf("unexpected error parsing version for testcase: %#v", tc)
846848
}
847849

848-
csiDrivers.driversMap[tc.pluginName1] = csiDriver{driverName: tc.pluginName1, driverEndpoint: tc.endpoint1, highestSupportedVersion: highestSupportedVersions1}
850+
csiDrivers.Clear()
851+
csiDrivers.Set(tc.pluginName1, Driver{
852+
endpoint: tc.endpoint1,
853+
highestSupportedVersion: highestSupportedVersions1,
854+
})
849855

850856
// Arrange & Act
851857
err = PluginHandler.ValidatePlugin(tc.pluginName2, tc.endpoint2, tc.versions2, tc.foundInDeprecatedDir2)

0 commit comments

Comments
 (0)