Skip to content

Commit b0d6c29

Browse files
committed
refactor: identity server does not need version passed, get rid of serve
move ParseEndpoint into util pkg
1 parent 791a811 commit b0d6c29

File tree

5 files changed

+156
-101
lines changed

5 files changed

+156
-101
lines changed

pkg/driver/driver.go

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ package driver
66
import (
77
"context"
88
"fmt"
9+
"net"
910

1011
"github.com/container-storage-interface/spec/lib/go/csi"
12+
"google.golang.org/grpc"
1113
"k8s.io/klog/v2"
1214

1315
"github.com/leaseweb/cloudstack-csi-driver/pkg/cloud"
1416
"github.com/leaseweb/cloudstack-csi-driver/pkg/mount"
17+
"github.com/leaseweb/cloudstack-csi-driver/pkg/util"
1518
)
1619

1720
// Interface is the CloudStack CSI driver interface.
@@ -22,7 +25,6 @@ type Interface interface {
2225

2326
type cloudstackDriver struct {
2427
controller csi.ControllerServer
25-
identity csi.IdentityServer
2628
node csi.NodeServer
2729
options *Options
2830
}
@@ -40,7 +42,6 @@ func New(ctx context.Context, csConnector cloud.Interface, options *Options, mou
4042
options: options,
4143
}
4244

43-
driver.identity = NewIdentityServer(driverVersion)
4445
switch options.Mode {
4546
case ControllerMode:
4647
driver.controller = NewControllerServer(csConnector)
@@ -57,7 +58,46 @@ func New(ctx context.Context, csConnector cloud.Interface, options *Options, mou
5758
}
5859

5960
func (cs *cloudstackDriver) Run(ctx context.Context) error {
60-
return cs.serve(ctx)
61+
logger := klog.FromContext(ctx)
62+
scheme, addr, err := util.ParseEndpoint(cs.options.Endpoint)
63+
if err != nil {
64+
return err
65+
}
66+
67+
listener, err := net.Listen(scheme, addr)
68+
if err != nil {
69+
return fmt.Errorf("failed to listen: %w", err)
70+
}
71+
72+
// Log every request and payloads (request + response)
73+
opts := []grpc.ServerOption{
74+
grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
75+
resp, err := handler(klog.NewContext(ctx, logger), req)
76+
if err != nil {
77+
logger.Error(err, "GRPC method failed", "method", info.FullMethod)
78+
}
79+
80+
return resp, err
81+
}),
82+
}
83+
grpcServer := grpc.NewServer(opts...)
84+
85+
csi.RegisterIdentityServer(grpcServer, cs)
86+
switch cs.options.Mode {
87+
case ControllerMode:
88+
csi.RegisterControllerServer(grpcServer, cs.controller)
89+
case NodeMode:
90+
csi.RegisterNodeServer(grpcServer, cs.node)
91+
case AllMode:
92+
csi.RegisterControllerServer(grpcServer, cs.controller)
93+
csi.RegisterNodeServer(grpcServer, cs.node)
94+
default:
95+
return fmt.Errorf("unknown mode: %s", cs.options.Mode)
96+
}
97+
98+
logger.Info("Listening for connections", "address", listener.Addr())
99+
100+
return grpcServer.Serve(listener)
61101
}
62102

63103
func validateMode(mode Mode) error {

pkg/driver/identity.go

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,28 @@ import (
44
"context"
55

66
"github.com/container-storage-interface/spec/lib/go/csi"
7-
"google.golang.org/grpc/codes"
8-
"google.golang.org/grpc/status"
97
"k8s.io/klog/v2"
108
)
119

12-
type identityServer struct {
13-
csi.UnimplementedIdentityServer
14-
version string
15-
}
16-
17-
// NewIdentityServer creates a new Identity gRPC server.
18-
func NewIdentityServer(version string) csi.IdentityServer {
19-
return &identityServer{
20-
version: version,
21-
}
22-
}
23-
24-
func (ids *identityServer) GetPluginInfo(ctx context.Context, req *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
10+
func (cs *cloudstackDriver) GetPluginInfo(ctx context.Context, req *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
2511
logger := klog.FromContext(ctx)
2612
logger.V(6).Info("GetPluginInfo: called", "args", *req)
27-
if ids.version == "" {
28-
return nil, status.Error(codes.Unavailable, "Driver is missing version")
29-
}
30-
3113
resp := &csi.GetPluginInfoResponse{
3214
Name: DriverName,
33-
VendorVersion: ids.version,
15+
VendorVersion: driverVersion,
3416
}
3517

3618
return resp, nil
3719
}
3820

39-
func (ids *identityServer) Probe(ctx context.Context, req *csi.ProbeRequest) (*csi.ProbeResponse, error) {
21+
func (cs *cloudstackDriver) Probe(ctx context.Context, req *csi.ProbeRequest) (*csi.ProbeResponse, error) {
4022
logger := klog.FromContext(ctx)
4123
logger.V(6).Info("Probe: called", "args", *req)
4224

4325
return &csi.ProbeResponse{}, nil
4426
}
4527

46-
func (ids *identityServer) GetPluginCapabilities(ctx context.Context, req *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
28+
func (cs *cloudstackDriver) GetPluginCapabilities(ctx context.Context, req *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
4729
logger := klog.FromContext(ctx)
4830
logger.V(6).Info("Probe: called", "args", *req)
4931

pkg/driver/server.go

Lines changed: 0 additions & 76 deletions
This file was deleted.

pkg/util/endpoint.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package util
2+
3+
import (
4+
"fmt"
5+
"net/url"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
)
10+
11+
// ParseEndpoint parses the CSI socket endpoint and returns the components scheme and addr.
12+
func ParseEndpoint(endpoint string) (string, string, error) {
13+
u, err := url.Parse(endpoint)
14+
if err != nil {
15+
return "", "", fmt.Errorf("could not parse endpoint: %w", err)
16+
}
17+
18+
addr := filepath.Join(u.Host, filepath.FromSlash(u.Path))
19+
20+
scheme := strings.ToLower(u.Scheme)
21+
switch scheme {
22+
case "tcp":
23+
case "unix":
24+
addr = filepath.Join("/", addr)
25+
// Remove the socket file if it already exists.
26+
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
27+
return "", "", fmt.Errorf("could not remove unix domain socket %q: %w", addr, err)
28+
}
29+
default:
30+
return "", "", fmt.Errorf("unsupported protocol: %s", scheme)
31+
}
32+
33+
return scheme, addr, nil
34+
}

pkg/util/endpoint_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package util
2+
3+
import (
4+
"errors"
5+
"testing"
6+
)
7+
8+
func TestParseEndpoint(t *testing.T) {
9+
testCases := []struct {
10+
name string
11+
endpoint string
12+
expScheme string
13+
expAddr string
14+
expErr error
15+
}{
16+
{
17+
name: "valid unix endpoint 1",
18+
endpoint: "unix:///csi/csi.sock",
19+
expScheme: "unix",
20+
expAddr: "/csi/csi.sock",
21+
},
22+
{
23+
name: "valid unix endpoint 2",
24+
endpoint: "unix://csi/csi.sock",
25+
expScheme: "unix",
26+
expAddr: "/csi/csi.sock",
27+
},
28+
{
29+
name: "valid unix endpoint 3",
30+
endpoint: "unix:/csi/csi.sock",
31+
expScheme: "unix",
32+
expAddr: "/csi/csi.sock",
33+
},
34+
{
35+
name: "valid tcp endpoint",
36+
endpoint: "tcp:///127.0.0.1/",
37+
expScheme: "tcp",
38+
expAddr: "/127.0.0.1",
39+
},
40+
{
41+
name: "valid tcp endpoint",
42+
endpoint: "tcp:///127.0.0.1",
43+
expScheme: "tcp",
44+
expAddr: "/127.0.0.1",
45+
},
46+
{
47+
name: "invalid endpoint",
48+
endpoint: "http://127.0.0.1",
49+
expErr: errors.New("unsupported protocol: http"),
50+
},
51+
}
52+
53+
for _, tc := range testCases {
54+
t.Run(tc.name, func(t *testing.T) {
55+
scheme, addr, err := ParseEndpoint(tc.endpoint)
56+
57+
if tc.expErr != nil {
58+
if err.Error() != tc.expErr.Error() {
59+
t.Fatalf("Expecting err: expected %v, got %v", tc.expErr, err)
60+
}
61+
} else {
62+
if err != nil {
63+
t.Fatalf("err is not nil. got: %v", err)
64+
}
65+
if scheme != tc.expScheme {
66+
t.Fatalf("scheme mismatches: expected %v, got %v", tc.expScheme, scheme)
67+
}
68+
69+
if addr != tc.expAddr {
70+
t.Fatalf("addr mismatches: expected %v, got %v", tc.expAddr, addr)
71+
}
72+
}
73+
})
74+
}
75+
}

0 commit comments

Comments
 (0)