Skip to content

Commit 535f52e

Browse files
authored
refactor(auth): refactor GetGRPCTransportCredsAndEndpoint return type to struct (googleapis#11599)
* Embed google.golang.org/grpc/credentials.TransportCredentials interface in new struct type with additional data. refs: googleapis#11588
1 parent ce7299f commit 535f52e

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

auth/grpctransport/grpctransport.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,13 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
262262
tOpts.EnableDirectPath = io.EnableDirectPath
263263
tOpts.EnableDirectPathXds = io.EnableDirectPathXds
264264
}
265-
transportCreds, endpoint, err := transport.GetGRPCTransportCredsAndEndpoint(tOpts)
265+
transportCreds, err := transport.GetGRPCTransportCredsAndEndpoint(tOpts)
266266
if err != nil {
267267
return nil, err
268268
}
269269

270270
if !secure {
271-
transportCreds = grpcinsecure.NewCredentials()
271+
transportCreds.TransportCredentials = grpcinsecure.NewCredentials()
272272
}
273273

274274
// Initialize gRPC dial options with transport-level security options.
@@ -324,9 +324,8 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
324324
clientUniverseDomain: opts.UniverseDomain,
325325
}),
326326
)
327-
328327
// Attempt Direct Path
329-
grpcOpts, endpoint = configureDirectPath(grpcOpts, opts, endpoint, creds)
328+
grpcOpts, transportCreds.Endpoint = configureDirectPath(grpcOpts, opts, transportCreds.Endpoint, creds)
330329
}
331330

332331
// Add tracing, but before the other options, so that clients can override the
@@ -335,7 +334,7 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
335334
grpcOpts = addOpenTelemetryStatsHandler(grpcOpts, opts)
336335
grpcOpts = append(grpcOpts, opts.GRPCDialOpts...)
337336

338-
return grpc.Dial(endpoint, grpcOpts...)
337+
return grpc.Dial(transportCreds.Endpoint, grpcOpts...)
339338
}
340339

341340
// grpcKeyProvider satisfies https://pkg.go.dev/google.golang.org/grpc/credentials#PerRPCCredentials.

auth/internal/transport/cba.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,20 @@ func fixScheme(baseURL string) string {
120120
return baseURL
121121
}
122122

123+
// GRPCTransportCredentials embeds interface TransportCredentials with additional data.
124+
type GRPCTransportCredentials struct {
125+
credentials.TransportCredentials
126+
Endpoint string
127+
// TransportType TransportType
128+
}
129+
123130
// GetGRPCTransportCredsAndEndpoint returns an instance of
124131
// [google.golang.org/grpc/credentials.TransportCredentials], and the
125132
// corresponding endpoint to use for GRPC client.
126-
func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
133+
func GetGRPCTransportCredsAndEndpoint(opts *Options) (*GRPCTransportCredentials, error) {
127134
config, err := getTransportConfig(opts)
128135
if err != nil {
129-
return nil, "", err
136+
return nil, err
130137
}
131138

132139
defaultTransportCreds := credentials.NewTLS(&tls.Config{
@@ -144,13 +151,13 @@ func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCrede
144151
if config.s2aAddress != "" {
145152
s2aAddr = config.s2aAddress
146153
} else {
147-
return defaultTransportCreds, config.endpoint, nil
154+
return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint}, nil
148155
}
149156
}
150157
} else if config.s2aAddress != "" {
151158
s2aAddr = config.s2aAddress
152159
} else {
153-
return defaultTransportCreds, config.endpoint, nil
160+
return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint}, nil
154161
}
155162

156163
var fallbackOpts *s2a.FallbackOptions
@@ -168,9 +175,9 @@ func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCrede
168175
})
169176
if err != nil {
170177
// Use default if we cannot initialize S2A client transport credentials.
171-
return defaultTransportCreds, config.endpoint, nil
178+
return &GRPCTransportCredentials{defaultTransportCreds, config.endpoint}, nil
172179
}
173-
return s2aTransportCreds, config.s2aMTLSEndpoint, nil
180+
return &GRPCTransportCredentials{s2aTransportCreds, config.s2aMTLSEndpoint}, nil
174181
}
175182

176183
// GetHTTPTransportConfig returns a client certificate source and a function for

auth/internal/transport/cba_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ func TestGetGRPCTransportConfigAndEndpoint_S2A(t *testing.T) {
413413
} else {
414414
t.Setenv(googleAPIUseCertSource, "false")
415415
}
416-
_, endpoint, _ := GetGRPCTransportCredsAndEndpoint(tc.opts)
417-
if tc.want != endpoint {
418-
t.Fatalf("want endpoint: %s, got %s", tc.want, endpoint)
416+
transportCreds, _ := GetGRPCTransportCredsAndEndpoint(tc.opts)
417+
if tc.want != transportCreds.Endpoint {
418+
t.Fatalf("want endpoint: %s, got %s", tc.want, transportCreds.Endpoint)
419419
}
420420
})
421421
}
@@ -764,12 +764,12 @@ func TestGetGRPCTransportCredsAndEndpoint_UniverseDomain(t *testing.T) {
764764
} else {
765765
t.Setenv(googleAPIUseCertSource, "false")
766766
}
767-
_, endpoint, err := GetGRPCTransportCredsAndEndpoint(tc.opts)
767+
transportCreds, err := GetGRPCTransportCredsAndEndpoint(tc.opts)
768768
if err != nil {
769769
t.Fatalf("err: %v", err)
770770
} else {
771-
if tc.wantEndpoint != endpoint {
772-
t.Errorf("want endpoint: %s, got %s", tc.wantEndpoint, endpoint)
771+
if tc.wantEndpoint != transportCreds.Endpoint {
772+
t.Errorf("want endpoint: %s, got %s", tc.wantEndpoint, transportCreds.Endpoint)
773773
}
774774
}
775775
})

0 commit comments

Comments
 (0)