Skip to content

Commit 92761da

Browse files
authored
Merge branch 'master' into fix-goroutines-leak
2 parents fb5f893 + 8e89390 commit 92761da

File tree

17 files changed

+462
-86
lines changed

17 files changed

+462
-86
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
* Fixed goroutines leak within topic reader on network problems
2+
3+
## v3.67.0
4+
* Added `ydb.WithNodeAddressMutator` experimental option for mutate node addresses from `discovery.ListEndpoints` response
25
* Added type assertion checks to enhance type safety and prevent unexpected panics in critical sections of the codebase
36

47
## v3.66.3

internal/cmd/gtrace/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
2121
)
2222

23-
//nolint:gocyclo
23+
//nolint:gocyclo,funlen
2424
func main() {
2525
var (
2626
// Reports whether we were called from go:generate.

internal/cmd/gtrace/writer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ func (w *Writer) composeHook(hook Hook, t1, t2, dst string) {
372372
w.line(`}`)
373373
}
374374

375+
//nolint:funlen
375376
func (w *Writer) composeHookCall(fn *Func, h1, h2 string) {
376377
w.newScope(func() {
377378
w.capture(h1, h2)

internal/credentials/oauth2.go

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -349,62 +349,89 @@ func (provider *oauth2TokenExchange) getRequestParams() (string, error) {
349349
}
350350

351351
func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.Response, now time.Time) error {
352-
var (
353-
data []byte
354-
err error
355-
)
356-
if result.Body != nil {
357-
data, err = io.ReadAll(result.Body)
358-
if err != nil {
359-
return xerrors.WithStackTrace(err)
360-
}
361-
} else {
362-
data = make([]byte, 0)
352+
data, err := readResponseBody(result)
353+
if err != nil {
354+
return err
363355
}
364356

365357
if result.StatusCode != http.StatusOK {
366-
description := result.Status
358+
return provider.handleErrorResponse(result.Status, data)
359+
}
367360

368-
//nolint:tagliatelle
369-
type errorResponse struct {
370-
ErrorName string `json:"error"`
371-
ErrorDescription string `json:"error_description"`
372-
ErrorURI string `json:"error_uri"`
373-
}
374-
var parsedErrorResponse errorResponse
375-
if err := json.Unmarshal(data, &parsedErrorResponse); err != nil {
376-
description += ", could not parse response: " + err.Error()
361+
parsedResponse, err := parseTokenResponse(data)
362+
if err != nil {
363+
return err
364+
}
377365

378-
return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description))
379-
}
366+
if err := validateTokenResponse(parsedResponse, provider); err != nil {
367+
return err
368+
}
380369

381-
if parsedErrorResponse.ErrorName != "" {
382-
description += ", error: " + parsedErrorResponse.ErrorName
383-
}
370+
provider.updateToken(parsedResponse, now)
384371

385-
if parsedErrorResponse.ErrorDescription != "" {
386-
description += fmt.Sprintf(", description: %q", parsedErrorResponse.ErrorDescription)
387-
}
372+
return nil
373+
}
388374

389-
if parsedErrorResponse.ErrorURI != "" {
390-
description += ", error_uri: " + parsedErrorResponse.ErrorURI
375+
func readResponseBody(result *http.Response) ([]byte, error) {
376+
if result.Body != nil {
377+
data, err := io.ReadAll(result.Body)
378+
if err != nil {
379+
return nil, xerrors.WithStackTrace(err)
391380
}
392381

393-
return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description))
382+
return data, nil
394383
}
395384

385+
return make([]byte, 0), nil
386+
}
387+
388+
func (provider *oauth2TokenExchange) handleErrorResponse(status string, data []byte) error {
389+
description := status
390+
396391
//nolint:tagliatelle
397-
type response struct {
398-
AccessToken string `json:"access_token"`
399-
TokenType string `json:"token_type"`
400-
ExpiresIn int64 `json:"expires_in"`
401-
Scope string `json:"scope"`
392+
type errorResponse struct {
393+
ErrorName string `json:"error"`
394+
ErrorDescription string `json:"error_description"`
395+
ErrorURI string `json:"error_uri"`
396+
}
397+
var parsedErrorResponse errorResponse
398+
if err := json.Unmarshal(data, &parsedErrorResponse); err != nil {
399+
description += ", could not parse response: " + err.Error()
400+
401+
return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description))
402402
}
403-
var parsedResponse response
403+
404+
if parsedErrorResponse.ErrorName != "" {
405+
description += ", error: " + parsedErrorResponse.ErrorName
406+
}
407+
if parsedErrorResponse.ErrorDescription != "" {
408+
description += fmt.Sprintf(", description: %q", parsedErrorResponse.ErrorDescription)
409+
}
410+
if parsedErrorResponse.ErrorURI != "" {
411+
description += ", error_uri: " + parsedErrorResponse.ErrorURI
412+
}
413+
414+
return xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description))
415+
}
416+
417+
//nolint:tagliatelle
418+
type tokenResponse struct {
419+
AccessToken string `json:"access_token"`
420+
TokenType string `json:"token_type"`
421+
ExpiresIn int64 `json:"expires_in"`
422+
Scope string `json:"scope"`
423+
}
424+
425+
func parseTokenResponse(data []byte) (*tokenResponse, error) {
426+
var parsedResponse tokenResponse
404427
if err := json.Unmarshal(data, &parsedResponse); err != nil {
405-
return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseResponse, err))
428+
return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseResponse, err))
406429
}
407430

431+
return &parsedResponse, nil
432+
}
433+
434+
func validateTokenResponse(parsedResponse *tokenResponse, provider *oauth2TokenExchange) error {
408435
if !strings.EqualFold(parsedResponse.TokenType, "bearer") {
409436
return xerrors.WithStackTrace(
410437
fmt.Errorf("%w: %q", errUnsupportedTokenType, parsedResponse.TokenType))
@@ -423,18 +450,17 @@ func (provider *oauth2TokenExchange) processTokenExchangeResponse(result *http.R
423450
}
424451
}
425452

453+
return nil
454+
}
455+
456+
func (provider *oauth2TokenExchange) updateToken(parsedResponse *tokenResponse, now time.Time) {
426457
provider.receivedToken = "Bearer " + parsedResponse.AccessToken
427458

428-
// Expire time
429-
expireDelta := time.Duration(parsedResponse.ExpiresIn)
430-
expireDelta *= time.Second
459+
expireDelta := time.Duration(parsedResponse.ExpiresIn) * time.Second
431460
provider.receivedTokenExpireTime = now.Add(expireDelta)
432461

433-
updateDelta := time.Duration(parsedResponse.ExpiresIn / updateTimeDivider)
434-
updateDelta *= time.Second
462+
updateDelta := time.Duration(parsedResponse.ExpiresIn/updateTimeDivider) * time.Second
435463
provider.updateTokenTime = now.Add(updateDelta)
436-
437-
return nil
438464
}
439465

440466
func (provider *oauth2TokenExchange) exchangeToken(ctx context.Context, now time.Time) error {

internal/credentials/static.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ func (c *Static) Token(ctx context.Context) (token string, err error) {
8787
fmt.Errorf("dial failed: %w", err),
8888
)
8989
}
90-
defer func() {
91-
_ = cc.Close()
92-
}()
90+
defer cc.Close()
9391

9492
client := Ydb_Auth_V1.NewAuthServiceClient(cc)
9593

internal/decimal/decimal.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ func handleRemainingDigits(s string, v *big.Int, precision uint32) (*big.Int, er
224224

225225
// Format returns the string representation of x with the given precision and
226226
// scale.
227+
//
228+
//nolint:funlen
227229
func Format(x *big.Int, precision, scale uint32) string {
228230
// Check for special values and nil pointer upfront.
229231
if x == nil {

internal/discovery/config/config.go

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package config
33
import (
44
"time"
55

6+
"github.com/jonboulle/clockwork"
7+
68
"github.com/ydb-platform/ydb-go-sdk/v3/internal/config"
79
"github.com/ydb-platform/ydb-go-sdk/v3/internal/meta"
810
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
@@ -15,10 +17,12 @@ const (
1517
type Config struct {
1618
config.Common
1719

18-
endpoint string
19-
database string
20-
secure bool
21-
meta *meta.Meta
20+
endpoint string
21+
database string
22+
secure bool
23+
meta *meta.Meta
24+
addressMutator func(address string) string
25+
clock clockwork.Clock
2226

2327
interval time.Duration
2428
trace *trace.Discovery
@@ -28,6 +32,10 @@ func New(opts ...Option) *Config {
2832
c := &Config{
2933
interval: DefaultInterval,
3034
trace: &trace.Discovery{},
35+
addressMutator: func(address string) string {
36+
return address
37+
},
38+
clock: clockwork.NewRealClock(),
3139
}
3240
for _, opt := range opts {
3341
if opt != nil {
@@ -38,10 +46,18 @@ func New(opts ...Option) *Config {
3846
return c
3947
}
4048

49+
func (c *Config) MutateAddress(fqdn string) string {
50+
return c.addressMutator(fqdn)
51+
}
52+
4153
func (c *Config) Meta() *meta.Meta {
4254
return c.meta
4355
}
4456

57+
func (c *Config) Clock() clockwork.Clock {
58+
return c.clock
59+
}
60+
4561
func (c *Config) Interval() time.Duration {
4662
return c.interval
4763
}
@@ -85,6 +101,18 @@ func WithDatabase(database string) Option {
85101
}
86102
}
87103

104+
func WithClock(clock clockwork.Clock) Option {
105+
return func(c *Config) {
106+
c.clock = clock
107+
}
108+
}
109+
110+
func WithAddressMutator(addressMutator func(address string) string) Option {
111+
return func(c *Config) {
112+
c.addressMutator = addressMutator
113+
}
114+
}
115+
88116
// WithSecure set flag for secure connection
89117
func WithSecure(ssl bool) Option {
90118
return func(c *Config) {

internal/discovery/discovery.go

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
2020
)
2121

22+
//go:generate mockgen -destination grpc_client_mock_test.go -package discovery -write_package_comment=false github.com/ydb-platform/ydb-go-genproto/Ydb_Discovery_V1 DiscoveryServiceClient
23+
2224
func New(ctx context.Context, cc grpc.ClientConnInterface, config *config.Config) *Client {
2325
return &Client{
2426
config: config,
@@ -35,65 +37,85 @@ type Client struct {
3537
client Ydb_Discovery_V1.DiscoveryServiceClient
3638
}
3739

38-
// Discover cluster endpoints
39-
func (c *Client) Discover(ctx context.Context) (endpoints []endpoint.Endpoint, err error) {
40+
func discover(
41+
ctx context.Context,
42+
client Ydb_Discovery_V1.DiscoveryServiceClient,
43+
config *config.Config,
44+
) (endpoints []endpoint.Endpoint, location string, err error) {
4045
var (
41-
onDone = trace.DiscoveryOnDiscover(
42-
c.config.Trace(), &ctx,
43-
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/discovery.(*Client).Discover"),
44-
c.config.Endpoint(), c.config.Database(),
45-
)
4646
request = Ydb_Discovery.ListEndpointsRequest{
47-
Database: c.config.Database(),
47+
Database: config.Database(),
4848
}
4949
response *Ydb_Discovery.ListEndpointsResponse
5050
result Ydb_Discovery.ListEndpointsResult
51-
location string
5251
)
53-
defer func() {
54-
nodes := make([]trace.EndpointInfo, 0, len(endpoints))
55-
for _, e := range endpoints {
56-
nodes = append(nodes, e.Copy())
57-
}
58-
onDone(location, nodes, err)
59-
}()
60-
61-
ctx, err = c.config.Meta().Context(ctx)
62-
if err != nil {
63-
return nil, xerrors.WithStackTrace(err)
64-
}
6552

66-
response, err = c.client.ListEndpoints(ctx, &request)
53+
response, err = client.ListEndpoints(ctx, &request)
6754
if err != nil {
68-
return nil, xerrors.WithStackTrace(err)
55+
return nil, location, xerrors.WithStackTrace(err)
6956
}
7057

7158
if response.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS {
72-
return nil, xerrors.WithStackTrace(
59+
return nil, location, xerrors.WithStackTrace(
7360
xerrors.FromOperation(response.GetOperation()),
7461
)
7562
}
7663

7764
err = response.GetOperation().GetResult().UnmarshalTo(&result)
7865
if err != nil {
79-
return nil, xerrors.WithStackTrace(err)
66+
return nil, location, xerrors.WithStackTrace(err)
8067
}
8168

8269
location = result.GetSelfLocation()
8370
endpoints = make([]endpoint.Endpoint, 0, len(result.GetEndpoints()))
8471
for _, e := range result.GetEndpoints() {
85-
if e.GetSsl() == c.config.Secure() {
72+
if e.GetSsl() == config.Secure() {
8673
endpoints = append(endpoints, endpoint.New(
87-
net.JoinHostPort(e.GetAddress(), strconv.Itoa(int(e.GetPort()))),
74+
net.JoinHostPort(
75+
config.MutateAddress(e.GetAddress()),
76+
strconv.Itoa(int(e.GetPort())),
77+
),
8878
endpoint.WithLocation(e.GetLocation()),
8979
endpoint.WithID(e.GetNodeId()),
9080
endpoint.WithLoadFactor(e.GetLoadFactor()),
9181
endpoint.WithLocalDC(e.GetLocation() == location),
9282
endpoint.WithServices(e.GetService()),
83+
endpoint.WithLastUpdated(config.Clock().Now()),
9384
))
9485
}
9586
}
9687

88+
return endpoints, result.GetSelfLocation(), nil
89+
}
90+
91+
// Discover cluster endpoints
92+
func (c *Client) Discover(ctx context.Context) (endpoints []endpoint.Endpoint, finalErr error) {
93+
var (
94+
onDone = trace.DiscoveryOnDiscover(
95+
c.config.Trace(), &ctx,
96+
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/discovery.(*Client).Discover"),
97+
c.config.Endpoint(), c.config.Database(),
98+
)
99+
location string
100+
)
101+
defer func() {
102+
nodes := make([]trace.EndpointInfo, 0, len(endpoints))
103+
for _, e := range endpoints {
104+
nodes = append(nodes, e.Copy())
105+
}
106+
onDone(location, nodes, finalErr)
107+
}()
108+
109+
ctx, err := c.config.Meta().Context(ctx)
110+
if err != nil {
111+
return nil, xerrors.WithStackTrace(err)
112+
}
113+
114+
endpoints, location, err = discover(ctx, c.client, c.config)
115+
if err != nil {
116+
return nil, xerrors.WithStackTrace(err)
117+
}
118+
97119
return endpoints, nil
98120
}
99121

0 commit comments

Comments
 (0)