From e62008ecf30d80a6d81b6ce102eed7f0fda12ce8 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 24 Mar 2025 12:00:00 +0000 Subject: [PATCH 01/44] Initial Setup - Core Files --- go.mod | 27 +++++++++++++++++++++++++++ go.sum | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 go.sum diff --git a/go.mod b/go.mod index ad872dd..f1cdd0d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,30 @@ module github.com/redis-developer/go-redis-entraid go 1.18 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 + github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/google/uuid v1.6.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/net v0.35.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..10a19fe --- /dev/null +++ b/go.sum @@ -0,0 +1,45 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 h1:iw4+KCeCoieuKodp1d5YhAa1TU/GgogCbw8RbGvsfLA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1/go.mod h1:AP8cDnDTGIVvayqKAhwzpcAyTJosXpvLYNmVFJb98x8= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.2.3 h1:BAUsn6/icUFtvUalVwCO0+hSF7qgU9DwwcEfCvtILtw= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 h1:8BKxhZZLX/WosEeoCvWysmKUscfa9v8LIPEEU0JjE2o= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/keybase/go-keychain v0.0.0-20231219164618-57a3676c3af6 h1:IsMZxCuZqKuao2vNdfD82fjjgPLfyHLpR41Z88viRWs= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a h1:R5xgk8m+CF7lVE0EGr+tLkT1eM3Zfd39BJfnANQqpKA= +github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 9b0b345cd2498a73ccb1cb08e07a1988b5dda64d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 25 Mar 2025 12:00:00 +0000 Subject: [PATCH 02/44] Initial Setup - Configuration Files --- .golangci.yml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .golangci.yml diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..ac9f76c --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,5 @@ +version: "2" +linters: + disable: + - depguard + From bdc94c812397dbe1d25f55b9e581d46afaeefa69 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 26 Mar 2025 12:00:00 +0000 Subject: [PATCH 03/44] Core Interfaces --- shared/identity_provider_response.go | 51 ++++ shared/identity_provider_response_test.go | 330 ++++++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 shared/identity_provider_response.go create mode 100644 shared/identity_provider_response_test.go diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go new file mode 100644 index 0000000..da88c8a --- /dev/null +++ b/shared/identity_provider_response.go @@ -0,0 +1,51 @@ +package shared + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/internal" + "github.com/redis-developer/go-redis-entraid/token" +) + +const ( + // ResponseTypeAuthResult is the type of the auth result. + ResponseTypeAuthResult = "AuthResult" + // ResponseTypeAccessToken is the type of the access token. + ResponseTypeAccessToken = "AccessToken" + // ResponseTypeRawToken is the type of the response when you have a raw string. + ResponseTypeRawToken = "RawToken" +) + +// IdentityProviderResponseParser is an interface that defines the methods for parsing the identity provider response. +// It is used to parse the response from the identity provider and extract the token. +// If not provided, the default implementation will be used. +type IdentityProviderResponseParser interface { + ParseResponse(response IdentityProviderResponse) (*token.Token, error) +} + +// IdentityProviderResponse is an interface that defines the methods for an identity provider authentication result. +// It is used to get the type of the authentication result, the authentication result itself (can be AuthResult or AccessToken), +type IdentityProviderResponse interface { + // Type returns the type of the auth result + Type() string + AuthResult() public.AuthResult + AccessToken() azcore.AccessToken + RawToken() string +} + +// IdentityProvider is an interface that defines the methods for an identity provider. +// It is used to request a token for authentication. +// The identity provider is responsible for providing the raw authentication token. +type IdentityProvider interface { + // RequestToken requests a token from the identity provider. + // It returns the token, the expiration time, and an error if any. + RequestToken() (IdentityProviderResponse, error) +} + +// NewIDPResponse creates a new auth result based on the type provided. +// It returns an IdentityProviderResponse interface. +// Type can be either AuthResult, AccessToken, or RawToken. +// Second argument is the result of the type provided in the first argument. +func NewIDPResponse(responseType string, result interface{}) (IdentityProviderResponse, error) { + return internal.NewIDPResp(responseType, result) +} diff --git a/shared/identity_provider_response_test.go b/shared/identity_provider_response_test.go new file mode 100644 index 0000000..0b5a014 --- /dev/null +++ b/shared/identity_provider_response_test.go @@ -0,0 +1,330 @@ +package shared + +import ( + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/assert" +) + +// Mock implementations for testing +type mockIDPResponse struct { + responseType string + authResult *public.AuthResult + accessToken *azcore.AccessToken + rawToken string +} + +func (m *mockIDPResponse) Type() string { + return m.responseType +} + +func (m *mockIDPResponse) AuthResult() public.AuthResult { + if m.authResult == nil { + return public.AuthResult{} + } + return *m.authResult +} + +func (m *mockIDPResponse) AccessToken() azcore.AccessToken { + if m.accessToken == nil { + return azcore.AccessToken{} + } + return *m.accessToken +} + +func (m *mockIDPResponse) RawToken() string { + return m.rawToken +} + +type mockIDPParser struct { + parseError error + token *token.Token +} + +func (m *mockIDPParser) ParseResponse(response IdentityProviderResponse) (*token.Token, error) { + if m.parseError != nil { + return nil, m.parseError + } + return m.token, nil +} + +type mockIDP struct { + response IdentityProviderResponse + err error +} + +func (m *mockIDP) RequestToken() (IdentityProviderResponse, error) { + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func TestNewIDPResponse(t *testing.T) { + tests := []struct { + name string + responseType string + result interface{} + expectedError string + }{ + { + name: "Valid AuthResult pointer", + responseType: ResponseTypeAuthResult, + result: &public.AuthResult{}, + }, + { + name: "Valid AuthResult value", + responseType: ResponseTypeAuthResult, + result: public.AuthResult{}, + }, + { + name: "Valid AccessToken pointer", + responseType: ResponseTypeAccessToken, + result: &azcore.AccessToken{Token: "test-token"}, + }, + { + name: "Valid AccessToken value", + responseType: ResponseTypeAccessToken, + result: azcore.AccessToken{Token: "test-token"}, + }, + { + name: "Valid RawToken string", + responseType: ResponseTypeRawToken, + result: "test-token", + }, + { + name: "Valid RawToken string pointer", + responseType: ResponseTypeRawToken, + result: stringPtr("test-token"), + }, + { + name: "Nil result", + responseType: ResponseTypeAuthResult, + result: nil, + expectedError: "result cannot be nil", + }, + { + name: "Nil string pointer", + responseType: ResponseTypeRawToken, + result: (*string)(nil), + expectedError: "raw token cannot be nil", + }, + { + name: "Invalid AuthResult type", + responseType: ResponseTypeAuthResult, + result: "not-an-auth-result", + expectedError: "invalid auth result type: expected public.AuthResult or *public.AuthResult", + }, + { + name: "Invalid AccessToken type", + responseType: ResponseTypeAccessToken, + result: "not-an-access-token", + expectedError: "invalid access token type: expected azcore.AccessToken or *azcore.AccessToken", + }, + { + name: "Invalid RawToken type", + responseType: ResponseTypeRawToken, + result: 123, + expectedError: "invalid raw token type: expected string or *string", + }, + { + name: "Invalid response type", + responseType: "InvalidType", + result: "test", + expectedError: "unsupported identity provider response type: InvalidType", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := NewIDPResponse(tt.responseType, tt.result) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, resp) + return + } + + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tt.responseType, resp.Type()) + + switch tt.responseType { + case ResponseTypeAuthResult: + assert.NotNil(t, resp.AuthResult()) + case ResponseTypeAccessToken: + assert.NotNil(t, resp.AccessToken()) + assert.NotEmpty(t, resp.AccessToken().Token) + case ResponseTypeRawToken: + assert.NotEmpty(t, resp.RawToken()) + } + }) + } +} + +func stringPtr(s string) *string { + return &s +} + +func TestIdentityProviderResponse(t *testing.T) { + now := time.Now() + expires := now.Add(time.Hour) + + authResult := &public.AuthResult{ + AccessToken: "test-access-token", + ExpiresOn: expires, + } + + accessToken := &azcore.AccessToken{ + Token: "test-access-token", + ExpiresOn: expires, + } + + tests := []struct { + name string + response *mockIDPResponse + expectedType string + }{ + { + name: "AuthResult response", + response: &mockIDPResponse{ + responseType: ResponseTypeAuthResult, + authResult: authResult, + }, + expectedType: ResponseTypeAuthResult, + }, + { + name: "AccessToken response", + response: &mockIDPResponse{ + responseType: ResponseTypeAccessToken, + accessToken: accessToken, + }, + expectedType: ResponseTypeAccessToken, + }, + { + name: "RawToken response", + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-raw-token", + }, + expectedType: ResponseTypeRawToken, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expectedType, tt.response.Type()) + + switch tt.expectedType { + case ResponseTypeAuthResult: + result := tt.response.AuthResult() + assert.Equal(t, authResult.AccessToken, result.AccessToken) + assert.Equal(t, authResult.ExpiresOn, result.ExpiresOn) + case ResponseTypeAccessToken: + token := tt.response.AccessToken() + assert.Equal(t, accessToken.Token, token.Token) + assert.Equal(t, accessToken.ExpiresOn, token.ExpiresOn) + case ResponseTypeRawToken: + assert.Equal(t, "test-raw-token", tt.response.RawToken()) + } + }) + } +} + +func TestIdentityProvider(t *testing.T) { + tests := []struct { + name string + provider *mockIDP + wantErr bool + }{ + { + name: "Successful token request", + provider: &mockIDP{ + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-token", + }, + }, + wantErr: false, + }, + { + name: "Failed token request", + provider: &mockIDP{ + err: assert.AnError, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response, err := tt.provider.RequestToken() + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, ResponseTypeRawToken, response.Type()) + assert.Equal(t, "test-token", response.RawToken()) + } + }) + } +} + +func TestIdentityProviderResponseParser(t *testing.T) { + now := time.Now() + expires := now.Add(time.Hour) + testToken := token.New("test-user", "test-password", "test-token", expires, now, int64(time.Hour.Seconds())) + + tests := []struct { + name string + parser *mockIDPParser + response IdentityProviderResponse + wantErr bool + wantToken *token.Token + }{ + { + name: "Successful parse", + parser: &mockIDPParser{ + token: testToken, + }, + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-token", + }, + wantErr: false, + wantToken: testToken, + }, + { + name: "Failed parse", + parser: &mockIDPParser{ + parseError: assert.AnError, + }, + response: &mockIDPResponse{ + responseType: ResponseTypeRawToken, + rawToken: "test-token", + }, + wantErr: true, + wantToken: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, err := tt.parser.ParseResponse(tt.response) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, token) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantToken, token) + } + }) + } +} From 6c1490b68f83878dcd46b6385fc0c1254e831002 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 27 Mar 2025 12:00:00 +0000 Subject: [PATCH 04/44] Testing Infrastructure - Core Setup --- internal/idp_response.go | 110 ++++++++++ internal/idp_response_test.go | 364 ++++++++++++++++++++++++++++++++++ internal/utils.go | 12 ++ internal/utils_test.go | 58 ++++++ 4 files changed, 544 insertions(+) create mode 100644 internal/idp_response.go create mode 100644 internal/idp_response_test.go create mode 100644 internal/utils.go create mode 100644 internal/utils_test.go diff --git a/internal/idp_response.go b/internal/idp_response.go new file mode 100644 index 0000000..457c5cd --- /dev/null +++ b/internal/idp_response.go @@ -0,0 +1,110 @@ +package internal + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" +) + +// IDPResp represents a response from an Identity Provider (IDP) +// It can contain either an AuthResult, AccessToken, or a raw token string +type IDPResp struct { + // resultType indicates which type of response this is + resultType string + authResultVal *public.AuthResult + accessTokenVal *azcore.AccessToken + rawTokenVal string +} + +// NewIDPResp creates a new IDPResp with the given values +// It validates the input and ensures the response type matches the provided value +func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) { + if result == nil { + return nil, fmt.Errorf("result cannot be nil") + } + + r := &IDPResp{resultType: resultType} + + switch resultType { + case "AuthResult": + switch v := result.(type) { + case *public.AuthResult: + r.authResultVal = v + case public.AuthResult: + r.authResultVal = &v + default: + return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T", result) + } + case "AccessToken": + switch v := result.(type) { + case *azcore.AccessToken: + r.accessTokenVal = v + r.rawTokenVal = v.Token + case azcore.AccessToken: + r.accessTokenVal = &v + r.rawTokenVal = v.Token + default: + return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T", result) + } + case "RawToken": + switch v := result.(type) { + case string: + r.rawTokenVal = v + case *string: + if v == nil { + return nil, fmt.Errorf("raw token cannot be nil") + } + r.rawTokenVal = *v + default: + return nil, fmt.Errorf("invalid raw token type: expected string or *string, got %T", result) + } + default: + return nil, fmt.Errorf("unsupported identity provider response type: %s", resultType) + } + + return r, nil +} + +// Type returns the type of response this IDPResp represents +func (a *IDPResp) Type() string { + return a.resultType +} + +// AuthResult returns the AuthResult if present, or an empty AuthResult if not set +// Use HasAuthResult() to check if the value is actually set +func (a *IDPResp) AuthResult() public.AuthResult { + if a.authResultVal == nil { + return public.AuthResult{} + } + return *a.authResultVal +} + +// HasAuthResult returns true if an AuthResult is set +func (a *IDPResp) HasAuthResult() bool { + return a.authResultVal != nil +} + +// AccessToken returns the AccessToken if present, or an empty AccessToken if not set +// Use HasAccessToken() to check if the value is actually set +func (a *IDPResp) AccessToken() azcore.AccessToken { + if a.accessTokenVal == nil { + return azcore.AccessToken{} + } + return *a.accessTokenVal +} + +// HasAccessToken returns true if an AccessToken is set +func (a *IDPResp) HasAccessToken() bool { + return a.accessTokenVal != nil +} + +// RawToken returns the raw token string +func (a *IDPResp) RawToken() string { + return a.rawTokenVal +} + +// HasRawToken returns true if a raw token is set +func (a *IDPResp) HasRawToken() bool { + return a.rawTokenVal != "" +} diff --git a/internal/idp_response_test.go b/internal/idp_response_test.go new file mode 100644 index 0000000..59f266b --- /dev/null +++ b/internal/idp_response_test.go @@ -0,0 +1,364 @@ +package internal + +import ( + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/assert" +) + +func TestIDPResp_Type(t *testing.T) { + tests := []struct { + name string + resultType string + want string + }{ + { + name: "AuthResult type", + resultType: "AuthResult", + want: "AuthResult", + }, + { + name: "Empty type", + resultType: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + resultType: tt.resultType, + } + if got := resp.Type(); got != tt.want { + t.Errorf("IDPResp.Type() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIDPResp_AuthResult(t *testing.T) { + now := time.Now() + authResult := &public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: now, + } + + tests := []struct { + name string + authResult *public.AuthResult + wantToken string + wantExpiresOn time.Time + }{ + { + name: "With AuthResult", + authResult: authResult, + wantToken: "test-token", + wantExpiresOn: now, + }, + { + name: "Nil AuthResult", + authResult: nil, + wantToken: "", + wantExpiresOn: time.Time{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + authResultVal: tt.authResult, + } + got := resp.AuthResult() + if got.AccessToken != tt.wantToken { + t.Errorf("IDPResp.AuthResult().AccessToken = %v, want %v", got.AccessToken, tt.wantToken) + } + if !got.ExpiresOn.Equal(tt.wantExpiresOn) { + t.Errorf("IDPResp.AuthResult().ExpiresOn = %v, want %v", got.ExpiresOn, tt.wantExpiresOn) + } + }) + } +} + +func TestIDPResp_AccessToken(t *testing.T) { + now := time.Now() + accessToken := &azcore.AccessToken{ + Token: "test-token", + ExpiresOn: now, + } + + tests := []struct { + name string + accessToken *azcore.AccessToken + wantToken string + wantExpiresOn time.Time + }{ + { + name: "With AccessToken", + accessToken: accessToken, + wantToken: "test-token", + wantExpiresOn: now, + }, + { + name: "Nil AccessToken", + accessToken: nil, + wantToken: "", + wantExpiresOn: time.Time{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + accessTokenVal: tt.accessToken, + } + got := resp.AccessToken() + if got.Token != tt.wantToken { + t.Errorf("IDPResp.AccessToken().Token = %v, want %v", got.Token, tt.wantToken) + } + if !got.ExpiresOn.Equal(tt.wantExpiresOn) { + t.Errorf("IDPResp.AccessToken().ExpiresOn = %v, want %v", got.ExpiresOn, tt.wantExpiresOn) + } + }) + } +} + +func TestIDPResp_RawToken(t *testing.T) { + tests := []struct { + name string + rawToken string + want string + }{ + { + name: "With RawToken", + rawToken: "test-raw-token", + want: "test-raw-token", + }, + { + name: "Empty RawToken", + rawToken: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &IDPResp{ + rawTokenVal: tt.rawToken, + } + if got := resp.RawToken(); got != tt.want { + t.Errorf("IDPResp.RawToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewIDPResp(t *testing.T) { + tests := []struct { + name string + resultType string + result interface{} + wantErr bool + checkResult func(t *testing.T, resp *IDPResp) + }{ + { + name: "valid AuthResult pointer", + resultType: "AuthResult", + result: &public.AuthResult{ + AccessToken: "test-token", + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.True(t, resp.HasAuthResult()) + assert.Equal(t, "test-token", resp.AuthResult().AccessToken) + assert.False(t, resp.HasAccessToken()) + assert.False(t, resp.HasRawToken()) + }, + }, + { + name: "valid AuthResult value", + resultType: "AuthResult", + result: public.AuthResult{ + AccessToken: "test-token", + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.True(t, resp.HasAuthResult()) + assert.Equal(t, "test-token", resp.AuthResult().AccessToken) + }, + }, + { + name: "valid AccessToken pointer", + resultType: "AccessToken", + result: &azcore.AccessToken{ + Token: "test-token", + ExpiresOn: time.Now(), + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.True(t, resp.HasAccessToken()) + assert.Equal(t, "test-token", resp.AccessToken().Token) + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "valid AccessToken value", + resultType: "AccessToken", + result: azcore.AccessToken{ + Token: "test-token", + ExpiresOn: time.Now(), + }, + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.True(t, resp.HasAccessToken()) + assert.Equal(t, "test-token", resp.AccessToken().Token) + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "valid RawToken string", + resultType: "RawToken", + result: "test-token", + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.True(t, resp.HasRawToken()) + assert.Equal(t, "test-token", resp.RawToken()) + assert.False(t, resp.HasAuthResult()) + assert.False(t, resp.HasAccessToken()) + }, + }, + { + name: "valid RawToken string pointer", + resultType: "RawToken", + result: stringPtr("test-token"), + wantErr: false, + checkResult: func(t *testing.T, resp *IDPResp) { + assert.True(t, resp.HasRawToken()) + assert.Equal(t, "test-token", resp.RawToken()) + }, + }, + { + name: "nil result", + resultType: "AuthResult", + result: nil, + wantErr: true, + }, + { + name: "nil RawToken pointer", + resultType: "RawToken", + result: (*string)(nil), + wantErr: true, + }, + { + name: "invalid AuthResult type", + resultType: "AuthResult", + result: "not-an-auth-result", + wantErr: true, + }, + { + name: "invalid AccessToken type", + resultType: "AccessToken", + result: "not-an-access-token", + wantErr: true, + }, + { + name: "invalid RawToken type", + resultType: "RawToken", + result: 123, + wantErr: true, + }, + { + name: "unsupported result type", + resultType: "InvalidType", + result: "test", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewIDPResp(tt.resultType, tt.result) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, tt.resultType, got.Type()) + + if tt.checkResult != nil { + tt.checkResult(t, got) + } + }) + } +} + +func stringPtr(s string) *string { + return &s +} + +func BenchmarkIDPResp_Type(b *testing.B) { + resp := &IDPResp{ + resultType: "AuthResult", + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.Type() + } +} + +func BenchmarkIDPResp_AuthResult(b *testing.B) { + now := time.Now() + authResult := &public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: now, + } + resp := &IDPResp{ + authResultVal: authResult, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.AuthResult() + } +} + +func BenchmarkIDPResp_AccessToken(b *testing.B) { + now := time.Now() + accessToken := &azcore.AccessToken{ + Token: "test-token", + ExpiresOn: now, + } + resp := &IDPResp{ + accessTokenVal: accessToken, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.AccessToken() + } +} + +func BenchmarkIDPResp_RawToken(b *testing.B) { + resp := &IDPResp{ + rawTokenVal: "test-raw-token", + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp.RawToken() + } +} + +func BenchmarkNewIDPResp(b *testing.B) { + now := time.Now() + authResult := &public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: now, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = NewIDPResp("AuthResult", authResult) + } +} diff --git a/internal/utils.go b/internal/utils.go new file mode 100644 index 0000000..ba82f1c --- /dev/null +++ b/internal/utils.go @@ -0,0 +1,12 @@ +package internal + +// IsClosed checks if a channel is closed. +func IsClosed(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + } + + return false +} diff --git a/internal/utils_test.go b/internal/utils_test.go new file mode 100644 index 0000000..e80ccb0 --- /dev/null +++ b/internal/utils_test.go @@ -0,0 +1,58 @@ +package internal + +import "testing" + +func TestIsClosedWithNilChannel(t *testing.T) { + t.Parallel() + var ch chan struct{} + if IsClosed(ch) { + t.Error("expected nil channel to be open") + } +} + +func TestIsClosedWithEmptyChannel(t *testing.T) { + t.Parallel() + ch := make(chan struct{}) + if IsClosed(ch) { + t.Error("expected empty channel to be open") + } + + close(ch) + if !IsClosed(ch) { + t.Error("expected empty channel to be closed") + } +} + +func TestIsClosedWithClosedChannel(t *testing.T) { + t.Parallel() + ch := make(chan struct{}) + close(ch) + if !IsClosed(ch) { + t.Error("expected closed channel to be closed") + } +} + +func BenchmarkIsClosedWithNilChannel(b *testing.B) { + var ch chan struct{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsClosed(ch) + } +} + +func BenchmarkIsClosedWithEmptyChannel(b *testing.B) { + ch := make(chan struct{}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsClosed(ch) + } +} + +func BenchmarkIsClosedWithClosedChannel(b *testing.B) { + ch := make(chan struct{}) + close(ch) + b.ResetTimer() + for i := 0; i < b.N; i++ { + IsClosed(ch) + } +} From 49585a7bd4dd8842a33d7f3e84654aa43f60308f Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 28 Mar 2025 12:00:00 +0000 Subject: [PATCH 05/44] Identity Provider - Core Interface --- identity/authority_configuration.go | 59 ++++ identity/authority_configuration_test.go | 160 +++++++++ identity/azure_default_identity_provider.go | 77 +++++ .../azure_default_identity_provider_test.go | 101 ++++++ identity/confidential_identity_provider.go | 169 ++++++++++ .../confidential_identity_provider_test.go | 308 ++++++++++++++++++ identity/identity_test.go | 86 +++++ identity/managed_identity_provider.go | 124 +++++++ identity/managed_identity_provider_test.go | 305 +++++++++++++++++ identity/providers.go | 25 ++ identity/providers_test.go | 52 +++ 11 files changed, 1466 insertions(+) create mode 100644 identity/authority_configuration.go create mode 100644 identity/authority_configuration_test.go create mode 100644 identity/azure_default_identity_provider.go create mode 100644 identity/azure_default_identity_provider_test.go create mode 100644 identity/confidential_identity_provider.go create mode 100644 identity/confidential_identity_provider_test.go create mode 100644 identity/identity_test.go create mode 100644 identity/managed_identity_provider.go create mode 100644 identity/managed_identity_provider_test.go create mode 100644 identity/providers.go create mode 100644 identity/providers_test.go diff --git a/identity/authority_configuration.go b/identity/authority_configuration.go new file mode 100644 index 0000000..bb229dd --- /dev/null +++ b/identity/authority_configuration.go @@ -0,0 +1,59 @@ +package identity + +import "fmt" + +const ( + // AuthorityTypeDefault is the default authority type. + // This is used to specify the authority type when requesting a token. + AuthorityTypeDefault = "default" + // AuthorityTypeMultiTenant is the multi-tenant authority type. + // This is used to specify the multi-tenant authority type when requesting a token. + // This type of authority is used to authenticate the identity when requesting a token. + AuthorityTypeMultiTenant = "multi-tenant" + // AuthorityTypeCustom is the custom authority type. + // This is used to specify the custom authority type when requesting a token. + AuthorityTypeCustom = "custom" +) + +// AuthorityConfiguration represents the authority configuration for the identity provider. +// It is used to configure the authority type and authority URL when requesting a token. +type AuthorityConfiguration struct { + // AuthorityType is the type of authority used to authenticate with the identity provider. + // This can be either "default", "multi-tenant", or "custom". + AuthorityType string + + // Authority is the authority used to authenticate with the identity provider. + // This is typically the URL of the identity provider. + // For example, "https://login.microsoftonline.com/{tenantID}/v2.0" + Authority string + + // TenantID is the tenant ID of the identity provider. + // This is used to identify the tenant when requesting a token. + // This is typically the ID of the Azure Active Directory tenant. + TenantID string +} + +// getAuthority returns the authority URL based on the authority type. +// The authority type can be either "default", "multi-tenant", or "custom". +func (a AuthorityConfiguration) getAuthority() (string, error) { + if a.AuthorityType == "" { + a.AuthorityType = AuthorityTypeDefault + } + + switch a.AuthorityType { + case AuthorityTypeDefault: + return "https://login.microsoftonline.com/common", nil + case AuthorityTypeMultiTenant: + if a.TenantID == "" { + return "", fmt.Errorf("tenant ID is required when using multi-tenant authority type") + } + return fmt.Sprintf("https://login.microsoftonline.com/%s", a.TenantID), nil + case AuthorityTypeCustom: + if a.Authority == "" { + return "", fmt.Errorf("authority is required when using custom authority type") + } + return a.Authority, nil + default: + return "", fmt.Errorf("invalid authority type: %s", a.AuthorityType) + } +} diff --git a/identity/authority_configuration_test.go b/identity/authority_configuration_test.go new file mode 100644 index 0000000..7ae4a67 --- /dev/null +++ b/identity/authority_configuration_test.go @@ -0,0 +1,160 @@ +package identity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAuthorityConfiguration(t *testing.T) { + t.Parallel() + tests := []struct { + name string + authorityType string + tenantID string + authority string + expected string + expectError bool + }{ + { + name: "Default Authority", + authorityType: AuthorityTypeDefault, + expected: "https://login.microsoftonline.com/common", + expectError: false, + }, + { + name: "Multi-Tenant Authority", + authorityType: AuthorityTypeMultiTenant, + tenantID: "12345", + expected: "https://login.microsoftonline.com/12345", + expectError: false, + }, + { + name: "Custom Authority", + authorityType: AuthorityTypeCustom, + authority: "https://custom-authority.com", + expected: "https://custom-authority.com", + expectError: false, + }, + { + name: "Invalid Authority Type", + authorityType: "invalid", + expectError: true, + }, + { + name: "Missing Tenant ID for Multi-Tenant", + authorityType: AuthorityTypeMultiTenant, + expectError: true, + }, + { + name: "Missing Authority for Custom", + authorityType: AuthorityTypeCustom, + expectError: true, + }, + { + name: "Default Authority Type with Tenant ID", + authorityType: AuthorityTypeDefault, + tenantID: "12345", + expected: "https://login.microsoftonline.com/common", + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ac := AuthorityConfiguration{ + AuthorityType: test.authorityType, + TenantID: test.tenantID, + Authority: test.authority, + } + result, err := ac.getAuthority() + if test.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, test.expected, result) + } + }) + } +} + +func TestAuthorityConfigurationDefault(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{} + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/common", result) +} + +func TestAuthorityConfigurationMultiTenant(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeMultiTenant, + TenantID: "12345", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/12345", result) +} + +func TestAuthorityConfigurationCustom(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + Authority: "https://custom-authority.com", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://custom-authority.com", result) +} + +func TestAuthorityConfigurationInvalid(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: "invalid", + } + result, err := ac.getAuthority() + assert.Error(t, err) + assert.Equal(t, "", result) +} + +func TestAuthorityConfigurationMissingTenantID(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeMultiTenant, + } + result, err := ac.getAuthority() + assert.Error(t, err) + assert.Equal(t, "", result) +} + +func TestAuthorityConfigurationMissingAuthority(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + } + result, err := ac.getAuthority() + assert.Error(t, err) + assert.Equal(t, "", result) +} + +func TestAuthorityConfigurationDefaultAuthorityType(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + TenantID: "12345", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/common", result) +} + +func TestAuthorityConfigurationDefaultAuthorityTypeWithTenantID(t *testing.T) { + t.Parallel() + ac := AuthorityConfiguration{ + AuthorityType: AuthorityTypeDefault, + TenantID: "12345", + } + result, err := ac.getAuthority() + assert.NoError(t, err) + assert.Equal(t, "https://login.microsoftonline.com/common", result) +} diff --git a/identity/azure_default_identity_provider.go b/identity/azure_default_identity_provider.go new file mode 100644 index 0000000..713ce96 --- /dev/null +++ b/identity/azure_default_identity_provider.go @@ -0,0 +1,77 @@ +package identity + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/redis-developer/go-redis-entraid/shared" +) + +// DefaultAzureIdentityProviderOptions represents the options for the DefaultAzureIdentityProvider. +type DefaultAzureIdentityProviderOptions struct { + // AzureOptions is the options used to configure the Azure identity provider. + AzureOptions *azidentity.DefaultAzureCredentialOptions + // Scopes is the list of scopes used to request a token from the identity provider. + Scopes []string + + // credFactory is a factory for creating the default Azure credential. + // This is used for testing purposes, to allow mocking the credential creation. + // If not provided, the default implementation - azidentity.NewDefaultAzureCredential will be used + credFactory credFactory +} + +type credFactory interface { + NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (azureCredential, error) +} + +type azureCredential interface { + GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) +} + +type defaultCredFactory struct{} + +func (d *defaultCredFactory) NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (azureCredential, error) { + return azidentity.NewDefaultAzureCredential(options) +} + +type DefaultAzureIdentityProvider struct { + options *azidentity.DefaultAzureCredentialOptions + credFactory credFactory + scopes []string +} + +// NewDefaultAzureIdentityProvider creates a new DefaultAzureIdentityProvider. +func NewDefaultAzureIdentityProvider(opts DefaultAzureIdentityProviderOptions) (*DefaultAzureIdentityProvider, error) { + if opts.Scopes == nil { + opts.Scopes = []string{RedisScopeDefault} + } + + return &DefaultAzureIdentityProvider{ + options: opts.AzureOptions, + scopes: opts.Scopes, + credFactory: opts.credFactory, + }, nil +} + +// RequestToken requests a token from the Azure Default Identity provider. +// It returns the token, the expiration time, and an error if any. +func (a *DefaultAzureIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { + credFactory := a.credFactory + if credFactory == nil { + credFactory = &defaultCredFactory{} + } + cred, err := credFactory.NewDefaultAzureCredential(a.options) + if err != nil { + return nil, fmt.Errorf("failed to create default azure credential: %w", err) + } + + token, err := cred.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: a.scopes}) + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + return shared.NewIDPResponse(shared.ResponseTypeAccessToken, &token) +} diff --git a/identity/azure_default_identity_provider_test.go b/identity/azure_default_identity_provider_test.go new file mode 100644 index 0000000..305e6b4 --- /dev/null +++ b/identity/azure_default_identity_provider_test.go @@ -0,0 +1,101 @@ +package identity + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewDefaultAzureIdentityProvider(t *testing.T) { + t.Parallel() + // Create a new DefaultAzureIdentityProvider with default options + provider, err := NewDefaultAzureIdentityProvider(DefaultAzureIdentityProviderOptions{}) + if err != nil { + t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err) + } + + // Check if the provider is not nil + if provider == nil { + t.Fatal("provider should not be nil") + } + + if provider.scopes == nil { + t.Fatal("provider.scopes should not be nil") + } + + assert.Contains(t, provider.scopes, RedisScopeDefault, "provider should contain default scope") +} +func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) { + t.Parallel() + provider, err := NewDefaultAzureIdentityProvider(DefaultAzureIdentityProviderOptions{}) + if err != nil { + t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err) + } + + // Request a token from the provider in incorrect environment + // should fail. + token, err := provider.RequestToken() + assert.Nil(t, token, "token should be nil") + assert.Error(t, err, "failed to request token") + + // use mockAzureCredential to simulate the environment + mToken := azcore.AccessToken{ + Token: testJWTToken, + } + mCreds := &mockAzureCredential{} + mCreds.On("GetToken", mock.Anything, mock.Anything).Return(mToken, nil) + mCredFactory := &mockCredFactory{} + mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) + provider.credFactory = mCredFactory + token, err = provider.RequestToken() + assert.NotNil(t, token, "token should not be nil") + assert.NoError(t, err, "failed to request token") + assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token") + assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTToken") +} + +func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { + // Create a new DefaultAzureIdentityProvider with custom scopes + scopes := []string{"https://example.com/.default"} + provider, err := NewDefaultAzureIdentityProvider(DefaultAzureIdentityProviderOptions{ + Scopes: scopes, + }) + if err != nil { + t.Fatalf("failed to create DefaultAzureIdentityProvider: %v", err) + } + + t.Run("RequestToken with custom scopes", func(t *testing.T) { + // Request a token from the provider + token, err := provider.RequestToken() + assert.Nil(t, token, "token should be nil") + assert.Error(t, err, "failed to request token") + + // use mockAzureCredential to simulate the environment + mToken := azcore.AccessToken{ + Token: testJWTToken, + } + mCreds := &mockAzureCredential{} + mCreds.On("GetToken", mock.Anything, policy.TokenRequestOptions{Scopes: scopes}).Return(mToken, nil) + mCredFactory := &mockCredFactory{} + mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) + provider.credFactory = mCredFactory + token, err = provider.RequestToken() + assert.NotNil(t, token, "token should not be nil") + assert.NoError(t, err, "failed to request token") + assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token") + assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTToken") + }) + t.Run("RequestToken with error from credFactory", func(t *testing.T) { + // use mockAzureCredential to simulate the environment + mCredFactory := &mockCredFactory{} + mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError) + provider.credFactory = mCredFactory + token, err := provider.RequestToken() + assert.Nil(t, token, "token should be nil") + assert.Error(t, err, "failed to request token") + }) +} diff --git a/identity/confidential_identity_provider.go b/identity/confidential_identity_provider.go new file mode 100644 index 0000000..97876fd --- /dev/null +++ b/identity/confidential_identity_provider.go @@ -0,0 +1,169 @@ +package identity + +import ( + "context" + "crypto" + "crypto/x509" + "fmt" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/redis-developer/go-redis-entraid/shared" +) + +// ConfidentialIdentityProviderOptions represents the options for the confidential identity provider. +type ConfidentialIdentityProviderOptions struct { + // ClientID is the client ID used to authenticate with the identity provider. + ClientID string + + // CredentialsType is the type of credentials used to authenticate with the identity provider. + // This can be either "ClientSecret" or "ClientCertificate". + CredentialsType string + + // ClientSecret is the client secret used to authenticate with the identity provider. + ClientSecret string + + // ClientCert is the client certificate used to authenticate with the identity provider. + ClientCert []*x509.Certificate + // ClientPrivateKey is the private key used to authenticate with the identity provider. + ClientPrivateKey crypto.PrivateKey + + // Scopes is the list of scopes used to request a token from the identity provider. + Scopes []string + + // Authority is the authority used to authenticate with the identity provider. + Authority AuthorityConfiguration + + // confidentialCredFactory is a factory for creating the confidential credential. + // This is used for testing purposes, to allow mocking the credential creation. + confidentialCredFactory confidentialCredFactory +} + +// ConfidentialIdentityProvider represents a confidential identity provider. +type ConfidentialIdentityProvider struct { + // clientID is the client ID used to authenticate with the identity provider. + clientID string + + // credential is the credential used to authenticate with the identity provider. + credential confidential.Credential + + // scopes is the list of scopes used to request a token from the identity provider. + scopes []string + + // client confidential is the client used to request a token from the identity provider. + client confidentialTokenClient +} + +// confidentialCredFacotory is a factory for creating the confidential credential. +// Introduced for testing purposes. This allows mocking the credential creation, default behavior is to use the confidential.NewCredFromSecret and confidential.NewCredFromCert methods. +type confidentialCredFactory interface { + NewCredFromSecret(clientSecret string) (confidential.Credential, error) + NewCredFromCert(clientCert []*x509.Certificate, clientPrivateKey crypto.PrivateKey) (confidential.Credential, error) +} + +// confidentialTokenClient is an interface that defines the methods for a confidential token client. +// It is used to acquire a token using the client credentials. +// Introduced for testing purposes. This allows mocking the token client, default behavior is to use the +// client returned by confidential.New method. +type confidentialTokenClient interface { + // AcquireTokenByCredential acquires a token using the client credentials. + // It returns the token and an error if any. + AcquireTokenByCredential(ctx context.Context, scopes []string, opts ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error) +} + +type defaultConfidentialCredFactory struct{} + +func (d *defaultConfidentialCredFactory) NewCredFromSecret(clientSecret string) (confidential.Credential, error) { + return confidential.NewCredFromSecret(clientSecret) +} + +func (d *defaultConfidentialCredFactory) NewCredFromCert(clientCert []*x509.Certificate, clientPrivateKey crypto.PrivateKey) (confidential.Credential, error) { + return confidential.NewCredFromCert(clientCert, clientPrivateKey) +} + +// NewConfidentialIdentityProvider creates a new confidential identity provider. +// It is used to configure the identity provider when requesting a token. +// It is used to specify the client ID, tenant ID, and scopes for the identity. +// It is also used to specify the type of credentials used to authenticate with the identity provider. +// The credentials can be either a client secret or a client certificate. +// The authority is used to authenticate with the identity provider. +func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) (*ConfidentialIdentityProvider, error) { + var credential confidential.Credential + var credFactory confidentialCredFactory + var authority string + var err error + + if opts.ClientID == "" { + return nil, fmt.Errorf("client ID is required") + } + + if opts.CredentialsType != ClientSecretCredentialType && opts.CredentialsType != ClientCertificateCredentialType { + return nil, fmt.Errorf("invalid credentials type") + } + + // Get the authority from the authority configuration. + authority, err = opts.Authority.getAuthority() + if err != nil { + return nil, fmt.Errorf("failed to get authority: %w", err) + } + + credFactory = &defaultConfidentialCredFactory{} + if opts.confidentialCredFactory != nil { + credFactory = opts.confidentialCredFactory + } + + switch opts.CredentialsType { + case ClientSecretCredentialType: + // ClientSecretCredentialType is the type of credentials that uses a client secret to authenticate. + if opts.ClientSecret == "" { + return nil, fmt.Errorf("client secret is required when using client secret credentials") + } + + credential, err = credFactory.NewCredFromSecret(opts.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to create client secret credential: %w", err) + } + case ClientCertificateCredentialType: + // ClientCertificateCredentialType is the type of credentials that uses a client certificate to authenticate. + if opts.ClientCert == nil { + return nil, fmt.Errorf("client certificate is required when using client certificate credentials") + } + if opts.ClientPrivateKey == nil { + return nil, fmt.Errorf("client private key is required when using client certificate credentials") + } + credential, err = credFactory.NewCredFromCert(opts.ClientCert, opts.ClientPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate credential: %w", err) + } + } + + client, err := confidential.New(authority, opts.ClientID, credential) + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + if opts.Scopes == nil { + opts.Scopes = []string{RedisScopeDefault} + } + + return &ConfidentialIdentityProvider{ + clientID: opts.ClientID, + credential: credential, + scopes: opts.Scopes, + client: &client, + }, nil +} + +// RequestToken requests a token from the identity provider. +// It returns the identity provider response, including the auth result. +func (c *ConfidentialIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { + if c.client == nil { + return nil, fmt.Errorf("client is not initialized") + } + + result, err := c.client.AcquireTokenByCredential(context.TODO(), c.scopes) + if err != nil { + return nil, fmt.Errorf("failed to acquire token: %w", err) + } + + return shared.NewIDPResponse(shared.ResponseTypeAuthResult, &result) +} diff --git a/identity/confidential_identity_provider_test.go b/identity/confidential_identity_provider_test.go new file mode 100644 index 0000000..df57d17 --- /dev/null +++ b/identity/confidential_identity_provider_test.go @@ -0,0 +1,308 @@ +package identity + +import ( + "crypto/x509" + "fmt" + "testing" + "time" + + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewConfidentialIdentityProvider(t *testing.T) { + t.Run("base", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + }) + + t.Run("with client certificate", func(t *testing.T) { + t.Parallel() + credFactory := &mockConfidentialCredentialFactory{} + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: []*x509.Certificate{}, + ClientPrivateKey: "private-key", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + confidentialCredFactory: credFactory, + } + credFactory.On("NewCredFromCert", opts.ClientCert, opts.ClientPrivateKey).Return(confidential.Credential{}, nil) + provider, err := NewConfidentialIdentityProvider(opts) + // confidential.New will fail since the credentials are invalid + assert.ErrorContains(t, err, "failed to create client:") + assert.Nil(t, provider) + }) + + t.Run("with failing client certificate", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: []*x509.Certificate{}, + ClientPrivateKey: "private-key", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + // invalid certificate should fail + provider, err := NewConfidentialIdentityProvider(opts) + assert.ErrorContains(t, err, "failed to create client certificate credential:") + assert.Nil(t, provider) + }) + + t.Run("with invalid credentials type", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "invalid-credentials-type", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("with missing client id", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + CredentialsType: "ClientSecret", + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("with bad authority type", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{AuthorityType: "bad-authority-type"}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + t.Run("with missing client secret", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + Scopes: []string{"scope1", "scope2"}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("with credentials from secret error", func(t *testing.T) { + t.Parallel() + credFactory := &mockConfidentialCredentialFactory{} + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + confidentialCredFactory: credFactory, + } + credFactory.On("NewCredFromSecret", "client-secret").Return(confidential.Credential{}, fmt.Errorf("error creating credential")) + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + credFactory.AssertExpectations(t) + }) + + t.Run("empty certificate", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: nil, + ClientPrivateKey: "private key", + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + + t.Run("empty private key", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientCertificate", + ClientCert: []*x509.Certificate{}, + ClientPrivateKey: nil, + Scopes: []string{"scope1", "scope2"}, + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err == nil { + t.Errorf("NewConfidentialIdentityProvider() error = nil, want error") + return + } + if provider != nil { + t.Errorf("NewConfidentialIdentityProvider() provider = %v, want nil", provider) + return + } + }) + t.Run("validate default scopes", func(t *testing.T) { + t.Parallel() + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Authority: AuthorityConfiguration{}, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + if len(provider.scopes) == 0 { + t.Errorf("NewConfidentialIdentityProvider() provider.Scopes = %v, want non-empty", provider.scopes) + return + } + assert.Contains(t, provider.scopes, RedisScopeDefault) + }) +} + +func TestConfidentialIdentityProvider_RequestToken(t *testing.T) { + t.Run("with mock client", func(t *testing.T) { + t.Parallel() + mClient := &mockConfidentialTokenClient{} + + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Authority: AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + Authority: "https://test-authority.dev/test", + }, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + expiresOn := time.Now().Add(time.Hour) + provider.client = mClient + mClient.On("AcquireTokenByCredential", mock.Anything, mock.Anything). + Return(confidential.AuthResult{ + ExpiresOn: expiresOn, + }, nil) + token, err := provider.RequestToken() + if err != nil { + t.Errorf("RequestToken() error = %v", err) + return + } + assert.NotEmpty(t, token, "RequestToken() token should not be empty") + assert.Equal(t, token.Type(), shared.ResponseTypeAuthResult, "RequestToken() token type should be AuthResult") + assert.Equal(t, token.AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match") + }) + t.Run("with error", func(t *testing.T) { + t.Parallel() + mClient := &mockConfidentialTokenClient{} + + opts := ConfidentialIdentityProviderOptions{ + ClientID: "client-id", + CredentialsType: "ClientSecret", + ClientSecret: "client-secret", + Authority: AuthorityConfiguration{ + AuthorityType: AuthorityTypeCustom, + Authority: "https://test-authority.dev/test", + }, + } + provider, err := NewConfidentialIdentityProvider(opts) + if err != nil { + t.Errorf("NewConfidentialIdentityProvider() error = %v", err) + return + } + if provider == nil { + t.Errorf("NewConfidentialIdentityProvider() provider = nil") + return + } + provider.client = mClient + mClient.On("AcquireTokenByCredential", mock.Anything, mock.Anything). + Return(confidential.AuthResult{}, fmt.Errorf("error acquiring token")) + token, err := provider.RequestToken() + assert.ErrorContains(t, err, "failed to acquire token:") + assert.Empty(t, token, "RequestToken() token should be empty") + }) + t.Run("without initialization", func(t *testing.T) { + t.Parallel() + provider := &ConfidentialIdentityProvider{} + token, err := provider.RequestToken() + assert.ErrorContains(t, err, "client is not initialized") + assert.Empty(t, token, "RequestToken() token should be empty") + }) +} diff --git a/identity/identity_test.go b/identity/identity_test.go new file mode 100644 index 0000000..4d33a65 --- /dev/null +++ b/identity/identity_test.go @@ -0,0 +1,86 @@ +package identity + +import ( + "context" + "crypto" + "crypto/x509" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" + "github.com/stretchr/testify/mock" +) + +// testJWTToken is a JWT token for testing +// +// { +// "iss": "test jwt", +// "iat": 1743515011, +// "exp": 1775051011, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0MzUxNTAxMSwiZXhwIjoxNzc1MDUxMDExLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.6RG721V2eFlSLsCRmo53kSRRrTZIe1UPdLZCUEvIarU" + +type mockAzureCredential struct { + mock.Mock +} + +func (m *mockAzureCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return azcore.AccessToken{}, args.Error(1) + } + return args.Get(0).(azcore.AccessToken), args.Error(1) +} + +type mockCredFactory struct { + // Mock implementation of the credFactory interface + mock.Mock +} + +func (m *mockCredFactory) NewDefaultAzureCredential(options *azidentity.DefaultAzureCredentialOptions) (azureCredential, error) { + args := m.Called(options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(azureCredential), args.Error(1) +} + +type mockConfidentialCredentialFactory struct { + // Mock implementation of the confidentialCredFactory interface + mock.Mock +} + +func (m *mockConfidentialCredentialFactory) NewCredFromSecret(clientSecret string) (confidential.Credential, error) { + args := m.Called(clientSecret) + if args.Get(0) == nil { + return confidential.Credential{}, args.Error(1) + } + return args.Get(0).(confidential.Credential), args.Error(1) +} + +func (m *mockConfidentialCredentialFactory) NewCredFromCert(clientCert []*x509.Certificate, clientPrivateKey crypto.PrivateKey) (confidential.Credential, error) { + args := m.Called(clientCert, clientPrivateKey) + if args.Get(0) == nil { + return confidential.Credential{}, args.Error(1) + } + return args.Get(0).(confidential.Credential), args.Error(1) +} + +type mockConfidentialTokenClient struct { + // Mock implementation of the confidentialTokenClient interface + mock.Mock +} + +func (m *mockConfidentialTokenClient) AcquireTokenByCredential(ctx context.Context, scopes []string, options ...confidential.AcquireByCredentialOption) (confidential.AuthResult, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return confidential.AuthResult{}, args.Error(1) + } + return args.Get(0).(confidential.AuthResult), args.Error(1) +} diff --git a/identity/managed_identity_provider.go b/identity/managed_identity_provider.go new file mode 100644 index 0000000..6a2153d --- /dev/null +++ b/identity/managed_identity_provider.go @@ -0,0 +1,124 @@ +package identity + +import ( + "context" + "errors" + "fmt" + + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/shared" +) + +// ManagedIdentityClient is an interface that defines the methods for a managed identity client. +// It is used to acquire a token using the managed identity. +type ManagedIdentityClient interface { + // AcquireToken acquires a token using the managed identity. + // It returns the token and an error if any. + AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) +} + +// ManagedIdentityProviderOptions represents the options for the managed identity provider. +// It is used to configure the identity provider when requesting a token. +type ManagedIdentityProviderOptions struct { + // UserAssignedClientID is the client ID of the user assigned identity. + // This is used to identify the identity when requesting a token. + UserAssignedClientID string + // ManagedIdentityType is the type of managed identity. + // This can be either SystemAssigned or UserAssigned. + ManagedIdentityType string + // Scopes is a list of scopes that the identity has access to. + // This is used to specify the permissions that the identity has when requesting a token. + Scopes []string +} + +// ManagedIdentityProvider represents a managed identity provider. +type ManagedIdentityProvider struct { + // userAssignedClientID is the client ID of the user assigned identity. + // This is used to identify the identity when requesting a token. + userAssignedClientID string + + // managedIdentityType is the type of managed identity. + // This can be either SystemAssigned or UserAssigned. + managedIdentityType string + + // scopes is a list of scopes that the identity has access to. + // This is used to specify the permissions that the identity has when requesting a token. + scopes []string + + // client is the managed identity client used to request a token. + client ManagedIdentityClient +} + +// realManagedIdentityClient is a wrapper around the real mi.Client that implements our interface +type realManagedIdentityClient struct { + client ManagedIdentityClient +} + +func (c *realManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) { + return c.client.AcquireToken(ctx, resource, opts...) +} + +// NewManagedIdentityProvider creates a new managed identity provider for Azure with managed identity. +// It is used to configure the identity provider when requesting a token. +func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedIdentityProvider, error) { + var client ManagedIdentityClient + + if opts.ManagedIdentityType != SystemAssignedIdentity && opts.ManagedIdentityType != UserAssignedIdentity { + return nil, errors.New("invalid managed identity type") + } + + switch opts.ManagedIdentityType { + case SystemAssignedIdentity: + // SystemAssignedIdentity is the type of identity that is automatically managed by Azure. + // This type of identity is automatically created and managed by Azure. + // It is used to authenticate the identity when requesting a token. + miClient, err := mi.New(mi.SystemAssigned()) + if err != nil { + return nil, fmt.Errorf("couldn't create managed identity client: %w", err) + } + client = &realManagedIdentityClient{client: miClient} + case UserAssignedIdentity: + // UserAssignedIdentity is required to be specified when using a user assigned identity. + if opts.UserAssignedClientID == "" { + return nil, errors.New("user assigned client ID is required when using user assigned identity") + } + // UserAssignedIdentity is the type of identity that is managed by the user. + miClient, err := mi.New(mi.UserAssignedClientID(opts.UserAssignedClientID)) + if err != nil { + return nil, fmt.Errorf("couldn't create managed identity client: %w", err) + } + client = &realManagedIdentityClient{client: miClient} + } + + return &ManagedIdentityProvider{ + userAssignedClientID: opts.UserAssignedClientID, + managedIdentityType: opts.ManagedIdentityType, + scopes: opts.Scopes, + client: client, + }, nil +} + +// RequestToken requests a token from the managed identity provider. +// It returns IdentityProviderResponse, which contains the Acc and the expiration time. +func (m *ManagedIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { + if m.client == nil { + return nil, errors.New("managed identity client is not initialized") + } + + // default resource is RedisResource == "https://redis.azure.com" + // if no scopes are provided, use the default resource + // if scopes are provided, use the first scope as the resource + resource := RedisResource + if len(m.scopes) > 0 { + resource = m.scopes[0] + } + // acquire token using the managed identity client + // the resource is the URL of the resource that the identity has access to + authResult, err := m.client.AcquireToken(context.TODO(), resource) + if err != nil { + return nil, fmt.Errorf("couldn't acquire token: %w", err) + } + + return shared.NewIDPResponse(shared.ResponseTypeAuthResult, &authResult) +} diff --git a/identity/managed_identity_provider_test.go b/identity/managed_identity_provider_test.go new file mode 100644 index 0000000..dc90c39 --- /dev/null +++ b/identity/managed_identity_provider_test.go @@ -0,0 +1,305 @@ +package identity + +import ( + "context" + "errors" + "testing" + "time" + + mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockManagedIdentityClient is a mock implementation of the managed identity client +type MockManagedIdentityClient struct { + mock.Mock +} + +func (m *MockManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) { + args := m.Called(ctx, resource) + return args.Get(0).(public.AuthResult), args.Error(1) +} + +func TestNewManagedIdentityProvider(t *testing.T) { + tests := []struct { + name string + opts ManagedIdentityProviderOptions + expectedError string + }{ + { + name: "System assigned identity", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: SystemAssignedIdentity, + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "", + }, + { + name: "User assigned identity with client ID", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: UserAssignedIdentity, + UserAssignedClientID: "test-client-id", + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "", + }, + { + name: "User assigned identity without client ID", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: UserAssignedIdentity, + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "user assigned client ID is required when using user assigned identity", + }, + { + name: "Invalid identity type", + opts: ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", + Scopes: []string{"https://redis.azure.com"}, + }, + expectedError: "invalid managed identity type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewManagedIdentityProvider(tt.opts) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + assert.Equal(t, tt.opts.ManagedIdentityType, provider.managedIdentityType) + assert.Equal(t, tt.opts.UserAssignedClientID, provider.userAssignedClientID) + assert.Equal(t, tt.opts.Scopes, provider.scopes) + assert.NotNil(t, provider.client) + } + }) + } +} + +func TestRequestToken(t *testing.T) { + tests := []struct { + name string + provider *ManagedIdentityProvider + expectedError string + }{ + { + name: "Success with default resource", + provider: &ManagedIdentityProvider{ + scopes: []string{}, + client: new(MockManagedIdentityClient), + }, + expectedError: "", + }, + { + name: "Success with custom resource", + provider: &ManagedIdentityProvider{ + scopes: []string{"custom-resource"}, + client: new(MockManagedIdentityClient), + }, + expectedError: "", + }, + { + name: "Error when client is nil", + provider: &ManagedIdentityProvider{ + scopes: []string{}, + client: nil, + }, + expectedError: "managed identity client is not initialized", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the mock expectations if we have a mock client + if tt.provider.client != nil { + mockClient := tt.provider.client.(*MockManagedIdentityClient) + expectedResource := RedisResource + if len(tt.provider.scopes) > 0 { + expectedResource = tt.provider.scopes[0] + } + + if tt.expectedError == "" { + mockClient.On("AcquireToken", mock.Anything, expectedResource). + Return(public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + } else { + mockClient.On("AcquireToken", mock.Anything, expectedResource). + Return(public.AuthResult{}, errors.New(tt.expectedError)) + } + } + + response, err := tt.provider.RequestToken() + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, response) + } else { + assert.NoError(t, err) + assert.NotNil(t, response) + } + + // Verify mock expectations + if tt.provider.client != nil { + mockClient := tt.provider.client.(*MockManagedIdentityClient) + mockClient.AssertExpectations(t) + } + }) + } +} + +func TestRequestToken_ErrorCases(t *testing.T) { + tests := []struct { + name string + provider *ManagedIdentityProvider + setupMock func(*MockManagedIdentityClient) + expectedError string + }{ + { + name: "AcquireToken fails", + provider: &ManagedIdentityProvider{ + scopes: []string{}, + client: new(MockManagedIdentityClient), + }, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, RedisResource). + Return(public.AuthResult{}, errors.New("failed to acquire token")) + }, + expectedError: "couldn't acquire token: failed to acquire token", + }, + { + name: "AcquireToken fails with custom resource", + provider: &ManagedIdentityProvider{ + scopes: []string{"custom-resource"}, + client: new(MockManagedIdentityClient), + }, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, "custom-resource"). + Return(public.AuthResult{}, errors.New("failed to acquire token")) + }, + expectedError: "couldn't acquire token: failed to acquire token", + }, + { + name: "AcquireToken fails with invalid resource", + provider: &ManagedIdentityProvider{ + scopes: []string{"invalid-resource"}, + client: new(MockManagedIdentityClient), + }, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, "invalid-resource"). + Return(public.AuthResult{}, errors.New("invalid resource")) + }, + expectedError: "couldn't acquire token: invalid resource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := tt.provider.client.(*MockManagedIdentityClient) + tt.setupMock(mockClient) + + response, err := tt.provider.RequestToken() + + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Nil(t, response) + mockClient.AssertExpectations(t) + }) + } +} + +// MockMIClient is a mock implementation of the mi.Client interface +type MockMIClient struct { + mock.Mock +} + +func (m *MockMIClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) { + args := m.Called(ctx, resource) + return args.Get(0).(public.AuthResult), args.Error(1) +} + +func (m *MockMIClient) Close() error { + args := m.Called() + return args.Error(0) +} + +func TestRealManagedIdentityClient(t *testing.T) { + // Create a mock managed identity client + mockMIClient := new(MockManagedIdentityClient) + client := &realManagedIdentityClient{client: mockMIClient} + + tests := []struct { + name string + resource string + setupMock func(*MockManagedIdentityClient) + expectedError string + }{ + { + name: "Success with default resource", + resource: RedisResource, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, RedisResource, mock.Anything). + Return(public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + }, + }, + { + name: "Success with custom resource", + resource: "custom-resource", + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, "custom-resource", mock.Anything). + Return(public.AuthResult{ + AccessToken: "test-token", + ExpiresOn: time.Now().Add(time.Hour), + }, nil) + }, + }, + { + name: "Error from underlying client", + resource: RedisResource, + setupMock: func(m *MockManagedIdentityClient) { + m.On("AcquireToken", mock.Anything, RedisResource, mock.Anything). + Return(public.AuthResult{}, errors.New("underlying client error")) + }, + expectedError: "underlying client error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset the mock for each test + mockMIClient.ExpectedCalls = nil + mockMIClient.Calls = nil + + // Set up the mock + tt.setupMock(mockMIClient) + + // Call AcquireToken with empty options slice to match mock setup + result, err := client.AcquireToken(context.Background(), tt.resource, []mi.AcquireTokenOption{}...) + + if tt.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + assert.Equal(t, public.AuthResult{}, result) + } else { + assert.NoError(t, err) + assert.NotEqual(t, public.AuthResult{}, result) + assert.Equal(t, "test-token", result.AccessToken) + } + + // Verify mock expectations + mockMIClient.AssertExpectations(t) + }) + } +} diff --git a/identity/providers.go b/identity/providers.go new file mode 100644 index 0000000..24126f3 --- /dev/null +++ b/identity/providers.go @@ -0,0 +1,25 @@ +package identity + +// CredentialsProviderOptions is a struct that holds the options for the credentials provider. + +const ( + // SystemAssignedIdentity is the type of identity that is automatically managed by Azure. + SystemAssignedIdentity = "SystemAssigned" + // UserAssignedIdentity is the type of identity that is managed by the user. + UserAssignedIdentity = "UserAssigned" + + // ClientSecretCredentialType is the type of credentials that uses a client secret to authenticate. + ClientSecretCredentialType = "ClientSecret" + // ClientCertificateCredentialType is the type of credentials that uses a client certificate to authenticate. + ClientCertificateCredentialType = "ClientCertificate" + + // RedisScopeDefault is the default scope for Redis. + // This is used to specify the scope that the identity has access to when requesting a token. + // The scope is typically the URL of the resource that the identity has access to. + RedisScopeDefault = "https://redis.azure.com/.default" + + // RedisResource is the default resource for Redis. + // This is used to specify the resource that the identity has access to when requesting a token. + // The resource is typically the URL of the resource that the identity has access to. + RedisResource = "https://redis.azure.com" +) diff --git a/identity/providers_test.go b/identity/providers_test.go new file mode 100644 index 0000000..0712d0f --- /dev/null +++ b/identity/providers_test.go @@ -0,0 +1,52 @@ +package identity + +import ( + "testing" +) + +func TestConstants(t *testing.T) { + tests := []struct { + name string + got string + expected string + }{ + { + name: "SystemAssignedIdentity", + got: SystemAssignedIdentity, + expected: "SystemAssigned", + }, + { + name: "UserAssignedIdentity", + got: UserAssignedIdentity, + expected: "UserAssigned", + }, + { + name: "ClientSecretCredentialType", + got: ClientSecretCredentialType, + expected: "ClientSecret", + }, + { + name: "ClientCertificateCredentialType", + got: ClientCertificateCredentialType, + expected: "ClientCertificate", + }, + { + name: "RedisScopeDefault", + got: RedisScopeDefault, + expected: "https://redis.azure.com/.default", + }, + { + name: "RedisResource", + got: RedisResource, + expected: "https://redis.azure.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.expected { + t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.expected) + } + }) + } +} From 62ee78bc054e72911bbda1a2d1acbd6eae2cf351 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 29 Mar 2025 12:00:00 +0000 Subject: [PATCH 06/44] Identity Provider - Default Implementation --- providers.go | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 providers.go diff --git a/providers.go b/providers.go new file mode 100644 index 0000000..f579079 --- /dev/null +++ b/providers.go @@ -0,0 +1,149 @@ +package entraid + +import ( + "fmt" + + "github.com/redis-developer/go-redis-entraid/identity" + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis/go-redis/v9/auth" +) + +// CredentialsProviderOptions is a struct that holds the options for the credentials provider. +// It is used to configure the streaming credentials provider when requesting a token with a token manager. +type CredentialsProviderOptions struct { + // ClientID is the client ID of the identity. + // This is used to identify the identity when requesting a token. + ClientID string + + // TokenManagerOptions is the options for the token manager. + // This is used to configure the token manager when requesting a token. + TokenManagerOptions manager.TokenManagerOptions + + // tokenManagerFactory is a private field that can be injected from within the package. + // It is used to create a token manager for the credentials provider. + tokenManagerFactory func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) +} + +// defaultTokenManagerFactory is the default implementation of the token manager factory. +// It creates a new token manager using the provided identity provider and options. +func defaultTokenManagerFactory(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + return manager.NewTokenManager(provider, options) +} + +// getTokenManagerFactory returns the token manager factory to use. +// If no factory is provided, it returns the default implementation. +func (o *CredentialsProviderOptions) getTokenManagerFactory() func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + if o.tokenManagerFactory == nil { + return defaultTokenManagerFactory + } + return o.tokenManagerFactory +} + +// Managed identity type + +// ManagedIdentityCredentialsProviderOptions is a struct that holds the options for the managed identity credentials provider. +type ManagedIdentityCredentialsProviderOptions struct { + // CredentialsProviderOptions is the options for the credentials provider. + // This is used to configure the credentials provider when requesting a token. + // It is used to specify the client ID, tenant ID, and scopes for the identity. + CredentialsProviderOptions + + // ManagedIdentityProviderOptions is the options for the managed identity provider. + // This is used to configure the managed identity provider when requesting a token. + identity.ManagedIdentityProviderOptions +} + +// NewManagedIdentityCredentialsProvider creates a new streaming credentials provider for managed identity. +// It uses the provided options to configure the provider. +// Use this when you want either a system assigned identity or a user assigned identity. +// The system assigned identity is automatically managed by Azure and does not require any additional configuration. +// The user assigned identity is a separate resource that can be managed independently. +func NewManagedIdentityCredentialsProvider(options ManagedIdentityCredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + // Create a new identity provider using the managed identity type. + idp, err := identity.NewManagedIdentityProvider(options.ManagedIdentityProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create managed identity provider: %w", err) + } + + // Create a new token manager using the identity provider. + tokenManager, err := options.getTokenManagerFactory()(idp, options.TokenManagerOptions) + if err != nil { + return nil, fmt.Errorf("cannot create token manager: %w", err) + } + // Create a new credentials provider using the token manager. + credentialsProvider, err := NewCredentialsProvider(tokenManager, options.CredentialsProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create credentials provider: %w", err) + } + + return credentialsProvider, nil +} + +// ConfidentialCredentialsProviderOptions is a struct that holds the options for the confidential credentials provider. +// It is used to configure the credentials provider when requesting a token. +type ConfidentialCredentialsProviderOptions struct { + // CredentialsProviderOptions is the options for the credentials provider. + // This is used to configure the credentials provider when requesting a token. + CredentialsProviderOptions + + // ConfidentialIdentityProviderOptions is the options for the confidential identity provider. + // This is used to configure the identity provider when requesting a token. + identity.ConfidentialIdentityProviderOptions +} + +// NewConfidentialCredentialsProvider creates a new confidential credentials provider. +// It uses client id and client credentials to authenticate with the identity provider. +// The client credentials can be either a client secret or a client certificate. +func NewConfidentialCredentialsProvider(options ConfidentialCredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + // Create a new identity provider using the client ID and client credentials. + idp, err := identity.NewConfidentialIdentityProvider(options.ConfidentialIdentityProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create confidential identity provider: %w", err) + } + + // Create a new token manager using the identity provider. + tokenManager, err := options.getTokenManagerFactory()(idp, options.TokenManagerOptions) + if err != nil { + return nil, fmt.Errorf("cannot create token manager: %w", err) + } + + // Create a new credentials provider using the token manager. + credentialsProvider, err := NewCredentialsProvider(tokenManager, options.CredentialsProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create credentials provider: %w", err) + } + return credentialsProvider, nil +} + +// DefaultAzureCredentialsProviderOptions is a struct that holds the options for the default azure credentials provider. +// It is used to configure the credentials provider when requesting a token. +type DefaultAzureCredentialsProviderOptions struct { + CredentialsProviderOptions + identity.DefaultAzureIdentityProviderOptions +} + +// NewDefaultAzureCredentialsProvider creates a new default azure credentials provider. +// It uses the default azure identity provider to authenticate with the identity provider. +// The default azure identity provider is a special type of identity provider that uses the default azure identity to authenticate. +// It is used to authenticate with the identity provider when requesting a token. +func NewDefaultAzureCredentialsProvider(options DefaultAzureCredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + // Create a new identity provider using the default azure identity type. + idp, err := identity.NewDefaultAzureIdentityProvider(options.DefaultAzureIdentityProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create default azure identity provider: %w", err) + } + + // Create a new token manager using the identity provider. + tokenManager, err := options.getTokenManagerFactory()(idp, options.TokenManagerOptions) + if err != nil { + return nil, fmt.Errorf("cannot create token manager: %w", err) + } + + // Create a new credentials provider using the token manager. + credentialsProvider, err := NewCredentialsProvider(tokenManager, options.CredentialsProviderOptions) + if err != nil { + return nil, fmt.Errorf("cannot create credentials provider: %w", err) + } + return credentialsProvider, nil +} From 9219edc2009269263e961eeddb29d8324c546f57 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sun, 30 Mar 2025 12:00:00 +0000 Subject: [PATCH 07/44] Token Management - Core Interface --- manager/defaults.go | 180 +++++ manager/errors.go | 9 + manager/manager_test.go | 180 +++++ manager/token_manager.go | 358 +++++++++ manager/token_manager_test.go | 1372 +++++++++++++++++++++++++++++++++ 5 files changed, 2099 insertions(+) create mode 100644 manager/defaults.go create mode 100644 manager/errors.go create mode 100644 manager/manager_test.go create mode 100644 manager/token_manager.go create mode 100644 manager/token_manager_test.go diff --git a/manager/defaults.go b/manager/defaults.go new file mode 100644 index 0000000..56587d0 --- /dev/null +++ b/manager/defaults.go @@ -0,0 +1,180 @@ +package manager + +import ( + "errors" + "fmt" + "net" + "os" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" +) + +const ( + DefaultExpirationRefreshRatio = 0.7 + DefaultRetryOptionsMaxAttempts = 3 + DefaultRetryOptionsInitialDelayMs = 1000 + DefaultRetryOptionsBackoffMultiplier = 2.0 + DefaultRetryOptionsMaxDelayMs = 10000 +) + +// defaultIsRetryable is a function that checks if the error is retriable. +// It takes an error as an argument and returns a boolean value. +// The function checks if the error is a net.Error and if it is a timeout or temporary error. +// Returns true for nil errors. +var defaultIsRetryable = func(err error) bool { + if err == nil { + return true + } + + var netErr net.Error + if errors.As(err, &netErr) { + // Check for timeout first as it's more specific + if netErr.Timeout() { + return true + } + // For temporary errors, we'll use a more modern approach + var tempErr interface{ Temporary() bool } + if errors.As(err, &tempErr) { + return tempErr.Temporary() + } + } + + return errors.Is(err, os.ErrDeadlineExceeded) +} + +// defaultRetryOptionsOr returns the default retry options if the provided options are not set. +// It sets the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. +// The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. +// The values can be overridden by the user. +func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions { + if retryOptions.IsRetryable == nil { + retryOptions.IsRetryable = defaultIsRetryable + } + + if retryOptions.MaxAttempts <= 0 { + retryOptions.MaxAttempts = DefaultRetryOptionsMaxAttempts + } + if retryOptions.InitialDelayMs == 0 { + retryOptions.InitialDelayMs = DefaultRetryOptionsInitialDelayMs + } + if retryOptions.BackoffMultiplier == 0 { + retryOptions.BackoffMultiplier = DefaultRetryOptionsBackoffMultiplier + } + if retryOptions.MaxDelayMs == 0 { + retryOptions.MaxDelayMs = DefaultRetryOptionsMaxDelayMs + } + return retryOptions +} + +// defaultIdentityProviderResponseParserOr returns the default token parser if the provided token parser is not set. +// It sets the default token parser to the defaultIdentityProviderResponseParser function. +// The default token parser is used to parse the raw token and return a Token object. +func defaultIdentityProviderResponseParserOr(idpResponseParser shared.IdentityProviderResponseParser) shared.IdentityProviderResponseParser { + if idpResponseParser == nil { + return &defaultIdentityProviderResponseParser{} + } + return idpResponseParser +} + +func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptions { + options.RetryOptions = defaultRetryOptionsOr(options.RetryOptions) + options.IdentityProviderResponseParser = defaultIdentityProviderResponseParserOr(options.IdentityProviderResponseParser) + if options.ExpirationRefreshRatio == 0 { + options.ExpirationRefreshRatio = DefaultExpirationRefreshRatio + } + return options +} + +type defaultIdentityProviderResponseParser struct{} + +// ParseResponse parses the response from the identity provider and extracts the token. +// It takes an IdentityProviderResponse as an argument and returns a Token and an error if any. +// The IdentityProviderResponse contains the raw token and the expiration time. +func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) { + if response == nil { + return nil, fmt.Errorf("identity provider response cannot be nil") + } + + var username, password, rawToken string + var expiresOn time.Time + now := time.Now().UTC() + + switch response.Type() { + case shared.ResponseTypeAuthResult: + authResult := response.AuthResult() + if authResult.ExpiresOn.IsZero() { + return nil, fmt.Errorf("auth result expiration time is not set") + } + if authResult.IDToken.Oid == "" { + return nil, fmt.Errorf("auth result OID is empty") + } + rawToken = authResult.IDToken.RawToken + username = authResult.IDToken.Oid + password = rawToken + expiresOn = authResult.ExpiresOn.UTC() + + case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken: + tokenStr := response.RawToken() + + if response.Type() == shared.ResponseTypeAccessToken { + accessToken := response.AccessToken() + if accessToken.Token == "" { + return nil, fmt.Errorf("access token value is empty") + } + tokenStr = accessToken.Token + expiresOn = accessToken.ExpiresOn.UTC() + } + + if tokenStr == "" { + return nil, fmt.Errorf("raw token is empty") + } + + claims := struct { + jwt.RegisteredClaims + Oid string `json:"oid,omitempty"` + }{} + + // Parse the token to extract claims, but note that signature verification + // should be handled by the identity provider + _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT token: %w", err) + } + + if claims.Oid == "" { + return nil, fmt.Errorf("JWT token does not contain OID claim") + } + + rawToken = tokenStr + username = claims.Oid + password = rawToken + + if expiresOn.IsZero() && claims.ExpiresAt != nil { + expiresOn = claims.ExpiresAt.UTC() + } + + default: + return nil, fmt.Errorf("unsupported response type: %s", response.Type()) + } + + if expiresOn.IsZero() { + return nil, fmt.Errorf("token expiration time is not set") + } + + if expiresOn.Before(now) { + return nil, fmt.Errorf("token has expired at %s (current time: %s)", expiresOn, now) + } + + // Create the token with consistent time reference + return token.New( + username, + password, + rawToken, + expiresOn, + now, + int64(time.Until(expiresOn).Seconds()), + ), nil +} diff --git a/manager/errors.go b/manager/errors.go new file mode 100644 index 0000000..840d46d --- /dev/null +++ b/manager/errors.go @@ -0,0 +1,9 @@ +package manager + +import "fmt" + +// ErrTokenManagerAlreadyCanceled is returned when the token manager is already canceled. +var ErrTokenManagerAlreadyCanceled = fmt.Errorf("token manager already canceled") + +// ErrTokenManagerAlreadyStarted is returned when the token manager is already started. +var ErrTokenManagerAlreadyStarted = fmt.Errorf("token manager already started") diff --git a/manager/manager_test.go b/manager/manager_test.go new file mode 100644 index 0000000..6d6bd32 --- /dev/null +++ b/manager/manager_test.go @@ -0,0 +1,180 @@ +package manager + +import ( + "net" + "os" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/mock" +) + +// testJWTToken is a JWT token for testing +// +// { +// "iss": "test jwt", +// "iat": 1743515011, +// "exp": 1775051011, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0MzUxNTAxMSwiZXhwIjoxNzc1MDUxMDExLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.6RG721V2eFlSLsCRmo53kSRRrTZIe1UPdLZCUEvIarU" + +// testJWTExpiredToken is an expired JWT token for testing +// +// { +// "iss": "test jwt", +// "iat": 1617795148, +// "exp": 1617795148, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTExpiredToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTYxNzc5NTE0OCwiZXhwIjoxNjE3Nzk1MTQ4LCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.IbGPhHRiPYcpUDrhAPf4h3gH1XXBOu560NYT59rUMzc" + +// testJWTWithZeroExpiryToken is a JWT token with zero expiry for testing +// +// { +// "iss": "test jwt", +// "iat": 1744025944, +// "exp": null, +// "aud": "www.example.com", +// "sub": "test@test.com", +// "oid": "test" +// } +// key: qwertyuiopasdfghjklzxcvbnm123456 +const testJWTWithZeroExpiryToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0NDAyNTk0NCwiZXhwIjpudWxsLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.bLSANIzawE5Y6rgspvvUaRhkBq6Y4E0ggjXlmHRn8ew" + +var testTokenValid = token.New( + "test", + "password", + "test", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), +) + +type mockIdentityProviderResponseParser struct { + // Mock implementation of the IdentityProviderResponseParser interface + mock.Mock +} + +func (m *mockIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) { + args := m.Called(response) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*token.Token), args.Error(1) +} + +type mockIdentityProvider struct { + // Mock implementation of the mockIdentityProvider interface + // Add any necessary fields or methods for the mock identity provider here + mock.Mock +} + +func (m *mockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(shared.IdentityProviderResponse), args.Error(1) +} + +func newMockError(retriable bool) error { + if retriable { + return &mockError{ + isTimeout: true, + isTemporary: true, + error: os.ErrDeadlineExceeded, + } + } else { + return &mockError{ + isTimeout: false, + isTemporary: false, + error: os.ErrInvalid, + } + } +} + +type mockError struct { + // Mock implementation of the network error + error + isTimeout bool + isTemporary bool +} + +func (m *mockError) Error() string { + return "this is mock error" +} + +func (m *mockError) Timeout() bool { + return m.isTimeout +} +func (m *mockError) Temporary() bool { + return m.isTemporary +} +func (m *mockError) Unwrap() error { + return m.error +} + +func (m *mockError) Is(err error) bool { + return m.error == err +} + +var _ net.Error = (*mockError)(nil) + +type mockTokenListener struct { + // Mock implementation of the TokenManagerListener interface + mock.Mock + Id int32 +} + +func (m *mockTokenListener) OnTokenNext(token *token.Token) { + _ = m.Called(token) +} + +func (m *mockTokenListener) OnTokenError(err error) { + _ = m.Called(err) +} + +type authResult struct { + // ResultType is the type of the response (AuthResult, AccessToken, or RawToken) + ResultType string + // AuthResultVal is the auth result value + AuthResultVal *public.AuthResult + // AccessTokenVal is the access token value + AccessTokenVal *azcore.AccessToken + // RawTokenVal is the raw token value + RawTokenVal string +} + +func (a *authResult) Type() string { + return a.ResultType +} + +func (a *authResult) AuthResult() public.AuthResult { + if a.AuthResultVal == nil { + return public.AuthResult{} + } + return *a.AuthResultVal +} + +func (a *authResult) AccessToken() azcore.AccessToken { + if a.AccessTokenVal == nil { + return azcore.AccessToken{} + } + return *a.AccessTokenVal +} + +func (a *authResult) RawToken() string { + return a.RawTokenVal +} diff --git a/manager/token_manager.go b/manager/token_manager.go new file mode 100644 index 0000000..6d4d8b0 --- /dev/null +++ b/manager/token_manager.go @@ -0,0 +1,358 @@ +package manager + +import ( + "fmt" + "sync" + "time" + + "github.com/redis-developer/go-redis-entraid/internal" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" +) + +// TokenManagerOptions is a struct that contains the options for the TokenManager. +type TokenManagerOptions struct { + // ExpirationRefreshRatio is the ratio of the token expiration time to refresh the token. + // It is used to determine when to refresh the token. + // The value should be between 0 and 1. + // For example, if the expiration time is 1 hour and the ratio is 0.75, + // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) + // + // default: 0.7 + ExpirationRefreshRatio float64 + // LowerRefreshBoundMs is the lower bound for the refresh time in milliseconds. + // Represents the minimum time in milliseconds before token expiration to trigger a refresh. + // This value sets a fixed lower bound for when a token refresh should occur, regardless + // of the token's total lifetime. + // + // default: 0 ms (no lower bound, refresh based on ExpirationRefreshRatio) + LowerRefreshBoundMs int64 + + // IdentityProviderResponseParser is an optional object that implements the IdentityProviderResponseParser interface. + // It is used to parse the response from the identity provider and extract the token. + // If not provided, the default implementation will be used. + // The objects ParseResponse method will be called to parse the response and return the token. + // + // required: false + // default: defaultIdentityProviderResponseParser + IdentityProviderResponseParser shared.IdentityProviderResponseParser + // RetryOptions is a struct that contains the options for retrying the token request. + // It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. + // + // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. + RetryOptions RetryOptions +} + +// RetryOptions is a struct that contains the options for retrying the token request. +type RetryOptions struct { + // IsRetryable is a function that checks if the error is retriable. + // It takes an error as an argument and returns a boolean value. + // + // default: defaultRetryableFunc + IsRetryable func(err error) bool + // MaxAttempts is the maximum number of attempts to retry the token request. + // + // default: 3 + MaxAttempts int + // InitialDelayMs is the initial delay in milliseconds before retrying the token request. + // + // default: 1000 ms + InitialDelayMs int + // MaxDelayMs is the maximum delay in milliseconds between retry attempts. + // + // default: 10000 ms + MaxDelayMs int + // BackoffMultiplier is the multiplier for the backoff delay. + // default: 2.0 + BackoffMultiplier float64 +} + +// TokenManager is an interface that defines the methods for managing tokens. +// It provides methods to get a token and start the token manager. +// The TokenManager is responsible for obtaining and refreshing the token. +// It is typically used in conjunction with an IdentityProvider to obtain the token. +type TokenManager interface { + // GetToken returns the token for authentication. + // It takes a boolean value forceRefresh as an argument. + GetToken(forceRefresh bool) (*token.Token, error) + // Start starts the token manager and returns a channel that will receive updates. + Start(listener TokenListener) (CancelFunc, error) + // Close closes the token manager and releases any resources. + Close() error +} + +// CancelFunc is a function that cancels the token manager. +type CancelFunc func() error + +// TokenListener is an interface that contains the methods for receiving updates from the token manager. +// The token manager will call the listener's OnTokenNext method with the updated token. +// If an error occurs, the token manager will call the listener's OnTokenError method with the error. +type TokenListener interface { + // OnTokenNext is called when the token is updated. + OnTokenNext(t *token.Token) + // OnTokenError is called when an error occurs. + OnTokenError(err error) +} + +// entraidIdentityProviderResponseParser is the default implementation of the IdentityProviderResponseParser interface. +var entraidIdentityProviderResponseParser shared.IdentityProviderResponseParser = &defaultIdentityProviderResponseParser{} + +// NewTokenManager creates a new TokenManager. +// It takes an IdentityProvider and TokenManagerOptions as arguments and returns a TokenManager interface. +// The IdentityProvider is used to obtain the token, and the TokenManagerOptions contains options for the TokenManager. +// The TokenManager is responsible for managing the token and refreshing it when necessary. +func NewTokenManager(idp shared.IdentityProvider, options TokenManagerOptions) (TokenManager, error) { + if options.ExpirationRefreshRatio < 0 || options.ExpirationRefreshRatio > 1 { + return nil, fmt.Errorf("expiration refresh ratio must be between 0 and 1") + } + options = defaultTokenManagerOptionsOr(options) + + if idp == nil { + return nil, fmt.Errorf("identity provider is required") + } + + return &entraidTokenManager{ + idp: idp, + token: nil, + closedChan: nil, + expirationRefreshRatio: options.ExpirationRefreshRatio, + lowerRefreshBoundMs: options.LowerRefreshBoundMs, + lowerBoundDuration: time.Duration(options.LowerRefreshBoundMs) * time.Millisecond, + identityProviderResponseParser: options.IdentityProviderResponseParser, + retryOptions: options.RetryOptions, + }, nil +} + +// entraidTokenManager is a struct that implements the TokenManager interface. +type entraidTokenManager struct { + // idp is the identity provider used to obtain the token. + idp shared.IdentityProvider + + // token is the authentication token for the user which should be kept in memory if valid. + token *token.Token + + // tokenRWLock is a read-write lock used to protect the token from concurrent access. + tokenRWLock sync.RWMutex + + // identityProviderResponseParser is the parser used to parse the response from the identity provider. + // It`s ParseResponse method will be called to parse the response and return the token. + identityProviderResponseParser shared.IdentityProviderResponseParser + + // retryOptions is a struct that contains the options for retrying the token request. + // It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. + // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. + // The values can be overridden by the user. + retryOptions RetryOptions + + // listener is the single listener for the token manager. + // It is used to receive updates from the token manager. + // The token manager will call the listener's OnTokenNext method with the updated token. + // If an error occurs, the token manager will call the listener's OnTokenError method with the error. + // if listener is set, Start will fail + listener TokenListener + + // lock locks the listener to prevent concurrent access. + lock sync.Mutex + + // expirationRefreshRatio is the ratio of the token expiration time to refresh the token. + // It is used to determine when to refresh the token. + // The value should be between 0 and 1. + // For example, if the expiration time is 1 hour and the ratio is 0.75, + // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) + expirationRefreshRatio float64 + + // lowerRefreshBoundMs is the lower bound for the refresh time in milliseconds. + // Represents the minimum time in milliseconds before token expiration to trigger a refresh, in milliseconds. + // This value sets a fixed lower bound for when a token refresh should occur, regardless + // of the token's total lifetime. + lowerRefreshBoundMs int64 + + // lowerBoundDuration is the lower bound for the refresh time in time.Duration. + lowerBoundDuration time.Duration + + // closedChan is a channel that is closedChan when the token manager is closedChan. + // It is used to signal the token manager to stop requesting tokens. + closedChan chan struct{} +} + +func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + e.tokenRWLock.RLock() + // check if the token is nil and if it is not expired + if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { + t := e.token + e.tokenRWLock.RUnlock() + return t, nil + } + e.tokenRWLock.RUnlock() + + // Upgrade to write lock for token update + e.tokenRWLock.Lock() + defer e.tokenRWLock.Unlock() + + // Double-check pattern to avoid unnecessary token refresh + if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { + return e.token, nil + } + + idpResult, err := e.idp.RequestToken() + if err != nil { + return nil, fmt.Errorf("failed to request token from idp: %w", err) + } + + t, err := e.identityProviderResponseParser.ParseResponse(idpResult) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + if t == nil { + return nil, fmt.Errorf("failed to get token: token is nil") + } + + // Store the token + e.token = t + // Return the token - no need to copy since it's immutable + return t, nil +} + +// Start starts the token manager and returns cancelFunc to stop the token manager. +// It takes a TokenListener as an argument, which is used to receive updates. +// The token manager will call the listener's OnTokenNext method with the updated token. +// If an error occurs, the token manager will call the listener's OnError method with the error. +// +// Note: The initial token is delivered synchronously. +// The TokenListener will receive the token immediately, before the token manager goroutine starts. +func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error) { + e.lock.Lock() + defer e.lock.Unlock() + if e.listener != nil { + return nil, ErrTokenManagerAlreadyStarted + } + + if e.closedChan != nil && !internal.IsClosed(e.closedChan) { + // there is a hanging goroutine that is waiting for the closedChan to be closed + // if the closedChan is not nil and not closed, close it + close(e.closedChan) + } + + t, err := e.GetToken(true) + if err != nil { + go listener.OnTokenError(err) + return nil, fmt.Errorf("failed to start token manager: %w", err) + } + + // Deliver initial token synchronously + listener.OnTokenNext(t) + + e.closedChan = make(chan struct{}) + e.listener = listener + + go func(listener TokenListener, closed <-chan struct{}) { + maxDelay := time.Duration(e.retryOptions.MaxDelayMs) * time.Millisecond + initialDelay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond + + for { + timeToRenewal := e.durationToRenewal() + select { + case <-closed: + return + case <-time.After(timeToRenewal): + if timeToRenewal == 0 { + // Token was requested immediately, guard against infinite loop + select { + case <-closed: + return + case <-time.After(initialDelay): + // continue to attempt + } + } + + // Token is about to expire, refresh it + delay := initialDelay + for i := 0; i < e.retryOptions.MaxAttempts; i++ { + t, err := e.GetToken(true) + if err == nil { + listener.OnTokenNext(t) + break + } + + // check if err is retriable + if e.retryOptions.IsRetryable(err) { + if i == e.retryOptions.MaxAttempts-1 { + // last attempt, call OnTokenError + listener.OnTokenError(fmt.Errorf("max attempts reached: %w", err)) + return + } + + // Exponential backoff + if delay < maxDelay { + delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier) + } + if delay > maxDelay { + delay = maxDelay + } + + select { + case <-closed: + return + case <-time.After(delay): + // continue to next attempt + } + } else { + // not retriable + listener.OnTokenError(err) + return + } + } + } + } + }(listener, e.closedChan) + + return e.Close, nil +} + +// Close closes the token manager and releases any resources. +func (e *entraidTokenManager) Close() error { + e.lock.Lock() + defer e.lock.Unlock() + + if e.closedChan == nil || e.listener == nil { + return ErrTokenManagerAlreadyCanceled + } + e.listener = nil + close(e.closedChan) + + return nil +} + +// durationToRenewal calculates the duration to the next token renewal. +// It returns the duration to the next token renewal based on the expiration refresh ratio and the lower bound duration. +// If the token is nil, it returns 0. +// If the time till expiration is less than the lower bound duration, it returns 0 to renew the token now. +func (e *entraidTokenManager) durationToRenewal() time.Duration { + e.tokenRWLock.RLock() + if e.token == nil { + e.tokenRWLock.RUnlock() + return 0 + } + + timeTillExpiration := time.Until(e.token.ExpirationOn()) + e.tokenRWLock.RUnlock() + + // if the timeTillExpiration is less than the lower bound (or 0), return 0 to renew the token NOW + if timeTillExpiration <= e.lowerBoundDuration || timeTillExpiration <= 0 { + return 0 + } + + // Calculate the time to renew the token based on the expiration refresh ratio + // Since timeTillExpiration is guarded by the lower bound, we can safely multiply it by the ratio + // and assume the duration is a positive number + duration := time.Duration(float64(timeTillExpiration) * e.expirationRefreshRatio) + + // if the duration will take us past the lower bound, return the duration to lower bound + if timeTillExpiration-e.lowerBoundDuration < duration { + return timeTillExpiration - e.lowerBoundDuration + } + + // return the calculated duration + return duration +} diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go new file mode 100644 index 0000000..433f808 --- /dev/null +++ b/manager/token_manager_test.go @@ -0,0 +1,1372 @@ +package manager + +import ( + "fmt" + "log" + "math/rand" + "os" + "reflect" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var assertFuncNameMatches = func(t *testing.T, func1, func2 interface{}) { + funcName1 := runtime.FuncForPC(reflect.ValueOf(func1).Pointer()).Name() + funcName2 := runtime.FuncForPC(reflect.ValueOf(func2).Pointer()).Name() + assert.Equal(t, funcName1, funcName2) +} + +func TestTokenManager(t *testing.T) { + t.Parallel() + t.Run("Without IDP", func(t *testing.T) { + t.Parallel() + tokenManager, err := NewTokenManager(nil, + TokenManagerOptions{}, + ) + assert.Error(t, err) + assert.Nil(t, tokenManager) + }) + + t.Run("With IDP", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + }) +} + +func TestTokenManagerWithOptions(t *testing.T) { + t.Parallel() + t.Run("Bad Expiration Refresh Ration", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + options := TokenManagerOptions{ + ExpirationRefreshRatio: 5, + } + tokenManager, err := NewTokenManager(idp, options) + assert.Error(t, err) + assert.Nil(t, tokenManager) + }) + t.Run("With IDP and Options", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + options := TokenManagerOptions{ + ExpirationRefreshRatio: 0.5, + } + tokenManager, err := NewTokenManager(idp, options) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Equal(t, 0.5, tm.expirationRefreshRatio) + }) + t.Run("Default Options", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + options := TokenManagerOptions{} + tokenManager, err := NewTokenManager(idp, options) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Equal(t, DefaultExpirationRefreshRatio, tm.expirationRefreshRatio) + assert.NotNil(t, tm.retryOptions.IsRetryable) + assertFuncNameMatches(t, tm.retryOptions.IsRetryable, defaultIsRetryable) + assert.Equal(t, DefaultRetryOptionsMaxAttempts, tm.retryOptions.MaxAttempts) + assert.Equal(t, DefaultRetryOptionsInitialDelayMs, tm.retryOptions.InitialDelayMs) + assert.Equal(t, DefaultRetryOptionsMaxDelayMs, tm.retryOptions.MaxDelayMs) + assert.Equal(t, DefaultRetryOptionsBackoffMultiplier, tm.retryOptions.BackoffMultiplier) + }) +} + +func TestDefaultIdentityProviderResponseParserOr(t *testing.T) { + t.Parallel() + var f shared.IdentityProviderResponseParser = &mockIdentityProviderResponseParser{} + + result := defaultIdentityProviderResponseParserOr(f) + assert.NotNil(t, result) + assert.Equal(t, result, f) + + defaultParser := defaultIdentityProviderResponseParserOr(nil) + assert.NotNil(t, defaultParser) + assert.NotEqual(t, defaultParser, f) + assert.Equal(t, entraidIdentityProviderResponseParser, defaultParser) +} + +func TestDefaultIsRetryable(t *testing.T) { + t.Parallel() + // with network error timeout + t.Run("Non-Retryable Error", func(t *testing.T) { + t.Parallel() + err := &azcore.ResponseError{ + StatusCode: 500, + } + is := defaultIsRetryable(err) + assert.False(t, is) + }) + + t.Run("Nil Error", func(t *testing.T) { + t.Parallel() + var err error + is := defaultIsRetryable(err) + assert.True(t, is) + + is = defaultIsRetryable(nil) + assert.True(t, is) + }) + + t.Run("Retryable Error with Timeout", func(t *testing.T) { + t.Parallel() + err := newMockError(true) + result := defaultIsRetryable(err) + assert.True(t, result) + }) + t.Run("Retryable Error with Temporary", func(t *testing.T) { + t.Parallel() + err := newMockError(true) + result := defaultIsRetryable(err) + assert.True(t, result) + }) + + t.Run("Retryable Error with err parent of os.ErrDeadlineExceeded", func(t *testing.T) { + t.Parallel() + err := fmt.Errorf("timeout: %w", os.ErrDeadlineExceeded) + res := defaultIsRetryable(err) + assert.True(t, res) + }) +} + +func TestTokenManager_Close(t *testing.T) { + t.Parallel() + t.Run("Close", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + assert.NotPanics(t, func() { + err = tokenManager.Close() + assert.Error(t, err) + }) + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + }) + assert.NotNil(t, tm.listener) + + err = tokenManager.Close() + assert.Nil(t, tm.listener) + assert.NoError(t, err) + + assert.NotPanics(t, func() { + err = tokenManager.Close() + assert.Error(t, err) + }) + }) + + t.Run("Close with Cancel", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + err = cancel() + assert.NoError(t, err) + assert.Nil(t, tm.listener) + err = cancel() + assert.Error(t, err) + assert.Nil(t, tm.listener) + }) + }) + t.Run("Close in multiple threads", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + var hasStopped int + var alreadyStopped int32 + wg := &sync.WaitGroup{} + + // Start 50000 goroutines to close the token manager + // and check if the listener is nil after each close. + numExecutions := 50000 + for i := 0; i < numExecutions; i++ { + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) + err := tokenManager.Close() + if err == nil { + hasStopped += 1 + return + } else { + atomic.AddInt32(&alreadyStopped, 1) + } + assert.Nil(t, tm.listener) + assert.Error(t, err) + }() + } + wg.Wait() + assert.Nil(t, tm.listener) + assert.Equal(t, 1, hasStopped) + assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStopped)) + }) + }) +} + +func TestTokenManager_Start(t *testing.T) { + t.Parallel() + t.Run("Start in multiple threads", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + assert.NotPanics(t, func() { + var hasStarted int + var alreadyStarted int32 + wg := &sync.WaitGroup{} + + numExecutions := 50000 + for i := 0; i < numExecutions; i++ { + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) + _, err := tokenManager.Start(listener) + if err == nil { + hasStarted += 1 + return + } else { + atomic.AddInt32(&alreadyStarted, 1) + } + assert.NotNil(t, tm.listener) + assert.Error(t, err) + }() + } + wg.Wait() + assert.NotNil(t, tm.listener) + assert.Equal(t, 1, hasStarted) + assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStarted)) + cancel, err := tokenManager.Start(listener) + assert.Nil(t, cancel) + assert.Error(t, err) + assert.NotNil(t, tm.listener) + }) + }) + + t.Run("concurrent stress token manager", func(t *testing.T) { + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + assert.NotPanics(t, func() { + last := &atomic.Int32{} + wg := &sync.WaitGroup{} + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + numExecutions := int32(50000) + for i := int32(0); i < numExecutions; i++ { + wg.Add(1) + go func(num int32) { + defer wg.Done() + var err error + time.Sleep(time.Duration(int64(rand.Intn(1000)+(300-int(num)/2)) * int64(time.Millisecond))) + last.Store(num) + if num%2 == 0 { + err = tokenManager.Close() + } else { + l := &mockTokenListener{Id: num} + l.On("OnTokenNext", testTokenValid).Return() + _, err = tokenManager.Start(l) + } + if err != nil { + if err != ErrTokenManagerAlreadyCanceled && err != ErrTokenManagerAlreadyStarted { + // this is un unexpected error, fail the test + assert.Error(t, err) + } + } + }(i) + } + wg.Wait() + lastExecution := last.Load() + if lastExecution%2 == 0 { + if tm.listener != nil { + l := tm.listener.(*mockTokenListener) + log.Printf("FAILING WITH lastExecution [STARTED]:[LISTENER:%d]: %d", l.Id, lastExecution) + } + assert.Nil(t, tm.listener) + } else { + if tm.listener == nil { + log.Printf("FAILING WITH lastExecution[STOPPED]: %d", lastExecution) + } + assert.NotNil(t, tm.listener) + cancel, err := tokenManager.Start(listener) + assert.Nil(t, cancel) + assert.Error(t, err) + // Close the token manager + err = tokenManager.Close() + assert.Nil(t, err) + } + assert.Nil(t, tm.listener) + }) + }) +} + +func TestDefaultIdentityProviderResponseParser(t *testing.T) { + t.Parallel() + parser := &defaultIdentityProviderResponseParser{} + t.Run("Default IdentityProviderResponseParser with type AuthResult", func(t *testing.T) { + t.Parallel() + authResultVal := testAuthResult(time.Now().Add(time.Hour).UTC()) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + token1, err := parser.ParseResponse(idpResponse) + assert.NoError(t, err) + assert.NotNil(t, token1) + assert.Equal(t, authResultVal.ExpiresOn, token1.ExpirationOn()) + }) + t.Run("Default IdentityProviderResponseParser with type AccessToken", func(t *testing.T) { + t.Parallel() + accessToken := &azcore.AccessToken{ + Token: testJWTToken, + ExpiresOn: time.Now().Add(time.Hour).UTC(), + } + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAccessToken, + AccessTokenVal: accessToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.NoError(t, err) + assert.NotNil(t, token1) + assert.Equal(t, accessToken.ExpiresOn, token1.ExpirationOn()) + assert.Equal(t, accessToken.Token, token1.RawCredentials()) + }) + t.Run("Default IdentityProviderResponseParser with type RawToken", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: testJWTToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.NoError(t, err) + assert.NotNil(t, token1) + }) + + t.Run("Default IdentityProviderResponseParser with expired JWT Token", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: testJWTExpiredToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("Default IdentityProviderResponseParser with zero expiry JWT Token", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: testJWTWithZeroExpiryToken, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("Default IdentityProviderResponseParser with type Unknown", func(t *testing.T) { + t.Parallel() + idpResponse := &authResult{ + ResultType: "Unknown", + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + types := []string{ + shared.ResponseTypeAuthResult, + shared.ResponseTypeAccessToken, + shared.ResponseTypeRawToken, + } + for _, rt := range types { + t.Run(fmt.Sprintf("Default IdentityProviderResponseParser with response type %s and nil value", rt), func(t *testing.T) { + idpResponse := &authResult{ + ResultType: rt, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + } + + t.Run("Default IdentityProviderResponseParser with response nil", func(t *testing.T) { + t.Parallel() + token1, err := parser.ParseResponse(nil) + assert.Error(t, err) + assert.Nil(t, token1) + }) + t.Run("Default IdentityProviderResponseParser with expired token", func(t *testing.T) { + t.Parallel() + authResultVal := testAuthResult(time.Now().Add(-time.Hour)) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) +} + +func TestEntraidTokenManager_GetToken(t *testing.T) { + t.Parallel() + t.Run("GetToken", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + token1, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token1) + }) + + t.Run("GetToken with parse error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(nil, fmt.Errorf("parse error")) + listener.On("OnTokenError", mock.Anything).Return() + + cancel, err := tokenManager.Start(listener) + assert.Error(t, err) + assert.Nil(t, cancel) + assert.Nil(t, tm.listener) + }) + t.Run("GetToken with expired token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + + authResultVal := testAuthResult(time.Now().Add(-time.Hour)) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + idp.On("RequestToken").Return(idpResponse, nil) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("GetToken with nil token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + _, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + + idp.On("RequestToken").Return(rawResponse, nil) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("GetToken with nil from parser", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + _, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + assert.NoError(t, err) + idp.On("RequestToken").Return(idpResponse, nil) + mParser.On("ParseResponse", idpResponse).Return(nil, nil) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) + + t.Run("GetToken with idp error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + _, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + idp.On("RequestToken").Return(nil, fmt.Errorf("idp error")) + + token1, err := tokenManager.GetToken(false) + assert.Error(t, err) + assert.Nil(t, token1) + }) +} + +func TestEntraidTokenManager_durationToRenewal(t *testing.T) { + t.Parallel() + t.Run("durationToRenewal", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ + LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + }) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + + result := tm.durationToRenewal() + // returns 0 for nil token + assert.Equal(t, time.Duration(0), result) + + // get token that expires before the lower bound + assert.NotPanics(t, func() { + expiresSoon := testAuthResult(time.Now().Add(tm.lowerBoundDuration - time.Minute).UTC()) + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, + expiresSoon) + assert.NoError(t, err) + idp.On("RequestToken").Return(idpResponse, nil).Once() + tm.token = nil + _, err = tm.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, tm.token) + + // return zero, should happen now since it expires before the lower bound + result = tm.durationToRenewal() + assert.Equal(t, time.Duration(0), result) + }) + + // get token that expires after the lower bound and expirationRefreshRatio to 1 + assert.NotPanics(t, func() { + tm.expirationRefreshRatio = 1 + expiresAfterlb := testAuthResult(time.Now().Add(tm.lowerBoundDuration + time.Hour).UTC()) + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, + expiresAfterlb) + assert.NoError(t, err) + idp.On("RequestToken").Return(idpResponse, nil).Once() + tm.token = nil + _, err = tm.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, tm.token) + + // return time to lower bound, if the returned time will be after the lower bound + result = tm.durationToRenewal() + assert.InEpsilon(t, time.Until(tm.token.ExpirationOn().Add(-1*tm.lowerBoundDuration)), result, float64(time.Second)) + }) + + }) +} + +func TestEntraidTokenManager_Streaming(t *testing.T) { + t.Parallel() + t.Run("Start and Close", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + authResultVal := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + + idp.On("RequestToken").Return(idpResponse, nil).Once() + token1 := token.New( + "test", + "test", + "test", + expiresOn, + time.Now(), + int64(time.Until(expiresOn)), + ) + + mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() + listener.On("OnTokenNext", token1).Return().Once() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + <-time.After(toRenewal / 10) + assert.NotNil(t, tm.listener) + assert.NoError(t, tokenManager.Close()) + assert.Nil(t, tm.listener) + assert.Panics(t, func() { + close(tm.closedChan) + }) + + <-time.After(toRenewal) + assert.Error(t, tokenManager.Close()) + mock.AssertExpectationsForObjects(t, idp, mParser, listener) + }) + + t.Run("Start and Listen with 0 renewal duration", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + done := make(chan struct{}) + var twice int32 + var start, stop time.Time + idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + if atomic.LoadInt32(&twice) == 1 { + stop = time.Now() + close(done) + return + } else { + atomic.StoreInt32(&twice, 1) + start = time.Now() + } + }).Return(idpResponse, nil) + + listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.Equal(t, time.Duration(0), toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-done + assert.NoError(t, cancel()) + + assert.InDelta(t, stop.Sub(start), time.Duration(tm.retryOptions.InitialDelayMs)*time.Millisecond, float64(200*time.Millisecond)) + + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnTokenNext", 2) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with 0 renewal duration and closing the token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + RetryOptions: RetryOptions{ + InitialDelayMs: 5000, // 5 seconds + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.Equal(t, time.Duration(0), toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-time.After(time.Duration(tm.retryOptions.InitialDelayMs/2) * time.Millisecond) + assert.NoError(t, cancel()) + assert.Nil(t, tm.listener) + assert.Panics(t, func() { + close(tm.closedChan) + }) + + // called only once since the token manager was closed prior to initial delay passing + idp.AssertNumberOfCalls(t, "RequestToken", 1) + listener.AssertNumberOfCalls(t, "OnTokenNext", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-time.After(toRenewal + time.Second) + + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with retriable error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0) + assert.NotNil(t, err) + }).Return().Maybe() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + noErrCall.Unset() + returnErr := newMockError(true) + idp.On("RequestToken").Return(nil, returnErr) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + <-time.After(toRenewal + 100*time.Millisecond) + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnTokenNext", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with NOT retriable error", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{}, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0).(error) + assert.NotNil(t, err) + }).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + noErrCall.Unset() + returnErr := newMockError(false) + idp.On("RequestToken").Return(nil, returnErr) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + <-time.After(toRenewal + 100*time.Millisecond) + + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnTokenNext", 1) + listener.AssertNumberOfCalls(t, "OnTokenError", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + + t.Run("Start and Listen with retriable error - max retries and max delay", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + maxAttempts := 3 + maxDelayMs := 500 + initialDelayMs := 100 + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + RetryOptions: RetryOptions{ + MaxAttempts: maxAttempts, + MaxDelayMs: maxDelayMs, + InitialDelayMs: initialDelayMs, + BackoffMultiplier: 10, + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + res.IDToken.Oid = "test" + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + res.IDToken.Oid = "test" + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + var start, end time.Time + var elapsed time.Duration + + _ = listener. + On("OnTokenNext", mock.AnythingOfType("*token.Token")). + Run(func(_ mock.Arguments) { + start = time.Now() + }).Return() + maxAttemptsReached := make(chan struct{}) + listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0).(error) + end = time.Now() + elapsed = end.Sub(start) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "max attempts reached") + close(maxAttemptsReached) + }).Return() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + + noErrCall.Unset() + returnErr := newMockError(true) + + idp.On("RequestToken").Return(nil, returnErr) + + select { + case <-time.After(toRenewal + time.Duration(maxAttempts*maxDelayMs)*time.Millisecond): + assert.Fail(t, "Timeout - max retries not reached") + case <-maxAttemptsReached: + } + + // initialRenewal window, maxAttempts - 1 * max delay + the initial one which was lower than max delay + allDelaysShouldBe := toRenewal + allDelaysShouldBe += time.Duration(initialDelayMs) * time.Millisecond + allDelaysShouldBe += time.Duration(maxAttempts-1) * time.Duration(maxDelayMs) * time.Millisecond + + assert.InEpsilon(t, elapsed, allDelaysShouldBe, float64(10*time.Millisecond)) + + idp.AssertNumberOfCalls(t, "RequestToken", tm.retryOptions.MaxAttempts+1) + listener.AssertNumberOfCalls(t, "OnTokenNext", 1) + listener.AssertNumberOfCalls(t, "OnTokenError", 1) + mock.AssertExpectationsForObjects(t, idp, listener) + }) + t.Run("Start and Listen and close during retries", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + RetryOptions: RetryOptions{ + MaxAttempts: 100, + }, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + assert.NoError(t, err) + + expiresIn := time.Second + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: res, + } + + noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + expiresOn := time.Now().Add(expiresIn).UTC() + res := testAuthResult(expiresOn) + idpResponse.AuthResultVal = res + }).Return(idpResponse, nil) + + listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + maxAttemptsReached := make(chan struct{}) + listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + err := args.Get(0).(error) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "max attempts reached") + close(maxAttemptsReached) + }).Return().Maybe() + + cancel, err := tokenManager.Start(listener) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + noErrCall.Unset() + returnErr := newMockError(true) + idp.On("RequestToken").Return(nil, returnErr) + + toRenewal := tm.durationToRenewal() + assert.NotEqual(t, time.Duration(0), toRenewal) + assert.NotEqual(t, expiresIn, toRenewal) + assert.True(t, expiresIn > toRenewal) + + <-time.After(toRenewal + 500*time.Millisecond) + assert.Nil(t, cancel()) + + select { + case <-maxAttemptsReached: + assert.Fail(t, "Max retries reached, token manager not closed") + case <-tm.closedChan: + } + + <-time.After(50 * time.Millisecond) + + // maxAttempts + the initial one + idp.AssertNumberOfCalls(t, "RequestToken", 2) + listener.AssertNumberOfCalls(t, "OnTokenError", 0) + mock.AssertExpectationsForObjects(t, idp, listener) + }) +} + +func testAuthResult(expiersOn time.Time) *public.AuthResult { + r := &public.AuthResult{ + ExpiresOn: expiersOn, + } + r.IDToken.Oid = "test" + return r +} + +func BenchmarkTokenManager_GetToken(b *testing.B) { + idp := &mockIdentityProvider{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + if err != nil { + b.Fatal(err) + } + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = tokenManager.GetToken(false) + } +} + +func BenchmarkTokenManager_Start(b *testing.B) { + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + if err != nil { + b.Fatal(err) + } + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = tokenManager.Start(listener) + } +} + +func BenchmarkTokenManager_Close(b *testing.B) { + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + if err != nil { + b.Fatal(err) + } + + rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken").Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnTokenNext", testTokenValid).Return() + + _, err = tokenManager.Start(listener) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = tokenManager.Close() + } +} + +func BenchmarkTokenManager_durationToRenewal(b *testing.B) { + idp := &mockIdentityProvider{} + tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ + LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + }) + if err != nil { + b.Fatal(err) + } + + tm, ok := tokenManager.(*entraidTokenManager) + if !ok { + b.Fatal("failed to cast to entraidTokenManager") + } + + expiresAfterlb := testAuthResult(time.Now().Add(tm.lowerBoundDuration + time.Hour).UTC()) + idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, expiresAfterlb) + if err != nil { + b.Fatal(err) + } + + idp.On("RequestToken").Return(idpResponse, nil) + _, err = tm.GetToken(false) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tm.durationToRenewal() + } +} From d5d0aa82c8eb1ec170cd26f92eb38717d3500ead Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 31 Mar 2025 12:00:00 +0000 Subject: [PATCH 08/44] Token Management - Implementation --- token/token.go | 83 ++++++++++++++++++++ token/token_test.go | 180 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 token/token.go create mode 100644 token/token_test.go diff --git a/token/token.go b/token/token.go new file mode 100644 index 0000000..0c67e4a --- /dev/null +++ b/token/token.go @@ -0,0 +1,83 @@ +package token + +import ( + "time" + + "github.com/redis/go-redis/v9/auth" +) + +// Ensure Token implements the auth.Credentials interface. +var _ auth.Credentials = (*Token)(nil) + +// New creates a new token with the specified username, password, raw token, expiration time, received at time, and time to live. +// NOTE: This won't do any validation on the token, expiresOn, receivedAt, or ttl. It will simply create a new token instance. +func New(username, password, rawToken string, expiresOn, receivedAt time.Time, ttl int64) *Token { + return &Token{ + username: username, + password: password, + expiresOn: expiresOn, + receivedAt: receivedAt, + ttl: ttl, + rawToken: rawToken, + } +} + +// Token represents parsed authentication token used to access the Redis server. +// It implements the auth.Credentials interface. +type Token struct { + // username is the username of the user. + username string + // password is the password of the user. + password string + // expiresOn is the expiration time of the token. + expiresOn time.Time + // ttl is the time to live of the token. + ttl int64 + // rawToken is the authentication token. + rawToken string + // receivedAt is the time when the token was received. + receivedAt time.Time +} + +// BasicAuth returns the username and password for basic authentication. +func (t *Token) BasicAuth() (string, string) { + return t.username, t.password +} + +// RawCredentials returns the raw credentials for authentication. +func (t *Token) RawCredentials() string { + return t.rawToken +} + +// ExpirationOn returns the expiration time of the token. +func (t *Token) ExpirationOn() time.Time { + return t.expiresOn +} + +// Copy creates a copy of the token. +func (t *Token) Copy() *Token { + return copyToken(t) +} + +// compareCredentials two tokens if they are the same credentials +func (t *Token) compareCredentials(token *Token) bool { + return t.username == token.username && t.password == token.password +} + +// compareRawCredentials two tokens if they are the same raw credentials +func (t *Token) compareRawCredentials(token *Token) bool { + return t.rawToken == token.rawToken +} + +// compareToken compares two tokens if they are the same token +func (t *Token) compareToken(token *Token) bool { + return t.compareCredentials(token) && t.compareRawCredentials(token) +} + +// copyToken creates a copy of the token. +func copyToken(token *Token) *Token { + if token == nil { + return nil + } + return New(token.username, token.password, token.rawToken, token.expiresOn, token.receivedAt, token.ttl) +} diff --git a/token/token_test.go b/token/token_test.go new file mode 100644 index 0000000..4324612 --- /dev/null +++ b/token/token_test.go @@ -0,0 +1,180 @@ +package token + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + assert.Equal(t, "username", token.username) + assert.Equal(t, "password", token.password) + assert.Equal(t, "rawToken", token.rawToken) + assert.Equal(t, int64(3600), token.ttl) +} + +func TestBasicAuth(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + username, password := token.BasicAuth() + assert.Equal(t, "username", username) + assert.Equal(t, "password", password) +} + +func TestRawCredentials(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + rawCredentials := token.RawCredentials() + assert.Equal(t, "rawToken", rawCredentials) +} + +func TestExpirationOn(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) + expirationOn := token.ExpirationOn() + assert.True(t, expirationOn.After(time.Now())) +} + +func TestTokenExpiration(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) + assert.True(t, token.ExpirationOn().After(time.Now())) + + token.expiresOn = time.Now().Add(-1 * time.Hour) + assert.False(t, token.ExpirationOn().After(time.Now())) +} + +func TestTokenReceivedAt(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now().Add(1*time.Hour), 3600) + assert.True(t, token.receivedAt.After(time.Now().Add(-1*time.Hour))) + assert.True(t, token.receivedAt.Before(time.Now().Add(1*time.Hour))) +} + +func TestTokenTTL(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + assert.Equal(t, int64(3600), token.ttl) + + token.ttl = 7200 + assert.Equal(t, int64(7200), token.ttl) +} + +func TestCopyToken(t *testing.T) { + t.Parallel() + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + copiedToken := copyToken(token) + + assert.Equal(t, token.username, copiedToken.username) + assert.Equal(t, token.password, copiedToken.password) + assert.Equal(t, token.rawToken, copiedToken.rawToken) + assert.Equal(t, token.ttl, copiedToken.ttl) + assert.Equal(t, token.expiresOn, copiedToken.expiresOn) + assert.Equal(t, token.receivedAt, copiedToken.receivedAt) + + // change the copied token + copiedToken.expiresOn = time.Now().Add(-1 * time.Hour) + assert.NotEqual(t, token.expiresOn, copiedToken.expiresOn) + + // copy nil + copiedToken = copyToken(nil) + assert.Nil(t, copiedToken) + // copy empty token + copiedToken = copyToken(&Token{}) + assert.NotNil(t, copiedToken) + anotherCopy := copiedToken.Copy() + anotherCopy.rawToken = "changed" + assert.NotEqual(t, copiedToken, anotherCopy) +} + +func TestTokenCompare(t *testing.T) { + t.Parallel() + // Create two tokens with the same credentials + token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + assert.True(t, token1.compareCredentials(token2)) + assert.True(t, token1.compareRawCredentials(token2)) + assert.True(t, token1.compareToken(token2)) + + // Create two tokens with different credentials and different raw credentials + token3 := New("username", "differentPassword", "differentRawToken", time.Now(), time.Now(), 3600) + assert.False(t, token1.compareCredentials(token3)) + assert.False(t, token1.compareRawCredentials(token3)) + assert.False(t, token1.compareToken(token3)) + + // Create token with same credentials but different rawCredentials + token4 := New("username", "password", "differentRawToken", time.Now(), time.Now(), 3600) + assert.False(t, token1.compareRawCredentials(token4)) + assert.False(t, token1.compareToken(token4)) + assert.True(t, token1.compareCredentials(token4)) +} + +func BenchmarkNew(b *testing.B) { + now := time.Now() + b.ResetTimer() + for i := 0; i < b.N; i++ { + New("username", "password", "rawToken", now, now, 3600) + } +} + +func BenchmarkBasicAuth(b *testing.B) { + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.BasicAuth() + } +} + +func BenchmarkRawCredentials(b *testing.B) { + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.RawCredentials() + } +} + +func BenchmarkExpirationOn(b *testing.B) { + token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.ExpirationOn() + } +} + +func BenchmarkCopyToken(b *testing.B) { + token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token.Copy() + } +} + +func BenchmarkCompareCredentials(b *testing.B) { + token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token1.compareCredentials(token2) + } +} + +func BenchmarkCompareRawCredentials(b *testing.B) { + token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token1.compareRawCredentials(token2) + } +} + +func BenchmarkCompareToken(b *testing.B) { + token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + b.ResetTimer() + for i := 0; i < b.N; i++ { + token1.compareToken(token2) + } +} From c846eb0f9100e4e6da226c63063c5cf5b08b6d23 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 1 Apr 2025 12:00:00 +0000 Subject: [PATCH 09/44] Token Listener Implementation --- token_listener.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 token_listener.go diff --git a/token_listener.go b/token_listener.go new file mode 100644 index 0000000..e76bba7 --- /dev/null +++ b/token_listener.go @@ -0,0 +1,24 @@ +package entraid + +import ( + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/token" +) + +type entraidTokenListener struct { + cp *entraidCredentialsProvider +} + +func tokenListenerFromCP(cp *entraidCredentialsProvider) manager.TokenListener { + return &entraidTokenListener{ + cp, + } +} + +func (l *entraidTokenListener) OnTokenNext(t *token.Token) { + l.cp.onTokenNext(t) +} + +func (l *entraidTokenListener) OnTokenError(err error) { + l.cp.onTokenError(err) +} From d7b60ddf46f3f7eb3ee5f78b831b3456ffc67912 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 2 Apr 2025 12:00:00 +0000 Subject: [PATCH 10/44] Credentials Provider - Core Interface --- credentials_provider.go | 129 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 credentials_provider.go diff --git a/credentials_provider.go b/credentials_provider.go new file mode 100644 index 0000000..5f0b56b --- /dev/null +++ b/credentials_provider.go @@ -0,0 +1,129 @@ +package entraid + +import ( + "fmt" + "sync" + + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" +) + +// Ensure entraidCredentialsProvider implements the auth.StreamingCredentialsProvider interface. +var _ auth.StreamingCredentialsProvider = (*entraidCredentialsProvider)(nil) + +// entraidCredentialsProvider is a struct that implements the StreamingCredentialsProvider interface. +type entraidCredentialsProvider struct { + options CredentialsProviderOptions + + tokenManager manager.TokenManager + cancelTokenManager manager.CancelFunc + + // listeners is a slice of listeners that are notified when the token manager receives a new token. + listeners []auth.CredentialsListener + + // rwLock is a mutex that is used to synchronize access to the listeners slice. + rwLock sync.RWMutex +} + +// onTokenNext is a method that is called when the token manager receives a new token. +func (e *entraidCredentialsProvider) onTokenNext(t *token.Token) { + e.rwLock.RLock() + defer e.rwLock.RUnlock() + // Notify all listeners with the new token. + for _, listener := range e.listeners { + listener.OnNext(t) + } +} + +// onTokenError is a method that is called when the token manager encounters an error. +// It notifies all listeners with the error. +func (e *entraidCredentialsProvider) onTokenError(err error) { + e.rwLock.RLock() + defer e.rwLock.RUnlock() + + // Notify all listeners with the error + for _, listener := range e.listeners { + listener.OnError(err) + } +} + +// Subscribe subscribes to the credentials provider and returns a channel that will receive updates. +// The first response is blocking, then data will notify the listener. +// The listener will be notified with the credentials when they are available. +// The listener will be notified with an error if there is an error obtaining the credentials. +// The caller can cancel the subscription by calling the cancel function which is the second return value. +func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.CancelProviderFunc, error) { + e.rwLock.Lock() + // Check if the listener is already in the list of listeners. + alreadySubscribed := false + for _, l := range e.listeners { + if l == listener { + alreadySubscribed = true + break + } + } + + if !alreadySubscribed { + // add new listener + e.listeners = append(e.listeners, listener) + } + e.rwLock.Unlock() + + token, err := e.tokenManager.GetToken(false) + if err != nil { + go listener.OnError(err) + return nil, nil, err + } + + // Notify the listener with the credentials. + go listener.OnNext(token) + + cancel := func() error { + // Remove the listener from the list of listeners. + e.rwLock.Lock() + defer e.rwLock.Unlock() + for i, l := range e.listeners { + if l == listener { + e.listeners = append(e.listeners[:i], e.listeners[i+1:]...) + break + } + } + if len(e.listeners) == 0 { + if e.cancelTokenManager != nil { + defer func() { + e.cancelTokenManager = nil + e.listeners = nil + }() + return e.cancelTokenManager() + } + } + return nil + } + + return token, cancel, nil +} + +// NewCredentialsProvider creates a new credentials provider. +// It takes a TokenManager and CredentialProviderOptions as arguments and returns a StreamingCredentialsProvider interface. +// The TokenManager is used to obtain the token, and the CredentialProviderOptions contains options for the credentials provider. +// The credentials provider is responsible for managing the credentials and refreshing them when necessary. +// It returns an error if the token manager cannot be started. +// +// This function is typically used when you need to create a custom credentials provider with a specific token manager. +// For most use cases, it's recommended to use the type-specific constructors: +// - NewManagedIdentityCredentialsProvider for managed identity authentication +// - NewConfidentialCredentialsProvider for client secret or certificate authentication +// - NewDefaultAzureCredentialsProvider for default Azure identity authentication +func NewCredentialsProvider(tokenManager manager.TokenManager, options CredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { + cp := &entraidCredentialsProvider{ + tokenManager: tokenManager, + options: options, + } + cancelTokenManager, err := cp.tokenManager.Start(tokenListenerFromCP(cp)) + if err != nil { + return nil, fmt.Errorf("couldn't start token manager: %w", err) + } + cp.cancelTokenManager = cancelTokenManager + return cp, nil +} From 02f9fe113da73aa315138855061044c7fef6f934 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 3 Apr 2025 12:00:00 +0000 Subject: [PATCH 11/44] Configuration - Core Options --- .testcoverage.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.testcoverage.yml b/.testcoverage.yml index fa9a770..2b3b4fa 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -15,11 +15,11 @@ threshold: # (optional; default 0) # Minimum coverage percentage required for each package. - package: 80 + package: 70 # (optional; default 0) # Minimum overall project coverage percentage required. - total: 95 + total: 80 # Holds regexp rules which will override thresholds for matched files or packages # using their paths. From e6ee88d4e6198455983f020e7d8cc99cd9a3767a Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 4 Apr 2025 12:00:00 +0000 Subject: [PATCH 12/44] Version Management - Core --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 4770c2f..fbcb668 100644 --- a/version.go +++ b/version.go @@ -1,4 +1,4 @@ -package redis +package entraid const version = "0.0.1" From 8dd60ed03cbb295b766c4b4654014cbb489cebe6 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 5 Apr 2025 12:00:00 +0000 Subject: [PATCH 13/44] CI/CD - GitHub Actions --- .github/workflows/bench.yml | 49 +++++++++++++++++++++++++++++++++++++ .github/workflows/build.yml | 2 +- 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/bench.yml diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..692dc78 --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,49 @@ +name: Benchmark Performance +on: + pull_request: + branches: + - master + - main + push: + branches: + - master + - main +permissions: + # deployments permission to deploy GitHub pages website + deployments: write + # contents permission to update benchmark contents in gh-pages branch + contents: write + +jobs: + benchmark: + name: Performance regression check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: "stable" + - name: Install dependencies + run: go mod tidy + - name: Run benchmark + run: go test ./... -bench=. -benchmem -count 2 -timeout 1m | tee benchmarks.txt + # Download previous benchmark result from cache (if exists) + - name: Download previous benchmark data + uses: actions/cache@v4 + with: + path: ./cache + key: ${{ runner.os }}-benchmark + # Run `github-action-benchmark` action + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + name: Go Benchmark + tool: 'go' + output-file-path: benchmarks.txt + github-token: ${{ secrets.GITHUB_TOKEN }} + auto-push: true + # Show alert with commit comment on detecting possible performance regression + alert-threshold: '200%' + comment-on-alert: true + fail-on-alert: true + alert-comment-cc-users: '@ndyakov' diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 499a5bf..149981e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: - name: Install dependencies run: go mod tidy - name: Run tests with coverage - run: go test ./... -coverprofile=./cover.out -covermode=atomic -coverpkg=./... + run: go test ./... -coverprofile=./cover.out -covermode=atomic -race -count 2 -timeout 1m - name: Upload coverage uses: actions/upload-artifact@v4 with: From 49dd82c4b0ab00483ad1ca8f35b31ba266854ee6 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sun, 6 Apr 2025 12:00:00 +0000 Subject: [PATCH 14/44] CI/CD - Git Hooks --- .githooks/pre-commit | 4 ++++ install-git-hook.sh | 2 ++ 2 files changed, 6 insertions(+) create mode 100755 .githooks/pre-commit create mode 100755 install-git-hook.sh diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..495b5de --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +goimports -l -w . # includes go fmt +golangci-lint run # includes golint, go vet diff --git a/install-git-hook.sh b/install-git-hook.sh new file mode 100755 index 0000000..5598436 --- /dev/null +++ b/install-git-hook.sh @@ -0,0 +1,2 @@ +chmod ug+x ./.githooks/* +git config core.hooksPath ./.githooks From 72bd3d79f16008d15e60d9ab35c3e977fa34c254 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 7 Apr 2025 12:00:00 +0000 Subject: [PATCH 15/44] Documentation - README --- README.md | 691 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 691 insertions(+) diff --git a/README.md b/README.md index e5b9399..e9ad564 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,693 @@ # go-redis-entraid Entra ID extension for go-redis + +## Table of Contents +- [Introduction](#introduction) +- [Quick Start](#quick-start) +- [Architecture Overview](#architecture-overview) +- [Authentication Providers](#authentication-providers) +- [Configuration Guide](#configuration-guide) +- [Examples](#examples) +- [Testing](#testing) +- [FAQ](#faq) + +## Introduction + +go-redis-entraid is a Go library that provides Entra ID (formerly Azure AD) authentication support for Redis Enterprise Cloud. It enables secure authentication using various Entra ID identity types and manages token lifecycle automatically. + +### Version Compatibility +- Go: 1.16+ +- Redis: 6.0+ +- Azure Entra ID: Latest + +### Key Features +- Support for multiple Entra ID identity types +- Automatic token refresh and management +- Configurable token refresh policies +- Retry mechanisms with exponential backoff +- Thread-safe token management +- Streaming credentials provider interface + +## Quick Start + +### Minimal Example +Here's the simplest way to get started: + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + "github.com/redis-developer/go-redis-entraid/entraid" + "github.com/redis/go-redis/v9" +) + +func main() { + // Get required environment variables + clientID := os.Getenv("AZURE_CLIENT_ID") + redisEndpoint := os.Getenv("REDIS_ENDPOINT") + if clientID == "" || redisEndpoint == "" { + log.Fatal("AZURE_CLIENT_ID and REDIS_ENDPOINT environment variables are required") + } + + // Create credentials provider + provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: clientID, + }, + }) + if err != nil { + log.Fatalf("Failed to create credentials provider: %v", err) + } + + // Create Redis client + client := redis.NewClient(&redis.Options{ + Addr: redisEndpoint, + StreamingCredentialsProvider: provider, + }) + defer client.Close() + + // Test connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis: %v", err) + } + log.Println("Connected to Redis!") +} +``` + +### Environment Setup +```bash +# Required environment variables +export AZURE_CLIENT_ID="your-client-id" +export REDIS_ENDPOINT="your-redis-endpoint:6380" + +# Optional environment variables +export AZURE_TENANT_ID="your-tenant-id" +export AZURE_CLIENT_SECRET="your-client-secret" +export AZURE_AUTHORITY_HOST="https://login.microsoftonline.com" # For custom authority +``` + +### Running the Example +```bash +go mod init your-app +go get github.com/redis-developer/go-redis-entraid +go run main.go +``` + +## Architecture Overview + +### Component Diagram +```mermaid +graph TD + A[Redis Client] --> B[StreamingCredentialsProvider] + B --> C[Token Manager] + C --> D[Identity Provider] + D --> E[Azure Entra ID] + + subgraph "Token Management" + C --> F[Token Cache] + C --> G[Token Refresh] + C --> H[Error Handling] + end + + subgraph "Identity Providers" + D --> I[Managed Identity] + D --> J[Confidential Client] + D --> K[Default Azure Identity] + D --> L[Custom Provider] + end +``` + +### Token Lifecycle +```mermaid +sequenceDiagram + participant Client + participant TokenManager + participant IdentityProvider + participant Azure + + Client->>TokenManager: GetToken() + alt Token Valid + TokenManager->>Client: Return cached token + else Token Expired + TokenManager->>IdentityProvider: RequestToken() + IdentityProvider->>Azure: Authenticate + Azure->>IdentityProvider: Return token + IdentityProvider->>TokenManager: Cache token + TokenManager->>Client: Return new token + end +``` + +### Component Responsibilities + +1. **Redis Client** + - Handles Redis connections + - Manages connection pooling + - Implements Redis protocol + +2. **StreamingCredentialsProvider** + - Provides authentication credentials + - Handles token refresh + - Manages authentication state + +3. **Token Manager** + - Caches tokens + - Handles token refresh + - Implements retry logic + - Manages token lifecycle + +4. **Identity Provider** + - Authenticates with Azure + - Handles different auth types + - Manages credentials + +## Authentication Providers + +### Provider Selection Guide + +```mermaid +graph TD + A[Choose Authentication] --> B{Managed Identity?} + B -->|Yes| C{System Assigned?} + B -->|No| D{Client Credentials?} + C -->|Yes| E[SystemAssignedIdentity] + C -->|No| F[UserAssignedIdentity] + D -->|Yes| G{Client Secret?} + D -->|No| H[DefaultAzureIdentity] + G -->|Yes| I[ClientSecret] + G -->|No| J[ClientCertificate] +``` + +### Provider Comparison + +| Provider Type | Best For | Security | Configuration | Performance | +|--------------|----------|----------|---------------|-------------| +| System Assigned | Azure-hosted apps | Highest | Minimal | Best | +| User Assigned | Shared identity | High | Moderate | Good | +| Client Secret | Service auth | High | Moderate | Good | +| Client Cert | High security | Highest | Complex | Good | +| Default Azure | Development | Moderate | Minimal | Good | + +## Configuration Guide + +### Environment Variables +```bash +# Required +AZURE_CLIENT_ID=your-client-id +REDIS_ENDPOINT=your-redis-endpoint:6380 + +# Optional +AZURE_TENANT_ID=your-tenant-id +AZURE_CLIENT_SECRET=your-client-secret +``` + +### Available Configuration Options + +#### 1. CredentialsProviderOptions +Base options for all credential providers: +```go +type CredentialsProviderOptions struct { + // Required: Client ID for authentication + ClientID string + + // Optional: Token manager configuration + TokenManagerOptions manager.TokenManagerOptions +} +``` + +#### 2. TokenManagerOptions +Options for token management: +```go +type TokenManagerOptions struct { + // Optional: Ratio of token lifetime to trigger refresh (0-1) + // Default: 0.7 (refresh at 70% of token lifetime) + ExpirationRefreshRatio float64 + + // Optional: Minimum time before expiration to refresh (ms) + // Default: 10000 (10 seconds) + LowerRefreshBoundMs int64 + + // Optional: Configuration for retry behavior + RetryOptions RetryOptions + + // Optional: Custom response parser + IdentityProviderResponseParser IdentityProviderResponseParser +} +``` + +#### 3. RetryOptions +Options for retry behavior: +```go +type RetryOptions struct { + // Optional: Maximum number of retry attempts + // Default: 3 + MaxAttempts int + + // Optional: Initial delay between retries (ms) + // Default: 1000 (1 second) + InitialDelayMs int64 + + // Optional: Maximum delay between retries (ms) + // Default: 30000 (30 seconds) + MaxDelayMs int64 + + // Optional: Multiplier for exponential backoff + // Default: 2.0 + BackoffMultiplier float64 + + // Optional: Custom retry predicate + IsRetryable func(error) bool +} +``` + +#### 4. ManagedIdentityProviderOptions +Options for managed identity authentication: +```go +type ManagedIdentityProviderOptions struct { + // Required: Type of managed identity + ManagedIdentityType ManagedIdentityType // SystemAssignedIdentity or UserAssignedIdentity + + // Optional: Client ID for user-assigned identity + UserAssignedClientID string + + // Optional: Scopes for token access + // Default: ["https://redis.azure.com/.default"] + Scopes []string +} +``` + +#### 5. ConfidentialIdentityProviderOptions +Options for confidential client authentication: +```go +type ConfidentialIdentityProviderOptions struct { + // Required: Client ID for authentication + ClientID string + + // Required: Type of credentials + CredentialsType string // identity.ClientSecretCredentialType or identity.ClientCertificateCredentialType + + // Required for ClientSecret: Client secret value + ClientSecret string + + // Required for ClientCertificate: Client certificate + // Type: []*x509.Certificate + ClientCert []*x509.Certificate + + // Required for ClientCertificate: Client private key + // Type: crypto.PrivateKey + ClientPrivateKey crypto.PrivateKey + + // Required: Authority configuration + Authority AuthorityConfiguration + + // Optional: Scopes for token access + // Default: ["https://redis.azure.com/.default"] + Scopes []string +} +``` + +#### 6. AuthorityConfiguration +Options for authority configuration: +```go +type AuthorityConfiguration struct { + // Required: Type of authority + AuthorityType AuthorityType // "default", "multi-tenant", or "custom" + + // Required: Azure AD tenant ID + // Use "common" for multi-tenant applications + TenantID string + + // Optional: Custom authority URL + // Required for custom authority type + Authority string +} +``` + +#### 7. DefaultAzureIdentityProviderOptions +Options for default Azure identity: +```go +type DefaultAzureIdentityProviderOptions struct { + // Optional: Azure identity provider options + AzureOptions *azidentity.DefaultAzureCredentialOptions + + // Optional: Scopes for token access + // Default: ["https://redis.azure.com/.default"] + Scopes []string +} +``` + +### Configuration Examples + +#### Basic Configuration +```go +options := entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + LowerRefreshBoundMs: 10000, + }, +} +``` + +#### Advanced Configuration +```go +options := entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + LowerRefreshBoundMs: 10000, + RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelayMs: 1000, + MaxDelayMs: 30000, + BackoffMultiplier: 2.0, + IsRetryable: func(err error) bool { + return strings.Contains(err.Error(), "network error") || + strings.Contains(err.Error(), "timeout") + }, + }, + }, +} +``` + +#### Authority Configuration +```go +// Multi-tenant application +authority := identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeMultiTenant, + TenantID: "common", +} + +// Single-tenant application +authority := identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeDefault, + TenantID: os.Getenv("AZURE_TENANT_ID"), +} + +// Custom authority +authority := identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeCustom, + TenantID: os.Getenv("AZURE_TENANT_ID"), + Authority: fmt.Sprintf("%s/%s/v2.0", + os.Getenv("AZURE_AUTHORITY_HOST"), + os.Getenv("AZURE_TENANT_ID")), +} +``` + +## Examples + +### System Assigned Identity +```go +// Create provider for system assigned identity +provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + ManagedIdentityType: identity.SystemAssignedIdentity, +}) +``` + +### User Assigned Identity +```go +// Create provider for user assigned identity +provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + ManagedIdentityType: identity.UserAssignedIdentity, + UserAssignedClientID: os.Getenv("USER_ASSIGNED_CLIENT_ID"), +}) +``` + +### Client Secret Authentication +```go +// Create provider for client secret authentication +provider, err := entraid.NewConfidentialCredentialsProvider(entraid.ConfidentialIdentityProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: os.Getenv("AZURE_CLIENT_SECRET"), + Authority: identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeDefault, + TenantID: os.Getenv("AZURE_TENANT_ID"), + }, +}) +``` + +### Client Certificate Authentication +```go +// Create provider for client certificate authentication +cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") +if err != nil { + log.Fatal(err) +} + +provider, err := entraid.NewConfidentialCredentialsProvider(entraid.ConfidentialIdentityProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + CredentialsType: identity.ClientCertificateCredentialType, + ClientCert: []*x509.Certificate{cert.Leaf}, + ClientPrivateKey: cert.PrivateKey, + Authority: identity.AuthorityConfiguration{ + AuthorityType: identity.AuthorityTypeDefault, + TenantID: os.Getenv("AZURE_TENANT_ID"), + }, +}) +``` + +### Advanced Usage with Custom Identity Provider + +This example shows how to implement your own IdentityProvider while leveraging our TokenManager and StreamingCredentialsProvider. This is useful when you need to authenticate with a custom token source but want to benefit from our token management and streaming capabilities. + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "time" + + "github.com/redis-developer/go-redis-entraid/entraid" + "github.com/redis-developer/go-redis-entraid/entraid/identity" + "github.com/redis-developer/go-redis-entraid/entraid/manager" + "github.com/redis-developer/go-redis-entraid/entraid/shared" + "github.com/redis/go-redis/v9" +) + +// CustomIdentityProvider implements the IdentityProvider interface +type CustomIdentityProvider struct { + // Add any fields needed for your custom authentication + tokenEndpoint string + clientID string + clientSecret string +} + +// RequestToken implements the IdentityProvider interface +func (p *CustomIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { + // Implement your custom token retrieval logic here + // This could be calling your own auth service, using a different auth protocol, etc. + + // For this example, we'll simulate getting a JWT token + token := "your.jwt.token" + + // Create a response using NewIDPResponse with RawToken type + return shared.NewIDPResponse(shared.ResponseTypeRawToken, token) +} + +func main() { + // Create your custom identity provider + customProvider := &CustomIdentityProvider{ + tokenEndpoint: "https://your-auth-endpoint.com/token", + clientID: os.Getenv("CUSTOM_CLIENT_ID"), + clientSecret: os.Getenv("CUSTOM_CLIENT_SECRET"), + } + + // Create token manager with your custom provider + tokenManager, err := manager.NewTokenManager(customProvider, manager.TokenManagerOptions{ + // Configure token refresh behavior + ExpirationRefreshRatio: 0.7, + LowerRefreshBoundMs: 10000, + }) + if err != nil { + log.Fatalf("Failed to create token manager: %v", err) + } + + // Create credentials provider using our StreamingCredentialsProvider + provider, err := entraid.NewCredentialsProvider(tokenManager, entraid.CredentialsProviderOptions{ + // Add any additional options needed + OnReAuthenticationError: func(err error) error { + log.Printf("Re-authentication error: %v", err) + return err + }, + }) + if err != nil { + log.Fatalf("Failed to create credentials provider: %v", err) + } + + // Create Redis client with your custom provider + client := redis.NewClient(&redis.Options{ + Addr: os.Getenv("REDIS_ENDPOINT"), + StreamingCredentialsProvider: provider, + }) + defer client.Close() + + // Test the connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis: %v", err) + } + log.Println("Connected to Redis with custom identity provider!") +} +``` + +Key points about this implementation: + +1. **Custom Identity Provider**: + - Implements the `IdentityProvider` interface with `RequestToken` method + - Returns a response using `shared.NewIDPResponse` with `ResponseTypeRawToken` + - Handles your custom authentication logic + +2. **Token Management**: + - Uses our `TokenManager` for automatic token refresh + - Benefits from our retry mechanisms + - Handles token caching and lifecycle + +3. **Streaming Credentials**: + - Uses our `StreamingCredentialsProvider` for Redis integration + - Handles connection authentication + - Manages token streaming to Redis + +4. **Error Handling**: + - Implements proper error handling + - Uses our error callback mechanisms + - Provides logging and monitoring hooks + +This approach gives you the flexibility of custom authentication while benefiting from our robust token management and Redis integration features. + +## Testing + +### Unit Testing +```go +func TestManagedIdentityProvider(t *testing.T) { + // Create test provider + provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: "test-client-id", + }, + }) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Test token retrieval + token, err := provider.GetToken(context.Background()) + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + if token == "" { + t.Error("Expected non-empty token") + } +} +``` + +### Integration Testing +```go +func TestRedisConnection(t *testing.T) { + // Create provider + provider, err := entraid.NewManagedIdentityCredentialsProvider(entraid.ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: entraid.CredentialsProviderOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + }, + }) + if err != nil { + t.Fatalf("Failed to create provider: %v", err) + } + + // Create Redis client + client := redis.NewClient(&redis.Options{ + Addr: os.Getenv("REDIS_ENDPOINT"), + StreamingCredentialsProvider: provider, + }) + defer client.Close() + + // Test connection + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + t.Fatalf("Failed to connect to Redis: %v", err) + } +} +``` + +## FAQ + +### Q: How do I handle token expiration? +A: The library handles token expiration automatically. Tokens are refreshed when they reach 70% of their lifetime (configurable via `ExpirationRefreshRatio`). You can customize this behavior using `TokenManagerOptions`. + +### Q: What's the difference between managed identity types? +A: +- System Assigned: Automatically created and managed by Azure +- User Assigned: Created and managed by you, can be shared across resources +- Default Azure: Uses environment-based authentication, good for development + +### Q: How do I handle connection failures? +A: The library includes built-in retry mechanisms in the TokenManager. You can configure retry behavior using `RetryOptions`: +```go +RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelayMs: 1000, + MaxDelayMs: 30000, + BackoffMultiplier: 2.0, +} +``` + +### Q: Does this work with Redis Cluster? +A: Yes, the library works with both standalone Redis and Redis Cluster. Use the appropriate Redis client constructor: +```go +// For standalone Redis +client := redis.NewClient(&redis.Options{ + Addr: "your-endpoint:6380", + StreamingCredentialsProvider: provider, +}) + +// For Redis Cluster +client := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{"your-endpoint:6380"}, + StreamingCredentialsProvider: provider, +}) +``` + +### Q: How do I implement custom authentication? +A: You can create a custom identity provider by implementing the `IdentityProvider` interface: +```go +// IdentityProviderResponse is an interface that defines the methods for an identity provider authentication result. +// It is used to get the type of the authentication result, the authentication result itself (can be AuthResult or AccessToken), +type IdentityProviderResponse interface { + // Type returns the type of the auth result + Type() string + AuthResult() public.AuthResult + AccessToken() azcore.AccessToken + RawToken() string +} + +// IdentityProvider is an interface that defines the methods for an identity provider. +// It is used to request a token for authentication. +// The identity provider is responsible for providing the raw authentication token. +type IdentityProvider interface { + // RequestToken requests a token from the identity provider. + // It returns the token, the expiration time, and an error if any. + RequestToken() (IdentityProviderResponse, error) +} +``` + +### Q: What happens if token refresh fails? +A: The library will retry according to the configured `RetryOptions`. If all retries fail, the error will be propagated to the client. \ No newline at end of file From 52a8be5873b2f54d9c86356d00908a271d2159df Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 8 Apr 2025 12:00:00 +0000 Subject: [PATCH 16/44] Documentation - Contributing Guide --- CONTRIBUTING.md | 103 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9947ff1 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,103 @@ +# Contributing to go-redis-entraid + +We welcome contributions from the community! If you'd like to contribute to this project, please follow these guidelines: + +## Getting Started + +1. Fork the repository +2. Create a new branch for your feature or bugfix +3. Make your changes +4. Run the tests and ensure they pass +5. Submit a pull request + +## Development Setup + +```bash +# Clone your fork +git clone https://github.com/your-username/go-redis-entraid.git +cd go-redis-entraid + +# Install dependencies +go mod download + +# Run tests +go test ./... +``` + +## Code Style and Standards + +- Follow the Go standard formatting (`go fmt`) +- Write clear and concise commit messages +- Include tests for new features +- Update documentation as needed +- Follow the existing code style and patterns + +## Testing + +We maintain high test coverage for the project. When contributing: + +- Add tests for new features +- Ensure existing tests pass +- Run the test coverage tool: + ```bash + go test -coverprofile=cover.out ./... + go tool cover -html=cover.out + ``` + +## Pull Request Process + +1. Ensure your code passes all tests +2. Update the README.md if necessary +3. Submit your pull request with a clear description of the changes + +## Reporting Issues + +If you find a bug or have a feature request: + +1. Check the existing issues to avoid duplicates +2. Create a new issue with: + - A clear title and description + - Steps to reproduce (for bugs) + - Expected and actual behavior + - Environment details (Go version, OS, etc.) + +## Development Workflow + +1. Create a new branch for your feature/fix: + ```bash + git checkout -b feature/your-feature-name + ``` + +2. Make your changes and commit them: + ```bash + git add . + git commit -m "Description of your changes" + ``` + +3. Push your changes to your fork: + ```bash + git push origin feature/your-feature-name + ``` + +4. Create a pull request from your fork to the main repository + +## Review Process + +- All pull requests will be reviewed by maintainers +- Be prepared to make changes based on feedback +- Ensure your code meets the project's standards +- Address any CI/CD failures + +## Documentation + +- Update relevant documentation when making changes +- Include examples for new features +- Update the README if necessary +- Add comments to complex code sections + +## Questions? + +If you have any questions about contributing, please: +1. Check the existing documentation +2. Look through existing issues +3. Create a new issue if your question hasn't been answered \ No newline at end of file From 02d53c90a2020c133d218d572358f56a64687ca3 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 9 Apr 2025 12:00:00 +0000 Subject: [PATCH 17/44] Version Management - Tests --- version_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version_test.go b/version_test.go index e95de1c..8351bea 100644 --- a/version_test.go +++ b/version_test.go @@ -1,4 +1,4 @@ -package redis +package entraid import ( "testing" From 3bf1fa90454a819b1db0da40de66d37f859adf6d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 10 Apr 2025 12:00:00 +0000 Subject: [PATCH 18/44] Identity Provider Tests --- providers_test.go | 703 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 703 insertions(+) create mode 100644 providers_test.go diff --git a/providers_test.go b/providers_test.go new file mode 100644 index 0000000..cb2bcd5 --- /dev/null +++ b/providers_test.go @@ -0,0 +1,703 @@ +package entraid + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/identity" + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTokenManager implements the TokenManager interface for testing +type mockTokenManager struct { + token *token.Token + err error +} + +func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + return m.token, m.err +} + +func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CancelFunc, error) { + if m.err != nil { + listener.OnTokenError(m.err) + return nil, m.err + } + + listener.OnTokenNext(m.token) + return func() error { return nil }, nil +} + +func (m *mockTokenManager) Close() error { + return nil +} + +// mockCredentialsListener implements the CredentialsListener interface for testing +type mockCredentialsListener struct { + LastTokenCh chan string + LastErrCh chan error +} + +func (m *mockCredentialsListener) OnNext(credentials auth.Credentials) { + if m.LastTokenCh == nil { + m.LastTokenCh = make(chan string) + } + m.LastTokenCh <- credentials.RawCredentials() +} + +func (m *mockCredentialsListener) OnError(err error) { + if m.LastErrCh == nil { + m.LastErrCh = make(chan error) + } + m.LastErrCh <- err +} + +// testTokenManagerFactory is a factory function that returns a mock token manager +func testTokenManagerFactory(token *token.Token, err error) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + return &mockTokenManager{ + token: token, + err: err, + }, nil + } +} + +func TestNewManagedIdentityCredentialsProvider(t *testing.T) { + tests := []struct { + name string + options ManagedIdentityCredentialsProviderOptions + expectedError error + }{ + { + name: "valid managed identity options", + options: ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + }, + expectedError: nil, + }, + { + name: "system assigned identity", + options: ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: identity.SystemAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + }, + expectedError: nil, + }, + { + name: "invalid managed identity type", + options: ManagedIdentityCredentialsProviderOptions{ + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", + }, + }, + expectedError: errors.New("invalid managed identity type"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + tt.options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + provider, err := NewManagedIdentityCredentialsProvider(tt.options) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Test the provider with a mock listener + listener := &mockCredentialsListener{LastTokenCh: make(chan string)} + _, _, err := provider.Subscribe(listener) + assert.NoError(t, err) + assert.Equal(t, "mock-token", <-listener.LastTokenCh) + } + }) + } +} + +func TestNewConfidentialCredentialsProvider(t *testing.T) { + tests := []struct { + name string + options ConfidentialCredentialsProviderOptions + expectedError error + }{ + { + name: "valid confidential options with client secret", + options: ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + }, + expectedError: nil, + }, + { + name: "missing required fields", + options: ConfidentialCredentialsProviderOptions{ + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + CredentialsType: identity.ClientSecretCredentialType, + }, + }, + expectedError: errors.New("client ID is required"), + }, + { + name: "invalid credentials type", + options: ConfidentialCredentialsProviderOptions{ + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: "invalid-type", + }, + }, + expectedError: errors.New("invalid credentials type"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + tt.options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(tt.options) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Test the provider with a mock listener + listener := &mockCredentialsListener{LastTokenCh: make(chan string)} + _, _, err := provider.Subscribe(listener) + assert.NoError(t, err) + assert.Equal(t, "mock-token", <-listener.LastTokenCh) + } + }) + } +} + +func TestNewDefaultAzureCredentialsProvider(t *testing.T) { + tests := []struct { + name string + options DefaultAzureCredentialsProviderOptions + expectedError error + }{ + { + name: "valid default azure options", + options: DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + }, + expectedError: nil, + }, + { + name: "empty options", + options: DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + }, + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + tt.options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + provider, err := NewDefaultAzureCredentialsProvider(tt.options) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + + // Test the provider with a mock listener + listener := &mockCredentialsListener{LastTokenCh: make(chan string)} + _, _, err := provider.Subscribe(listener) + assert.NoError(t, err) + assert.Equal(t, "mock-token", <-listener.LastTokenCh) + } + }) + } +} + +func TestCredentialsProviderErrorHandling(t *testing.T) { + t.Run("on re-authentication error", func(t *testing.T) { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + + // Test that the error handler is properly set + // Note: This is a simplified test as actual authentication would require Azure credentials + assert.NotNil(t, provider) + }) + + t.Run("on retryable error", func(t *testing.T) { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + + // Test that the error handler is properly set + // Note: This is a simplified test as actual authentication would require Azure credentials + assert.NotNil(t, provider) + }) +} + +func TestCredentialsProviderInterface(t *testing.T) { + // Test that all providers implement the StreamingCredentialsProvider interface + tests := []struct { + name string + provider auth.StreamingCredentialsProvider + }{ + { + name: "managed identity provider", + provider: func() auth.StreamingCredentialsProvider { + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + p, _ := NewManagedIdentityCredentialsProvider(options) + return p + }(), + }, + { + name: "confidential provider", + provider: func() auth.StreamingCredentialsProvider { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + p, _ := NewConfidentialCredentialsProvider(options) + return p + }(), + }, + { + name: "default azure provider", + provider: func() auth.StreamingCredentialsProvider { + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Set the token manager factory in the options + options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + p, _ := NewDefaultAzureCredentialsProvider(options) + return p + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that the provider implements the interface by calling its methods + // Note: These are simplified tests as actual authentication would require Azure credentials + listener := &mockCredentialsListener{} + credentials, cancel, err := tt.provider.Subscribe(listener) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + assert.NoError(t, err) + }) + } +} + +func TestCredentialsProviderSubscribe(t *testing.T) { + // Create a test token + testToken := token.New( + "test", + "test", + "mock-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + + // Create a test provider + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Set the token manager factory in the options + options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + + t.Run("concurrent subscribe and cancel", func(t *testing.T) { + const numListeners = 10 + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + cancels := make([]auth.CancelProviderFunc, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + _, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + cancels[idx] = cancel + }(i) + } + wg.Wait() + + // Verify all listeners received the token + for i, listener := range listeners { + select { + case token := <-listener.LastTokenCh: + assert.Equal(t, "mock-token", token, "listener %d received wrong token", i) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received error: %v", i, err) + } + } + + // Cancel all subscriptions concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := cancels[idx]() + require.NoError(t, err) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case token := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, token) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + default: + // No message received, which is expected + } + } + }) +} + +func TestCredentialsProviderOptions(t *testing.T) { + t.Run("default token manager factory", func(t *testing.T) { + options := CredentialsProviderOptions{} + factory := options.getTokenManagerFactory() + assert.NotNil(t, factory) + }) + + t.Run("custom token manager factory", func(t *testing.T) { + m := &mockTokenManager{} + customFactory := func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return m, nil + } + options := CredentialsProviderOptions{ + tokenManagerFactory: customFactory, + } + tm, err := options.getTokenManagerFactory()(nil, manager.TokenManagerOptions{}) + assert.NotNil(t, tm) + assert.NoError(t, err) + assert.Equal(t, m, tm) + }) +} + +func TestCredentialsProviderErrorScenarios(t *testing.T) { + t.Run("token manager start error", func(t *testing.T) { + // Create a test provider with invalid options + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: "invalid-type", // Invalid credentials type + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("token manager get token error", func(t *testing.T) { + // Create a test provider with invalid options + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "", // Empty client secret + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("concurrent error handling", func(t *testing.T) { + // Create a test provider with invalid options + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", // Invalid managed identity type + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("concurrent token updates", func(t *testing.T) { + // Create a test provider with invalid options + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{}, // Empty scopes + }, + } + + provider, err := NewDefaultAzureCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) +} From e0d0879376beac54a3dcb5b73e9d59f84723088d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 11 Apr 2025 12:00:00 +0000 Subject: [PATCH 19/44] Token Listener Tests --- token_listener_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 token_listener_test.go diff --git a/token_listener_test.go b/token_listener_test.go new file mode 100644 index 0000000..2185061 --- /dev/null +++ b/token_listener_test.go @@ -0,0 +1,46 @@ +package entraid + +import ( + "errors" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/token" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenListenerFromCP(t *testing.T) { + cp := &entraidCredentialsProvider{} + listener := tokenListenerFromCP(cp) + + require.NotNil(t, listener) + _, ok := listener.(*entraidTokenListener) + assert.True(t, ok, "listener should be of type entraidTokenListener") +} + +func TestOnTokenNext(t *testing.T) { + cp := &entraidCredentialsProvider{} + listener := tokenListenerFromCP(cp) + + now := time.Now() + testToken := token.New("test-user", "test-pass", "test-token", now.Add(time.Hour), now, 3600) + + listener.OnTokenNext(testToken) + + // Since we can't directly access the internal state of entraidCredentialsProvider, + // we'll verify that the listener was created and the call didn't panic + assert.NotNil(t, listener) +} + +func TestOnTokenError(t *testing.T) { + cp := &entraidCredentialsProvider{} + listener := tokenListenerFromCP(cp) + + testError := errors.New("test error") + listener.OnTokenError(testError) + + // Since we can't directly access the internal state of entraidCredentialsProvider, + // we'll verify that the listener was created and the call didn't panic + assert.NotNil(t, listener) +} From 305fcaf8d4cc5b99797db3c561ca53a0573a8a3b Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 12 Apr 2025 09:54:48 +0300 Subject: [PATCH 20/44] Fix flaky test --- .gitignore | 3 +++ manager/token_manager_test.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index 8455110..ce0bc89 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,8 @@ *.tar.gz *.dic coverage.txt +cover.out **/coverage.txt +**/cover.out .vscode +tmp/ diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 433f808..4f476b0 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -857,7 +857,10 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.Equal(t, time.Duration(0), toRenewal) assert.True(t, expiresIn > toRenewal) + // wait for request token to be called <-done + // wait a bit for listener to be notified + <-time.After(10 * time.Millisecond) assert.NoError(t, cancel()) assert.InDelta(t, stop.Sub(start), time.Duration(tm.retryOptions.InitialDelayMs)*time.Millisecond, float64(200*time.Millisecond)) From bf65058e32952af528b03aba191fcfc0735fdf9f Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 12 Apr 2025 15:34:11 +0300 Subject: [PATCH 21/44] improve credentials provider --- credentials_provider.go | 58 +++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/credentials_provider.go b/credentials_provider.go index 5f0b56b..98709de 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -1,3 +1,6 @@ +// Package entraid provides a credentials provider that manages token retrieval and notifies listeners +// of token updates. It implements the auth.StreamingCredentialsProvider interface and is designed +// for use with the Redis authentication system. package entraid import ( @@ -14,19 +17,20 @@ var _ auth.StreamingCredentialsProvider = (*entraidCredentialsProvider)(nil) // entraidCredentialsProvider is a struct that implements the StreamingCredentialsProvider interface. type entraidCredentialsProvider struct { - options CredentialsProviderOptions + options CredentialsProviderOptions // Configuration options for the provider. - tokenManager manager.TokenManager - cancelTokenManager manager.CancelFunc + tokenManager manager.TokenManager // Manages token retrieval. + cancelTokenManager manager.CancelFunc // Function to cancel the token manager. // listeners is a slice of listeners that are notified when the token manager receives a new token. - listeners []auth.CredentialsListener + listeners []auth.CredentialsListener // Slice of listeners notified on token updates. // rwLock is a mutex that is used to synchronize access to the listeners slice. - rwLock sync.RWMutex + rwLock sync.RWMutex // Mutex for synchronizing access to the listeners slice. } // onTokenNext is a method that is called when the token manager receives a new token. +// It notifies all registered listeners with the new token. func (e *entraidCredentialsProvider) onTokenNext(t *token.Token) { e.rwLock.RLock() defer e.rwLock.RUnlock() @@ -37,7 +41,7 @@ func (e *entraidCredentialsProvider) onTokenNext(t *token.Token) { } // onTokenError is a method that is called when the token manager encounters an error. -// It notifies all listeners with the error. +// It notifies all registered listeners with the error. func (e *entraidCredentialsProvider) onTokenError(err error) { e.rwLock.RLock() defer e.rwLock.RUnlock() @@ -48,11 +52,18 @@ func (e *entraidCredentialsProvider) onTokenError(err error) { } } -// Subscribe subscribes to the credentials provider and returns a channel that will receive updates. -// The first response is blocking, then data will notify the listener. -// The listener will be notified with the credentials when they are available. -// The listener will be notified with an error if there is an error obtaining the credentials. -// The caller can cancel the subscription by calling the cancel function which is the second return value. +// Subscribe subscribes a listener to the credentials provider. +// It returns the current credentials, a cancel function to unsubscribe, and an error if the subscription fails. +// +// Parameters: +// - listener: The listener that will receive updates about token changes. +// +// Returns: +// - auth.Credentials: The current credentials for the listener. +// - auth.CancelProviderFunc: A function that can be called to unsubscribe the listener. +// - error: An error if the subscription fails, such as if the token cannot be retrieved. +// +// Note: If the listener is already subscribed, it will not receive duplicate notifications. func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.CancelProviderFunc, error) { e.rwLock.Lock() // Check if the listener is already in the list of listeners. @@ -83,17 +94,20 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener // Remove the listener from the list of listeners. e.rwLock.Lock() defer e.rwLock.Unlock() + for i, l := range e.listeners { if l == listener { e.listeners = append(e.listeners[:i], e.listeners[i+1:]...) break } } + + // Clear the listeners slice if it's empty if len(e.listeners) == 0 { + e.listeners = make([]auth.CredentialsListener, 0) if e.cancelTokenManager != nil { defer func() { e.cancelTokenManager = nil - e.listeners = nil }() return e.cancelTokenManager() } @@ -104,21 +118,21 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener return token, cancel, nil } -// NewCredentialsProvider creates a new credentials provider. -// It takes a TokenManager and CredentialProviderOptions as arguments and returns a StreamingCredentialsProvider interface. -// The TokenManager is used to obtain the token, and the CredentialProviderOptions contains options for the credentials provider. -// The credentials provider is responsible for managing the credentials and refreshing them when necessary. -// It returns an error if the token manager cannot be started. +// NewCredentialsProvider creates a new credentials provider with the specified token manager and options. +// It returns a StreamingCredentialsProvider interface and an error if the token manager cannot be started. +// +// Parameters: +// - tokenManager: The TokenManager used to obtain tokens. +// - options: Options for configuring the credentials provider. // -// This function is typically used when you need to create a custom credentials provider with a specific token manager. -// For most use cases, it's recommended to use the type-specific constructors: -// - NewManagedIdentityCredentialsProvider for managed identity authentication -// - NewConfidentialCredentialsProvider for client secret or certificate authentication -// - NewDefaultAzureCredentialsProvider for default Azure identity authentication +// Returns: +// - auth.StreamingCredentialsProvider: The newly created credentials provider. +// - error: An error if the token manager cannot be started. func NewCredentialsProvider(tokenManager manager.TokenManager, options CredentialsProviderOptions) (auth.StreamingCredentialsProvider, error) { cp := &entraidCredentialsProvider{ tokenManager: tokenManager, options: options, + listeners: make([]auth.CredentialsListener, 0), } cancelTokenManager, err := cp.tokenManager.Start(tokenListenerFromCP(cp)) if err != nil { From 43a27796924d22ea7acd8c7aa0bb88e96b6f1c4b Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 14 Apr 2025 17:57:28 +0300 Subject: [PATCH 22/44] add more tests --- credentials_provider.go | 4 +- credentials_provider_test.go | 3 + providers_test.go | 316 +++++++++++++++++++++++++++++++---- token/token.go | 10 ++ 4 files changed, 302 insertions(+), 31 deletions(-) create mode 100644 credentials_provider_test.go diff --git a/credentials_provider.go b/credentials_provider.go index 98709de..471eb2f 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -83,12 +83,12 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener token, err := e.tokenManager.GetToken(false) if err != nil { - go listener.OnError(err) + //go listener.OnError(err) return nil, nil, err } // Notify the listener with the credentials. - go listener.OnNext(token) + //go listener.OnNext(token) cancel := func() error { // Remove the listener from the list of listeners. diff --git a/credentials_provider_test.go b/credentials_provider_test.go new file mode 100644 index 0000000..58eabe7 --- /dev/null +++ b/credentials_provider_test.go @@ -0,0 +1,3 @@ +package entraid + +// This file is intentionally empty as all tests have been moved to providers_test.go diff --git a/providers_test.go b/providers_test.go index cb2bcd5..d392dfd 100644 --- a/providers_test.go +++ b/providers_test.go @@ -19,20 +19,51 @@ import ( type mockTokenManager struct { token *token.Token err error + lock sync.Mutex } +const rawTokenString = "mock-token" +const tokenExpiration = 100 * time.Millisecond + func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + if forceRefresh { + m.token = token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(100*time.Millisecond), + ) + } return m.token, m.err } func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CancelFunc, error) { - if m.err != nil { - listener.OnTokenError(m.err) - return nil, m.err - } + done := make(chan struct{}) + go func() { + for { + select { + case <-time.After(tokenExpiration): + m.lock.Lock() + if m.err != nil { + listener.OnTokenError(m.err) + return + } + listener.OnTokenNext(m.token) + m.lock.Unlock() + case <-done: + // Exit the loop if done channel is closed + return + + } + } + }() - listener.OnTokenNext(m.token) - return func() error { return nil }, nil + return func() error { + close(done) + return nil + }, nil } func (m *mockTokenManager) Close() error { @@ -45,6 +76,17 @@ type mockCredentialsListener struct { LastErrCh chan error } +func (m *mockCredentialsListener) readWithTimeout(timeout time.Duration) (string, error) { + select { + case tk := <-m.LastTokenCh: + return tk, nil + case err := <-m.LastErrCh: + return "", err + case <-time.After(timeout): + return "", errors.New("timeout waiting for token") + } +} + func (m *mockCredentialsListener) OnNext(credentials auth.Credentials) { if m.LastTokenCh == nil { m.LastTokenCh = make(chan string) @@ -60,10 +102,10 @@ func (m *mockCredentialsListener) OnError(err error) { } // testTokenManagerFactory is a factory function that returns a mock token manager -func testTokenManagerFactory(token *token.Token, err error) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { +func testTokenManagerFactory(tk *token.Token, err error) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { return &mockTokenManager{ - token: token, + token: tk, err: err, }, nil } @@ -124,7 +166,7 @@ func TestNewManagedIdentityCredentialsProvider(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -143,9 +185,15 @@ func TestNewManagedIdentityCredentialsProvider(t *testing.T) { // Test the provider with a mock listener listener := &mockCredentialsListener{LastTokenCh: make(chan string)} - _, _, err := provider.Subscribe(listener) + tk, cancel, err := provider.Subscribe(listener) + defer func() { + err := cancel() + if err != nil { + panic(err) + } + }() + assert.Equal(t, rawTokenString, tk.RawCredentials()) assert.NoError(t, err) - assert.Equal(t, "mock-token", <-listener.LastTokenCh) } }) } @@ -203,7 +251,7 @@ func TestNewConfidentialCredentialsProvider(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -222,9 +270,15 @@ func TestNewConfidentialCredentialsProvider(t *testing.T) { // Test the provider with a mock listener listener := &mockCredentialsListener{LastTokenCh: make(chan string)} - _, _, err := provider.Subscribe(listener) + credentials, cancel, err := provider.Subscribe(listener) + defer func() { + err := cancel() + if err != nil { + panic(err) + } + }() + assert.Equal(t, rawTokenString, credentials.RawCredentials()) assert.NoError(t, err) - assert.Equal(t, "mock-token", <-listener.LastTokenCh) } }) } @@ -270,7 +324,7 @@ func TestNewDefaultAzureCredentialsProvider(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -289,9 +343,15 @@ func TestNewDefaultAzureCredentialsProvider(t *testing.T) { // Test the provider with a mock listener listener := &mockCredentialsListener{LastTokenCh: make(chan string)} - _, _, err := provider.Subscribe(listener) + tk, cancel, err := provider.Subscribe(listener) + defer func() { + err := cancel() + if err != nil { + panic(err) + } + }() + assert.Equal(t, rawTokenString, tk.RawCredentials()) assert.NoError(t, err) - assert.Equal(t, "mock-token", <-listener.LastTokenCh) } }) } @@ -319,7 +379,7 @@ func TestCredentialsProviderErrorHandling(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -358,7 +418,7 @@ func TestCredentialsProviderErrorHandling(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -404,7 +464,7 @@ func TestCredentialsProviderInterface(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -440,7 +500,7 @@ func TestCredentialsProviderInterface(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -472,7 +532,7 @@ func TestCredentialsProviderInterface(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", + rawTokenString, time.Now().Add(time.Hour), time.Now(), int64(time.Hour), @@ -505,10 +565,10 @@ func TestCredentialsProviderSubscribe(t *testing.T) { testToken := token.New( "test", "test", - "mock-token", - time.Now().Add(time.Hour), + rawTokenString, + time.Now().Add(tokenExpiration), time.Now(), - int64(time.Hour), + int64(tokenExpiration), ) // Create a test provider @@ -561,10 +621,12 @@ func TestCredentialsProviderSubscribe(t *testing.T) { // Verify all listeners received the token for i, listener := range listeners { select { - case token := <-listener.LastTokenCh: - assert.Equal(t, "mock-token", token, "listener %d received wrong token", i) + case tk := <-listener.LastTokenCh: + assert.Equal(t, rawTokenString, tk, "listener %d received wrong token", i) case err := <-listener.LastErrCh: t.Fatalf("listener %d received error: %v", i, err) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener %d timed out waiting for token", i) } } @@ -582,8 +644,8 @@ func TestCredentialsProviderSubscribe(t *testing.T) { // Verify no more tokens are sent after cancellation for i, listener := range listeners { select { - case token := <-listener.LastTokenCh: - t.Fatalf("listener %d received unexpected token after cancellation: %s", i, token) + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) case err := <-listener.LastErrCh: t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) default: @@ -701,3 +763,199 @@ func TestCredentialsProviderErrorScenarios(t *testing.T) { assert.Nil(t, provider) }) } + +func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { + t.Parallel() + + t.Run("Subscribe and Unsubscribe", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &mockTokenManager{ + token: token.New( + "test", + "test", + "test-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.NoError(t, err) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + + // Wait for initial token + tk, err := listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", tk) + + // Unsubscribe + err = cancel() + assert.NoError(t, err) + }) + + t.Run("Multiple Listeners", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &mockTokenManager{ + token: token.New( + "test", + "test", + "test-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create multiple mock listeners + listener1 := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + listener2 := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe first listener + credentials1, cancel1, err := cp.Subscribe(listener1) + assert.NoError(t, err) + assert.NotNil(t, credentials1) + assert.NotNil(t, cancel1) + + // Subscribe second listener + credentials2, cancel2, err := cp.Subscribe(listener2) + assert.NoError(t, err) + assert.NotNil(t, credentials2) + assert.NotNil(t, cancel2) + + // Wait for initial tokens + token1, err := listener1.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", token1) + + token2, err := listener2.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", token2) + + // Unsubscribe first listener + err = cancel1() + assert.NoError(t, err) + + // Unsubscribe second listener + err = cancel2() + assert.NoError(t, err) + }) + + t.Run("Token Updates", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &mockTokenManager{ + token: token.New( + "test", + "test", + "initial-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.NoError(t, err) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + + // Wait for initial token + tk, err := listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "initial-token", tk) + + tm.lock.Lock() + // Update token + tm.token = token.New( + "test", + "test", + "updated-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + tm.lock.Unlock() + + // Wait for token update + tk, err = listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "updated-token", tk) + + // Unsubscribe + err = cancel() + assert.NoError(t, err) + }) + + t.Run("Error Handling", func(t *testing.T) { + t.Parallel() + + // Create mock token manager with error + tm := &mockTokenManager{ + err: assert.AnError, + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.Error(t, err) + assert.Nil(t, credentials) + assert.Nil(t, cancel) + + // Wait for error + _, err = listener.readWithTimeout(time.Second) + assert.Error(t, err) + assert.Equal(t, assert.AnError, err) + }) +} diff --git a/token/token.go b/token/token.go index 0c67e4a..66bd100 100644 --- a/token/token.go +++ b/token/token.go @@ -46,9 +46,19 @@ func (t *Token) BasicAuth() (string, string) { // RawCredentials returns the raw credentials for authentication. func (t *Token) RawCredentials() string { + return t.RawToken() +} + +// RawToken returns the raw token. +func (t *Token) RawToken() string { return t.rawToken } +// ReceivedAt returns the time when the token was received. +func (t *Token) ReceivedAt() time.Time { + return t.receivedAt +} + // ExpirationOn returns the expiration time of the token. func (t *Token) ExpirationOn() time.Time { return t.expiresOn From e05c56c5ece43b603b13ac80f2453da6b330cb2b Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 14 Apr 2025 23:05:15 +0300 Subject: [PATCH 23/44] refactor tests --- .testcoverage.yml | 12 +- credentials_provider.go | 34 +- credentials_provider_test.go | 581 ++++++++++++++++++++++++++- entraid_test.go | 212 ++++++++++ manager/errors.go | 4 +- manager/token_manager.go | 10 +- manager/token_manager_test.go | 2 +- providers.go | 6 + providers_test.go | 725 +++++++++------------------------- token/token_test.go | 28 +- token_listener.go | 12 + 11 files changed, 1038 insertions(+), 588 deletions(-) create mode 100644 entraid_test.go diff --git a/.testcoverage.yml b/.testcoverage.yml index 2b3b4fa..91a1592 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -11,15 +11,15 @@ profile: cover.out threshold: # (optional; default 0) # Minimum coverage percentage required for individual files. - file: 70 + file: 85 # (optional; default 0) # Minimum coverage percentage required for each package. - package: 70 + package: 90 # (optional; default 0) # Minimum overall project coverage percentage required. - total: 80 + total: 85 # Holds regexp rules which will override thresholds for matched files or packages # using their paths. @@ -28,9 +28,9 @@ threshold: # new threshold to it. If project has multiple rules that match same path, # override rules should be listed in order from specific to more general rules. override: - # Increase coverage threshold to 100% for `foo` package - # (default is 80, as configured above in this example). - - path: ^pkg/lib/foo$ + - path: ^internal$ + threshold: 95 + - path: ^token$ threshold: 100 # Holds regexp rules which will exclude matched files or packages diff --git a/credentials_provider.go b/credentials_provider.go index 471eb2f..50ed6b7 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -19,8 +19,8 @@ var _ auth.StreamingCredentialsProvider = (*entraidCredentialsProvider)(nil) type entraidCredentialsProvider struct { options CredentialsProviderOptions // Configuration options for the provider. - tokenManager manager.TokenManager // Manages token retrieval. - cancelTokenManager manager.CancelFunc // Function to cancel the token manager. + tokenManager manager.TokenManager // Manages token retrieval. + closeTokenManager manager.CloseFunc // Function to cancel the token manager. // listeners is a slice of listeners that are notified when the token manager receives a new token. listeners []auth.CredentialsListener // Slice of listeners notified on token updates. @@ -65,6 +65,12 @@ func (e *entraidCredentialsProvider) onTokenError(err error) { // // Note: If the listener is already subscribed, it will not receive duplicate notifications. func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.CancelProviderFunc, error) { + // First try to get a token, only then subscribe the listener. + token, err := e.tokenManager.GetToken(false) + if err != nil { + return nil, nil, err + } + e.rwLock.Lock() // Check if the listener is already in the list of listeners. alreadySubscribed := false @@ -81,15 +87,6 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener } e.rwLock.Unlock() - token, err := e.tokenManager.GetToken(false) - if err != nil { - //go listener.OnError(err) - return nil, nil, err - } - - // Notify the listener with the credentials. - //go listener.OnNext(token) - cancel := func() error { // Remove the listener from the list of listeners. e.rwLock.Lock() @@ -105,11 +102,14 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener // Clear the listeners slice if it's empty if len(e.listeners) == 0 { e.listeners = make([]auth.CredentialsListener, 0) - if e.cancelTokenManager != nil { - defer func() { - e.cancelTokenManager = nil - }() - return e.cancelTokenManager() + if e.closeTokenManager != nil { + err := e.closeTokenManager() + if err != nil { + return fmt.Errorf("couldn't cancel token manager: %w", err) + } + // Set the cancelTokenManager to nil to indicate that it has been canceled. + // This prevents multiple calls to cancelTokenManager. + e.closeTokenManager = nil } } return nil @@ -138,6 +138,6 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia if err != nil { return nil, fmt.Errorf("couldn't start token manager: %w", err) } - cp.cancelTokenManager = cancelTokenManager + cp.closeTokenManager = cancelTokenManager return cp, nil } diff --git a/credentials_provider_test.go b/credentials_provider_test.go index 58eabe7..7f172b4 100644 --- a/credentials_provider_test.go +++ b/credentials_provider_test.go @@ -1,3 +1,582 @@ package entraid -// This file is intentionally empty as all tests have been moved to providers_test.go +import ( + "sync" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/identity" + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestCredentialsProviderErrorScenarios(t *testing.T) { + t.Run("token manager start error", func(t *testing.T) { + // Create a test provider with invalid options + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: "invalid-type", // Invalid credentials type + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("token manager get token error", func(t *testing.T) { + // Create a test provider with invalid options + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "", // Empty client secret + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("concurrent error handling", func(t *testing.T) { + // Create a test provider with invalid options + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + ManagedIdentityType: "invalid-type", // Invalid managed identity type + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Nil(t, provider) + }) + + t.Run("concurrent token updates", func(t *testing.T) { + // Create a test provider with invalid options + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{}, // Empty scopes + }, + } + + provider, err := NewDefaultAzureCredentialsProvider(options) + // bad options - empty scopes + assert.Error(t, err) + assert.Nil(t, provider) + }) +} + +func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { + t.Parallel() + + t.Run("Subscribe and Unsubscribe", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &fakeTokenManager{ + token: token.New( + "test", + "test", + "test-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.NoError(t, err) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + + // Wait for initial token + tk, err := listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", tk) + + // Unsubscribe + err = cancel() + assert.NoError(t, err) + }) + + t.Run("Multiple Listeners", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &fakeTokenManager{ + token: token.New( + "test", + "test", + "test-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create multiple mock listeners + listener1 := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + listener2 := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe first listener + credentials1, cancel1, err := cp.Subscribe(listener1) + assert.NoError(t, err) + assert.NotNil(t, credentials1) + assert.NotNil(t, cancel1) + + // Subscribe second listener + credentials2, cancel2, err := cp.Subscribe(listener2) + assert.NoError(t, err) + assert.NotNil(t, credentials2) + assert.NotNil(t, cancel2) + + // Wait for initial tokens + token1, err := listener1.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", token1) + + token2, err := listener2.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "test-token", token2) + + // Unsubscribe first listener + err = cancel1() + assert.NoError(t, err) + + // Unsubscribe second listener + err = cancel2() + assert.NoError(t, err) + }) + + t.Run("Token Updates", func(t *testing.T) { + t.Parallel() + + // Create mock token manager + tm := &fakeTokenManager{ + token: token.New( + "test", + "test", + "initial-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ), + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.NoError(t, err) + assert.NotNil(t, cp) + + // Create mock listener + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string), + LastErrCh: make(chan error), + } + + // Subscribe listener + credentials, cancel, err := cp.Subscribe(listener) + assert.NoError(t, err) + assert.NotNil(t, credentials) + assert.NotNil(t, cancel) + + // Wait for initial token + tk, err := listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "initial-token", tk) + + tm.lock.Lock() + // Update token + tm.token = token.New( + "test", + "test", + "updated-token", + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) + tm.lock.Unlock() + + // Wait for token update + tk, err = listener.readWithTimeout(time.Second) + assert.NoError(t, err) + assert.Equal(t, "updated-token", tk) + + // Unsubscribe + err = cancel() + assert.NoError(t, err) + }) + + t.Run("Error Handling", func(t *testing.T) { + t.Parallel() + + // Create mock token manager with error + tm := &fakeTokenManager{ + err: assert.AnError, + } + + // Create credentials provider + cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) + assert.Error(t, err) + assert.Nil(t, cp) + }) +} + +func TestCredentialsProviderOptions(t *testing.T) { + t.Run("default token manager factory", func(t *testing.T) { + options := CredentialsProviderOptions{} + factory := options.getTokenManagerFactory() + assert.NotNil(t, factory) + }) + + t.Run("custom token manager factory", func(t *testing.T) { + m := &fakeTokenManager{} + customFactory := func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return m, nil + } + options := CredentialsProviderOptions{ + tokenManagerFactory: customFactory, + } + tm, err := options.getTokenManagerFactory()(nil, manager.TokenManagerOptions{}) + assert.NotNil(t, tm) + assert.NoError(t, err) + assert.Equal(t, m, tm) + }) +} + +func TestCredentialsProviderSubscribe(t *testing.T) { + // Create a test provider + opts := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + t.Run("double subscribe and cancel resubscribe", func(t *testing.T) { + t.Parallel() + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(tokenExpiration), + ) + + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + mtm := &mockTokenManager{done: make(chan struct{})} + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = mockTokenManagerFactory(mtm) + mtm.On("GetToken", false).Return(testToken, nil) + mtm.On("Start", mock.Anything). + Run(mockTokenManagerLoop(mtm, tokenExpiration, testToken, nil)). + Return(manager.CloseFunc(mtm.Close), nil) + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + // Subscribe the listener + tk, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + require.NotNil(t, tk) + require.NotNil(t, cancel) + // try to subscribe the same listener again + tk2, cancel2, err := provider.Subscribe(listener) + require.NoError(t, err) + require.NotNil(t, tk2) + require.NotNil(t, cancel2) + // Verify the listener received the token once + select { + case tk := <-listener.LastTokenCh: + assert.Equal(t, rawTokenString, tk, "listener received wrong token") + case err := <-listener.LastErrCh: + t.Fatalf("listener received error: %v", err) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener timed out waiting for token") + } + // verify it is not received again + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener received unexpected token: %v", tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener received unexpected error: %v", err) + case <-time.After(tokenExpiration / 2): + // No message received, which is expected + } + + }) + + t.Run("concurrent subscribe and cancel with error ", func(t *testing.T) { + t.Parallel() + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(tokenExpiration), + ) + mtm := &mockTokenManager{done: make(chan struct{})} + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = mockTokenManagerFactory(mtm) + mtm.On("GetToken", false).Return(testToken, nil) + + mtm.On("Start", mock.Anything). + Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). + Return(manager.CloseFunc(mtm.Close), nil) + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + cancels := make([]auth.CancelProviderFunc, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + _, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + cancels[idx] = cancel + }(i) + } + wg.Wait() + + // Verify all listeners received the token + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received token: %v", i, tk) + case err := <-listener.LastErrCh: + assert.Equal(t, errTokenError.Error(), err.Error(), "listener %d received wrong error", i) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener %d timed out waiting for token", i) + } + } + + // Cancel all subscriptions concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := cancels[idx]() + require.NoError(t, err) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + case <-time.After(3 * tokenExpiration): + // No message received, which is expected + } + } + }) + + t.Run("concurrent subscribe and get token error ", func(t *testing.T) { + t.Parallel() + mtm := &mockTokenManager{done: make(chan struct{})} + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = mockTokenManagerFactory(mtm) + mtm.On("GetToken", false).Return(nil, assert.AnError) + + mtm.On("Start", mock.Anything). + Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). + Return(manager.CloseFunc(mtm.Close), nil) + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + tk, cancel, err := provider.Subscribe(listener) + require.Nil(t, tk) + require.Error(t, err) + require.Nil(t, cancel) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + case <-time.After(3 * tokenExpiration): + // No message received, which is expected + } + } + }) + + t.Run("concurrent subscribe and cancel", func(t *testing.T) { + t.Parallel() + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(tokenExpiration), + ) + // Set the token manager factory in the options + options := opts + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) + + provider, err := NewConfidentialCredentialsProvider(options) + require.NoError(t, err) + require.NotNil(t, provider) + var wg sync.WaitGroup + listeners := make([]*mockCredentialsListener, numListeners) + cancels := make([]auth.CancelProviderFunc, numListeners) + + // Subscribe multiple listeners concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + listener := &mockCredentialsListener{ + LastTokenCh: make(chan string, 1), + LastErrCh: make(chan error, 1), + } + listeners[idx] = listener + _, cancel, err := provider.Subscribe(listener) + require.NoError(t, err) + cancels[idx] = cancel + }(i) + } + wg.Wait() + + // Verify all listeners received the token + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + assert.Equal(t, rawTokenString, tk, "listener %d received wrong token", i) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received error: %v", i, err) + case <-time.After(3 * tokenExpiration): + t.Fatalf("listener %d timed out waiting for token", i) + } + } + + // Cancel all subscriptions concurrently + for i := 0; i < numListeners; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := cancels[idx]() + require.NoError(t, err) + }(i) + } + wg.Wait() + + // Verify no more tokens are sent after cancellation + for i, listener := range listeners { + select { + case tk := <-listener.LastTokenCh: + t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) + case err := <-listener.LastErrCh: + t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) + case <-time.After(3 * tokenExpiration): + // No message received, which is expected + } + } + }) +} diff --git a/entraid_test.go b/entraid_test.go new file mode 100644 index 0000000..86b994f --- /dev/null +++ b/entraid_test.go @@ -0,0 +1,212 @@ +package entraid + +import ( + "errors" + "flag" + "sync" + "testing" + "time" + + "github.com/redis-developer/go-redis-entraid/manager" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/mock" +) + +// fakeTokenManager implements the TokenManager interface for testing +type fakeTokenManager struct { + token *token.Token + err error + lock sync.Mutex +} + +const rawTokenString = "mock-token" + +// numListeners is set to 3 for short tests and 12 for long tests +var numListeners = 12 + +// tokenExpiration is set to 100ms for long tests and 10ms for short tests +var tokenExpiration = 100 * time.Millisecond + +func init() { + testing.Init() + flag.Parse() + tokenExpiration = 100 * time.Millisecond + numListeners = 12 + if testing.Short() { + tokenExpiration = 10 * time.Millisecond + numListeners = 3 + } +} + +func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + if forceRefresh { + m.token = token.New( + "test", + "test", + rawTokenString, + time.Now().Add(tokenExpiration), + time.Now(), + int64(100*time.Millisecond), + ) + } + return m.token, m.err +} + +func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseFunc, error) { + if m.err != nil { + return nil, m.err + } + done := make(chan struct{}) + go func() { + for { + select { + case <-time.After(tokenExpiration): + m.lock.Lock() + if m.err != nil { + listener.OnTokenError(m.err) + return + } + listener.OnTokenNext(m.token) + m.lock.Unlock() + case <-done: + // Exit the loop if done channel is closed + return + + } + } + }() + + return func() error { + close(done) + return nil + }, nil +} + +func (m *fakeTokenManager) Close() error { + return nil +} + +// mockCredentialsListener implements the CredentialsListener interface for testing +type mockCredentialsListener struct { + LastTokenCh chan string + LastErrCh chan error +} + +func (m *mockCredentialsListener) readWithTimeout(timeout time.Duration) (string, error) { + select { + case tk := <-m.LastTokenCh: + return tk, nil + case err := <-m.LastErrCh: + return "", err + case <-time.After(timeout): + return "", errors.New("timeout waiting for token") + } +} + +func (m *mockCredentialsListener) OnNext(credentials auth.Credentials) { + if m.LastTokenCh == nil { + m.LastTokenCh = make(chan string) + } + m.LastTokenCh <- credentials.RawCredentials() +} + +func (m *mockCredentialsListener) OnError(err error) { + if m.LastErrCh == nil { + m.LastErrCh = make(chan error) + } + m.LastErrCh <- err +} + +// testFakeTokenManagerFactory is a factory function that returns a mock token manager +func testFakeTokenManagerFactory(tk *token.Token, err error) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + return &fakeTokenManager{ + token: tk, + err: err, + }, nil + } +} + +// mockTokenManager is a mock implementation of the TokenManager interface +type mockTokenManager struct { + mock.Mock + idp shared.IdentityProvider + done chan struct{} + options manager.TokenManagerOptions + listener manager.TokenListener + lock sync.Mutex +} + +func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + args := m.Called(forceRefresh) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*token.Token), args.Error(1) +} + +func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CloseFunc, error) { + args := m.Called(listener) + m.lock.Lock() + if m.done == nil { + m.done = make(chan struct{}) + } + if m.listener != nil { + defer m.lock.Unlock() + return nil, manager.ErrTokenManagerAlreadyStarted + } + if m.listener == nil { + m.listener = listener + } + m.lock.Unlock() + return args.Get(0).(manager.CloseFunc), args.Error(1) +} +func (m *mockTokenManager) Close() error { + m.lock.Lock() + defer m.lock.Unlock() + if m.listener == nil { + return manager.ErrTokenManagerAlreadyClosed + } + if m.listener != nil { + m.listener = nil + } + if m.done != nil { + close(m.done) + m.done = nil + } + return nil +} + +// mockTokenManagerFactory is a factory function that returns a mock token manager +func mockTokenManagerFactory(mtm *mockTokenManager) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { + mtm.idp = provider + mtm.options = options + return mtm, nil + } +} + +var errTokenError = errors.New("token error") + +func mockTokenManagerLoop(mtm *mockTokenManager, tokenExpiration time.Duration, testToken *token.Token, err error) func(args mock.Arguments) { + return func(args mock.Arguments) { + go func() { + for { + select { + case <-mtm.done: + return + case <-time.After(tokenExpiration): + mtm.lock.Lock() + if err != nil { + mtm.listener.OnTokenError(err) + } else { + mtm.listener.OnTokenNext(testToken) + } + mtm.lock.Unlock() + } + } + }() + } +} diff --git a/manager/errors.go b/manager/errors.go index 840d46d..40b707b 100644 --- a/manager/errors.go +++ b/manager/errors.go @@ -2,8 +2,8 @@ package manager import "fmt" -// ErrTokenManagerAlreadyCanceled is returned when the token manager is already canceled. -var ErrTokenManagerAlreadyCanceled = fmt.Errorf("token manager already canceled") +// ErrTokenManagerAlreadyClosed is returned when the token manager is already closed. +var ErrTokenManagerAlreadyClosed = fmt.Errorf("token manager already closed") // ErrTokenManagerAlreadyStarted is returned when the token manager is already started. var ErrTokenManagerAlreadyStarted = fmt.Errorf("token manager already started") diff --git a/manager/token_manager.go b/manager/token_manager.go index 6d4d8b0..0f232fe 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -76,13 +76,13 @@ type TokenManager interface { // It takes a boolean value forceRefresh as an argument. GetToken(forceRefresh bool) (*token.Token, error) // Start starts the token manager and returns a channel that will receive updates. - Start(listener TokenListener) (CancelFunc, error) + Start(listener TokenListener) (CloseFunc, error) // Close closes the token manager and releases any resources. Close() error } -// CancelFunc is a function that cancels the token manager. -type CancelFunc func() error +// CloseFunc is a function that closes the token manager. +type CloseFunc func() error // TokenListener is an interface that contains the methods for receiving updates from the token manager. // The token manager will call the listener's OnTokenNext method with the updated token. @@ -221,7 +221,7 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) // // Note: The initial token is delivered synchronously. // The TokenListener will receive the token immediately, before the token manager goroutine starts. -func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error) { +func (e *entraidTokenManager) Start(listener TokenListener) (CloseFunc, error) { e.lock.Lock() defer e.lock.Unlock() if e.listener != nil { @@ -316,7 +316,7 @@ func (e *entraidTokenManager) Close() error { defer e.lock.Unlock() if e.closedChan == nil || e.listener == nil { - return ErrTokenManagerAlreadyCanceled + return ErrTokenManagerAlreadyClosed } e.listener = nil close(e.closedChan) diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 4f476b0..276a958 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -388,7 +388,7 @@ func TestTokenManager_Start(t *testing.T) { _, err = tokenManager.Start(l) } if err != nil { - if err != ErrTokenManagerAlreadyCanceled && err != ErrTokenManagerAlreadyStarted { + if err != ErrTokenManagerAlreadyClosed && err != ErrTokenManagerAlreadyStarted { // this is un unexpected error, fail the test assert.Error(t, err) } diff --git a/providers.go b/providers.go index f579079..2d053d7 100644 --- a/providers.go +++ b/providers.go @@ -119,7 +119,13 @@ func NewConfidentialCredentialsProvider(options ConfidentialCredentialsProviderO // DefaultAzureCredentialsProviderOptions is a struct that holds the options for the default azure credentials provider. // It is used to configure the credentials provider when requesting a token. type DefaultAzureCredentialsProviderOptions struct { + // CredentialsProviderOptions is the options for the credentials provider. + // This is used to configure the credentials provider when requesting a token. + // It includes the clientId and TokenManagerOptions. CredentialsProviderOptions + // DefaultAzureIdentityProviderOptions is the options for the default azure identity provider. + // This is used to configure the identity provider when requesting a token. + // It is used to specify the client ID, tenant ID, and scopes for the identity. identity.DefaultAzureIdentityProviderOptions } diff --git a/providers_test.go b/providers_test.go index d392dfd..88d88b3 100644 --- a/providers_test.go +++ b/providers_test.go @@ -2,7 +2,7 @@ package entraid import ( "errors" - "sync" + "fmt" "testing" "time" @@ -12,105 +12,8 @@ import ( "github.com/redis-developer/go-redis-entraid/token" "github.com/redis/go-redis/v9/auth" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -// mockTokenManager implements the TokenManager interface for testing -type mockTokenManager struct { - token *token.Token - err error - lock sync.Mutex -} - -const rawTokenString = "mock-token" -const tokenExpiration = 100 * time.Millisecond - -func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { - if forceRefresh { - m.token = token.New( - "test", - "test", - rawTokenString, - time.Now().Add(tokenExpiration), - time.Now(), - int64(100*time.Millisecond), - ) - } - return m.token, m.err -} - -func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CancelFunc, error) { - done := make(chan struct{}) - go func() { - for { - select { - case <-time.After(tokenExpiration): - m.lock.Lock() - if m.err != nil { - listener.OnTokenError(m.err) - return - } - listener.OnTokenNext(m.token) - m.lock.Unlock() - case <-done: - // Exit the loop if done channel is closed - return - - } - } - }() - - return func() error { - close(done) - return nil - }, nil -} - -func (m *mockTokenManager) Close() error { - return nil -} - -// mockCredentialsListener implements the CredentialsListener interface for testing -type mockCredentialsListener struct { - LastTokenCh chan string - LastErrCh chan error -} - -func (m *mockCredentialsListener) readWithTimeout(timeout time.Duration) (string, error) { - select { - case tk := <-m.LastTokenCh: - return tk, nil - case err := <-m.LastErrCh: - return "", err - case <-time.After(timeout): - return "", errors.New("timeout waiting for token") - } -} - -func (m *mockCredentialsListener) OnNext(credentials auth.Credentials) { - if m.LastTokenCh == nil { - m.LastTokenCh = make(chan string) - } - m.LastTokenCh <- credentials.RawCredentials() -} - -func (m *mockCredentialsListener) OnError(err error) { - if m.LastErrCh == nil { - m.LastErrCh = make(chan error) - } - m.LastErrCh <- err -} - -// testTokenManagerFactory is a factory function that returns a mock token manager -func testTokenManagerFactory(tk *token.Token, err error) func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { - return func(provider shared.IdentityProvider, options manager.TokenManagerOptions) (manager.TokenManager, error) { - return &mockTokenManager{ - token: tk, - err: err, - }, nil - } -} - func TestNewManagedIdentityCredentialsProvider(t *testing.T) { tests := []struct { name string @@ -173,7 +76,7 @@ func TestNewManagedIdentityCredentialsProvider(t *testing.T) { ) // Set the token manager factory in the options - tt.options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + tt.options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) provider, err := NewManagedIdentityCredentialsProvider(tt.options) if tt.expectedError != nil { @@ -258,7 +161,7 @@ func TestNewConfidentialCredentialsProvider(t *testing.T) { ) // Set the token manager factory in the options - tt.options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + tt.options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) provider, err := NewConfidentialCredentialsProvider(tt.options) if tt.expectedError != nil { @@ -331,7 +234,7 @@ func TestNewDefaultAzureCredentialsProvider(t *testing.T) { ) // Set the token manager factory in the options - tt.options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + tt.options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) provider, err := NewDefaultAzureCredentialsProvider(tt.options) if tt.expectedError != nil { @@ -357,86 +260,6 @@ func TestNewDefaultAzureCredentialsProvider(t *testing.T) { } } -func TestCredentialsProviderErrorHandling(t *testing.T) { - t.Run("on re-authentication error", func(t *testing.T) { - options := ConfidentialCredentialsProviderOptions{ - CredentialsProviderOptions: CredentialsProviderOptions{ - ClientID: "test-client-id", - TokenManagerOptions: manager.TokenManagerOptions{ - ExpirationRefreshRatio: 0.7, - }, - }, - ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ - ClientID: "test-client-id", - CredentialsType: identity.ClientSecretCredentialType, - ClientSecret: "test-secret", - Scopes: []string{identity.RedisScopeDefault}, - Authority: identity.AuthorityConfiguration{}, - }, - } - - // Create a test token - testToken := token.New( - "test", - "test", - rawTokenString, - time.Now().Add(time.Hour), - time.Now(), - int64(time.Hour), - ) - - // Set the token manager factory in the options - options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) - - provider, err := NewConfidentialCredentialsProvider(options) - require.NoError(t, err) - require.NotNil(t, provider) - - // Test that the error handler is properly set - // Note: This is a simplified test as actual authentication would require Azure credentials - assert.NotNil(t, provider) - }) - - t.Run("on retryable error", func(t *testing.T) { - options := ConfidentialCredentialsProviderOptions{ - CredentialsProviderOptions: CredentialsProviderOptions{ - ClientID: "test-client-id", - TokenManagerOptions: manager.TokenManagerOptions{ - ExpirationRefreshRatio: 0.7, - }, - }, - ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ - ClientID: "test-client-id", - CredentialsType: identity.ClientSecretCredentialType, - ClientSecret: "test-secret", - Scopes: []string{identity.RedisScopeDefault}, - Authority: identity.AuthorityConfiguration{}, - }, - } - - // Create a test token - testToken := token.New( - "test", - "test", - rawTokenString, - time.Now().Add(time.Hour), - time.Now(), - int64(time.Hour), - ) - - // Set the token manager factory in the options - options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) - - provider, err := NewConfidentialCredentialsProvider(options) - require.NoError(t, err) - require.NotNil(t, provider) - - // Test that the error handler is properly set - // Note: This is a simplified test as actual authentication would require Azure credentials - assert.NotNil(t, provider) - }) -} - func TestCredentialsProviderInterface(t *testing.T) { // Test that all providers implement the StreamingCredentialsProvider interface tests := []struct { @@ -471,7 +294,7 @@ func TestCredentialsProviderInterface(t *testing.T) { ) // Set the token manager factory in the options - options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) p, _ := NewManagedIdentityCredentialsProvider(options) return p @@ -507,7 +330,7 @@ func TestCredentialsProviderInterface(t *testing.T) { ) // Set the token manager factory in the options - options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) p, _ := NewConfidentialCredentialsProvider(options) return p @@ -539,7 +362,7 @@ func TestCredentialsProviderInterface(t *testing.T) { ) // Set the token manager factory in the options - options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + options.tokenManagerFactory = testFakeTokenManagerFactory(testToken, nil) p, _ := NewDefaultAzureCredentialsProvider(options) return p @@ -560,18 +383,127 @@ func TestCredentialsProviderInterface(t *testing.T) { } } -func TestCredentialsProviderSubscribe(t *testing.T) { +func TestNewManagedIdentityCredentialsProvider_TokenManagerFactoryError(t *testing.T) { + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Set the token manager factory to return an error + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return nil, fmt.Errorf("token manager factory error") + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager factory error") + assert.Nil(t, provider) +} + +func TestNewConfidentialCredentialsProvider_TokenManagerFactoryError(t *testing.T) { + options := ConfidentialCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ + ClientID: "test-client-id", + CredentialsType: identity.ClientSecretCredentialType, + ClientSecret: "test-secret", + Scopes: []string{identity.RedisScopeDefault}, + Authority: identity.AuthorityConfiguration{}, + }, + } + + // Set the token manager factory to return an error + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return nil, fmt.Errorf("token manager factory error") + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager factory error") + assert.Nil(t, provider) +} + +func TestNewDefaultAzureCredentialsProvider_TokenManagerFactoryError(t *testing.T) { + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + } + + // Set the token manager factory to return an error + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return nil, fmt.Errorf("token manager factory error") + } + + provider, err := NewDefaultAzureCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager factory error") + assert.Nil(t, provider) +} + +func TestNewManagedIdentityCredentialsProvider_TokenManagerStartError(t *testing.T) { + options := ManagedIdentityCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + }, + }, + ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ + UserAssignedClientID: "test-client-id", + ManagedIdentityType: identity.UserAssignedIdentity, + Scopes: []string{identity.RedisScopeDefault}, + }, + } + // Create a test token testToken := token.New( "test", "test", rawTokenString, - time.Now().Add(tokenExpiration), + time.Now().Add(time.Hour), time.Now(), - int64(tokenExpiration), + int64(time.Hour), ) - // Create a test provider + // Create a mock token manager that returns an error on Start + mockTM := &fakeTokenManager{ + token: testToken, + err: fmt.Errorf("token manager start error"), + } + + // Set the token manager factory to return our mock + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return mockTM, nil + } + + provider, err := NewManagedIdentityCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager start error") + assert.Nil(t, provider) +} + +func TestNewConfidentialCredentialsProvider_TokenManagerStartError(t *testing.T) { options := ConfidentialCredentialsProviderOptions{ CredentialsProviderOptions: CredentialsProviderOptions{ ClientID: "test-client-id", @@ -588,374 +520,69 @@ func TestCredentialsProviderSubscribe(t *testing.T) { }, } - // Set the token manager factory in the options - options.tokenManagerFactory = testTokenManagerFactory(testToken, nil) + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) - provider, err := NewConfidentialCredentialsProvider(options) - require.NoError(t, err) - require.NotNil(t, provider) - - t.Run("concurrent subscribe and cancel", func(t *testing.T) { - const numListeners = 10 - var wg sync.WaitGroup - listeners := make([]*mockCredentialsListener, numListeners) - cancels := make([]auth.CancelProviderFunc, numListeners) - - // Subscribe multiple listeners concurrently - for i := 0; i < numListeners; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - listener := &mockCredentialsListener{ - LastTokenCh: make(chan string, 1), - LastErrCh: make(chan error, 1), - } - listeners[idx] = listener - _, cancel, err := provider.Subscribe(listener) - require.NoError(t, err) - cancels[idx] = cancel - }(i) - } - wg.Wait() - - // Verify all listeners received the token - for i, listener := range listeners { - select { - case tk := <-listener.LastTokenCh: - assert.Equal(t, rawTokenString, tk, "listener %d received wrong token", i) - case err := <-listener.LastErrCh: - t.Fatalf("listener %d received error: %v", i, err) - case <-time.After(3 * tokenExpiration): - t.Fatalf("listener %d timed out waiting for token", i) - } - } - - // Cancel all subscriptions concurrently - for i := 0; i < numListeners; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - err := cancels[idx]() - require.NoError(t, err) - }(i) - } - wg.Wait() - - // Verify no more tokens are sent after cancellation - for i, listener := range listeners { - select { - case tk := <-listener.LastTokenCh: - t.Fatalf("listener %d received unexpected token after cancellation: %s", i, tk) - case err := <-listener.LastErrCh: - t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err) - default: - // No message received, which is expected - } - } - }) -} + // Create a mock token manager that returns an error on Start + mockTM := &fakeTokenManager{ + token: testToken, + err: fmt.Errorf("token manager start error"), + } -func TestCredentialsProviderOptions(t *testing.T) { - t.Run("default token manager factory", func(t *testing.T) { - options := CredentialsProviderOptions{} - factory := options.getTokenManagerFactory() - assert.NotNil(t, factory) - }) - - t.Run("custom token manager factory", func(t *testing.T) { - m := &mockTokenManager{} - customFactory := func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { - return m, nil - } - options := CredentialsProviderOptions{ - tokenManagerFactory: customFactory, - } - tm, err := options.getTokenManagerFactory()(nil, manager.TokenManagerOptions{}) - assert.NotNil(t, tm) - assert.NoError(t, err) - assert.Equal(t, m, tm) - }) + // Set the token manager factory to return our mock + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return mockTM, nil + } + + provider, err := NewConfidentialCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager start error") + assert.Nil(t, provider) } -func TestCredentialsProviderErrorScenarios(t *testing.T) { - t.Run("token manager start error", func(t *testing.T) { - // Create a test provider with invalid options - options := ConfidentialCredentialsProviderOptions{ - CredentialsProviderOptions: CredentialsProviderOptions{ - ClientID: "test-client-id", - TokenManagerOptions: manager.TokenManagerOptions{ - ExpirationRefreshRatio: 0.7, - }, - }, - ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ - ClientID: "test-client-id", - CredentialsType: "invalid-type", // Invalid credentials type - ClientSecret: "test-secret", - Scopes: []string{identity.RedisScopeDefault}, - Authority: identity.AuthorityConfiguration{}, - }, - } - - provider, err := NewConfidentialCredentialsProvider(options) - assert.Error(t, err) - assert.Nil(t, provider) - }) - - t.Run("token manager get token error", func(t *testing.T) { - // Create a test provider with invalid options - options := ConfidentialCredentialsProviderOptions{ - CredentialsProviderOptions: CredentialsProviderOptions{ - ClientID: "test-client-id", - TokenManagerOptions: manager.TokenManagerOptions{ - ExpirationRefreshRatio: 0.7, - }, - }, - ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{ - ClientID: "test-client-id", - CredentialsType: identity.ClientSecretCredentialType, - ClientSecret: "", // Empty client secret - Scopes: []string{identity.RedisScopeDefault}, - Authority: identity.AuthorityConfiguration{}, - }, - } - - provider, err := NewConfidentialCredentialsProvider(options) - assert.Error(t, err) - assert.Nil(t, provider) - }) - - t.Run("concurrent error handling", func(t *testing.T) { - // Create a test provider with invalid options - options := ManagedIdentityCredentialsProviderOptions{ - CredentialsProviderOptions: CredentialsProviderOptions{ - ClientID: "test-client-id", - TokenManagerOptions: manager.TokenManagerOptions{ - ExpirationRefreshRatio: 0.7, - }, - }, - ManagedIdentityProviderOptions: identity.ManagedIdentityProviderOptions{ - ManagedIdentityType: "invalid-type", // Invalid managed identity type - Scopes: []string{identity.RedisScopeDefault}, - }, - } - - provider, err := NewManagedIdentityCredentialsProvider(options) - assert.Error(t, err) - assert.Nil(t, provider) - }) - - t.Run("concurrent token updates", func(t *testing.T) { - // Create a test provider with invalid options - options := DefaultAzureCredentialsProviderOptions{ - CredentialsProviderOptions: CredentialsProviderOptions{ - ClientID: "test-client-id", - TokenManagerOptions: manager.TokenManagerOptions{ - ExpirationRefreshRatio: 0.7, - }, - }, - DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ - Scopes: []string{}, // Empty scopes +func TestNewDefaultAzureCredentialsProvider_TokenManagerStartError(t *testing.T) { + options := DefaultAzureCredentialsProviderOptions{ + CredentialsProviderOptions: CredentialsProviderOptions{ + ClientID: "test-client-id", + TokenManagerOptions: manager.TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, }, - } + }, + DefaultAzureIdentityProviderOptions: identity.DefaultAzureIdentityProviderOptions{ + Scopes: []string{identity.RedisScopeDefault}, + }, + } - provider, err := NewDefaultAzureCredentialsProvider(options) - assert.Error(t, err) - assert.Nil(t, provider) - }) -} + // Create a test token + testToken := token.New( + "test", + "test", + rawTokenString, + time.Now().Add(time.Hour), + time.Now(), + int64(time.Hour), + ) -func TestCredentialsProviderWithMockIdentityProvider(t *testing.T) { - t.Parallel() + // Create a mock token manager that returns an error on Start + mockTM := &fakeTokenManager{ + token: testToken, + err: fmt.Errorf("token manager start error"), + } - t.Run("Subscribe and Unsubscribe", func(t *testing.T) { - t.Parallel() + // Set the token manager factory to return our mock + options.tokenManagerFactory = func(shared.IdentityProvider, manager.TokenManagerOptions) (manager.TokenManager, error) { + return mockTM, nil + } - // Create mock token manager - tm := &mockTokenManager{ - token: token.New( - "test", - "test", - "test-token", - time.Now().Add(time.Hour), - time.Now(), - int64(time.Hour), - ), - } - - // Create credentials provider - cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) - assert.NoError(t, err) - assert.NotNil(t, cp) - - // Create mock listener - listener := &mockCredentialsListener{ - LastTokenCh: make(chan string), - LastErrCh: make(chan error), - } - - // Subscribe listener - credentials, cancel, err := cp.Subscribe(listener) - assert.NoError(t, err) - assert.NotNil(t, credentials) - assert.NotNil(t, cancel) - - // Wait for initial token - tk, err := listener.readWithTimeout(time.Second) - assert.NoError(t, err) - assert.Equal(t, "test-token", tk) - - // Unsubscribe - err = cancel() - assert.NoError(t, err) - }) - - t.Run("Multiple Listeners", func(t *testing.T) { - t.Parallel() - - // Create mock token manager - tm := &mockTokenManager{ - token: token.New( - "test", - "test", - "test-token", - time.Now().Add(time.Hour), - time.Now(), - int64(time.Hour), - ), - } - - // Create credentials provider - cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) - assert.NoError(t, err) - assert.NotNil(t, cp) - - // Create multiple mock listeners - listener1 := &mockCredentialsListener{ - LastTokenCh: make(chan string), - LastErrCh: make(chan error), - } - listener2 := &mockCredentialsListener{ - LastTokenCh: make(chan string), - LastErrCh: make(chan error), - } - - // Subscribe first listener - credentials1, cancel1, err := cp.Subscribe(listener1) - assert.NoError(t, err) - assert.NotNil(t, credentials1) - assert.NotNil(t, cancel1) - - // Subscribe second listener - credentials2, cancel2, err := cp.Subscribe(listener2) - assert.NoError(t, err) - assert.NotNil(t, credentials2) - assert.NotNil(t, cancel2) - - // Wait for initial tokens - token1, err := listener1.readWithTimeout(time.Second) - assert.NoError(t, err) - assert.Equal(t, "test-token", token1) - - token2, err := listener2.readWithTimeout(time.Second) - assert.NoError(t, err) - assert.Equal(t, "test-token", token2) - - // Unsubscribe first listener - err = cancel1() - assert.NoError(t, err) - - // Unsubscribe second listener - err = cancel2() - assert.NoError(t, err) - }) - - t.Run("Token Updates", func(t *testing.T) { - t.Parallel() - - // Create mock token manager - tm := &mockTokenManager{ - token: token.New( - "test", - "test", - "initial-token", - time.Now().Add(time.Hour), - time.Now(), - int64(time.Hour), - ), - } - - // Create credentials provider - cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) - assert.NoError(t, err) - assert.NotNil(t, cp) - - // Create mock listener - listener := &mockCredentialsListener{ - LastTokenCh: make(chan string), - LastErrCh: make(chan error), - } - - // Subscribe listener - credentials, cancel, err := cp.Subscribe(listener) - assert.NoError(t, err) - assert.NotNil(t, credentials) - assert.NotNil(t, cancel) - - // Wait for initial token - tk, err := listener.readWithTimeout(time.Second) - assert.NoError(t, err) - assert.Equal(t, "initial-token", tk) - - tm.lock.Lock() - // Update token - tm.token = token.New( - "test", - "test", - "updated-token", - time.Now().Add(time.Hour), - time.Now(), - int64(time.Hour), - ) - tm.lock.Unlock() - - // Wait for token update - tk, err = listener.readWithTimeout(time.Second) - assert.NoError(t, err) - assert.Equal(t, "updated-token", tk) - - // Unsubscribe - err = cancel() - assert.NoError(t, err) - }) - - t.Run("Error Handling", func(t *testing.T) { - t.Parallel() - - // Create mock token manager with error - tm := &mockTokenManager{ - err: assert.AnError, - } - - // Create credentials provider - cp, err := NewCredentialsProvider(tm, CredentialsProviderOptions{}) - assert.NoError(t, err) - assert.NotNil(t, cp) - - // Create mock listener - listener := &mockCredentialsListener{ - LastTokenCh: make(chan string), - LastErrCh: make(chan error), - } - - // Subscribe listener - credentials, cancel, err := cp.Subscribe(listener) - assert.Error(t, err) - assert.Nil(t, credentials) - assert.Nil(t, cancel) - - // Wait for error - _, err = listener.readWithTimeout(time.Second) - assert.Error(t, err) - assert.Equal(t, assert.AnError, err) - }) + provider, err := NewDefaultAzureCredentialsProvider(options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token manager start error") + assert.Nil(t, provider) } diff --git a/token/token_test.go b/token/token_test.go index 4324612..f94e892 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -47,13 +47,6 @@ func TestTokenExpiration(t *testing.T) { assert.False(t, token.ExpirationOn().After(time.Now())) } -func TestTokenReceivedAt(t *testing.T) { - t.Parallel() - token := New("username", "password", "rawToken", time.Now(), time.Now().Add(1*time.Hour), 3600) - assert.True(t, token.receivedAt.After(time.Now().Add(-1*time.Hour))) - assert.True(t, token.receivedAt.Before(time.Now().Add(1*time.Hour))) -} - func TestTokenTTL(t *testing.T) { t.Parallel() token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) @@ -112,6 +105,27 @@ func TestTokenCompare(t *testing.T) { assert.True(t, token1.compareCredentials(token4)) } +func TestTokenReceivedAt(t *testing.T) { + t.Parallel() + // Create a token with a specific receivedAt time + receivedAt := time.Now() + token := New("username", "password", "rawToken", time.Now(), receivedAt, 3600) + + assert.True(t, token.receivedAt.After(time.Now().Add(-1*time.Hour))) + assert.True(t, token.receivedAt.Before(time.Now().Add(1*time.Hour))) + + // Check if the receivedAt time is set correctly + assert.Equal(t, receivedAt, token.ReceivedAt()) + + tcopiedToken := token.Copy() + // Check if the copied token has the same receivedAt time + assert.Equal(t, receivedAt, tcopiedToken.ReceivedAt()) + // Check if the copied token is not the same instance as the original token + assert.NotSame(t, token, tcopiedToken) + // Check if the copied token is a new instance + assert.NotNil(t, tcopiedToken) +} + func BenchmarkNew(b *testing.B) { now := time.Now() b.ResetTimer() diff --git a/token_listener.go b/token_listener.go index e76bba7..19a0e31 100644 --- a/token_listener.go +++ b/token_listener.go @@ -5,20 +5,32 @@ import ( "github.com/redis-developer/go-redis-entraid/token" ) +// entraidTokenListener implements the TokenListener interface for the entraidCredentialsProvider. +// It listens for token updates and errors from the token manager and notifies the credentials provider. type entraidTokenListener struct { cp *entraidCredentialsProvider } +// tokenListenerFromCP creates a new entraidTokenListener from the given entraidCredentialsProvider. +// It is used to listen for token updates and errors from the token manager. +// This function is typically called when starting the token manager. +// It returns a pointer to the entraidTokenListener instance that is created from the credentials provider. func tokenListenerFromCP(cp *entraidCredentialsProvider) manager.TokenListener { return &entraidTokenListener{ cp, } } +// OnTokenNext is called when the token manager receives a new token. +// It notifies the credentials provider with the new token. +// This function is typically called when the token manager successfully retrieves a token. func (l *entraidTokenListener) OnTokenNext(t *token.Token) { l.cp.onTokenNext(t) } +// OnTokenError is called when the token manager encounters an error. +// It notifies the credentials provider with the error. +// This function is typically called when the token manager fails to retrieve a token. func (l *entraidTokenListener) OnTokenError(err error) { l.cp.onTokenError(err) } From 89b83e6ff2b9fabaf0fe05f55569f89f107f7fb8 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 15 Apr 2025 11:35:49 +0300 Subject: [PATCH 24/44] improve answer in readme --- README.md | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e9ad564..6e05482 100644 --- a/README.md +++ b/README.md @@ -634,10 +634,37 @@ func TestRedisConnection(t *testing.T) { A: The library handles token expiration automatically. Tokens are refreshed when they reach 70% of their lifetime (configurable via `ExpirationRefreshRatio`). You can customize this behavior using `TokenManagerOptions`. ### Q: What's the difference between managed identity types? -A: -- System Assigned: Automatically created and managed by Azure -- User Assigned: Created and managed by you, can be shared across resources -- Default Azure: Uses environment-based authentication, good for development +A: There are three main types of managed identities in Azure: + +1. **System Assigned Managed Identity**: + - Automatically created and managed by Azure + - Tied directly to a specific Azure resource (VM, App Service, etc.) + - Cannot be shared between resources + - Automatically deleted when the resource is deleted + - Best for single-resource applications with dedicated identity + +2. **User Assigned Managed Identity**: + - Created and managed independently of resources + - Can be assigned to multiple Azure resources + - Has its own lifecycle independent of resources + - Can be shared across multiple resources + - Best for applications that need a shared identity or run across multiple resources + +3. **Default Azure Identity**: + - Uses environment-based authentication + - Automatically tries multiple authentication methods in sequence: + 1. Environment variables + 2. Managed Identity + 3. Visual Studio Code + 4. Azure CLI + 5. Visual Studio + - Best for development and testing environments + - Provides flexibility during development without changing code + +The choice between these types depends on your specific use case: +- Use System Assigned for single-resource applications +- Use User Assigned for shared identity scenarios +- Use Default Azure Identity for development and testing ### Q: How do I handle connection failures? A: The library includes built-in retry mechanisms in the TokenManager. You can configure retry behavior using `RetryOptions`: From 9d087bc29c12844471ca0a05fe04dea5684ad985 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 15 Apr 2025 14:13:32 +0300 Subject: [PATCH 25/44] export shared types --- entraid.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 entraid.go diff --git a/entraid.go b/entraid.go new file mode 100644 index 0000000..06896c5 --- /dev/null +++ b/entraid.go @@ -0,0 +1,12 @@ +package entraid + +import "github.com/redis-developer/go-redis-entraid/shared" + +// IdentityProvider is an alias for the shared.IdentityProvider interface. +type IdentityProvider = shared.IdentityProvider + +// IdentityProviderResponse is an alias for the shared.IdentityProviderResponse interface. +type IdentityProviderResponse = shared.IdentityProviderResponse + +// IdentityProviderResponseParser is an alias for the shared.IdentityProviderResponseParser interface. +type IdentityProviderResponseParser = shared.IdentityProviderResponseParser From 217097d87879b974e347381784a76499d1318cc5 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 15 Apr 2025 15:49:40 +0300 Subject: [PATCH 26/44] Address PR comments related to token tests Removed unused functionality. Introduced getter for TTL. Improved tests to use defined time.Time. --- token/token.go | 20 ++------ token/token_test.go | 116 +++++++++++++++++--------------------------- 2 files changed, 49 insertions(+), 87 deletions(-) diff --git a/token/token.go b/token/token.go index 66bd100..fafce60 100644 --- a/token/token.go +++ b/token/token.go @@ -64,26 +64,16 @@ func (t *Token) ExpirationOn() time.Time { return t.expiresOn } +// TTL returns the time to live of the token. +func (t *Token) TTL() int64 { + return t.ttl +} + // Copy creates a copy of the token. func (t *Token) Copy() *Token { return copyToken(t) } -// compareCredentials two tokens if they are the same credentials -func (t *Token) compareCredentials(token *Token) bool { - return t.username == token.username && t.password == token.password -} - -// compareRawCredentials two tokens if they are the same raw credentials -func (t *Token) compareRawCredentials(token *Token) bool { - return t.rawToken == token.rawToken -} - -// compareToken compares two tokens if they are the same token -func (t *Token) compareToken(token *Token) bool { - return t.compareCredentials(token) && t.compareRawCredentials(token) -} - // copyToken creates a copy of the token. func copyToken(token *Token) *Token { if token == nil { diff --git a/token/token_test.go b/token/token_test.go index f94e892..72a54ad 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -1,6 +1,7 @@ package token import ( + "fmt" "testing" "time" @@ -9,51 +10,71 @@ import ( func TestNew(t *testing.T) { t.Parallel() - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New("username", "password", "rawToken", expiration, receivedAt, ttl) assert.Equal(t, "username", token.username) assert.Equal(t, "password", token.password) assert.Equal(t, "rawToken", token.rawToken) - assert.Equal(t, int64(3600), token.ttl) + assert.Equal(t, expiration, token.expiresOn) + assert.Equal(t, receivedAt, token.receivedAt) + assert.Equal(t, ttl, token.ttl) } func TestBasicAuth(t *testing.T) { t.Parallel() - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - username, password := token.BasicAuth() - assert.Equal(t, "username", username) - assert.Equal(t, "password", password) + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) + baUsername, baPassword := token.BasicAuth() + assert.Equal(t, username, baUsername) + assert.Equal(t, password, baPassword) } func TestRawCredentials(t *testing.T) { t.Parallel() - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) rawCredentials := token.RawCredentials() - assert.Equal(t, "rawToken", rawCredentials) + assert.Equal(t, rawToken, rawCredentials) + assert.Contains(t, rawCredentials, username) + assert.Contains(t, rawCredentials, password) } func TestExpirationOn(t *testing.T) { t.Parallel() - token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) expirationOn := token.ExpirationOn() assert.True(t, expirationOn.After(time.Now())) -} - -func TestTokenExpiration(t *testing.T) { - t.Parallel() - token := New("username", "password", "rawToken", time.Now().Add(1*time.Hour), time.Now(), 3600) - assert.True(t, token.ExpirationOn().After(time.Now())) - - token.expiresOn = time.Now().Add(-1 * time.Hour) - assert.False(t, token.ExpirationOn().After(time.Now())) + assert.Equal(t, expiration, expirationOn) } func TestTokenTTL(t *testing.T) { t.Parallel() - token := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - assert.Equal(t, int64(3600), token.ttl) - - token.ttl = 7200 - assert.Equal(t, int64(7200), token.ttl) + username := "username12" + password := "password32" + rawToken := fmt.Sprintf("%s:%s", username, password) + expiration := time.Now().Add(1 * time.Hour) + receivedAt := time.Now() + ttl := expiration.Unix() - receivedAt.Unix() + token := New(username, password, rawToken, expiration, receivedAt, ttl) + assert.Equal(t, ttl, token.TTL()) } func TestCopyToken(t *testing.T) { @@ -83,28 +104,6 @@ func TestCopyToken(t *testing.T) { assert.NotEqual(t, copiedToken, anotherCopy) } -func TestTokenCompare(t *testing.T) { - t.Parallel() - // Create two tokens with the same credentials - token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - assert.True(t, token1.compareCredentials(token2)) - assert.True(t, token1.compareRawCredentials(token2)) - assert.True(t, token1.compareToken(token2)) - - // Create two tokens with different credentials and different raw credentials - token3 := New("username", "differentPassword", "differentRawToken", time.Now(), time.Now(), 3600) - assert.False(t, token1.compareCredentials(token3)) - assert.False(t, token1.compareRawCredentials(token3)) - assert.False(t, token1.compareToken(token3)) - - // Create token with same credentials but different rawCredentials - token4 := New("username", "password", "differentRawToken", time.Now(), time.Now(), 3600) - assert.False(t, token1.compareRawCredentials(token4)) - assert.False(t, token1.compareToken(token4)) - assert.True(t, token1.compareCredentials(token4)) -} - func TestTokenReceivedAt(t *testing.T) { t.Parallel() // Create a token with a specific receivedAt time @@ -165,30 +164,3 @@ func BenchmarkCopyToken(b *testing.B) { token.Copy() } } - -func BenchmarkCompareCredentials(b *testing.B) { - token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - b.ResetTimer() - for i := 0; i < b.N; i++ { - token1.compareCredentials(token2) - } -} - -func BenchmarkCompareRawCredentials(b *testing.B) { - token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - b.ResetTimer() - for i := 0; i < b.N; i++ { - token1.compareRawCredentials(token2) - } -} - -func BenchmarkCompareToken(b *testing.B) { - token1 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - token2 := New("username", "password", "rawToken", time.Now(), time.Now(), 3600) - b.ResetTimer() - for i := 0; i < b.N; i++ { - token1.compareToken(token2) - } -} From 87f7d84c85af8d6a01199c6a4a4677be6b0b05e4 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 16 Apr 2025 11:11:16 +0300 Subject: [PATCH 27/44] add deadlock test Kindly contributed by @bobymicroby --- .golangci.yml | 2 + .testcoverage.yml | 4 +- manager/defaults.go | 2 +- manager/token_manager_test.go | 301 ++++++++++++++++++++++++++++++++++ 4 files changed, 306 insertions(+), 3 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index ac9f76c..cd7f519 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,4 +1,6 @@ version: "2" +run: + tests: false linters: disable: - depguard diff --git a/.testcoverage.yml b/.testcoverage.yml index 91a1592..3c88348 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -15,11 +15,11 @@ threshold: # (optional; default 0) # Minimum coverage percentage required for each package. - package: 90 + package: 85 # (optional; default 0) # Minimum overall project coverage percentage required. - total: 85 + total: 90 # Holds regexp rules which will override thresholds for matched files or packages # using their paths. diff --git a/manager/defaults.go b/manager/defaults.go index 56587d0..e7e09f1 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -74,7 +74,7 @@ func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions { // The default token parser is used to parse the raw token and return a Token object. func defaultIdentityProviderResponseParserOr(idpResponseParser shared.IdentityProviderResponseParser) shared.IdentityProviderResponseParser { if idpResponseParser == nil { - return &defaultIdentityProviderResponseParser{} + return entraidIdentityProviderResponseParser } return idpResponseParser } diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 276a958..08c57bf 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -7,6 +7,7 @@ import ( "os" "reflect" "runtime" + "strings" "sync" "sync/atomic" "testing" @@ -1373,3 +1374,303 @@ func BenchmarkTokenManager_durationToRenewal(b *testing.B) { tm.durationToRenewal() } } + +// TestConcurrentTokenManagerOperations tests concurrent operations on the TokenManager +// to verify there are no deadlocks or race conditions in the implementation. +func TestConcurrentTokenManagerOperations(t *testing.T) { + t.Parallel() + + // Create a mock identity provider that returns predictable tokens + mockIdp := &concurrentMockIdentityProvider{ + tokenCounter: 0, + } + + // Create token manager with the mock provider + options := TokenManagerOptions{ + ExpirationRefreshRatio: 0.7, + LowerRefreshBoundMs: 100, + } + tm, err := NewTokenManager(mockIdp, options) + assert.NoError(t, err) + assert.NotNil(t, tm) + + // Number of concurrent operations to perform + const numConcurrentOps = 50 + const numGoroutines = 1000 + + // Channels to track received tokens and errors + tokenCh := make(chan *token.Token, numConcurrentOps*numGoroutines) + errorCh := make(chan error, numConcurrentOps*numGoroutines) + + // Channel to signal completion of all operations + doneCh := make(chan struct{}) + + // Track closers for cleanup + var closers sync.Map + + // Start multiple goroutines that will concurrently interact with the token manager + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(routineID int) { + defer wg.Done() + + for j := 0; j < numConcurrentOps; j++ { + // Create a listener for this operation + listener := &concurrentTestTokenListener{ + onNextFunc: func(t *token.Token) { + select { + case tokenCh <- t: + default: + // Channel full, ignore + } + }, + onErrorFunc: func(err error) { + select { + case errorCh <- err: + default: + // Channel full, ignore + } + }, + } + + // Choose operation based on a pattern + // Using modulo for a deterministic pattern that exercises all operations + opType := j % 3 + + // t.Logf("Goroutine %d, Operation %d: Performing operation type %d", routineID, j, opType) + + switch opType { + case 0: + // Start the token manager with a new listener + // t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j) + closeFunc, err := tm.Start(listener) + + if err != nil { + if err != ErrTokenManagerAlreadyStarted { + // t.Logf("Goroutine %d, Operation %d: Start failed with error: %v", routineID, j, err) + select { + case errorCh <- fmt.Errorf("failed to start token manager: %w", err): + default: + t.Fatalf("Goroutine %d, Operation %d: Failed to start token manager: %v", routineID, j, err) + } + } + continue + } + + // t.Logf("Goroutine %d, Operation %d: Successfully started token manager", routineID, j) + // Store the closer for later cleanup + closerKey := fmt.Sprintf("closer-%d-%d", routineID, j) + closers.Store(closerKey, closeFunc) + + // Simulate some work + time.Sleep(time.Duration(500-rand.Intn(400)) * time.Millisecond) + + case 1: + // Get current token + //t.Logf("Goroutine %d, Operation %d: Getting token", routineID, j) + token, err := tm.GetToken(false) + if err != nil { + //t.Logf("Goroutine %d, Operation %d: GetToken failed with error: %v", routineID, j, err) + select { + case errorCh <- fmt.Errorf("failed to get token: %w", err): + default: + t.Fatalf("Goroutine %d, Operation %d: Failed to get token: %v", routineID, j, err) + } + } else if token != nil { + //t.Logf("Goroutine %d, Operation %d: Successfully got token, expires: %v", routineID, j, token.ExpirationOn()) + select { + case tokenCh <- token: + default: + // Channel full, ignore + } + } + + case 2: + // Close a previously created token manager listener + // This simulates multiple subscriptions being created and destroyed + //t.Logf("Goroutine %d, Operation %d: Attempting to close a token manager", routineID, j) + closedAny := false + + closers.Range(func(key, value interface{}) bool { + if j%10 > 7 { // Only close some of the time based on a pattern + closedAny = true + //t.Logf("Goroutine %d, Operation %d: Closing token manager with key %v", routineID, j, key) + + closeFunc := value.(CloseFunc) + if err := closeFunc(); err != nil { + if err != ErrTokenManagerAlreadyClosed { + // t.Logf("Goroutine %d, Operation %d: Close failed with error: %v", routineID, j, err) + select { + case errorCh <- fmt.Errorf("failed to close token manager: %w", err): + default: + t.Fatalf("Goroutine %d, Operation %d: Failed to close token manager: %v", routineID, j, err) + } + } else { + //t.Logf("Goroutine %d, Operation %d: TokenManager was already closed", routineID, j) + } + } else { + // t.Logf("Goroutine %d, Operation %d: Successfully closed token manager", routineID, j) + } + + closers.Delete(key) + return false // stop after finding one to close + } + return true + }) + + if !closedAny { + //t.Logf("Goroutine %d, Operation %d: No token manager to close or condition not met", routineID, j) + } + } + } + }(i) + } + + // Wait for all operations to complete or timeout + go func() { + wg.Wait() + close(doneCh) + }() + + // Use a timeout to detect deadlocks + select { + case <-doneCh: + // All operations completed successfully + t.Log("All concurrent operations completed successfully") + case <-time.After(30 * time.Second): + t.Fatal("test timed out, possible deadlock detected") + } + + // Count operations by type + var startCount, getTokenCount, closeCount int32 + + // Collect all ops from goroutines + for i := 0; i < numGoroutines; i++ { + for j := 0; j < numConcurrentOps; j++ { + opType := j % 3 + switch opType { + case 0: + atomic.AddInt32(&startCount, 1) + case 1: + atomic.AddInt32(&getTokenCount, 1) + case 2: + atomic.AddInt32(&closeCount, 1) + } + } + } + + // Clean up any remaining closers + closers.Range(func(key, value interface{}) bool { + closeFunc := value.(CloseFunc) + _ = closeFunc() // Ignore errors during cleanup + return true + }) + + // Close channels to avoid goroutine leaks + close(tokenCh) + close(errorCh) + + // Count tokens and check their validity + var tokens []*token.Token + for t := range tokenCh { + tokens = append(tokens, t) + } + + // Collect and categorize errors + var startErrors, getTokenErrors, closeErrors, otherErrors []error + for err := range errorCh { + errStr := err.Error() + if strings.Contains(errStr, "failed to start token manager") { + startErrors = append(startErrors, err) + } else if strings.Contains(errStr, "failed to get token") { + getTokenErrors = append(getTokenErrors, err) + } else if strings.Contains(errStr, "failed to close token manager") { + closeErrors = append(closeErrors, err) + } else { + otherErrors = append(otherErrors, err) + t.Fatalf("Unexpected error during concurrent operations: %v", err) + } + } + + totalOps := startCount + getTokenCount + closeCount + expectedOps := int32(numGoroutines * numConcurrentOps) + + // Report operation counts + t.Logf("Concurrent test summary:") + t.Logf("- Total operations executed: %d (expected: %d)", totalOps, expectedOps) + t.Logf("- Start operations: %d (with %d errors)", startCount, len(startErrors)) + t.Logf("- GetToken operations: %d (with %d errors, %d successful)", + getTokenCount, len(getTokenErrors), len(tokens)) + t.Logf("- Close operations: %d (with %d errors)", closeCount, len(closeErrors)) + + // Some errors are expected due to concurrent operations + // but we should have received tokens successfully + assert.Equal(t, expectedOps, totalOps, "All operations should be accounted for") + assert.True(t, len(tokens) > 0, "Should have received tokens") + + // Verify the token manager still works after all the concurrent operations + finalListener := &concurrentTestTokenListener{ + onNextFunc: func(t *token.Token) { + // Just verify we get a token - don't use assert within this callback + if t == nil { + panic("Final token should not be nil") + } + }, + onErrorFunc: func(err error) { + t.Errorf("Unexpected error in final listener: %v", err) + }, + } + + closeFunc, err := tm.Start(finalListener) + if err != nil && err != ErrTokenManagerAlreadyStarted { + t.Fatalf("Failed to start token manager after concurrent operations: %v", err) + } + if closeFunc != nil { + defer closeFunc() + } + + // Get token one more time to verify everything still works + finalToken, err := tm.GetToken(true) + assert.NoError(t, err, "Should be able to get token after concurrent operations") + assert.NotNil(t, finalToken, "Final token should not be nil") +} + +// concurrentTestTokenListener is a test implementation of TokenListener for concurrent tests +type concurrentTestTokenListener struct { + onNextFunc func(*token.Token) + onErrorFunc func(error) +} + +func (l *concurrentTestTokenListener) OnTokenNext(t *token.Token) { + if l.onNextFunc != nil { + l.onNextFunc(t) + } +} + +func (l *concurrentTestTokenListener) OnTokenError(err error) { + if l.onErrorFunc != nil { + l.onErrorFunc(err) + } +} + +// concurrentMockIdentityProvider is a mock implementation of shared.IdentityProvider for concurrent tests +type concurrentMockIdentityProvider struct { + tokenCounter int + mutex sync.Mutex +} + +func (m *concurrentMockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.tokenCounter++ + + // Use the existing test JWT token which is already properly formatted + resp, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, testJWTToken) + if err != nil { + return nil, err + } + return resp, nil +} From 2a286c63498d360334dfdb8d603399beaddfd364 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 16 Apr 2025 13:58:15 +0300 Subject: [PATCH 28/44] Address PR comments, refactor - Change type names to make more sense (e.g. Start / Stop ) - Add context.Context to the IDP RequestToken - Add RequestTimeout to TokenManagerOptions which will be utilized by the context - Change the LowerRefreshBoundMs from int64 to time.Duration and use better name (dropping the Ms suffix) TODO: - Address changes in the documentation --- .github/workflows/build.yml | 2 +- README.md | 8 +- credentials_provider.go | 24 +-- credentials_provider_test.go | 10 +- entraid_test.go | 20 +-- go.mod | 4 +- go.sum | 6 + identity/azure_default_identity_provider.go | 4 +- .../azure_default_identity_provider_test.go | 11 +- identity/confidential_identity_provider.go | 4 +- .../confidential_identity_provider_test.go | 7 +- identity/managed_identity_provider.go | 4 +- identity/managed_identity_provider_test.go | 4 +- manager/errors.go | 4 +- manager/manager_test.go | 9 +- manager/token_manager.go | 91 ++++++---- manager/token_manager_test.go | 163 +++++++++--------- shared/identity_provider_response.go | 5 +- token_listener.go | 4 +- token_listener_test.go | 4 +- 20 files changed, 211 insertions(+), 177 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 149981e..e167232 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: - name: Install dependencies run: go mod tidy - name: Run tests with coverage - run: go test ./... -coverprofile=./cover.out -covermode=atomic -race -count 2 -timeout 1m + run: go test ./... -coverprofile=./cover.out -covermode=atomic -race -count 2 -timeout 5m - name: Upload coverage uses: actions/upload-artifact@v4 with: diff --git a/README.md b/README.md index 6e05482..9c9ceb5 100644 --- a/README.md +++ b/README.md @@ -231,7 +231,7 @@ type TokenManagerOptions struct { // Optional: Minimum time before expiration to refresh (ms) // Default: 10000 (10 seconds) - LowerRefreshBoundMs int64 + LowerRefreshBounds int64 // Optional: Configuration for retry behavior RetryOptions RetryOptions @@ -350,7 +350,7 @@ options := entraid.CredentialsProviderOptions{ ClientID: os.Getenv("AZURE_CLIENT_ID"), TokenManagerOptions: manager.TokenManagerOptions{ ExpirationRefreshRatio: 0.7, - LowerRefreshBoundMs: 10000, + LowerRefreshBounds: 10000, }, } ``` @@ -361,7 +361,7 @@ options := entraid.CredentialsProviderOptions{ ClientID: os.Getenv("AZURE_CLIENT_ID"), TokenManagerOptions: manager.TokenManagerOptions{ ExpirationRefreshRatio: 0.7, - LowerRefreshBoundMs: 10000, + LowerRefreshBounds: 10000, RetryOptions: manager.RetryOptions{ MaxAttempts: 3, InitialDelayMs: 1000, @@ -516,7 +516,7 @@ func main() { tokenManager, err := manager.NewTokenManager(customProvider, manager.TokenManagerOptions{ // Configure token refresh behavior ExpirationRefreshRatio: 0.7, - LowerRefreshBoundMs: 10000, + LowerRefreshBounds: 10000, }) if err != nil { log.Fatalf("Failed to create token manager: %v", err) diff --git a/credentials_provider.go b/credentials_provider.go index 50ed6b7..30ff8e5 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -19,8 +19,8 @@ var _ auth.StreamingCredentialsProvider = (*entraidCredentialsProvider)(nil) type entraidCredentialsProvider struct { options CredentialsProviderOptions // Configuration options for the provider. - tokenManager manager.TokenManager // Manages token retrieval. - closeTokenManager manager.CloseFunc // Function to cancel the token manager. + tokenManager manager.TokenManager // Manages token retrieval. + stopTokenManager manager.StopFunc // Function to stop the token manager. // listeners is a slice of listeners that are notified when the token manager receives a new token. listeners []auth.CredentialsListener // Slice of listeners notified on token updates. @@ -64,7 +64,7 @@ func (e *entraidCredentialsProvider) onTokenError(err error) { // - error: An error if the subscription fails, such as if the token cannot be retrieved. // // Note: If the listener is already subscribed, it will not receive duplicate notifications. -func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.CancelProviderFunc, error) { +func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { // First try to get a token, only then subscribe the listener. token, err := e.tokenManager.GetToken(false) if err != nil { @@ -87,7 +87,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener } e.rwLock.Unlock() - cancel := func() error { + unsub := func() error { // Remove the listener from the list of listeners. e.rwLock.Lock() defer e.rwLock.Unlock() @@ -102,20 +102,20 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener // Clear the listeners slice if it's empty if len(e.listeners) == 0 { e.listeners = make([]auth.CredentialsListener, 0) - if e.closeTokenManager != nil { - err := e.closeTokenManager() + if e.stopTokenManager != nil { + err := e.stopTokenManager() if err != nil { return fmt.Errorf("couldn't cancel token manager: %w", err) } - // Set the cancelTokenManager to nil to indicate that it has been canceled. - // This prevents multiple calls to cancelTokenManager. - e.closeTokenManager = nil + // Set the stopTokenManager to nil to indicate that it has been stopped. + // This prevents multiple calls to stopTokenManager. + e.stopTokenManager = nil } } return nil } - return token, cancel, nil + return token, unsub, nil } // NewCredentialsProvider creates a new credentials provider with the specified token manager and options. @@ -134,10 +134,10 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia options: options, listeners: make([]auth.CredentialsListener, 0), } - cancelTokenManager, err := cp.tokenManager.Start(tokenListenerFromCP(cp)) + stopTM, err := cp.tokenManager.Start(tokenListenerFromCP(cp)) if err != nil { return nil, fmt.Errorf("couldn't start token manager: %w", err) } - cp.closeTokenManager = cancelTokenManager + cp.stopTokenManager = stopTM return cp, nil } diff --git a/credentials_provider_test.go b/credentials_provider_test.go index 7f172b4..bb4871d 100644 --- a/credentials_provider_test.go +++ b/credentials_provider_test.go @@ -343,7 +343,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { mtm.On("GetToken", false).Return(testToken, nil) mtm.On("Start", mock.Anything). Run(mockTokenManagerLoop(mtm, tokenExpiration, testToken, nil)). - Return(manager.CloseFunc(mtm.Close), nil) + Return(manager.StopFunc(mtm.Stop), nil) provider, err := NewConfidentialCredentialsProvider(options) require.NoError(t, err) require.NotNil(t, provider) @@ -396,13 +396,13 @@ func TestCredentialsProviderSubscribe(t *testing.T) { mtm.On("Start", mock.Anything). Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). - Return(manager.CloseFunc(mtm.Close), nil) + Return(manager.StopFunc(mtm.Stop), nil) provider, err := NewConfidentialCredentialsProvider(options) require.NoError(t, err) require.NotNil(t, provider) var wg sync.WaitGroup listeners := make([]*mockCredentialsListener, numListeners) - cancels := make([]auth.CancelProviderFunc, numListeners) + cancels := make([]auth.UnsubscribeFunc, numListeners) // Subscribe multiple listeners concurrently for i := 0; i < numListeners; i++ { @@ -467,7 +467,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { mtm.On("Start", mock.Anything). Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). - Return(manager.CloseFunc(mtm.Close), nil) + Return(manager.StopFunc(mtm.Stop), nil) provider, err := NewConfidentialCredentialsProvider(options) require.NoError(t, err) require.NotNil(t, provider) @@ -525,7 +525,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { require.NotNil(t, provider) var wg sync.WaitGroup listeners := make([]*mockCredentialsListener, numListeners) - cancels := make([]auth.CancelProviderFunc, numListeners) + cancels := make([]auth.UnsubscribeFunc, numListeners) // Subscribe multiple listeners concurrently for i := 0; i < numListeners; i++ { diff --git a/entraid_test.go b/entraid_test.go index 86b994f..5f64167 100644 --- a/entraid_test.go +++ b/entraid_test.go @@ -54,7 +54,7 @@ func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { return m.token, m.err } -func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseFunc, error) { +func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.StopFunc, error) { if m.err != nil { return nil, m.err } @@ -65,10 +65,10 @@ func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseF case <-time.After(tokenExpiration): m.lock.Lock() if m.err != nil { - listener.OnTokenError(m.err) + listener.OnError(m.err) return } - listener.OnTokenNext(m.token) + listener.OnNext(m.token) m.lock.Unlock() case <-done: // Exit the loop if done channel is closed @@ -84,7 +84,7 @@ func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseF }, nil } -func (m *fakeTokenManager) Close() error { +func (m *fakeTokenManager) Stop() error { return nil } @@ -147,7 +147,7 @@ func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { return args.Get(0).(*token.Token), args.Error(1) } -func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CloseFunc, error) { +func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.StopFunc, error) { args := m.Called(listener) m.lock.Lock() if m.done == nil { @@ -161,13 +161,13 @@ func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CloseF m.listener = listener } m.lock.Unlock() - return args.Get(0).(manager.CloseFunc), args.Error(1) + return args.Get(0).(manager.StopFunc), args.Error(1) } -func (m *mockTokenManager) Close() error { +func (m *mockTokenManager) Stop() error { m.lock.Lock() defer m.lock.Unlock() if m.listener == nil { - return manager.ErrTokenManagerAlreadyClosed + return manager.ErrTokenManagerAlreadyStopped } if m.listener != nil { m.listener = nil @@ -200,9 +200,9 @@ func mockTokenManagerLoop(mtm *mockTokenManager, tokenExpiration time.Duration, case <-time.After(tokenExpiration): mtm.lock.Lock() if err != nil { - mtm.listener.OnTokenError(err) + mtm.listener.OnError(err) } else { - mtm.listener.OnTokenNext(testToken) + mtm.listener.OnNext(testToken) } mtm.lock.Unlock() } diff --git a/go.mod b/go.mod index f1cdd0d..b59f728 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,14 @@ go 1.18 require ( github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1 - github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a + github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420 github.com/stretchr/testify v1.10.0 ) require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 10a19fe..1d7d291 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,12 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkY github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 h1:8BKxhZZLX/WosEeoCvWysmKUscfa9v8LIPEEU0JjE2o= github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -25,6 +29,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a h1:R5xgk8m+CF7lVE0EGr+tLkT1eM3Zfd39BJfnANQqpKA= github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420 h1:/dxO9rhmlhKP5pyI7omDH3QQzC0AppWxHT1w5TBsdTU= +github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= diff --git a/identity/azure_default_identity_provider.go b/identity/azure_default_identity_provider.go index 713ce96..26e6a57 100644 --- a/identity/azure_default_identity_provider.go +++ b/identity/azure_default_identity_provider.go @@ -58,7 +58,7 @@ func NewDefaultAzureIdentityProvider(opts DefaultAzureIdentityProviderOptions) ( // RequestToken requests a token from the Azure Default Identity provider. // It returns the token, the expiration time, and an error if any. -func (a *DefaultAzureIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { +func (a *DefaultAzureIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { credFactory := a.credFactory if credFactory == nil { credFactory = &defaultCredFactory{} @@ -68,7 +68,7 @@ func (a *DefaultAzureIdentityProvider) RequestToken() (shared.IdentityProviderRe return nil, fmt.Errorf("failed to create default azure credential: %w", err) } - token, err := cred.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: a.scopes}) + token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: a.scopes}) if err != nil { return nil, fmt.Errorf("failed to get token: %w", err) } diff --git a/identity/azure_default_identity_provider_test.go b/identity/azure_default_identity_provider_test.go index 305e6b4..101c760 100644 --- a/identity/azure_default_identity_provider_test.go +++ b/identity/azure_default_identity_provider_test.go @@ -1,6 +1,7 @@ package identity import ( + "context" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -38,7 +39,7 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) { // Request a token from the provider in incorrect environment // should fail. - token, err := provider.RequestToken() + token, err := provider.RequestToken(context.Background()) assert.Nil(t, token, "token should be nil") assert.Error(t, err, "failed to request token") @@ -51,7 +52,7 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) { mCredFactory := &mockCredFactory{} mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) provider.credFactory = mCredFactory - token, err = provider.RequestToken() + token, err = provider.RequestToken(context.Background()) assert.NotNil(t, token, "token should not be nil") assert.NoError(t, err, "failed to request token") assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token") @@ -70,7 +71,7 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { t.Run("RequestToken with custom scopes", func(t *testing.T) { // Request a token from the provider - token, err := provider.RequestToken() + token, err := provider.RequestToken(context.Background()) assert.Nil(t, token, "token should be nil") assert.Error(t, err, "failed to request token") @@ -83,7 +84,7 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { mCredFactory := &mockCredFactory{} mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) provider.credFactory = mCredFactory - token, err = provider.RequestToken() + token, err = provider.RequestToken(context.Background()) assert.NotNil(t, token, "token should not be nil") assert.NoError(t, err, "failed to request token") assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token") @@ -94,7 +95,7 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { mCredFactory := &mockCredFactory{} mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError) provider.credFactory = mCredFactory - token, err := provider.RequestToken() + token, err := provider.RequestToken(context.Background()) assert.Nil(t, token, "token should be nil") assert.Error(t, err, "failed to request token") }) diff --git a/identity/confidential_identity_provider.go b/identity/confidential_identity_provider.go index 97876fd..87d9a7d 100644 --- a/identity/confidential_identity_provider.go +++ b/identity/confidential_identity_provider.go @@ -155,12 +155,12 @@ func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) ( // RequestToken requests a token from the identity provider. // It returns the identity provider response, including the auth result. -func (c *ConfidentialIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { +func (c *ConfidentialIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { if c.client == nil { return nil, fmt.Errorf("client is not initialized") } - result, err := c.client.AcquireTokenByCredential(context.TODO(), c.scopes) + result, err := c.client.AcquireTokenByCredential(ctx, c.scopes) if err != nil { return nil, fmt.Errorf("failed to acquire token: %w", err) } diff --git a/identity/confidential_identity_provider_test.go b/identity/confidential_identity_provider_test.go index df57d17..7d063e0 100644 --- a/identity/confidential_identity_provider_test.go +++ b/identity/confidential_identity_provider_test.go @@ -1,6 +1,7 @@ package identity import ( + "context" "crypto/x509" "fmt" "testing" @@ -260,7 +261,7 @@ func TestConfidentialIdentityProvider_RequestToken(t *testing.T) { Return(confidential.AuthResult{ ExpiresOn: expiresOn, }, nil) - token, err := provider.RequestToken() + token, err := provider.RequestToken(context.Background()) if err != nil { t.Errorf("RequestToken() error = %v", err) return @@ -294,14 +295,14 @@ func TestConfidentialIdentityProvider_RequestToken(t *testing.T) { provider.client = mClient mClient.On("AcquireTokenByCredential", mock.Anything, mock.Anything). Return(confidential.AuthResult{}, fmt.Errorf("error acquiring token")) - token, err := provider.RequestToken() + token, err := provider.RequestToken(context.Background()) assert.ErrorContains(t, err, "failed to acquire token:") assert.Empty(t, token, "RequestToken() token should be empty") }) t.Run("without initialization", func(t *testing.T) { t.Parallel() provider := &ConfidentialIdentityProvider{} - token, err := provider.RequestToken() + token, err := provider.RequestToken(context.Background()) assert.ErrorContains(t, err, "client is not initialized") assert.Empty(t, token, "RequestToken() token should be empty") }) diff --git a/identity/managed_identity_provider.go b/identity/managed_identity_provider.go index 6a2153d..cb5a4af 100644 --- a/identity/managed_identity_provider.go +++ b/identity/managed_identity_provider.go @@ -101,7 +101,7 @@ func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedId // RequestToken requests a token from the managed identity provider. // It returns IdentityProviderResponse, which contains the Acc and the expiration time. -func (m *ManagedIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { +func (m *ManagedIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { if m.client == nil { return nil, errors.New("managed identity client is not initialized") } @@ -115,7 +115,7 @@ func (m *ManagedIdentityProvider) RequestToken() (shared.IdentityProviderRespons } // acquire token using the managed identity client // the resource is the URL of the resource that the identity has access to - authResult, err := m.client.AcquireToken(context.TODO(), resource) + authResult, err := m.client.AcquireToken(ctx, resource) if err != nil { return nil, fmt.Errorf("couldn't acquire token: %w", err) } diff --git a/identity/managed_identity_provider_test.go b/identity/managed_identity_provider_test.go index dc90c39..80dd661 100644 --- a/identity/managed_identity_provider_test.go +++ b/identity/managed_identity_provider_test.go @@ -137,7 +137,7 @@ func TestRequestToken(t *testing.T) { } } - response, err := tt.provider.RequestToken() + response, err := tt.provider.RequestToken(context.Background()) if tt.expectedError != "" { assert.Error(t, err) @@ -207,7 +207,7 @@ func TestRequestToken_ErrorCases(t *testing.T) { mockClient := tt.provider.client.(*MockManagedIdentityClient) tt.setupMock(mockClient) - response, err := tt.provider.RequestToken() + response, err := tt.provider.RequestToken(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), tt.expectedError) diff --git a/manager/errors.go b/manager/errors.go index 40b707b..dac0e9a 100644 --- a/manager/errors.go +++ b/manager/errors.go @@ -2,8 +2,8 @@ package manager import "fmt" -// ErrTokenManagerAlreadyClosed is returned when the token manager is already closed. -var ErrTokenManagerAlreadyClosed = fmt.Errorf("token manager already closed") +// ErrTokenManagerAlreadyStopped is returned when the token manager is already stopped. +var ErrTokenManagerAlreadyStopped = fmt.Errorf("token manager already stopped") // ErrTokenManagerAlreadyStarted is returned when the token manager is already started. var ErrTokenManagerAlreadyStarted = fmt.Errorf("token manager already started") diff --git a/manager/manager_test.go b/manager/manager_test.go index 6d6bd32..55c85d6 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -1,6 +1,7 @@ package manager import ( + "context" "net" "os" "time" @@ -81,8 +82,8 @@ type mockIdentityProvider struct { mock.Mock } -func (m *mockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { - args := m.Called() +func (m *mockIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } @@ -138,11 +139,11 @@ type mockTokenListener struct { Id int32 } -func (m *mockTokenListener) OnTokenNext(token *token.Token) { +func (m *mockTokenListener) OnNext(token *token.Token) { _ = m.Called(token) } -func (m *mockTokenListener) OnTokenError(err error) { +func (m *mockTokenListener) OnError(err error) { _ = m.Called(err) } diff --git a/manager/token_manager.go b/manager/token_manager.go index 0f232fe..6c5fe07 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -1,6 +1,7 @@ package manager import ( + "context" "fmt" "sync" "time" @@ -20,13 +21,13 @@ type TokenManagerOptions struct { // // default: 0.7 ExpirationRefreshRatio float64 - // LowerRefreshBoundMs is the lower bound for the refresh time in milliseconds. - // Represents the minimum time in milliseconds before token expiration to trigger a refresh. + // LowerRefreshBound is the lower bound for the refresh time + // Represents the minimum time before token expiration to trigger a refresh. // This value sets a fixed lower bound for when a token refresh should occur, regardless // of the token's total lifetime. // - // default: 0 ms (no lower bound, refresh based on ExpirationRefreshRatio) - LowerRefreshBoundMs int64 + // default: 0 (no lower bound, refresh based on ExpirationRefreshRatio) + LowerRefreshBound time.Duration // IdentityProviderResponseParser is an optional object that implements the IdentityProviderResponseParser interface. // It is used to parse the response from the identity provider and extract the token. @@ -41,6 +42,9 @@ type TokenManagerOptions struct { // // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. RetryOptions RetryOptions + + // RequestTimeout is the timeout for the request to the identity provider. + RequestTimeout time.Duration } // RetryOptions is a struct that contains the options for retrying the token request. @@ -76,22 +80,22 @@ type TokenManager interface { // It takes a boolean value forceRefresh as an argument. GetToken(forceRefresh bool) (*token.Token, error) // Start starts the token manager and returns a channel that will receive updates. - Start(listener TokenListener) (CloseFunc, error) - // Close closes the token manager and releases any resources. - Close() error + Start(listener TokenListener) (StopFunc, error) + // Stop stops the token manager and releases any resources. + Stop() error } -// CloseFunc is a function that closes the token manager. -type CloseFunc func() error +// StopFunc is a function that stops the token manager. +type StopFunc func() error // TokenListener is an interface that contains the methods for receiving updates from the token manager. // The token manager will call the listener's OnTokenNext method with the updated token. // If an error occurs, the token manager will call the listener's OnTokenError method with the error. type TokenListener interface { - // OnTokenNext is called when the token is updated. - OnTokenNext(t *token.Token) - // OnTokenError is called when an error occurs. - OnTokenError(err error) + // OnNext is called when the token is updated. + OnNext(t *token.Token) + // OnError is called when an error occurs. + OnError(err error) } // entraidIdentityProviderResponseParser is the default implementation of the IdentityProviderResponseParser interface. @@ -111,15 +115,18 @@ func NewTokenManager(idp shared.IdentityProvider, options TokenManagerOptions) ( return nil, fmt.Errorf("identity provider is required") } + ctx, ctxCancel := context.WithCancel(context.Background()) return &entraidTokenManager{ idp: idp, token: nil, closedChan: nil, + ctx: ctx, + ctxCancel: ctxCancel, expirationRefreshRatio: options.ExpirationRefreshRatio, - lowerRefreshBoundMs: options.LowerRefreshBoundMs, - lowerBoundDuration: time.Duration(options.LowerRefreshBoundMs) * time.Millisecond, + lowerBoundDuration: options.LowerRefreshBound, identityProviderResponseParser: options.IdentityProviderResponseParser, retryOptions: options.RetryOptions, + requestTimeout: options.RequestTimeout, }, nil } @@ -146,8 +153,8 @@ type entraidTokenManager struct { // listener is the single listener for the token manager. // It is used to receive updates from the token manager. - // The token manager will call the listener's OnTokenNext method with the updated token. - // If an error occurs, the token manager will call the listener's OnTokenError method with the error. + // The token manager will call the listener's OnNext method with the updated token. + // If an error occurs, the token manager will call the listener's OnError method with the error. // if listener is set, Start will fail listener TokenListener @@ -161,23 +168,27 @@ type entraidTokenManager struct { // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) expirationRefreshRatio float64 - // lowerRefreshBoundMs is the lower bound for the refresh time in milliseconds. - // Represents the minimum time in milliseconds before token expiration to trigger a refresh, in milliseconds. - // This value sets a fixed lower bound for when a token refresh should occur, regardless - // of the token's total lifetime. - lowerRefreshBoundMs int64 - // lowerBoundDuration is the lower bound for the refresh time in time.Duration. lowerBoundDuration time.Duration // closedChan is a channel that is closedChan when the token manager is closedChan. // It is used to signal the token manager to stop requesting tokens. closedChan chan struct{} + + // context is the context used to request the token from the identity provider. + ctx context.Context + + // ctxCancel is the cancel function for the context. + ctxCancel context.CancelFunc + + // requestTimeout is the timeout for the request to the identity provider. + requestTimeout time.Duration } func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { e.tokenRWLock.RLock() // check if the token is nil and if it is not expired + if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { t := e.token e.tokenRWLock.RUnlock() @@ -185,6 +196,12 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) } e.tokenRWLock.RUnlock() + // start the context early, + // since at heavy concurrent load + // locks may take some time to acquire + ctx, ctxCancel := context.WithTimeout(e.ctx, e.requestTimeout) + defer ctxCancel() + // Upgrade to write lock for token update e.tokenRWLock.Lock() defer e.tokenRWLock.Unlock() @@ -194,7 +211,8 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) return e.token, nil } - idpResult, err := e.idp.RequestToken() + // Request a new token from the identity provider + idpResult, err := e.idp.RequestToken(ctx) if err != nil { return nil, fmt.Errorf("failed to request token from idp: %w", err) } @@ -216,12 +234,12 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) // Start starts the token manager and returns cancelFunc to stop the token manager. // It takes a TokenListener as an argument, which is used to receive updates. -// The token manager will call the listener's OnTokenNext method with the updated token. +// The token manager will call the listener's OnNext method with the updated token. // If an error occurs, the token manager will call the listener's OnError method with the error. // // Note: The initial token is delivered synchronously. // The TokenListener will receive the token immediately, before the token manager goroutine starts. -func (e *entraidTokenManager) Start(listener TokenListener) (CloseFunc, error) { +func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { e.lock.Lock() defer e.lock.Unlock() if e.listener != nil { @@ -236,12 +254,12 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CloseFunc, error) { t, err := e.GetToken(true) if err != nil { - go listener.OnTokenError(err) + go listener.OnError(err) return nil, fmt.Errorf("failed to start token manager: %w", err) } // Deliver initial token synchronously - listener.OnTokenNext(t) + listener.OnNext(t) e.closedChan = make(chan struct{}) e.listener = listener @@ -271,15 +289,15 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CloseFunc, error) { for i := 0; i < e.retryOptions.MaxAttempts; i++ { t, err := e.GetToken(true) if err == nil { - listener.OnTokenNext(t) + listener.OnNext(t) break } // check if err is retriable if e.retryOptions.IsRetryable(err) { if i == e.retryOptions.MaxAttempts-1 { - // last attempt, call OnTokenError - listener.OnTokenError(fmt.Errorf("max attempts reached: %w", err)) + // last attempt, call OnError + listener.OnError(fmt.Errorf("max attempts reached: %w", err)) return } @@ -299,7 +317,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CloseFunc, error) { } } else { // not retriable - listener.OnTokenError(err) + listener.OnError(err) return } } @@ -307,17 +325,18 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CloseFunc, error) { } }(listener, e.closedChan) - return e.Close, nil + return e.Stop, nil } -// Close closes the token manager and releases any resources. -func (e *entraidTokenManager) Close() error { +// Stop closes the token manager and releases any resources. +func (e *entraidTokenManager) Stop() error { e.lock.Lock() defer e.lock.Unlock() if e.closedChan == nil || e.listener == nil { - return ErrTokenManagerAlreadyClosed + return ErrTokenManagerAlreadyStopped } + e.listener = nil close(e.closedChan) diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 08c57bf..84a2c7c 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -1,6 +1,7 @@ package manager import ( + "context" "fmt" "log" "math/rand" @@ -168,15 +169,15 @@ func TestTokenManager_Close(t *testing.T) { assert.True(t, ok) assert.Nil(t, tm.listener) assert.NotPanics(t, func() { - err = tokenManager.Close() + err = tokenManager.Stop() assert.Error(t, err) }) rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") assert.NoError(t, err) - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { cancel, err := tokenManager.Start(listener) @@ -185,12 +186,12 @@ func TestTokenManager_Close(t *testing.T) { }) assert.NotNil(t, tm.listener) - err = tokenManager.Close() + err = tokenManager.Stop() assert.Nil(t, tm.listener) assert.NoError(t, err) assert.NotPanics(t, func() { - err = tokenManager.Close() + err = tokenManager.Stop() assert.Error(t, err) }) }) @@ -214,9 +215,9 @@ func TestTokenManager_Close(t *testing.T) { rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") assert.NoError(t, err) - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { cancel, err := tokenManager.Start(listener) @@ -250,9 +251,9 @@ func TestTokenManager_Close(t *testing.T) { rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") assert.NoError(t, err) - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { cancel, err := tokenManager.Start(listener) @@ -271,7 +272,7 @@ func TestTokenManager_Close(t *testing.T) { go func() { defer wg.Done() time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) - err := tokenManager.Close() + err := tokenManager.Stop() if err == nil { hasStopped += 1 return @@ -311,9 +312,9 @@ func TestTokenManager_Start(t *testing.T) { rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") assert.NoError(t, err) - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { var hasStarted int @@ -370,9 +371,9 @@ func TestTokenManager_Start(t *testing.T) { last := &atomic.Int32{} wg := &sync.WaitGroup{} - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() numExecutions := int32(50000) for i := int32(0); i < numExecutions; i++ { wg.Add(1) @@ -382,14 +383,14 @@ func TestTokenManager_Start(t *testing.T) { time.Sleep(time.Duration(int64(rand.Intn(1000)+(300-int(num)/2)) * int64(time.Millisecond))) last.Store(num) if num%2 == 0 { - err = tokenManager.Close() + err = tokenManager.Stop() } else { l := &mockTokenListener{Id: num} - l.On("OnTokenNext", testTokenValid).Return() + l.On("OnNext", testTokenValid).Return() _, err = tokenManager.Start(l) } if err != nil { - if err != ErrTokenManagerAlreadyClosed && err != ErrTokenManagerAlreadyStarted { + if err != ErrTokenManagerAlreadyStopped && err != ErrTokenManagerAlreadyStarted { // this is un unexpected error, fail the test assert.Error(t, err) } @@ -413,7 +414,7 @@ func TestTokenManager_Start(t *testing.T) { assert.Nil(t, cancel) assert.Error(t, err) // Close the token manager - err = tokenManager.Close() + err = tokenManager.Stop() assert.Nil(t, err) } assert.Nil(t, tm.listener) @@ -553,9 +554,9 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { RawTokenVal: "test", } - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) @@ -588,9 +589,9 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { RawTokenVal: "test", } - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(nil, fmt.Errorf("parse error")) - listener.On("OnTokenError", mock.Anything).Return() + listener.On("OnError", mock.Anything).Return() cancel, err := tokenManager.Start(listener) assert.Error(t, err) @@ -615,7 +616,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { assert.True(t, ok) assert.Nil(t, tm.listener) - idp.On("RequestToken").Return(idpResponse, nil) + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil) token1, err := tokenManager.GetToken(false) assert.Error(t, err) @@ -636,7 +637,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") assert.NoError(t, err) - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) token1, err := tokenManager.GetToken(false) assert.Error(t, err) @@ -659,7 +660,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") assert.NoError(t, err) - idp.On("RequestToken").Return(idpResponse, nil) + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil) mParser.On("ParseResponse", idpResponse).Return(nil, nil) token1, err := tokenManager.GetToken(false) @@ -681,7 +682,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { _, ok := tokenManager.(*entraidTokenManager) assert.True(t, ok) - idp.On("RequestToken").Return(nil, fmt.Errorf("idp error")) + idp.On("RequestToken", mock.Anything).Return(nil, fmt.Errorf("idp error")) token1, err := tokenManager.GetToken(false) assert.Error(t, err) @@ -695,7 +696,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) { t.Parallel() idp := &mockIdentityProvider{} tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ - LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + LowerRefreshBound: time.Hour, }) assert.NoError(t, err) assert.NotNil(t, tokenManager) @@ -712,7 +713,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) { idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, expiresSoon) assert.NoError(t, err) - idp.On("RequestToken").Return(idpResponse, nil).Once() + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil).Once() tm.token = nil _, err = tm.GetToken(false) assert.NoError(t, err) @@ -730,7 +731,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) { idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult, expiresAfterlb) assert.NoError(t, err) - idp.On("RequestToken").Return(idpResponse, nil).Once() + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil).Once() tm.token = nil _, err = tm.GetToken(false) assert.NoError(t, err) @@ -770,7 +771,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { AuthResultVal: authResultVal, } - idp.On("RequestToken").Return(idpResponse, nil).Once() + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil).Once() token1 := token.New( "test", "test", @@ -781,7 +782,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { ) mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() - listener.On("OnTokenNext", token1).Return().Once() + listener.On("OnNext", token1).Return().Once() cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) @@ -794,14 +795,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.True(t, expiresIn > toRenewal) <-time.After(toRenewal / 10) assert.NotNil(t, tm.listener) - assert.NoError(t, tokenManager.Close()) + assert.NoError(t, tokenManager.Stop()) assert.Nil(t, tm.listener) assert.Panics(t, func() { close(tm.closedChan) }) <-time.After(toRenewal) - assert.Error(t, tokenManager.Close()) + assert.Error(t, tokenManager.Stop()) mock.AssertExpectationsForObjects(t, idp, mParser, listener) }) @@ -811,7 +812,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener := &mockTokenListener{} tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ - LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + LowerRefreshBound: time.Hour, }, ) assert.NoError(t, err) @@ -833,7 +834,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { done := make(chan struct{}) var twice int32 var start, stop time.Time - idp.On("RequestToken").Run(func(args mock.Arguments) { + idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) idpResponse.AuthResultVal = res @@ -847,7 +848,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { } }).Return(idpResponse, nil) - listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) @@ -867,7 +868,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.InDelta(t, stop.Sub(start), time.Duration(tm.retryOptions.InitialDelayMs)*time.Millisecond, float64(200*time.Millisecond)) idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnTokenNext", 2) + listener.AssertNumberOfCalls(t, "OnNext", 2) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -877,7 +878,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener := &mockTokenListener{} tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ - LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + LowerRefreshBound: time.Hour, RetryOptions: RetryOptions{ InitialDelayMs: 5000, // 5 seconds }, @@ -898,13 +899,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { ResultType: shared.ResponseTypeAuthResult, AuthResultVal: res, } - idp.On("RequestToken").Run(func(args mock.Arguments) { + idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) @@ -924,7 +925,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { // called only once since the token manager was closed prior to initial delay passing idp.AssertNumberOfCalls(t, "RequestToken", 1) - listener.AssertNumberOfCalls(t, "OnTokenNext", 1) + listener.AssertNumberOfCalls(t, "OnNext", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -951,13 +952,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { ResultType: shared.ResponseTypeAuthResult, AuthResultVal: res, } - idp.On("RequestToken").Run(func(args mock.Arguments) { + idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) @@ -997,14 +998,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { AuthResultVal: res, } - noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() - listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0) assert.NotNil(t, err) }).Return().Maybe() @@ -1016,7 +1017,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { noErrCall.Unset() returnErr := newMockError(true) - idp.On("RequestToken").Return(nil, returnErr) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) toRenewal := tm.durationToRenewal() assert.NotEqual(t, time.Duration(0), toRenewal) @@ -1024,7 +1025,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.True(t, expiresIn > toRenewal) <-time.After(toRenewal + 100*time.Millisecond) idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnTokenNext", 1) + listener.AssertNumberOfCalls(t, "OnNext", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1051,14 +1052,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { AuthResultVal: res, } - noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() - listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0).(error) assert.NotNil(t, err) }).Return() @@ -1070,7 +1071,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { noErrCall.Unset() returnErr := newMockError(false) - idp.On("RequestToken").Return(nil, returnErr) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) toRenewal := tm.durationToRenewal() assert.NotEqual(t, time.Duration(0), toRenewal) @@ -1079,8 +1080,8 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { <-time.After(toRenewal + 100*time.Millisecond) idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnTokenNext", 1) - listener.AssertNumberOfCalls(t, "OnTokenError", 1) + listener.AssertNumberOfCalls(t, "OnNext", 1) + listener.AssertNumberOfCalls(t, "OnError", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1118,7 +1119,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { AuthResultVal: res, } - noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) res.IDToken.Oid = "test" @@ -1128,12 +1129,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { var elapsed time.Duration _ = listener. - On("OnTokenNext", mock.AnythingOfType("*token.Token")). + On("OnNext", mock.AnythingOfType("*token.Token")). Run(func(_ mock.Arguments) { start = time.Now() }).Return() maxAttemptsReached := make(chan struct{}) - listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0).(error) end = time.Now() elapsed = end.Sub(start) @@ -1154,7 +1155,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { noErrCall.Unset() returnErr := newMockError(true) - idp.On("RequestToken").Return(nil, returnErr) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) select { case <-time.After(toRenewal + time.Duration(maxAttempts*maxDelayMs)*time.Millisecond): @@ -1170,8 +1171,8 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.InEpsilon(t, elapsed, allDelaysShouldBe, float64(10*time.Millisecond)) idp.AssertNumberOfCalls(t, "RequestToken", tm.retryOptions.MaxAttempts+1) - listener.AssertNumberOfCalls(t, "OnTokenNext", 1) - listener.AssertNumberOfCalls(t, "OnTokenError", 1) + listener.AssertNumberOfCalls(t, "OnNext", 1) + listener.AssertNumberOfCalls(t, "OnError", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) t.Run("Start and Listen and close during retries", func(t *testing.T) { @@ -1201,15 +1202,15 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { AuthResultVal: res, } - noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) { + noErrCall := idp.On("RequestToken", mock.Anything).Run(func(args mock.Arguments) { expiresOn := time.Now().Add(expiresIn).UTC() res := testAuthResult(expiresOn) idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnTokenNext", mock.AnythingOfType("*token.Token")).Return() + listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() maxAttemptsReached := make(chan struct{}) - listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) { + listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0).(error) assert.NotNil(t, err) assert.ErrorContains(t, err, "max attempts reached") @@ -1223,7 +1224,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { noErrCall.Unset() returnErr := newMockError(true) - idp.On("RequestToken").Return(nil, returnErr) + idp.On("RequestToken", mock.Anything).Return(nil, returnErr) toRenewal := tm.durationToRenewal() assert.NotEqual(t, time.Duration(0), toRenewal) @@ -1243,7 +1244,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { // maxAttempts + the initial one idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnTokenError", 0) + listener.AssertNumberOfCalls(t, "OnError", 0) mock.AssertExpectationsForObjects(t, idp, listener) }) } @@ -1273,7 +1274,7 @@ func BenchmarkTokenManager_GetToken(b *testing.B) { b.Fatal(err) } - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) b.ResetTimer() @@ -1300,9 +1301,9 @@ func BenchmarkTokenManager_Start(b *testing.B) { b.Fatal(err) } - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -1328,9 +1329,9 @@ func BenchmarkTokenManager_Close(b *testing.B) { b.Fatal(err) } - idp.On("RequestToken").Return(rawResponse, nil) + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnTokenNext", testTokenValid).Return() + listener.On("OnNext", testTokenValid).Return() _, err = tokenManager.Start(listener) if err != nil { @@ -1339,14 +1340,14 @@ func BenchmarkTokenManager_Close(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _ = tokenManager.Close() + _ = tokenManager.Stop() } } func BenchmarkTokenManager_durationToRenewal(b *testing.B) { idp := &mockIdentityProvider{} tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ - LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour + LowerRefreshBound: time.Hour, }) if err != nil { b.Fatal(err) @@ -1363,7 +1364,7 @@ func BenchmarkTokenManager_durationToRenewal(b *testing.B) { b.Fatal(err) } - idp.On("RequestToken").Return(idpResponse, nil) + idp.On("RequestToken", mock.Anything).Return(idpResponse, nil) _, err = tm.GetToken(false) if err != nil { b.Fatal(err) @@ -1388,7 +1389,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { // Create token manager with the mock provider options := TokenManagerOptions{ ExpirationRefreshRatio: 0.7, - LowerRefreshBoundMs: 100, + LowerRefreshBound: 100 * time.Millisecond, } tm, err := NewTokenManager(mockIdp, options) assert.NoError(t, err) @@ -1498,9 +1499,9 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { closedAny = true //t.Logf("Goroutine %d, Operation %d: Closing token manager with key %v", routineID, j, key) - closeFunc := value.(CloseFunc) + closeFunc := value.(StopFunc) if err := closeFunc(); err != nil { - if err != ErrTokenManagerAlreadyClosed { + if err != ErrTokenManagerAlreadyStopped { // t.Logf("Goroutine %d, Operation %d: Close failed with error: %v", routineID, j, err) select { case errorCh <- fmt.Errorf("failed to close token manager: %w", err): @@ -1508,7 +1509,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { t.Fatalf("Goroutine %d, Operation %d: Failed to close token manager: %v", routineID, j, err) } } else { - //t.Logf("Goroutine %d, Operation %d: TokenManager was already closed", routineID, j) + //t.Logf("Goroutine %d, Operation %d: TokenManager was already stopped", routineID, j) } } else { // t.Logf("Goroutine %d, Operation %d: Successfully closed token manager", routineID, j) @@ -1563,7 +1564,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { // Clean up any remaining closers closers.Range(func(key, value interface{}) bool { - closeFunc := value.(CloseFunc) + closeFunc := value.(StopFunc) _ = closeFunc() // Ignore errors during cleanup return true }) @@ -1643,13 +1644,13 @@ type concurrentTestTokenListener struct { onErrorFunc func(error) } -func (l *concurrentTestTokenListener) OnTokenNext(t *token.Token) { +func (l *concurrentTestTokenListener) OnNext(t *token.Token) { if l.onNextFunc != nil { l.onNextFunc(t) } } -func (l *concurrentTestTokenListener) OnTokenError(err error) { +func (l *concurrentTestTokenListener) OnError(err error) { if l.onErrorFunc != nil { l.onErrorFunc(err) } @@ -1661,7 +1662,7 @@ type concurrentMockIdentityProvider struct { mutex sync.Mutex } -func (m *concurrentMockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { +func (m *concurrentMockIdentityProvider) RequestToken(_ context.Context) (shared.IdentityProviderResponse, error) { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go index da88c8a..73f82b2 100644 --- a/shared/identity_provider_response.go +++ b/shared/identity_provider_response.go @@ -1,6 +1,8 @@ package shared import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/redis-developer/go-redis-entraid/internal" @@ -38,8 +40,9 @@ type IdentityProviderResponse interface { // The identity provider is responsible for providing the raw authentication token. type IdentityProvider interface { // RequestToken requests a token from the identity provider. + // The context is passed to the request to allow for cancellation and timeouts. // It returns the token, the expiration time, and an error if any. - RequestToken() (IdentityProviderResponse, error) + RequestToken(ctx context.Context) (IdentityProviderResponse, error) } // NewIDPResponse creates a new auth result based on the type provided. diff --git a/token_listener.go b/token_listener.go index 19a0e31..8515fda 100644 --- a/token_listener.go +++ b/token_listener.go @@ -24,13 +24,13 @@ func tokenListenerFromCP(cp *entraidCredentialsProvider) manager.TokenListener { // OnTokenNext is called when the token manager receives a new token. // It notifies the credentials provider with the new token. // This function is typically called when the token manager successfully retrieves a token. -func (l *entraidTokenListener) OnTokenNext(t *token.Token) { +func (l *entraidTokenListener) OnNext(t *token.Token) { l.cp.onTokenNext(t) } // OnTokenError is called when the token manager encounters an error. // It notifies the credentials provider with the error. // This function is typically called when the token manager fails to retrieve a token. -func (l *entraidTokenListener) OnTokenError(err error) { +func (l *entraidTokenListener) OnError(err error) { l.cp.onTokenError(err) } diff --git a/token_listener_test.go b/token_listener_test.go index 2185061..d2d996c 100644 --- a/token_listener_test.go +++ b/token_listener_test.go @@ -26,7 +26,7 @@ func TestOnTokenNext(t *testing.T) { now := time.Now() testToken := token.New("test-user", "test-pass", "test-token", now.Add(time.Hour), now, 3600) - listener.OnTokenNext(testToken) + listener.OnNext(testToken) // Since we can't directly access the internal state of entraidCredentialsProvider, // we'll verify that the listener was created and the call didn't panic @@ -38,7 +38,7 @@ func TestOnTokenError(t *testing.T) { listener := tokenListenerFromCP(cp) testError := errors.New("test error") - listener.OnTokenError(testError) + listener.OnError(testError) // Since we can't directly access the internal state of entraidCredentialsProvider, // we'll verify that the listener was created and the call didn't panic From 07ab7893aae1c0b8bda9bab5bc792ce78988d081 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 16 Apr 2025 14:21:42 +0300 Subject: [PATCH 29/44] Update documentation and RetryOptions A more idiomatic approach for Go would be to use time.Duration instead of int representation of Milliseconds. --- README.md | 154 ++++++++++++++++++---------------- manager/defaults.go | 12 +-- manager/token_manager.go | 16 ++-- manager/token_manager_test.go | 24 +++--- 4 files changed, 108 insertions(+), 98 deletions(-) diff --git a/README.md b/README.md index 9c9ceb5..a105600 100644 --- a/README.md +++ b/README.md @@ -229,15 +229,18 @@ type TokenManagerOptions struct { // Default: 0.7 (refresh at 70% of token lifetime) ExpirationRefreshRatio float64 - // Optional: Minimum time before expiration to refresh (ms) - // Default: 10000 (10 seconds) - LowerRefreshBounds int64 + // Optional: Minimum time before expiration to trigger refresh + // Default: 0 (no lower bound, refresh based on ExpirationRefreshRatio) + LowerRefreshBound time.Duration + + // Optional: Custom response parser + IdentityProviderResponseParser shared.IdentityProviderResponseParser // Optional: Configuration for retry behavior RetryOptions RetryOptions - // Optional: Custom response parser - IdentityProviderResponseParser IdentityProviderResponseParser + // Optional: Timeout for token requests + RequestTimeout time.Duration } ``` @@ -245,24 +248,25 @@ type TokenManagerOptions struct { Options for retry behavior: ```go type RetryOptions struct { + // Optional: Function to determine if an error is retryable + // Default: Checks for network errors and timeouts + IsRetryable func(err error) bool + // Optional: Maximum number of retry attempts // Default: 3 MaxAttempts int - // Optional: Initial delay between retries (ms) - // Default: 1000 (1 second) - InitialDelayMs int64 + // Optional: Initial delay between retries + // Default: 1 second + InitialDelay time.Duration - // Optional: Maximum delay between retries (ms) - // Default: 30000 (30 seconds) - MaxDelayMs int64 + // Optional: Maximum delay between retries + // Default: 10 seconds + MaxDelay time.Duration // Optional: Multiplier for exponential backoff // Default: 2.0 BackoffMultiplier float64 - - // Optional: Custom retry predicate - IsRetryable func(error) bool } ``` @@ -364,8 +368,8 @@ options := entraid.CredentialsProviderOptions{ LowerRefreshBounds: 10000, RetryOptions: manager.RetryOptions{ MaxAttempts: 3, - InitialDelayMs: 1000, - MaxDelayMs: 30000, + InitialDelay: 1000 * time.Millisecond, + MaxDelay: 30000 * time.Millisecond, BackoffMultiplier: 2.0, IsRetryable: func(err error) bool { return strings.Contains(err.Error(), "network error") || @@ -475,6 +479,7 @@ import ( "fmt" "log" "os" + "strings" "time" "github.com/redis-developer/go-redis-entraid/entraid" @@ -516,7 +521,18 @@ func main() { tokenManager, err := manager.NewTokenManager(customProvider, manager.TokenManagerOptions{ // Configure token refresh behavior ExpirationRefreshRatio: 0.7, - LowerRefreshBounds: 10000, + LowerRefreshBound: time.Second * 10, + RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelay: time.Second, + MaxDelay: time.Second * 10, + BackoffMultiplier: 2.0, + IsRetryable: func(err error) bool { + return strings.Contains(err.Error(), "network error") || + strings.Contains(err.Error(), "timeout") + }, + }, + RequestTimeout: time.Second * 30, }) if err != nil { log.Fatalf("Failed to create token manager: %v", err) @@ -561,6 +577,7 @@ Key points about this implementation: - Uses our `TokenManager` for automatic token refresh - Benefits from our retry mechanisms - Handles token caching and lifecycle + - Configurable refresh timing and retry behavior 3. **Streaming Credentials**: - Uses our `StreamingCredentialsProvider` for Redis integration @@ -631,7 +648,53 @@ func TestRedisConnection(t *testing.T) { ## FAQ ### Q: How do I handle token expiration? -A: The library handles token expiration automatically. Tokens are refreshed when they reach 70% of their lifetime (configurable via `ExpirationRefreshRatio`). You can customize this behavior using `TokenManagerOptions`. +A: The library handles token expiration automatically. Tokens are refreshed when they reach 70% of their lifetime (configurable via `ExpirationRefreshRatio`). You can also set a minimum time before expiration to trigger refresh using `LowerRefreshBound`. The token manager will automatically handle token refresh and caching. + +### Q: How do I handle connection failures? +A: The library includes built-in retry mechanisms in the TokenManager. You can configure retry behavior using `RetryOptions`: +```go +RetryOptions: manager.RetryOptions{ + MaxAttempts: 3, + InitialDelay: time.Second, + MaxDelay: time.Second * 10, + BackoffMultiplier: 2.0, + IsRetryable: func(err error) bool { + return strings.Contains(err.Error(), "network error") || + strings.Contains(err.Error(), "timeout") + }, +} +``` + +### Q: What happens if token refresh fails? +A: The library will retry according to the configured `RetryOptions`. If all retries fail, the error will be propagated to the client. You can customize the retry behavior by: +1. Setting the maximum number of attempts +2. Configuring the initial and maximum delay between retries using `time.Duration` values +3. Setting the backoff multiplier for exponential backoff +4. Providing a custom function to determine which errors are retryable + +### Q: How do I implement custom authentication? +A: You can create a custom identity provider by implementing the `IdentityProvider` interface: +```go +type IdentityProvider interface { + // RequestToken requests a token from the identity provider. + // It returns the token, the expiration time, and an error if any. + RequestToken() (IdentityProviderResponse, error) +} +``` + +The `IdentityProviderResponse` interface provides methods to access the authentication result: +```go +type IdentityProviderResponse interface { + // Type returns the type of the auth result + Type() string + AuthResult() public.AuthResult + AccessToken() azcore.AccessToken + RawToken() string +} +``` + +### Q: Can I customize how token responses are parsed? +A: Yes, you can provide a custom `IdentityProviderResponseParser` in the `TokenManagerOptions`. This allows you to handle custom token formats or implement special parsing logic. ### Q: What's the difference between managed identity types? A: There are three main types of managed identities in Azure: @@ -664,57 +727,4 @@ A: There are three main types of managed identities in Azure: The choice between these types depends on your specific use case: - Use System Assigned for single-resource applications - Use User Assigned for shared identity scenarios -- Use Default Azure Identity for development and testing - -### Q: How do I handle connection failures? -A: The library includes built-in retry mechanisms in the TokenManager. You can configure retry behavior using `RetryOptions`: -```go -RetryOptions: manager.RetryOptions{ - MaxAttempts: 3, - InitialDelayMs: 1000, - MaxDelayMs: 30000, - BackoffMultiplier: 2.0, -} -``` - -### Q: Does this work with Redis Cluster? -A: Yes, the library works with both standalone Redis and Redis Cluster. Use the appropriate Redis client constructor: -```go -// For standalone Redis -client := redis.NewClient(&redis.Options{ - Addr: "your-endpoint:6380", - StreamingCredentialsProvider: provider, -}) - -// For Redis Cluster -client := redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: []string{"your-endpoint:6380"}, - StreamingCredentialsProvider: provider, -}) -``` - -### Q: How do I implement custom authentication? -A: You can create a custom identity provider by implementing the `IdentityProvider` interface: -```go -// IdentityProviderResponse is an interface that defines the methods for an identity provider authentication result. -// It is used to get the type of the authentication result, the authentication result itself (can be AuthResult or AccessToken), -type IdentityProviderResponse interface { - // Type returns the type of the auth result - Type() string - AuthResult() public.AuthResult - AccessToken() azcore.AccessToken - RawToken() string -} - -// IdentityProvider is an interface that defines the methods for an identity provider. -// It is used to request a token for authentication. -// The identity provider is responsible for providing the raw authentication token. -type IdentityProvider interface { - // RequestToken requests a token from the identity provider. - // It returns the token, the expiration time, and an error if any. - RequestToken() (IdentityProviderResponse, error) -} -``` - -### Q: What happens if token refresh fails? -A: The library will retry according to the configured `RetryOptions`. If all retries fail, the error will be propagated to the client. \ No newline at end of file +- Use Default Azure Identity for development and testing \ No newline at end of file diff --git a/manager/defaults.go b/manager/defaults.go index e7e09f1..d1b5333 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -15,9 +15,9 @@ import ( const ( DefaultExpirationRefreshRatio = 0.7 DefaultRetryOptionsMaxAttempts = 3 - DefaultRetryOptionsInitialDelayMs = 1000 DefaultRetryOptionsBackoffMultiplier = 2.0 - DefaultRetryOptionsMaxDelayMs = 10000 + DefaultRetryOptionsInitialDelay = 1000 * time.Millisecond + DefaultRetryOptionsMaxDelay = 10000 * time.Millisecond ) // defaultIsRetryable is a function that checks if the error is retriable. @@ -57,14 +57,14 @@ func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions { if retryOptions.MaxAttempts <= 0 { retryOptions.MaxAttempts = DefaultRetryOptionsMaxAttempts } - if retryOptions.InitialDelayMs == 0 { - retryOptions.InitialDelayMs = DefaultRetryOptionsInitialDelayMs + if retryOptions.InitialDelay == 0 { + retryOptions.InitialDelay = DefaultRetryOptionsInitialDelay } if retryOptions.BackoffMultiplier == 0 { retryOptions.BackoffMultiplier = DefaultRetryOptionsBackoffMultiplier } - if retryOptions.MaxDelayMs == 0 { - retryOptions.MaxDelayMs = DefaultRetryOptionsMaxDelayMs + if retryOptions.MaxDelay == 0 { + retryOptions.MaxDelay = DefaultRetryOptionsMaxDelay } return retryOptions } diff --git a/manager/token_manager.go b/manager/token_manager.go index 6c5fe07..8423d2c 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -58,14 +58,14 @@ type RetryOptions struct { // // default: 3 MaxAttempts int - // InitialDelayMs is the initial delay in milliseconds before retrying the token request. + // InitialDelay is the initial delay before retrying the token request. // - // default: 1000 ms - InitialDelayMs int - // MaxDelayMs is the maximum delay in milliseconds between retry attempts. + // default: 1 second + InitialDelay time.Duration + // MaxDelay is the maximum delay between retry attempts. // - // default: 10000 ms - MaxDelayMs int + // default: 10 seconds + MaxDelay time.Duration // BackoffMultiplier is the multiplier for the backoff delay. // default: 2.0 BackoffMultiplier float64 @@ -265,8 +265,8 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { e.listener = listener go func(listener TokenListener, closed <-chan struct{}) { - maxDelay := time.Duration(e.retryOptions.MaxDelayMs) * time.Millisecond - initialDelay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond + maxDelay := e.retryOptions.MaxDelay + initialDelay := e.retryOptions.InitialDelay for { timeToRenewal := e.durationToRenewal() diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 84a2c7c..2fb2e97 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -88,8 +88,8 @@ func TestTokenManagerWithOptions(t *testing.T) { assert.NotNil(t, tm.retryOptions.IsRetryable) assertFuncNameMatches(t, tm.retryOptions.IsRetryable, defaultIsRetryable) assert.Equal(t, DefaultRetryOptionsMaxAttempts, tm.retryOptions.MaxAttempts) - assert.Equal(t, DefaultRetryOptionsInitialDelayMs, tm.retryOptions.InitialDelayMs) - assert.Equal(t, DefaultRetryOptionsMaxDelayMs, tm.retryOptions.MaxDelayMs) + assert.Equal(t, DefaultRetryOptionsInitialDelay, tm.retryOptions.InitialDelay) + assert.Equal(t, DefaultRetryOptionsMaxDelay, tm.retryOptions.MaxDelay) assert.Equal(t, DefaultRetryOptionsBackoffMultiplier, tm.retryOptions.BackoffMultiplier) }) } @@ -865,7 +865,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { <-time.After(10 * time.Millisecond) assert.NoError(t, cancel()) - assert.InDelta(t, stop.Sub(start), time.Duration(tm.retryOptions.InitialDelayMs)*time.Millisecond, float64(200*time.Millisecond)) + assert.InDelta(t, stop.Sub(start), tm.retryOptions.InitialDelay, float64(200*time.Millisecond)) idp.AssertNumberOfCalls(t, "RequestToken", 2) listener.AssertNumberOfCalls(t, "OnNext", 2) @@ -880,7 +880,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { TokenManagerOptions{ LowerRefreshBound: time.Hour, RetryOptions: RetryOptions{ - InitialDelayMs: 5000, // 5 seconds + InitialDelay: 5 * time.Second, }, }, ) @@ -916,7 +916,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.Equal(t, time.Duration(0), toRenewal) assert.True(t, expiresIn > toRenewal) - <-time.After(time.Duration(tm.retryOptions.InitialDelayMs/2) * time.Millisecond) + <-time.After(time.Duration(tm.retryOptions.InitialDelay / 2)) assert.NoError(t, cancel()) assert.Nil(t, tm.listener) assert.Panics(t, func() { @@ -1090,14 +1090,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { idp := &mockIdentityProvider{} listener := &mockTokenListener{} maxAttempts := 3 - maxDelayMs := 500 - initialDelayMs := 100 + maxDelay := 500 * time.Millisecond + initialDelay := 100 * time.Millisecond tokenManager, err := NewTokenManager(idp, TokenManagerOptions{ RetryOptions: RetryOptions{ MaxAttempts: maxAttempts, - MaxDelayMs: maxDelayMs, - InitialDelayMs: initialDelayMs, + MaxDelay: maxDelay, + InitialDelay: initialDelay, BackoffMultiplier: 10, }, }, @@ -1158,15 +1158,15 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { idp.On("RequestToken", mock.Anything).Return(nil, returnErr) select { - case <-time.After(toRenewal + time.Duration(maxAttempts*maxDelayMs)*time.Millisecond): + case <-time.After(toRenewal + time.Duration(maxAttempts)*maxDelay): assert.Fail(t, "Timeout - max retries not reached") case <-maxAttemptsReached: } // initialRenewal window, maxAttempts - 1 * max delay + the initial one which was lower than max delay allDelaysShouldBe := toRenewal - allDelaysShouldBe += time.Duration(initialDelayMs) * time.Millisecond - allDelaysShouldBe += time.Duration(maxAttempts-1) * time.Duration(maxDelayMs) * time.Millisecond + allDelaysShouldBe += initialDelay + allDelaysShouldBe += time.Duration(maxAttempts-1) * maxDelay assert.InEpsilon(t, elapsed, allDelaysShouldBe, float64(10*time.Millisecond)) From 7ad66748019b341acc8c9efe41f9ad017beaca36 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 16 Apr 2025 14:27:48 +0300 Subject: [PATCH 30/44] close the context when token manager is stopped --- manager/token_manager.go | 1 + 1 file changed, 1 insertion(+) diff --git a/manager/token_manager.go b/manager/token_manager.go index 8423d2c..4dc2c33 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -337,6 +337,7 @@ func (e *entraidTokenManager) Stop() error { return ErrTokenManagerAlreadyStopped } + e.ctxCancel() e.listener = nil close(e.closedChan) From 263196831a0251b48285a33165d5ad0c7c818d01 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Wed, 16 Apr 2025 15:58:06 +0300 Subject: [PATCH 31/44] update go.mod go.sum --- go.mod | 2 -- go.sum | 6 ------ 2 files changed, 8 deletions(-) diff --git a/go.mod b/go.mod index b59f728..7e908f0 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,7 @@ require ( ) require ( - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 1d7d291..36c644a 100644 --- a/go.sum +++ b/go.sum @@ -8,12 +8,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkY github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 h1:8BKxhZZLX/WosEeoCvWysmKUscfa9v8LIPEEU0JjE2o= github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -27,8 +23,6 @@ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmd github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a h1:R5xgk8m+CF7lVE0EGr+tLkT1eM3Zfd39BJfnANQqpKA= -github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420 h1:/dxO9rhmlhKP5pyI7omDH3QQzC0AppWxHT1w5TBsdTU= github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= From 1fca220fb0177887e5ca77f3a0ee10967956bb0a Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:13:54 +0300 Subject: [PATCH 32/44] Update identity/confidential_identity_provider.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- identity/confidential_identity_provider.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/identity/confidential_identity_provider.go b/identity/confidential_identity_provider.go index 87d9a7d..fe5aa3b 100644 --- a/identity/confidential_identity_provider.go +++ b/identity/confidential_identity_provider.go @@ -124,8 +124,8 @@ func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) ( } case ClientCertificateCredentialType: // ClientCertificateCredentialType is the type of credentials that uses a client certificate to authenticate. - if opts.ClientCert == nil { - return nil, fmt.Errorf("client certificate is required when using client certificate credentials") + if opts.ClientCert == nil || len(opts.ClientCert) == 0 { + return nil, fmt.Errorf("non-empty client certificate is required when using client certificate credentials") } if opts.ClientPrivateKey == nil { return nil, fmt.Errorf("client private key is required when using client certificate credentials") From 22dc6ae6106cb527225a53a874b314f78774ba24 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 17 Apr 2025 16:16:57 +0300 Subject: [PATCH 33/44] Change IdentityProviderResponse interface --- .../azure_default_identity_provider_test.go | 26 +++++++++---------- identity/confidential_identity_provider.go | 2 +- .../confidential_identity_provider_test.go | 8 +++--- internal/idp_response.go | 15 ----------- internal/idp_response_test.go | 10 ------- manager/defaults.go | 6 ++--- shared/identity_provider_response.go | 21 ++++++++++++--- shared/identity_provider_response_test.go | 10 +++---- 8 files changed, 44 insertions(+), 54 deletions(-) diff --git a/identity/azure_default_identity_provider_test.go b/identity/azure_default_identity_provider_test.go index 101c760..67a43f8 100644 --- a/identity/azure_default_identity_provider_test.go +++ b/identity/azure_default_identity_provider_test.go @@ -52,11 +52,11 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) { mCredFactory := &mockCredFactory{} mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) provider.credFactory = mCredFactory - token, err = provider.RequestToken(context.Background()) - assert.NotNil(t, token, "token should not be nil") - assert.NoError(t, err, "failed to request token") - assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token") - assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTToken") + resp, err := provider.RequestToken(context.Background()) + assert.NotNil(t, resp, "resp should not be nil") + assert.NoError(t, err, "failed to request resp") + assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp") + assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access token should be equal to testJWTToken") } func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { @@ -84,19 +84,19 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { mCredFactory := &mockCredFactory{} mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil) provider.credFactory = mCredFactory - token, err = provider.RequestToken(context.Background()) - assert.NotNil(t, token, "token should not be nil") - assert.NoError(t, err, "failed to request token") - assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token") - assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTToken") + resp, err := provider.RequestToken(context.Background()) + assert.NotNil(t, resp, "resp should not be nil") + assert.NoError(t, err, "failed to request resp") + assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp") + assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access resp should be equal to testJWTToken") }) t.Run("RequestToken with error from credFactory", func(t *testing.T) { // use mockAzureCredential to simulate the environment mCredFactory := &mockCredFactory{} mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError) provider.credFactory = mCredFactory - token, err := provider.RequestToken(context.Background()) - assert.Nil(t, token, "token should be nil") - assert.Error(t, err, "failed to request token") + resp, err := provider.RequestToken(context.Background()) + assert.Nil(t, resp, "resp should be nil") + assert.Error(t, err, "failed to request resp") }) } diff --git a/identity/confidential_identity_provider.go b/identity/confidential_identity_provider.go index fe5aa3b..532b0c8 100644 --- a/identity/confidential_identity_provider.go +++ b/identity/confidential_identity_provider.go @@ -124,7 +124,7 @@ func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) ( } case ClientCertificateCredentialType: // ClientCertificateCredentialType is the type of credentials that uses a client certificate to authenticate. - if opts.ClientCert == nil || len(opts.ClientCert) == 0 { + if len(opts.ClientCert) == 0 { return nil, fmt.Errorf("non-empty client certificate is required when using client certificate credentials") } if opts.ClientPrivateKey == nil { diff --git a/identity/confidential_identity_provider_test.go b/identity/confidential_identity_provider_test.go index 7d063e0..cef1976 100644 --- a/identity/confidential_identity_provider_test.go +++ b/identity/confidential_identity_provider_test.go @@ -40,7 +40,7 @@ func TestNewConfidentialIdentityProvider(t *testing.T) { opts := ConfidentialIdentityProviderOptions{ ClientID: "client-id", CredentialsType: "ClientCertificate", - ClientCert: []*x509.Certificate{}, + ClientCert: []*x509.Certificate{&x509.Certificate{}}, ClientPrivateKey: "private-key", Scopes: []string{"scope1", "scope2"}, Authority: AuthorityConfiguration{}, @@ -58,7 +58,7 @@ func TestNewConfidentialIdentityProvider(t *testing.T) { opts := ConfidentialIdentityProviderOptions{ ClientID: "client-id", CredentialsType: "ClientCertificate", - ClientCert: []*x509.Certificate{}, + ClientCert: []*x509.Certificate{&x509.Certificate{}}, ClientPrivateKey: "private-key", Scopes: []string{"scope1", "scope2"}, Authority: AuthorityConfiguration{}, @@ -192,7 +192,7 @@ func TestNewConfidentialIdentityProvider(t *testing.T) { opts := ConfidentialIdentityProviderOptions{ ClientID: "client-id", CredentialsType: "ClientCertificate", - ClientCert: []*x509.Certificate{}, + ClientCert: []*x509.Certificate{&x509.Certificate{}}, ClientPrivateKey: nil, Scopes: []string{"scope1", "scope2"}, Authority: AuthorityConfiguration{}, @@ -268,7 +268,7 @@ func TestConfidentialIdentityProvider_RequestToken(t *testing.T) { } assert.NotEmpty(t, token, "RequestToken() token should not be empty") assert.Equal(t, token.Type(), shared.ResponseTypeAuthResult, "RequestToken() token type should be AuthResult") - assert.Equal(t, token.AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match") + assert.Equal(t, token.(shared.AuthResultIDPResponse).AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match") }) t.Run("with error", func(t *testing.T) { t.Parallel() diff --git a/internal/idp_response.go b/internal/idp_response.go index 457c5cd..8ae0cdc 100644 --- a/internal/idp_response.go +++ b/internal/idp_response.go @@ -80,11 +80,6 @@ func (a *IDPResp) AuthResult() public.AuthResult { return *a.authResultVal } -// HasAuthResult returns true if an AuthResult is set -func (a *IDPResp) HasAuthResult() bool { - return a.authResultVal != nil -} - // AccessToken returns the AccessToken if present, or an empty AccessToken if not set // Use HasAccessToken() to check if the value is actually set func (a *IDPResp) AccessToken() azcore.AccessToken { @@ -94,17 +89,7 @@ func (a *IDPResp) AccessToken() azcore.AccessToken { return *a.accessTokenVal } -// HasAccessToken returns true if an AccessToken is set -func (a *IDPResp) HasAccessToken() bool { - return a.accessTokenVal != nil -} - // RawToken returns the raw token string func (a *IDPResp) RawToken() string { return a.rawTokenVal } - -// HasRawToken returns true if a raw token is set -func (a *IDPResp) HasRawToken() bool { - return a.rawTokenVal != "" -} diff --git a/internal/idp_response_test.go b/internal/idp_response_test.go index 59f266b..7b98226 100644 --- a/internal/idp_response_test.go +++ b/internal/idp_response_test.go @@ -171,10 +171,7 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.True(t, resp.HasAuthResult()) assert.Equal(t, "test-token", resp.AuthResult().AccessToken) - assert.False(t, resp.HasAccessToken()) - assert.False(t, resp.HasRawToken()) }, }, { @@ -185,7 +182,6 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.True(t, resp.HasAuthResult()) assert.Equal(t, "test-token", resp.AuthResult().AccessToken) }, }, @@ -198,7 +194,6 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.True(t, resp.HasAccessToken()) assert.Equal(t, "test-token", resp.AccessToken().Token) assert.Equal(t, "test-token", resp.RawToken()) }, @@ -212,7 +207,6 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.True(t, resp.HasAccessToken()) assert.Equal(t, "test-token", resp.AccessToken().Token) assert.Equal(t, "test-token", resp.RawToken()) }, @@ -223,10 +217,7 @@ func TestNewIDPResp(t *testing.T) { result: "test-token", wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.True(t, resp.HasRawToken()) assert.Equal(t, "test-token", resp.RawToken()) - assert.False(t, resp.HasAuthResult()) - assert.False(t, resp.HasAccessToken()) }, }, { @@ -235,7 +226,6 @@ func TestNewIDPResp(t *testing.T) { result: stringPtr("test-token"), wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.True(t, resp.HasRawToken()) assert.Equal(t, "test-token", resp.RawToken()) }, }, diff --git a/manager/defaults.go b/manager/defaults.go index d1b5333..992b878 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -104,7 +104,7 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden switch response.Type() { case shared.ResponseTypeAuthResult: - authResult := response.AuthResult() + authResult := response.(shared.AuthResultIDPResponse).AuthResult() if authResult.ExpiresOn.IsZero() { return nil, fmt.Errorf("auth result expiration time is not set") } @@ -117,10 +117,10 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden expiresOn = authResult.ExpiresOn.UTC() case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken: - tokenStr := response.RawToken() + tokenStr := response.(shared.RawTokenIDPResponse).RawToken() if response.Type() == shared.ResponseTypeAccessToken { - accessToken := response.AccessToken() + accessToken := response.(shared.AccessTokenIDPResponse).AccessToken() if accessToken.Token == "" { return nil, fmt.Errorf("access token value is empty") } diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go index 73f82b2..f4bb7a3 100644 --- a/shared/identity_provider_response.go +++ b/shared/identity_provider_response.go @@ -25,13 +25,28 @@ type IdentityProviderResponseParser interface { ParseResponse(response IdentityProviderResponse) (*token.Token, error) } -// IdentityProviderResponse is an interface that defines the methods for an identity provider authentication result. -// It is used to get the type of the authentication result, the authentication result itself (can be AuthResult or AccessToken), +// IdentityProviderResponse is an interface that defines the +// type method for the identity provider response. It is used to +// identify the type of response returned by the identity provider. +// The type can be either AuthResult, AccessToken, or RawToken. You can +// use this interface to check the type of the response and handle it accordingly. type IdentityProviderResponse interface { - // Type returns the type of the auth result + // Type returns the type of identity provider response Type() string +} + +// AuthResultIDPResponse is an interface that defines the method for getting the auth result. +type AuthResultIDPResponse interface { AuthResult() public.AuthResult +} + +// AccessTokenIDPResponse is an interface that defines the method for getting the access token. +type AccessTokenIDPResponse interface { AccessToken() azcore.AccessToken +} + +// RawTokenIDPResponse is an interface that defines the method for getting the raw token. +type RawTokenIDPResponse interface { RawToken() string } diff --git a/shared/identity_provider_response_test.go b/shared/identity_provider_response_test.go index 0b5a014..715dce2 100644 --- a/shared/identity_provider_response_test.go +++ b/shared/identity_provider_response_test.go @@ -156,12 +156,12 @@ func TestNewIDPResponse(t *testing.T) { switch tt.responseType { case ResponseTypeAuthResult: - assert.NotNil(t, resp.AuthResult()) + assert.NotNil(t, resp.(AuthResultIDPResponse).AuthResult()) case ResponseTypeAccessToken: - assert.NotNil(t, resp.AccessToken()) - assert.NotEmpty(t, resp.AccessToken().Token) + assert.NotNil(t, resp.(AccessTokenIDPResponse).AccessToken()) + assert.NotEmpty(t, resp.(AccessTokenIDPResponse).AccessToken().Token) case ResponseTypeRawToken: - assert.NotEmpty(t, resp.RawToken()) + assert.NotEmpty(t, resp.(RawTokenIDPResponse).RawToken()) } }) } @@ -271,7 +271,7 @@ func TestIdentityProvider(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, response) assert.Equal(t, ResponseTypeRawToken, response.Type()) - assert.Equal(t, "test-token", response.RawToken()) + assert.Equal(t, "test-token", response.(RawTokenIDPResponse).RawToken()) } }) } From 15c4d5de1cdd4f91254a088c93f4887232e324dc Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 17 Apr 2025 16:46:02 +0300 Subject: [PATCH 34/44] Update Readme.md --- README.md | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a105600..ff184ef 100644 --- a/README.md +++ b/README.md @@ -498,14 +498,14 @@ type CustomIdentityProvider struct { } // RequestToken implements the IdentityProvider interface -func (p *CustomIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) { +func (p *CustomIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { // Implement your custom token retrieval logic here // This could be calling your own auth service, using a different auth protocol, etc. // For this example, we'll simulate getting a JWT token token := "your.jwt.token" - // Create a response using NewIDPResponse with RawToken type + // Create a response using NewIDPResponse return shared.NewIDPResponse(shared.ResponseTypeRawToken, token) } @@ -677,22 +677,85 @@ A: You can create a custom identity provider by implementing the `IdentityProvid ```go type IdentityProvider interface { // RequestToken requests a token from the identity provider. + // The context is passed to the request to allow for cancellation and timeouts. // It returns the token, the expiration time, and an error if any. - RequestToken() (IdentityProviderResponse, error) + RequestToken(ctx context.Context) (IdentityProviderResponse, error) } ``` -The `IdentityProviderResponse` interface provides methods to access the authentication result: +The response types are defined as constants: ```go +const ( + // ResponseTypeAuthResult is the type of the auth result. + ResponseTypeAuthResult = "AuthResult" + // ResponseTypeAccessToken is the type of the access token. + ResponseTypeAccessToken = "AccessToken" + // ResponseTypeRawToken is the type of the response when you have a raw string. + ResponseTypeRawToken = "RawToken" +) +``` + +The `IdentityProviderResponse` interface and related interfaces provide methods to access the authentication result: +```go +// IdentityProviderResponse is the base interface that defines the type method type IdentityProviderResponse interface { - // Type returns the type of the auth result + // Type returns the type of identity provider response Type() string +} + +// AuthResultIDPResponse defines the method for getting the auth result +type AuthResultIDPResponse interface { AuthResult() public.AuthResult +} + +// AccessTokenIDPResponse defines the method for getting the access token +type AccessTokenIDPResponse interface { AccessToken() azcore.AccessToken +} + +// RawTokenIDPResponse defines the method for getting the raw token +type RawTokenIDPResponse interface { RawToken() string } ``` +You can create a new response using the `NewIDPResponse` function: +```go +// NewIDPResponse creates a new auth result based on the type provided. +// Type can be either AuthResult, AccessToken, or RawToken. +// Second argument is the result of the type provided in the first argument. +func NewIDPResponse(responseType string, result interface{}) (IdentityProviderResponse, error) +``` + +Here's an example of how to use these types in a custom identity provider: +```go +type CustomIdentityProvider struct { + tokenEndpoint string + clientID string + clientSecret string +} + +func (p *CustomIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) { + // Get the token from your custom auth service + token, err := p.getTokenFromCustomService() + if err != nil { + return nil, err + } + + // Create a response based on the token type + switch token.Type { + case "jwt": + return shared.NewIDPResponse(shared.ResponseTypeRawToken, token.Value) + case "access_token": + return shared.NewIDPResponse(shared.ResponseTypeAccessToken, token.Value) + case "auth_result": + return shared.NewIDPResponse(shared.ResponseTypeAuthResult, token.Value) + default: + return nil, fmt.Errorf("unsupported token type: %s", token.Type) + } +} +``` + ### Q: Can I customize how token responses are parsed? A: Yes, you can provide a custom `IdentityProviderResponseParser` in the `TokenManagerOptions`. This allows you to handle custom token formats or implement special parsing logic. From f493bdacb6d573db1f100bc2488d419cec7ef860 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 24 Apr 2025 14:05:00 +0300 Subject: [PATCH 35/44] refactor(response)!: getters will return err IdentityProviderResponse getters will return error in the case where the type is incorrect or the response is not set. --- README.md | 159 +++++++++++++++++- .../azure_default_identity_provider_test.go | 10 +- .../confidential_identity_provider_test.go | 4 +- internal/errors.go | 9 + internal/idp_response.go | 21 ++- internal/idp_response_test.go | 47 ++++-- manager/defaults.go | 22 ++- manager/manager_test.go | 19 ++- manager/token_manager.go | 8 +- shared/identity_provider_response.go | 31 +++- shared/identity_provider_response_test.go | 97 +++++++---- 11 files changed, 341 insertions(+), 86 deletions(-) create mode 100644 internal/errors.go diff --git a/README.md b/README.md index ff184ef..f485829 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Entra ID extension for go-redis - [Examples](#examples) - [Testing](#testing) - [FAQ](#faq) +- [Error Handling](#error-handling) ## Introduction @@ -607,12 +608,23 @@ func TestManagedIdentityProvider(t *testing.T) { } // Test token retrieval - token, err := provider.GetToken(context.Background()) + response, err := provider.RequestToken(context.Background()) if err != nil { t.Fatalf("Failed to get token: %v", err) } - if token == "" { - t.Error("Expected non-empty token") + + // Check response type and get token + switch response.Type() { + case shared.ResponseTypeRawToken: + token, err := response.(shared.RawTokenIDPResponse).RawToken() + if err != nil { + t.Fatalf("Failed to get raw token: %v", err) + } + if token == "" { + t.Error("Expected non-empty token") + } + default: + t.Errorf("Unexpected response type: %s", response.Type()) } } ``` @@ -705,17 +717,23 @@ type IdentityProviderResponse interface { // AuthResultIDPResponse defines the method for getting the auth result type AuthResultIDPResponse interface { - AuthResult() public.AuthResult + // AuthResult returns the Microsoft Authentication Library AuthResult. + // Returns ErrAuthResultNotFound if the auth result is not set. + AuthResult() (public.AuthResult, error) } // AccessTokenIDPResponse defines the method for getting the access token type AccessTokenIDPResponse interface { - AccessToken() azcore.AccessToken + // AccessToken returns the Azure SDK AccessToken. + // Returns ErrAccessTokenNotFound if the access token is not set. + AccessToken() (azcore.AccessToken, error) } // RawTokenIDPResponse defines the method for getting the raw token type RawTokenIDPResponse interface { - RawToken() string + // RawToken returns the raw token string. + // Returns ErrRawTokenNotFound if the raw token is not set. + RawToken() (string, error) } ``` @@ -754,6 +772,43 @@ func (p *CustomIdentityProvider) RequestToken(ctx context.Context) (shared.Ident return nil, fmt.Errorf("unsupported token type: %s", token.Type) } } + +// Example usage: +func main() { + provider := &CustomIdentityProvider{ + tokenEndpoint: "https://your-auth-endpoint.com/token", + clientID: os.Getenv("CUSTOM_CLIENT_ID"), + clientSecret: os.Getenv("CUSTOM_CLIENT_SECRET"), + } + + response, err := provider.RequestToken(context.Background()) + if err != nil { + log.Fatalf("Failed to get token: %v", err) + } + + switch response.Type() { + case shared.ResponseTypeRawToken: + token, err := response.(shared.RawTokenIDPResponse).RawToken() + if err != nil { + log.Fatalf("Failed to get raw token: %v", err) + } + log.Printf("Got raw token: %s", token) + + case shared.ResponseTypeAccessToken: + token, err := response.(shared.AccessTokenIDPResponse).AccessToken() + if err != nil { + log.Fatalf("Failed to get access token: %v", err) + } + log.Printf("Got access token: %s", token.Token) + + case shared.ResponseTypeAuthResult: + result, err := response.(shared.AuthResultIDPResponse).AuthResult() + if err != nil { + log.Fatalf("Failed to get auth result: %v", err) + } + log.Printf("Got auth result: %s", result.AccessToken) + } +} ``` ### Q: Can I customize how token responses are parsed? @@ -790,4 +845,94 @@ A: There are three main types of managed identities in Azure: The choice between these types depends on your specific use case: - Use System Assigned for single-resource applications - Use User Assigned for shared identity scenarios -- Use Default Azure Identity for development and testing \ No newline at end of file +- Use Default Azure Identity for development and testing + +## Error Handling + +### Available Errors + +The library provides several error types that you can check against: + +```go +// Import the shared package to access error types +import "github.com/redis-developer/go-redis-entraid/shared" + +// Available error types: +var ( + // ErrInvalidIDPResponse is returned when the response from the identity provider is invalid + ErrInvalidIDPResponse = shared.ErrInvalidIDPResponse + + // ErrInvalidIDPResponseType is returned when the response type is not supported + ErrInvalidIDPResponseType = shared.ErrInvalidIDPResponseType + + // ErrAuthResultNotFound is returned when trying to get an AuthResult that is not set + ErrAuthResultNotFound = shared.ErrAuthResultNotFound + + // ErrAccessTokenNotFound is returned when trying to get an AccessToken that is not set + ErrAccessTokenNotFound = shared.ErrAccessTokenNotFound + + // ErrRawTokenNotFound is returned when trying to get a RawToken that is not set + ErrRawTokenNotFound = shared.ErrRawTokenNotFound +) +``` + +### Error Handling Example + +Here's how to handle errors when working with identity provider responses: + +```go +// Example of handling different response types and their errors +response, err := identityProvider.RequestToken(ctx) +if err != nil { + // Handle request error + return err +} + +switch response.Type() { +case shared.ResponseTypeAuthResult: + authResult, err := response.(shared.AuthResultIDPResponse).AuthResult() + if err != nil { + if errors.Is(err, shared.ErrAuthResultNotFound) { + // Handle missing auth result + } + return err + } + // Use authResult... + +case shared.ResponseTypeAccessToken: + accessToken, err := response.(shared.AccessTokenIDPResponse).AccessToken() + if err != nil { + if errors.Is(err, shared.ErrAccessTokenNotFound) { + // Handle missing access token + } + return err + } + // Use accessToken... + +case shared.ResponseTypeRawToken: + rawToken, err := response.(shared.RawTokenIDPResponse).RawToken() + if err != nil { + if errors.Is(err, shared.ErrRawTokenNotFound) { + // Handle missing raw token + } + return err + } + // Use rawToken... +} +``` + +### Response Types + +The library supports three types of identity provider responses: + +1. **AuthResult** (`ResponseTypeAuthResult`) + - Contains Microsoft Authentication Library AuthResult + - Returns `ErrAuthResultNotFound` if not set + +2. **AccessToken** (`ResponseTypeAccessToken`) + - Contains Azure SDK AccessToken + - Returns `ErrAccessTokenNotFound` if not set + +3. **RawToken** (`ResponseTypeRawToken`) + - Contains raw token string + - Returns `ErrRawTokenNotFound` if not set \ No newline at end of file diff --git a/identity/azure_default_identity_provider_test.go b/identity/azure_default_identity_provider_test.go index 67a43f8..b73b617 100644 --- a/identity/azure_default_identity_provider_test.go +++ b/identity/azure_default_identity_provider_test.go @@ -56,7 +56,10 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) { assert.NotNil(t, resp, "resp should not be nil") assert.NoError(t, err, "failed to request resp") assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp") - assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access token should be equal to testJWTToken") + atoken, err := resp.(shared.AccessTokenIDPResponse).AccessToken() + assert.NotNil(t, atoken, "token should not be nil") + assert.NoError(t, err, "failed to get token") + assert.Equal(t, mToken, atoken, "access resp should be equal to testJWTToken") } func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { @@ -88,7 +91,10 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) { assert.NotNil(t, resp, "resp should not be nil") assert.NoError(t, err, "failed to request resp") assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp") - assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access resp should be equal to testJWTToken") + atoken, err := resp.(shared.AccessTokenIDPResponse).AccessToken() + assert.NotNil(t, atoken, "token should not be nil") + assert.NoError(t, err, "failed to get token") + assert.Equal(t, mToken, atoken, "access resp should be equal to testJWTToken") }) t.Run("RequestToken with error from credFactory", func(t *testing.T) { // use mockAzureCredential to simulate the environment diff --git a/identity/confidential_identity_provider_test.go b/identity/confidential_identity_provider_test.go index cef1976..c1adf4f 100644 --- a/identity/confidential_identity_provider_test.go +++ b/identity/confidential_identity_provider_test.go @@ -268,7 +268,9 @@ func TestConfidentialIdentityProvider_RequestToken(t *testing.T) { } assert.NotEmpty(t, token, "RequestToken() token should not be empty") assert.Equal(t, token.Type(), shared.ResponseTypeAuthResult, "RequestToken() token type should be AuthResult") - assert.Equal(t, token.(shared.AuthResultIDPResponse).AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match") + res, err := token.(shared.AuthResultIDPResponse).AuthResult() + assert.NoError(t, err, "RequestToken() token should be AuthResultIDPResponse") + assert.Equal(t, expiresOn, res.ExpiresOn, "RequestToken() token should be equal to expiresOn") }) t.Run("with error", func(t *testing.T) { t.Parallel() diff --git a/internal/errors.go b/internal/errors.go new file mode 100644 index 0000000..6e88ecf --- /dev/null +++ b/internal/errors.go @@ -0,0 +1,9 @@ +package internal + +import "fmt" + +var ErrInvalidIDPResponse = fmt.Errorf("invalid identity provider response") +var ErrInvalidIDPResponseType = fmt.Errorf("invalid identity provider response type") +var ErrAuthResultNotFound = fmt.Errorf("auth result not found") +var ErrAccessTokenNotFound = fmt.Errorf("access token not found") +var ErrRawTokenNotFound = fmt.Errorf("raw token not found") diff --git a/internal/idp_response.go b/internal/idp_response.go index 8ae0cdc..dcdb181 100644 --- a/internal/idp_response.go +++ b/internal/idp_response.go @@ -21,7 +21,7 @@ type IDPResp struct { // It validates the input and ensures the response type matches the provided value func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) { if result == nil { - return nil, fmt.Errorf("result cannot be nil") + return nil, ErrInvalidIDPResponse } r := &IDPResp{resultType: resultType} @@ -73,23 +73,26 @@ func (a *IDPResp) Type() string { // AuthResult returns the AuthResult if present, or an empty AuthResult if not set // Use HasAuthResult() to check if the value is actually set -func (a *IDPResp) AuthResult() public.AuthResult { +func (a *IDPResp) AuthResult() (public.AuthResult, error) { if a.authResultVal == nil { - return public.AuthResult{} + return public.AuthResult{}, ErrAuthResultNotFound } - return *a.authResultVal + return *a.authResultVal, nil } // AccessToken returns the AccessToken if present, or an empty AccessToken if not set // Use HasAccessToken() to check if the value is actually set -func (a *IDPResp) AccessToken() azcore.AccessToken { +func (a *IDPResp) AccessToken() (azcore.AccessToken, error) { if a.accessTokenVal == nil { - return azcore.AccessToken{} + return azcore.AccessToken{}, ErrAccessTokenNotFound } - return *a.accessTokenVal + return *a.accessTokenVal, nil } // RawToken returns the raw token string -func (a *IDPResp) RawToken() string { - return a.rawTokenVal +func (a *IDPResp) RawToken() (string, error) { + if a.rawTokenVal == "" { + return "", ErrRawTokenNotFound + } + return a.rawTokenVal, nil } diff --git a/internal/idp_response_test.go b/internal/idp_response_test.go index 7b98226..13fad12 100644 --- a/internal/idp_response_test.go +++ b/internal/idp_response_test.go @@ -51,6 +51,7 @@ func TestIDPResp_AuthResult(t *testing.T) { authResult *public.AuthResult wantToken string wantExpiresOn time.Time + wantErr error }{ { name: "With AuthResult", @@ -63,6 +64,7 @@ func TestIDPResp_AuthResult(t *testing.T) { authResult: nil, wantToken: "", wantExpiresOn: time.Time{}, + wantErr: ErrAuthResultNotFound, }, } @@ -71,9 +73,9 @@ func TestIDPResp_AuthResult(t *testing.T) { resp := &IDPResp{ authResultVal: tt.authResult, } - got := resp.AuthResult() - if got.AccessToken != tt.wantToken { - t.Errorf("IDPResp.AuthResult().AccessToken = %v, want %v", got.AccessToken, tt.wantToken) + got, err := resp.AuthResult() + if got.AccessToken != tt.wantToken || err != tt.wantErr { + t.Errorf("IDPResp.AuthResult().AccessToken = %v, %v, want %v, %v", got.AccessToken, err, tt.wantToken, tt.wantErr) } if !got.ExpiresOn.Equal(tt.wantExpiresOn) { t.Errorf("IDPResp.AuthResult().ExpiresOn = %v, want %v", got.ExpiresOn, tt.wantExpiresOn) @@ -94,18 +96,21 @@ func TestIDPResp_AccessToken(t *testing.T) { accessToken *azcore.AccessToken wantToken string wantExpiresOn time.Time + wantErr error }{ { name: "With AccessToken", accessToken: accessToken, wantToken: "test-token", wantExpiresOn: now, + wantErr: nil, }, { name: "Nil AccessToken", accessToken: nil, wantToken: "", wantExpiresOn: time.Time{}, + wantErr: ErrAccessTokenNotFound, }, } @@ -114,8 +119,8 @@ func TestIDPResp_AccessToken(t *testing.T) { resp := &IDPResp{ accessTokenVal: tt.accessToken, } - got := resp.AccessToken() - if got.Token != tt.wantToken { + got, err := resp.AccessToken() + if got.Token != tt.wantToken || err != tt.wantErr { t.Errorf("IDPResp.AccessToken().Token = %v, want %v", got.Token, tt.wantToken) } if !got.ExpiresOn.Equal(tt.wantExpiresOn) { @@ -130,6 +135,7 @@ func TestIDPResp_RawToken(t *testing.T) { name string rawToken string want string + err error }{ { name: "With RawToken", @@ -140,6 +146,7 @@ func TestIDPResp_RawToken(t *testing.T) { name: "Empty RawToken", rawToken: "", want: "", + err: ErrRawTokenNotFound, }, } @@ -148,8 +155,8 @@ func TestIDPResp_RawToken(t *testing.T) { resp := &IDPResp{ rawTokenVal: tt.rawToken, } - if got := resp.RawToken(); got != tt.want { - t.Errorf("IDPResp.RawToken() = %v, want %v", got, tt.want) + if got, err := resp.RawToken(); got != tt.want || err != tt.err { + t.Errorf("IDPResp.RawToken() = %v, %v, want %v, %v", got, err, tt.want, tt.err) } }) } @@ -171,7 +178,9 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.Equal(t, "test-token", resp.AuthResult().AccessToken) + token, err := resp.AuthResult() + assert.NoError(t, err) + assert.Equal(t, "test-token", token.AccessToken) }, }, { @@ -182,7 +191,9 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.Equal(t, "test-token", resp.AuthResult().AccessToken) + result, err := resp.AuthResult() + assert.NoError(t, err) + assert.Equal(t, "test-token", result.AccessToken) }, }, { @@ -194,8 +205,9 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.Equal(t, "test-token", resp.AccessToken().Token) - assert.Equal(t, "test-token", resp.RawToken()) + token, err := resp.AccessToken() + assert.NoError(t, err) + assert.Equal(t, "test-token", token.Token) }, }, { @@ -207,8 +219,9 @@ func TestNewIDPResp(t *testing.T) { }, wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.Equal(t, "test-token", resp.AccessToken().Token) - assert.Equal(t, "test-token", resp.RawToken()) + token, err := resp.AccessToken() + assert.NoError(t, err) + assert.Equal(t, "test-token", token.Token) }, }, { @@ -217,7 +230,9 @@ func TestNewIDPResp(t *testing.T) { result: "test-token", wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.Equal(t, "test-token", resp.RawToken()) + rawToken, err := resp.RawToken() + assert.NoError(t, err) + assert.Equal(t, "test-token", rawToken) }, }, { @@ -226,7 +241,9 @@ func TestNewIDPResp(t *testing.T) { result: stringPtr("test-token"), wantErr: false, checkResult: func(t *testing.T, resp *IDPResp) { - assert.Equal(t, "test-token", resp.RawToken()) + rawToken, err := resp.RawToken() + assert.NoError(t, err) + assert.Equal(t, "test-token", rawToken) }, }, { diff --git a/manager/defaults.go b/manager/defaults.go index 992b878..5303e4f 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -104,7 +104,10 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden switch response.Type() { case shared.ResponseTypeAuthResult: - authResult := response.(shared.AuthResultIDPResponse).AuthResult() + authResult, err := response.(shared.AuthResultIDPResponse).AuthResult() + if err != nil { + return nil, fmt.Errorf("failed to get auth result: %w", err) + } if authResult.ExpiresOn.IsZero() { return nil, fmt.Errorf("auth result expiration time is not set") } @@ -117,10 +120,19 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden expiresOn = authResult.ExpiresOn.UTC() case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken: - tokenStr := response.(shared.RawTokenIDPResponse).RawToken() - + var tokenStr string + var err error + if response.Type() == shared.ResponseTypeRawToken { + tokenStr, err = response.(shared.RawTokenIDPResponse).RawToken() + if err != nil { + return nil, fmt.Errorf("failed to get raw token: %w", err) + } + } if response.Type() == shared.ResponseTypeAccessToken { - accessToken := response.(shared.AccessTokenIDPResponse).AccessToken() + accessToken, err := response.(shared.AccessTokenIDPResponse).AccessToken() + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } if accessToken.Token == "" { return nil, fmt.Errorf("access token value is empty") } @@ -139,7 +151,7 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden // Parse the token to extract claims, but note that signature verification // should be handled by the identity provider - _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) + _, _, err = jwt.NewParser().ParseUnverified(tokenStr, &claims) if err != nil { return nil, fmt.Errorf("failed to parse JWT token: %w", err) } diff --git a/manager/manager_test.go b/manager/manager_test.go index 55c85d6..5150981 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -162,20 +162,23 @@ func (a *authResult) Type() string { return a.ResultType } -func (a *authResult) AuthResult() public.AuthResult { +func (a *authResult) AuthResult() (public.AuthResult, error) { if a.AuthResultVal == nil { - return public.AuthResult{} + return public.AuthResult{}, shared.ErrAuthResultNotFound } - return *a.AuthResultVal + return *a.AuthResultVal, nil } -func (a *authResult) AccessToken() azcore.AccessToken { +func (a *authResult) AccessToken() (azcore.AccessToken, error) { if a.AccessTokenVal == nil { - return azcore.AccessToken{} + return azcore.AccessToken{}, shared.ErrAccessTokenNotFound } - return *a.AccessTokenVal + return *a.AccessTokenVal, nil } -func (a *authResult) RawToken() string { - return a.RawTokenVal +func (a *authResult) RawToken() (string, error) { + if a.RawTokenVal == "" { + return "", shared.ErrRawTokenNotFound + } + return a.RawTokenVal, nil } diff --git a/manager/token_manager.go b/manager/token_manager.go index 4dc2c33..211a659 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -329,9 +329,15 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { } // Stop closes the token manager and releases any resources. -func (e *entraidTokenManager) Stop() error { +func (e *entraidTokenManager) Stop() (err error) { e.lock.Lock() defer e.lock.Unlock() + defer func() { + // recover from panic and return the error + if r := recover(); r != nil { + err = fmt.Errorf("failed to stop token manager: %s", r) + } + }() if e.closedChan == nil || e.listener == nil { return ErrTokenManagerAlreadyStopped diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go index f4bb7a3..1eb4b4f 100644 --- a/shared/identity_provider_response.go +++ b/shared/identity_provider_response.go @@ -18,6 +18,12 @@ const ( ResponseTypeRawToken = "RawToken" ) +var ErrInvalidIDPResponse = internal.ErrInvalidIDPResponse +var ErrInvalidIDPResponseType = internal.ErrInvalidIDPResponseType +var ErrAuthResultNotFound = internal.ErrAuthResultNotFound +var ErrAccessTokenNotFound = internal.ErrAccessTokenNotFound +var ErrRawTokenNotFound = internal.ErrRawTokenNotFound + // IdentityProviderResponseParser is an interface that defines the methods for parsing the identity provider response. // It is used to parse the response from the identity provider and extract the token. // If not provided, the default implementation will be used. @@ -30,29 +36,48 @@ type IdentityProviderResponseParser interface { // identify the type of response returned by the identity provider. // The type can be either AuthResult, AccessToken, or RawToken. You can // use this interface to check the type of the response and handle it accordingly. +// Available response types are: +// - ResponseTypeAuthResult: For Microsoft Authentication Library AuthResult +// - ResponseTypeAccessToken: For Azure SDK AccessToken +// - ResponseTypeRawToken: For raw token strings type IdentityProviderResponse interface { // Type returns the type of identity provider response Type() string } // AuthResultIDPResponse is an interface that defines the method for getting the auth result. +// Returns ErrAuthResultNotFound if the auth result is not set. type AuthResultIDPResponse interface { - AuthResult() public.AuthResult + // AuthResult returns the Microsoft Authentication Library AuthResult. + // Returns ErrAuthResultNotFound if the auth result is not set. + AuthResult() (public.AuthResult, error) } // AccessTokenIDPResponse is an interface that defines the method for getting the access token. +// Returns ErrAccessTokenNotFound if the access token is not set. type AccessTokenIDPResponse interface { - AccessToken() azcore.AccessToken + // AccessToken returns the Azure SDK AccessToken. + // Returns ErrAccessTokenNotFound if the access token is not set. + AccessToken() (azcore.AccessToken, error) } // RawTokenIDPResponse is an interface that defines the method for getting the raw token. +// Returns ErrRawTokenNotFound if the raw token is not set. type RawTokenIDPResponse interface { - RawToken() string + // RawToken returns the raw token string. + // Returns ErrRawTokenNotFound if the raw token is not set. + RawToken() (string, error) } // IdentityProvider is an interface that defines the methods for an identity provider. // It is used to request a token for authentication. // The identity provider is responsible for providing the raw authentication token. +// Available errors: +// - ErrInvalidIDPResponse: When the response from the identity provider is invalid +// - ErrInvalidIDPResponseType: When the response type is not supported +// - ErrAuthResultNotFound: When trying to get an AuthResult that is not set +// - ErrAccessTokenNotFound: When trying to get an AccessToken that is not set +// - ErrRawTokenNotFound: When trying to get a RawToken that is not set type IdentityProvider interface { // RequestToken requests a token from the identity provider. // The context is passed to the request to allow for cancellation and timeouts. diff --git a/shared/identity_provider_response_test.go b/shared/identity_provider_response_test.go index 715dce2..084ad09 100644 --- a/shared/identity_provider_response_test.go +++ b/shared/identity_provider_response_test.go @@ -22,22 +22,25 @@ func (m *mockIDPResponse) Type() string { return m.responseType } -func (m *mockIDPResponse) AuthResult() public.AuthResult { +func (m *mockIDPResponse) AuthResult() (public.AuthResult, error) { if m.authResult == nil { - return public.AuthResult{} + return public.AuthResult{}, ErrAuthResultNotFound } - return *m.authResult + return *m.authResult, nil } -func (m *mockIDPResponse) AccessToken() azcore.AccessToken { +func (m *mockIDPResponse) AccessToken() (azcore.AccessToken, error) { if m.accessToken == nil { - return azcore.AccessToken{} + return azcore.AccessToken{}, ErrAccessTokenNotFound } - return *m.accessToken + return *m.accessToken, nil } -func (m *mockIDPResponse) RawToken() string { - return m.rawToken +func (m *mockIDPResponse) RawToken() (string, error) { + if m.rawToken == "" { + return "", ErrRawTokenNotFound + } + return m.rawToken, nil } type mockIDPParser struct { @@ -105,7 +108,7 @@ func TestNewIDPResponse(t *testing.T) { name: "Nil result", responseType: ResponseTypeAuthResult, result: nil, - expectedError: "result cannot be nil", + expectedError: ErrInvalidIDPResponse.Error(), }, { name: "Nil string pointer", @@ -156,12 +159,24 @@ func TestNewIDPResponse(t *testing.T) { switch tt.responseType { case ResponseTypeAuthResult: - assert.NotNil(t, resp.(AuthResultIDPResponse).AuthResult()) + response, ok := resp.(AuthResultIDPResponse) + assert.True(t, ok) + res, err := response.AuthResult() + assert.NoError(t, err) + assert.NotNil(t, res) case ResponseTypeAccessToken: - assert.NotNil(t, resp.(AccessTokenIDPResponse).AccessToken()) - assert.NotEmpty(t, resp.(AccessTokenIDPResponse).AccessToken().Token) + response, ok := resp.(AccessTokenIDPResponse) + assert.True(t, ok) + res, err := response.AccessToken() + assert.NoError(t, err) + assert.NotNil(t, res) + assert.NotEmpty(t, res.Token) case ResponseTypeRawToken: - assert.NotEmpty(t, resp.(RawTokenIDPResponse).RawToken()) + response, ok := resp.(RawTokenIDPResponse) + assert.True(t, ok) + res, err := response.RawToken() + assert.NoError(t, err) + assert.NotNil(t, res) } }) } @@ -187,50 +202,55 @@ func TestIdentityProviderResponse(t *testing.T) { tests := []struct { name string - response *mockIDPResponse + responseType string + result interface{} expectedType string }{ { - name: "AuthResult response", - response: &mockIDPResponse{ - responseType: ResponseTypeAuthResult, - authResult: authResult, - }, + name: "AuthResult response", + responseType: ResponseTypeAuthResult, + result: authResult, expectedType: ResponseTypeAuthResult, }, { - name: "AccessToken response", - response: &mockIDPResponse{ - responseType: ResponseTypeAccessToken, - accessToken: accessToken, - }, + name: "AccessToken response", + responseType: ResponseTypeAccessToken, + result: accessToken, expectedType: ResponseTypeAccessToken, }, { - name: "RawToken response", - response: &mockIDPResponse{ - responseType: ResponseTypeRawToken, - rawToken: "test-raw-token", - }, + name: "RawToken response", + responseType: ResponseTypeRawToken, + result: "test-raw-token", expectedType: ResponseTypeRawToken, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expectedType, tt.response.Type()) - + response, err := NewIDPResponse(tt.responseType, tt.result) + assert.NoError(t, err) switch tt.expectedType { case ResponseTypeAuthResult: - result := tt.response.AuthResult() + typedResponse, ok := response.(AuthResultIDPResponse) + assert.True(t, ok) + result, err := typedResponse.AuthResult() + assert.NoError(t, err) assert.Equal(t, authResult.AccessToken, result.AccessToken) assert.Equal(t, authResult.ExpiresOn, result.ExpiresOn) case ResponseTypeAccessToken: - token := tt.response.AccessToken() + typedResponse, ok := response.(AccessTokenIDPResponse) + assert.True(t, ok) + token, err := typedResponse.AccessToken() + assert.NoError(t, err) assert.Equal(t, accessToken.Token, token.Token) assert.Equal(t, accessToken.ExpiresOn, token.ExpiresOn) case ResponseTypeRawToken: - assert.Equal(t, "test-raw-token", tt.response.RawToken()) + typedResponse, ok := response.(RawTokenIDPResponse) + assert.True(t, ok) + rawToken, err := typedResponse.RawToken() + assert.NoError(t, err) + assert.Equal(t, "test-raw-token", rawToken) } }) } @@ -271,7 +291,14 @@ func TestIdentityProvider(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, response) assert.Equal(t, ResponseTypeRawToken, response.Type()) - assert.Equal(t, "test-token", response.(RawTokenIDPResponse).RawToken()) + rawTokenResponse, ok := response.(RawTokenIDPResponse) + assert.True(t, ok) + assert.NotNil(t, rawTokenResponse) + // Check the raw token value + rawToken, err := rawTokenResponse.RawToken() + assert.NoError(t, err) + assert.NotEmpty(t, rawToken) + assert.Equal(t, "test-token", rawToken) } }) } From aba78bf8a06865a38f49d6d90c49d9e0e2fe061c Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 28 Apr 2025 11:20:13 +0300 Subject: [PATCH 36/44] chore(comments): Updated some comments --- credentials_provider.go | 2 +- internal/utils.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/credentials_provider.go b/credentials_provider.go index 30ff8e5..4d25ed6 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -60,7 +60,7 @@ func (e *entraidCredentialsProvider) onTokenError(err error) { // // Returns: // - auth.Credentials: The current credentials for the listener. -// - auth.CancelProviderFunc: A function that can be called to unsubscribe the listener. +// - auth.UnsubscribeFunc: A function that can be called to unsubscribe the listener. // - error: An error if the subscription fails, such as if the token cannot be retrieved. // // Note: If the listener is already subscribed, it will not receive duplicate notifications. diff --git a/internal/utils.go b/internal/utils.go index ba82f1c..46b3842 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -1,6 +1,10 @@ package internal // IsClosed checks if a channel is closed. +// +// NOTE: It returns true if the channel is closed as well +// as if the channel is not empty. Used internally +// to check if the channel is closed. func IsClosed(ch <-chan struct{}) bool { select { case <-ch: From 76cefccfdfc14992d560f5f0e1cf2e6d59c14be7 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 28 Apr 2025 15:13:44 +0300 Subject: [PATCH 37/44] fix(manager): requestTimeout and parser Set default requestTimeout to 30 seconds --- manager/defaults.go | 31 +++++++++++++++++++++++++------ manager/token_manager.go | 2 ++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/manager/defaults.go b/manager/defaults.go index 5303e4f..090cea3 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -13,6 +13,7 @@ import ( ) const ( + DefaultRequestTimeout = 30 * time.Second DefaultExpirationRefreshRatio = 0.7 DefaultRetryOptionsMaxAttempts = 3 DefaultRetryOptionsBackoffMultiplier = 2.0 @@ -85,6 +86,9 @@ func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptio if options.ExpirationRefreshRatio == 0 { options.ExpirationRefreshRatio = DefaultExpirationRefreshRatio } + if options.RequestTimeout == 0 { + options.RequestTimeout = DefaultRequestTimeout + } return options } @@ -108,16 +112,31 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden if err != nil { return nil, fmt.Errorf("failed to get auth result: %w", err) } - if authResult.ExpiresOn.IsZero() { - return nil, fmt.Errorf("auth result expiration time is not set") + + claims := struct { + jwt.RegisteredClaims + Oid string `json:"oid,omitempty"` + }{} + + // Parse the token to extract claims, but note that signature verification + // should be handled by the identity provider + _, _, err = jwt.NewParser().ParseUnverified(authResult.AccessToken, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT token: %w", err) } - if authResult.IDToken.Oid == "" { + + if claims.Oid == "" { return nil, fmt.Errorf("auth result OID is empty") } - rawToken = authResult.IDToken.RawToken - username = authResult.IDToken.Oid + + if claims.ExpiresAt.IsZero() { + return nil, fmt.Errorf("auth result expiration time is not set") + } + + rawToken = authResult.AccessToken + username = claims.Oid password = rawToken - expiresOn = authResult.ExpiresOn.UTC() + expiresOn = claims.ExpiresAt.UTC() case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken: var tokenStr string diff --git a/manager/token_manager.go b/manager/token_manager.go index 211a659..fe022ca 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -44,6 +44,8 @@ type TokenManagerOptions struct { RetryOptions RetryOptions // RequestTimeout is the timeout for the request to the identity provider. + // + // default: 30 seconds RequestTimeout time.Duration } From 3b9fcb7f29dd27872becaea12f6c8530db831e28 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 28 Apr 2025 15:53:53 +0300 Subject: [PATCH 38/44] refactor(manager)!: remove Stop method Remove Stop method, fix tests. --- credentials_provider_test.go | 6 ++--- entraid_test.go | 4 +-- manager/defaults.go | 2 +- manager/manager_test.go | 21 ++++++++++++++++ manager/token_manager.go | 8 +++--- manager/token_manager_test.go | 47 +++++++++++++++++++---------------- 6 files changed, 55 insertions(+), 33 deletions(-) diff --git a/credentials_provider_test.go b/credentials_provider_test.go index bb4871d..d9bfb3f 100644 --- a/credentials_provider_test.go +++ b/credentials_provider_test.go @@ -343,7 +343,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { mtm.On("GetToken", false).Return(testToken, nil) mtm.On("Start", mock.Anything). Run(mockTokenManagerLoop(mtm, tokenExpiration, testToken, nil)). - Return(manager.StopFunc(mtm.Stop), nil) + Return(manager.StopFunc(mtm.stop), nil) provider, err := NewConfidentialCredentialsProvider(options) require.NoError(t, err) require.NotNil(t, provider) @@ -396,7 +396,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { mtm.On("Start", mock.Anything). Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). - Return(manager.StopFunc(mtm.Stop), nil) + Return(manager.StopFunc(mtm.stop), nil) provider, err := NewConfidentialCredentialsProvider(options) require.NoError(t, err) require.NotNil(t, provider) @@ -467,7 +467,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) { mtm.On("Start", mock.Anything). Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)). - Return(manager.StopFunc(mtm.Stop), nil) + Return(manager.StopFunc(mtm.stop), nil) provider, err := NewConfidentialCredentialsProvider(options) require.NoError(t, err) require.NotNil(t, provider) diff --git a/entraid_test.go b/entraid_test.go index 5f64167..0a66685 100644 --- a/entraid_test.go +++ b/entraid_test.go @@ -84,7 +84,7 @@ func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.StopFu }, nil } -func (m *fakeTokenManager) Stop() error { +func (m *fakeTokenManager) stop() error { return nil } @@ -163,7 +163,7 @@ func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.StopFu m.lock.Unlock() return args.Get(0).(manager.StopFunc), args.Error(1) } -func (m *mockTokenManager) Stop() error { +func (m *mockTokenManager) stop() error { m.lock.Lock() defer m.lock.Unlock() if m.listener == nil { diff --git a/manager/defaults.go b/manager/defaults.go index 090cea3..8722f17 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -104,7 +104,7 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden var username, password, rawToken string var expiresOn time.Time - now := time.Now().UTC() + now := time.Now().UTC().Truncate(time.Second) switch response.Type() { case shared.ResponseTypeAuthResult: diff --git a/manager/manager_test.go b/manager/manager_test.go index 5150981..ba5cdff 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -8,6 +8,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" + "github.com/golang-jwt/jwt/v5" "github.com/redis-developer/go-redis-entraid/shared" "github.com/redis-developer/go-redis-entraid/token" "github.com/stretchr/testify/mock" @@ -63,6 +64,26 @@ var testTokenValid = token.New( int64(time.Hour), ) +func newTestJWTToken(expiresOn time.Time) string { + claims := struct { + jwt.RegisteredClaims + Oid string `json:"oid,omitempty"` + }{} + + // Parse the token to extract claims, but note that signature verification + // should be handled by the identity provider + _, _, err := jwt.NewParser().ParseUnverified(testJWTToken, &claims) + if err != nil { + panic(err) + } + claims.ExpiresAt = jwt.NewNumericDate(expiresOn) + tokenStr, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("qwertyuiopasdfghjklzxcvbnm123456")) + if err != nil { + panic(err) + } + return tokenStr +} + type mockIdentityProviderResponseParser struct { // Mock implementation of the IdentityProviderResponseParser interface mock.Mock diff --git a/manager/token_manager.go b/manager/token_manager.go index fe022ca..9933174 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -83,8 +83,6 @@ type TokenManager interface { GetToken(forceRefresh bool) (*token.Token, error) // Start starts the token manager and returns a channel that will receive updates. Start(listener TokenListener) (StopFunc, error) - // Stop stops the token manager and releases any resources. - Stop() error } // StopFunc is a function that stops the token manager. @@ -327,11 +325,11 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { } }(listener, e.closedChan) - return e.Stop, nil + return e.stop, nil } -// Stop closes the token manager and releases any resources. -func (e *entraidTokenManager) Stop() (err error) { +// stop closes the token manager and releases any resources. +func (e *entraidTokenManager) stop() (err error) { e.lock.Lock() defer e.lock.Unlock() defer func() { diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 2fb2e97..71dc574 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -155,6 +155,7 @@ func TestTokenManager_Close(t *testing.T) { t.Parallel() t.Run("Close", func(t *testing.T) { t.Parallel() + var err error idp := &mockIdentityProvider{} listener := &mockTokenListener{} mParser := &mockIdentityProviderResponseParser{} @@ -169,7 +170,7 @@ func TestTokenManager_Close(t *testing.T) { assert.True(t, ok) assert.Nil(t, tm.listener) assert.NotPanics(t, func() { - err = tokenManager.Stop() + err = tm.stop() assert.Error(t, err) }) rawResponse, err := shared.NewIDPResponse(shared.ResponseTypeRawToken, "test") @@ -179,19 +180,20 @@ func TestTokenManager_Close(t *testing.T) { mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) listener.On("OnNext", testTokenValid).Return() + var stopper StopFunc assert.NotPanics(t, func() { - cancel, err := tokenManager.Start(listener) - assert.NotNil(t, cancel) + stopper, err = tokenManager.Start(listener) + assert.NotNil(t, stopper) assert.NoError(t, err) }) assert.NotNil(t, tm.listener) - err = tokenManager.Stop() + err = stopper() assert.Nil(t, tm.listener) assert.NoError(t, err) assert.NotPanics(t, func() { - err = tokenManager.Stop() + err = stopper() assert.Error(t, err) }) }) @@ -256,8 +258,8 @@ func TestTokenManager_Close(t *testing.T) { listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { - cancel, err := tokenManager.Start(listener) - assert.NotNil(t, cancel) + stopper, err := tokenManager.Start(listener) + assert.NotNil(t, stopper) assert.NoError(t, err) assert.NotNil(t, tm.listener) var hasStopped int @@ -272,7 +274,7 @@ func TestTokenManager_Close(t *testing.T) { go func() { defer wg.Done() time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) - err := tokenManager.Stop() + err := stopper() if err == nil { hasStopped += 1 return @@ -383,7 +385,7 @@ func TestTokenManager_Start(t *testing.T) { time.Sleep(time.Duration(int64(rand.Intn(1000)+(300-int(num)/2)) * int64(time.Millisecond))) last.Store(num) if num%2 == 0 { - err = tokenManager.Stop() + err = tm.stop() } else { l := &mockTokenListener{Id: num} l.On("OnNext", testTokenValid).Return() @@ -410,11 +412,11 @@ func TestTokenManager_Start(t *testing.T) { log.Printf("FAILING WITH lastExecution[STOPPED]: %d", lastExecution) } assert.NotNil(t, tm.listener) - cancel, err := tokenManager.Start(listener) - assert.Nil(t, cancel) + stopper, err := tokenManager.Start(listener) + assert.Nil(t, stopper) assert.Error(t, err) - // Close the token manager - err = tokenManager.Stop() + // Stop the token manager with internal stop, since stopper should be nil + err = tm.stop() assert.Nil(t, err) } assert.Nil(t, tm.listener) @@ -435,7 +437,7 @@ func TestDefaultIdentityProviderResponseParser(t *testing.T) { token1, err := parser.ParseResponse(idpResponse) assert.NoError(t, err) assert.NotNil(t, token1) - assert.Equal(t, authResultVal.ExpiresOn, token1.ExpirationOn()) + assert.InEpsilon(t, authResultVal.ExpiresOn.Unix(), token1.ExpirationOn().Unix(), 1) }) t.Run("Default IdentityProviderResponseParser with type AccessToken", func(t *testing.T) { t.Parallel() @@ -784,8 +786,8 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() listener.On("OnNext", token1).Return().Once() - cancel, err := tokenManager.Start(listener) - assert.NotNil(t, cancel) + stopper, err := tokenManager.Start(listener) + assert.NotNil(t, stopper) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -795,14 +797,15 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.True(t, expiresIn > toRenewal) <-time.After(toRenewal / 10) assert.NotNil(t, tm.listener) - assert.NoError(t, tokenManager.Stop()) + assert.NoError(t, stopper()) assert.Nil(t, tm.listener) assert.Panics(t, func() { close(tm.closedChan) }) <-time.After(toRenewal) - assert.Error(t, tokenManager.Stop()) + // already stopped + assert.Error(t, stopper()) mock.AssertExpectationsForObjects(t, idp, mParser, listener) }) @@ -1251,9 +1254,9 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { func testAuthResult(expiersOn time.Time) *public.AuthResult { r := &public.AuthResult{ - ExpiresOn: expiersOn, + ExpiresOn: expiersOn, + AccessToken: newTestJWTToken(expiersOn), } - r.IDToken.Oid = "test" return r } @@ -1333,14 +1336,14 @@ func BenchmarkTokenManager_Close(b *testing.B) { mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) listener.On("OnNext", testTokenValid).Return() - _, err = tokenManager.Start(listener) + stopper, err := tokenManager.Start(listener) if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - _ = tokenManager.Stop() + _ = stopper() } } From 57e500d81b3c1fa7f46e9c3983c15d323761208c Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 28 Apr 2025 16:10:40 +0300 Subject: [PATCH 39/44] refactor(tokenparser): simplify default parser Simplify the default identity provider response parser by extracting the raw token from the response and then parsing it as jwt token. --- internal/idp_response.go | 4 +- manager/defaults.go | 106 ++++++++++----------------- shared/identity_provider_response.go | 17 ++--- 3 files changed, 46 insertions(+), 81 deletions(-) diff --git a/internal/idp_response.go b/internal/idp_response.go index dcdb181..c78e9dd 100644 --- a/internal/idp_response.go +++ b/internal/idp_response.go @@ -36,17 +36,17 @@ func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) { default: return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T", result) } + r.rawTokenVal = r.authResultVal.AccessToken case "AccessToken": switch v := result.(type) { case *azcore.AccessToken: r.accessTokenVal = v - r.rawTokenVal = v.Token case azcore.AccessToken: r.accessTokenVal = &v - r.rawTokenVal = v.Token default: return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T", result) } + r.rawTokenVal = r.accessTokenVal.Token case "RawToken": switch v := result.(type) { case string: diff --git a/manager/defaults.go b/manager/defaults.go index 8722f17..f8ff16a 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -96,7 +96,8 @@ type defaultIdentityProviderResponseParser struct{} // ParseResponse parses the response from the identity provider and extracts the token. // It takes an IdentityProviderResponse as an argument and returns a Token and an error if any. -// The IdentityProviderResponse contains the raw token and the expiration time. +// The raw token is extracted based on the IdentityProviderResponse Type and then +// is parsed as a JWT token to extract the claims. func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) { if response == nil { return nil, fmt.Errorf("identity provider response cannot be nil") @@ -113,82 +114,51 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden return nil, fmt.Errorf("failed to get auth result: %w", err) } - claims := struct { - jwt.RegisteredClaims - Oid string `json:"oid,omitempty"` - }{} - - // Parse the token to extract claims, but note that signature verification - // should be handled by the identity provider - _, _, err = jwt.NewParser().ParseUnverified(authResult.AccessToken, &claims) - if err != nil { - return nil, fmt.Errorf("failed to parse JWT token: %w", err) - } - - if claims.Oid == "" { - return nil, fmt.Errorf("auth result OID is empty") - } - - if claims.ExpiresAt.IsZero() { - return nil, fmt.Errorf("auth result expiration time is not set") - } - + expiresOn = authResult.ExpiresOn.UTC() rawToken = authResult.AccessToken - username = claims.Oid - password = rawToken - expiresOn = claims.ExpiresAt.UTC() - - case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken: - var tokenStr string - var err error - if response.Type() == shared.ResponseTypeRawToken { - tokenStr, err = response.(shared.RawTokenIDPResponse).RawToken() - if err != nil { - return nil, fmt.Errorf("failed to get raw token: %w", err) - } - } - if response.Type() == shared.ResponseTypeAccessToken { - accessToken, err := response.(shared.AccessTokenIDPResponse).AccessToken() - if err != nil { - return nil, fmt.Errorf("failed to get access token: %w", err) - } - if accessToken.Token == "" { - return nil, fmt.Errorf("access token value is empty") - } - tokenStr = accessToken.Token - expiresOn = accessToken.ExpiresOn.UTC() - } - - if tokenStr == "" { - return nil, fmt.Errorf("raw token is empty") + case shared.ResponseTypeAccessToken: + accessToken, err := response.(shared.AccessTokenIDPResponse).AccessToken() + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) } - claims := struct { - jwt.RegisteredClaims - Oid string `json:"oid,omitempty"` - }{} - - // Parse the token to extract claims, but note that signature verification - // should be handled by the identity provider - _, _, err = jwt.NewParser().ParseUnverified(tokenStr, &claims) + rawToken = accessToken.Token + expiresOn = accessToken.ExpiresOn.UTC() + case shared.ResponseTypeRawToken: + tokenStr, err := response.(shared.RawTokenIDPResponse).RawToken() if err != nil { - return nil, fmt.Errorf("failed to parse JWT token: %w", err) + return nil, fmt.Errorf("failed to get raw token: %w", err) } + rawToken = tokenStr + default: + return nil, fmt.Errorf("unsupported response type: %s", response.Type()) + } - if claims.Oid == "" { - return nil, fmt.Errorf("JWT token does not contain OID claim") - } + if rawToken == "" { + return nil, fmt.Errorf("raw token is empty") + } - rawToken = tokenStr - username = claims.Oid - password = rawToken + // Parse JWT + claims := struct { + jwt.RegisteredClaims + Oid string `json:"oid,omitempty"` + }{} + + // Parse the token to extract claims, but note that signature verification + // should be handled by the identity provider + _, _, err := jwt.NewParser().ParseUnverified(rawToken, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT token: %w", err) + } - if expiresOn.IsZero() && claims.ExpiresAt != nil { - expiresOn = claims.ExpiresAt.UTC() - } + if claims.Oid == "" { + return nil, fmt.Errorf("JWT token does not contain OID claim") + } - default: - return nil, fmt.Errorf("unsupported response type: %s", response.Type()) + username = claims.Oid + password = rawToken + if expiresOn.IsZero() && claims.ExpiresAt != nil { + expiresOn = claims.ExpiresAt.UTC() } if expiresOn.IsZero() { diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go index 1eb4b4f..c3d638a 100644 --- a/shared/identity_provider_response.go +++ b/shared/identity_provider_response.go @@ -43,11 +43,15 @@ type IdentityProviderResponseParser interface { type IdentityProviderResponse interface { // Type returns the type of identity provider response Type() string + // RawToken returns the raw token string. + // Returns ErrRawTokenNotFound if the raw token is not set. + RawToken() (string, error) } // AuthResultIDPResponse is an interface that defines the method for getting the auth result. // Returns ErrAuthResultNotFound if the auth result is not set. type AuthResultIDPResponse interface { + IdentityProviderResponse // AuthResult returns the Microsoft Authentication Library AuthResult. // Returns ErrAuthResultNotFound if the auth result is not set. AuthResult() (public.AuthResult, error) @@ -56,28 +60,19 @@ type AuthResultIDPResponse interface { // AccessTokenIDPResponse is an interface that defines the method for getting the access token. // Returns ErrAccessTokenNotFound if the access token is not set. type AccessTokenIDPResponse interface { + IdentityProviderResponse // AccessToken returns the Azure SDK AccessToken. // Returns ErrAccessTokenNotFound if the access token is not set. AccessToken() (azcore.AccessToken, error) } -// RawTokenIDPResponse is an interface that defines the method for getting the raw token. -// Returns ErrRawTokenNotFound if the raw token is not set. type RawTokenIDPResponse interface { - // RawToken returns the raw token string. - // Returns ErrRawTokenNotFound if the raw token is not set. - RawToken() (string, error) + IdentityProviderResponse } // IdentityProvider is an interface that defines the methods for an identity provider. // It is used to request a token for authentication. // The identity provider is responsible for providing the raw authentication token. -// Available errors: -// - ErrInvalidIDPResponse: When the response from the identity provider is invalid -// - ErrInvalidIDPResponseType: When the response type is not supported -// - ErrAuthResultNotFound: When trying to get an AuthResult that is not set -// - ErrAccessTokenNotFound: When trying to get an AccessToken that is not set -// - ErrRawTokenNotFound: When trying to get a RawToken that is not set type IdentityProvider interface { // RequestToken requests a token from the identity provider. // The context is passed to the request to allow for cancellation and timeouts. From 7755fb934a7d1d06b60272206cdbc1efb6d24052 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 28 Apr 2025 16:53:48 +0300 Subject: [PATCH 40/44] test(parser): add additional tests for bad token --- internal/errors.go | 1 - manager/manager_test.go | 19 +++++++++++++++++++ manager/token_manager_test.go | 28 ++++++++++++++++++++++++++++ shared/identity_provider_response.go | 1 - 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/internal/errors.go b/internal/errors.go index 6e88ecf..d45c271 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -3,7 +3,6 @@ package internal import "fmt" var ErrInvalidIDPResponse = fmt.Errorf("invalid identity provider response") -var ErrInvalidIDPResponseType = fmt.Errorf("invalid identity provider response type") var ErrAuthResultNotFound = fmt.Errorf("auth result not found") var ErrAccessTokenNotFound = fmt.Errorf("access token not found") var ErrRawTokenNotFound = fmt.Errorf("raw token not found") diff --git a/manager/manager_test.go b/manager/manager_test.go index ba5cdff..85eafa0 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -84,6 +84,25 @@ func newTestJWTToken(expiresOn time.Time) string { return tokenStr } +func newTestJWTTokenWithoutOID(expiresOn time.Time) string { + claims := struct { + jwt.RegisteredClaims + }{} + + // Parse the token to extract claims, but note that signature verification + // should be handled by the identity provider + _, _, err := jwt.NewParser().ParseUnverified(testJWTToken, &claims) + if err != nil { + panic(err) + } + claims.ExpiresAt = jwt.NewNumericDate(expiresOn) + tokenStr, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("qwertyuiopasdfghjklzxcvbnm123456")) + if err != nil { + panic(err) + } + return tokenStr +} + type mockIdentityProviderResponseParser struct { // Mock implementation of the IdentityProviderResponseParser interface mock.Mock diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index 71dc574..f576695 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -439,6 +439,34 @@ func TestDefaultIdentityProviderResponseParser(t *testing.T) { assert.NotNil(t, token1) assert.InEpsilon(t, authResultVal.ExpiresOn.Unix(), token1.ExpirationOn().Unix(), 1) }) + t.Run("Default IdentityProviderResponseParser with type AuthResult and empty token", func(t *testing.T) { + t.Parallel() + authResultVal := &public.AuthResult{ + ExpiresOn: time.Now().Add(time.Hour).UTC(), + AccessToken: "", + } + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) + t.Run("Default IdentityProviderResponseParser with type AuthResult and token without oid", func(t *testing.T) { + t.Parallel() + authResultVal := &public.AuthResult{ + ExpiresOn: time.Now().Add(time.Hour).UTC(), + AccessToken: newTestJWTTokenWithoutOID(time.Now().Add(time.Hour).UTC()), + } + idpResponse := &authResult{ + ResultType: shared.ResponseTypeAuthResult, + AuthResultVal: authResultVal, + } + token1, err := parser.ParseResponse(idpResponse) + assert.Error(t, err) + assert.Nil(t, token1) + }) t.Run("Default IdentityProviderResponseParser with type AccessToken", func(t *testing.T) { t.Parallel() accessToken := &azcore.AccessToken{ diff --git a/shared/identity_provider_response.go b/shared/identity_provider_response.go index c3d638a..86265ec 100644 --- a/shared/identity_provider_response.go +++ b/shared/identity_provider_response.go @@ -19,7 +19,6 @@ const ( ) var ErrInvalidIDPResponse = internal.ErrInvalidIDPResponse -var ErrInvalidIDPResponseType = internal.ErrInvalidIDPResponseType var ErrAuthResultNotFound = internal.ErrAuthResultNotFound var ErrAccessTokenNotFound = internal.ErrAccessTokenNotFound var ErrRawTokenNotFound = internal.ErrRawTokenNotFound From 676cf1c0357e839ae48a4bf78fed9baf82e1ae77 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 13 May 2025 16:54:54 +0300 Subject: [PATCH 41/44] fix(manager): starting and stopping the manager Starting and stopping the manager can be executed multiple times --- credentials_provider.go | 31 +++++++++---- manager/token_manager.go | 30 ++++++++----- manager/token_manager_test.go | 83 ++++++++++++++++++++++++++--------- 3 files changed, 104 insertions(+), 40 deletions(-) diff --git a/credentials_provider.go b/credentials_provider.go index 4d25ed6..64f6f85 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -27,6 +27,8 @@ type entraidCredentialsProvider struct { // rwLock is a mutex that is used to synchronize access to the listeners slice. rwLock sync.RWMutex // Mutex for synchronizing access to the listeners slice. + + tmLock sync.Mutex } // onTokenNext is a method that is called when the token manager receives a new token. @@ -65,11 +67,25 @@ func (e *entraidCredentialsProvider) onTokenError(err error) { // // Note: If the listener is already subscribed, it will not receive duplicate notifications. func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { - // First try to get a token, only then subscribe the listener. - token, err := e.tokenManager.GetToken(false) - if err != nil { - return nil, nil, err + var token *token.Token + // check if the manager is working + // If the stopTokenManager is nil, the token manager is not started. + e.tmLock.Lock() + if e.stopTokenManager == nil { + t, stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e)) + if err != nil { + return nil, nil, fmt.Errorf("couldn't start token manager: %w", err) + } + e.stopTokenManager = stopTM + token = t + } else { + t, err := e.tokenManager.GetToken(false) + if err != nil { + return nil, nil, fmt.Errorf("couldn't get token: %w", err) + } + token = t } + e.tmLock.Unlock() e.rwLock.Lock() // Check if the listener is already in the list of listeners. @@ -102,6 +118,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener // Clear the listeners slice if it's empty if len(e.listeners) == 0 { e.listeners = make([]auth.CredentialsListener, 0) + e.tmLock.Lock() if e.stopTokenManager != nil { err := e.stopTokenManager() if err != nil { @@ -111,6 +128,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener // This prevents multiple calls to stopTokenManager. e.stopTokenManager = nil } + e.tmLock.Unlock() } return nil } @@ -134,10 +152,5 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia options: options, listeners: make([]auth.CredentialsListener, 0), } - stopTM, err := cp.tokenManager.Start(tokenListenerFromCP(cp)) - if err != nil { - return nil, fmt.Errorf("couldn't start token manager: %w", err) - } - cp.stopTokenManager = stopTM return cp, nil } diff --git a/manager/token_manager.go b/manager/token_manager.go index 9933174..83647f4 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -82,7 +82,7 @@ type TokenManager interface { // It takes a boolean value forceRefresh as an argument. GetToken(forceRefresh bool) (*token.Token, error) // Start starts the token manager and returns a channel that will receive updates. - Start(listener TokenListener) (StopFunc, error) + Start(listener TokenListener) (*token.Token, StopFunc, error) } // StopFunc is a function that stops the token manager. @@ -239,11 +239,11 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) // // Note: The initial token is delivered synchronously. // The TokenListener will receive the token immediately, before the token manager goroutine starts. -func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { +func (e *entraidTokenManager) Start(listener TokenListener) (*token.Token, StopFunc, error) { e.lock.Lock() defer e.lock.Unlock() if e.listener != nil { - return nil, ErrTokenManagerAlreadyStarted + return nil, nil, ErrTokenManagerAlreadyStarted } if e.closedChan != nil && !internal.IsClosed(e.closedChan) { @@ -252,15 +252,25 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { close(e.closedChan) } - t, err := e.GetToken(true) + ctx, ctxCancel := context.WithCancel(context.Background()) + e.ctx = ctx + e.ctxCancel = ctxCancel + + t, err := e.GetToken(false) + // If a token was found in the cache, check if: + // - it is expired (based on the lower bound) + // - it is about to expire (based on the expiration refresh ratio) + // if so, get a new token + expirationRefreshTime := t.ReceivedAt().Add(time.Duration(float64(t.TTL()) * float64(time.Second) * e.expirationRefreshRatio)) + expirationWithoutLowerBound := t.ExpirationOn().Add(-1 * e.lowerBoundDuration) + now := time.Now() + if t != nil && (expirationWithoutLowerBound.Before(now) || expirationRefreshTime.Before(now)) { + t, err = e.GetToken(true) + } if err != nil { - go listener.OnError(err) - return nil, fmt.Errorf("failed to start token manager: %w", err) + return nil, nil, fmt.Errorf("failed to start token manager: %w", err) } - // Deliver initial token synchronously - listener.OnNext(t) - e.closedChan = make(chan struct{}) e.listener = listener @@ -325,7 +335,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { } }(listener, e.closedChan) - return e.stop, nil + return t, e.stop, nil } // stop closes the token manager and releases any resources. diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index f576695..e1d5d27 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -182,7 +182,7 @@ func TestTokenManager_Close(t *testing.T) { var stopper StopFunc assert.NotPanics(t, func() { - stopper, err = tokenManager.Start(listener) + _, stopper, err = tokenManager.Start(listener) assert.NotNil(t, stopper) assert.NoError(t, err) }) @@ -222,7 +222,7 @@ func TestTokenManager_Close(t *testing.T) { listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -258,7 +258,7 @@ func TestTokenManager_Close(t *testing.T) { listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { - stopper, err := tokenManager.Start(listener) + _, stopper, err := tokenManager.Start(listener) assert.NotNil(t, stopper) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -329,7 +329,7 @@ func TestTokenManager_Start(t *testing.T) { go func() { defer wg.Done() time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) - _, err := tokenManager.Start(listener) + _, _, err := tokenManager.Start(listener) if err == nil { hasStarted += 1 return @@ -344,7 +344,7 @@ func TestTokenManager_Start(t *testing.T) { assert.NotNil(t, tm.listener) assert.Equal(t, 1, hasStarted) assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStarted)) - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.Nil(t, cancel) assert.Error(t, err) assert.NotNil(t, tm.listener) @@ -389,7 +389,7 @@ func TestTokenManager_Start(t *testing.T) { } else { l := &mockTokenListener{Id: num} l.On("OnNext", testTokenValid).Return() - _, err = tokenManager.Start(l) + _, _, err = tokenManager.Start(l) } if err != nil { if err != ErrTokenManagerAlreadyStopped && err != ErrTokenManagerAlreadyStarted { @@ -412,7 +412,7 @@ func TestTokenManager_Start(t *testing.T) { log.Printf("FAILING WITH lastExecution[STOPPED]: %d", lastExecution) } assert.NotNil(t, tm.listener) - stopper, err := tokenManager.Start(listener) + _, stopper, err := tokenManager.Start(listener) assert.Nil(t, stopper) assert.Error(t, err) // Stop the token manager with internal stop, since stopper should be nil @@ -588,7 +588,8 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) listener.On("OnNext", testTokenValid).Return() - cancel, err := tokenManager.Start(listener) + initialToken, cancel, err := tokenManager.Start(listener) + assert.NotNil(t, initialToken) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -598,6 +599,46 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { assert.NotNil(t, token1) }) + t.Run("GetToken with cached token", func(t *testing.T) { + t.Parallel() + idp := &mockIdentityProvider{} + listener := &mockTokenListener{} + mParser := &mockIdentityProviderResponseParser{} + tokenManager, err := NewTokenManager(idp, + TokenManagerOptions{ + IdentityProviderResponseParser: mParser, + }, + ) + assert.NoError(t, err) + assert.NotNil(t, tokenManager) + tm, ok := tokenManager.(*entraidTokenManager) + assert.True(t, ok) + assert.Nil(t, tm.listener) + + rawResponse := &authResult{ + ResultType: shared.ResponseTypeRawToken, + RawTokenVal: "test", + } + + idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) + mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) + listener.On("OnNext", testTokenValid).Return() + + initialToken, cancel, err := tokenManager.Start(listener) + assert.NotNil(t, initialToken) + assert.NotNil(t, cancel) + assert.NoError(t, err) + assert.NotNil(t, tm.listener) + + token1, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.NotNil(t, token1) + + token2, err := tokenManager.GetToken(false) + assert.NoError(t, err) + assert.Equal(t, token1, token2) + }) + t.Run("GetToken with parse error", func(t *testing.T) { t.Parallel() idp := &mockIdentityProvider{} @@ -623,7 +664,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { mParser.On("ParseResponse", rawResponse).Return(nil, fmt.Errorf("parse error")) listener.On("OnError", mock.Anything).Return() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.Error(t, err) assert.Nil(t, cancel) assert.Nil(t, tm.listener) @@ -814,7 +855,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() listener.On("OnNext", token1).Return().Once() - stopper, err := tokenManager.Start(listener) + _, stopper, err := tokenManager.Start(listener) assert.NotNil(t, stopper) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -881,7 +922,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -938,7 +979,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -991,7 +1032,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1041,7 +1082,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.NotNil(t, err) }).Return().Maybe() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1095,7 +1136,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.NotNil(t, err) }).Return() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1174,7 +1215,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { close(maxAttemptsReached) }).Return() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1248,7 +1289,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { close(maxAttemptsReached) }).Return().Maybe() - cancel, err := tokenManager.Start(listener) + _, cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1338,7 +1379,7 @@ func BenchmarkTokenManager_Start(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = tokenManager.Start(listener) + _, _, _ = tokenManager.Start(listener) } } @@ -1364,7 +1405,7 @@ func BenchmarkTokenManager_Close(b *testing.B) { mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) listener.On("OnNext", testTokenValid).Return() - stopper, err := tokenManager.Start(listener) + _, stopper, err := tokenManager.Start(listener) if err != nil { b.Fatal(err) } @@ -1477,7 +1518,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { case 0: // Start the token manager with a new listener // t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j) - closeFunc, err := tm.Start(listener) + _, closeFunc, err := tm.Start(listener) if err != nil { if err != ErrTokenManagerAlreadyStarted { @@ -1655,7 +1696,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { }, } - closeFunc, err := tm.Start(finalListener) + _, closeFunc, err := tm.Start(finalListener) if err != nil && err != ErrTokenManagerAlreadyStarted { t.Fatalf("Failed to start token manager after concurrent operations: %v", err) } From 41af0a7599797c024c59fa266d3abe1b2fc3d5af Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 13 May 2025 20:15:36 +0300 Subject: [PATCH 42/44] fix(manager): wip, still have tests to resolve --- credentials_provider.go | 21 +-- entraid_test.go | 2 +- manager/defaults.go | 5 +- manager/entraid_manager.go | 267 +++++++++++++++++++++++++++++++++ manager/manager_test.go | 2 +- manager/token_manager.go | 269 +--------------------------------- manager/token_manager_test.go | 118 +++++---------- providers_test.go | 2 +- token/token.go | 3 + 9 files changed, 325 insertions(+), 364 deletions(-) create mode 100644 manager/entraid_manager.go diff --git a/credentials_provider.go b/credentials_provider.go index 64f6f85..0cd86bc 100644 --- a/credentials_provider.go +++ b/credentials_provider.go @@ -67,26 +67,23 @@ func (e *entraidCredentialsProvider) onTokenError(err error) { // // Note: If the listener is already subscribed, it will not receive duplicate notifications. func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { - var token *token.Token // check if the manager is working // If the stopTokenManager is nil, the token manager is not started. e.tmLock.Lock() if e.stopTokenManager == nil { - t, stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e)) + stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e)) if err != nil { return nil, nil, fmt.Errorf("couldn't start token manager: %w", err) } e.stopTokenManager = stopTM - token = t - } else { - t, err := e.tokenManager.GetToken(false) - if err != nil { - return nil, nil, fmt.Errorf("couldn't get token: %w", err) - } - token = t } e.tmLock.Unlock() + token, err := e.tokenManager.GetToken(false) + if err != nil { + return nil, nil, fmt.Errorf("couldn't get token: %w", err) + } + e.rwLock.Lock() // Check if the listener is already in the list of listeners. alreadySubscribed := false @@ -152,5 +149,11 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia options: options, listeners: make([]auth.CredentialsListener, 0), } + // Start the token manager. + stop, err := tokenManager.Start(tokenListenerFromCP(cp)) + if err != nil { + return nil, fmt.Errorf("couldn't start token manager: %w", err) + } + cp.stopTokenManager = stop return cp, nil } diff --git a/entraid_test.go b/entraid_test.go index 0a66685..a824e29 100644 --- a/entraid_test.go +++ b/entraid_test.go @@ -48,7 +48,7 @@ func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { rawTokenString, time.Now().Add(tokenExpiration), time.Now(), - int64(100*time.Millisecond), + int64(tokenExpiration.Seconds()), ) } return m.token, m.err diff --git a/manager/defaults.go b/manager/defaults.go index f8ff16a..10f4ba5 100644 --- a/manager/defaults.go +++ b/manager/defaults.go @@ -3,6 +3,7 @@ package manager import ( "errors" "fmt" + "math" "net" "os" "time" @@ -105,7 +106,7 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden var username, password, rawToken string var expiresOn time.Time - now := time.Now().UTC().Truncate(time.Second) + now := time.Now().UTC().Truncate(time.Second).Add(time.Second) switch response.Type() { case shared.ResponseTypeAuthResult: @@ -176,6 +177,6 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden rawToken, expiresOn, now, - int64(time.Until(expiresOn).Seconds()), + int64(math.Ceil(time.Until(expiresOn).Seconds())), ), nil } diff --git a/manager/entraid_manager.go b/manager/entraid_manager.go new file mode 100644 index 0000000..4b8c716 --- /dev/null +++ b/manager/entraid_manager.go @@ -0,0 +1,267 @@ +package manager + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/redis-developer/go-redis-entraid/internal" + "github.com/redis-developer/go-redis-entraid/shared" + "github.com/redis-developer/go-redis-entraid/token" +) + +// entraidTokenManager is a struct that implements the TokenManager interface. +type entraidTokenManager struct { + // idp is the identity provider used to obtain the token. + idp shared.IdentityProvider + + // token is the authentication token for the user which should be kept in memory if valid. + token *token.Token + + // tokenRWLock is a read-write lock used to protect the token from concurrent access. + tokenRWLock sync.RWMutex + + // identityProviderResponseParser is the parser used to parse the response from the identity provider. + // It`s ParseResponse method will be called to parse the response and return the token. + identityProviderResponseParser shared.IdentityProviderResponseParser + + // retryOptions is a struct that contains the options for retrying the token request. + // It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. + // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. + // The values can be overridden by the user. + retryOptions RetryOptions + + // listener is the single listener for the token manager. + // It is used to receive updates from the token manager. + // The token manager will call the listener's OnNext method with the updated token. + // If an error occurs, the token manager will call the listener's OnError method with the error. + // if listener is set, Start will fail + listener TokenListener + + // lock locks the listener to prevent concurrent access. + lock sync.Mutex + + // expirationRefreshRatio is the ratio of the token expiration time to refresh the token. + // It is used to determine when to refresh the token. + // The value should be between 0 and 1. + // For example, if the expiration time is 1 hour and the ratio is 0.75, + // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) + expirationRefreshRatio float64 + + // lowerBoundDuration is the lower bound for the refresh time in time.Duration. + lowerBoundDuration time.Duration + + // closedChan is a channel that is closedChan when the token manager is closedChan. + // It is used to signal the token manager to stop requesting tokens. + closedChan chan struct{} + + // context is the context used to request the token from the identity provider. + ctx context.Context + + // ctxCancel is the cancel function for the context. + ctxCancel context.CancelFunc + + // requestTimeout is the timeout for the request to the identity provider. + requestTimeout time.Duration +} + +func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { + e.tokenRWLock.RLock() + // check if the token is nil and if it is not expired + t := e.token + duration := e.durationToRenewal(t) + if !forceRefresh && t != nil && duration > 0 { + e.tokenRWLock.RUnlock() + return t, nil + } + e.tokenRWLock.RUnlock() + + // start the context early, + // since at heavy concurrent load + // locks may take some time to acquire + ctx, ctxCancel := context.WithTimeout(e.ctx, e.requestTimeout) + defer ctxCancel() + + // Upgrade to write lock for token update + e.tokenRWLock.Lock() + defer e.tokenRWLock.Unlock() + + // Double-check pattern to avoid unnecessary token refresh + t = e.token + duration = e.durationToRenewal(t) + if !forceRefresh && t != nil && duration > 0 { + return t, nil + } + + // Request a new token from the identity provider + idpResult, err := e.idp.RequestToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to request token from idp: %w", err) + } + + t, err = e.identityProviderResponseParser.ParseResponse(idpResult) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + if t == nil { + return nil, fmt.Errorf("failed to get token: token is nil") + } + + // Store the token + e.token = t + // Return the token - no need to copy since it's immutable + return t, nil +} + +// Start starts the token manager and returns cancelFunc to stop the token manager. +// It takes a TokenListener as an argument, which is used to receive updates. +// The token manager will call the listener's OnNext method with the updated token. +// If an error occurs, the token manager will call the listener's OnError method with the error. +func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) { + e.lock.Lock() + defer e.lock.Unlock() + if e.listener != nil { + return nil, ErrTokenManagerAlreadyStarted + } + + if e.closedChan != nil && !internal.IsClosed(e.closedChan) { + // there is a hanging goroutine that is waiting for the closedChan to be closed + // if the closedChan is not nil and not closed, close it + close(e.closedChan) + } + + ctx, ctxCancel := context.WithCancel(context.Background()) + e.ctx = ctx + e.ctxCancel = ctxCancel + + // make sure there is token in memory before starting the loop + _, err := e.GetToken(false) + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + e.closedChan = make(chan struct{}) + e.listener = listener + + go func(listener TokenListener, closed <-chan struct{}) { + maxDelay := e.retryOptions.MaxDelay + initialDelay := e.retryOptions.InitialDelay + + for { + e.tokenRWLock.RLock() + timeToRenewal := e.durationToRenewal(e.token) + e.tokenRWLock.RUnlock() + select { + case <-closed: + return + case <-time.After(timeToRenewal): + if timeToRenewal == 0 { + // Token was requested immediately, guard against infinite loop + select { + case <-closed: + return + case <-time.After(initialDelay): + // continue to attempt + } + } + + // Token is about to expire, refresh it + delay := initialDelay + for i := 0; i < e.retryOptions.MaxAttempts; i++ { + t, err := e.GetToken(true) + if err == nil { + listener.OnNext(t) + break + } + + // check if err is retriable + if e.retryOptions.IsRetryable(err) { + if i == e.retryOptions.MaxAttempts-1 { + // last attempt, call OnError + listener.OnError(fmt.Errorf("max attempts reached: %w", err)) + return + } + + // Exponential backoff + if delay < maxDelay { + delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier) + } + if delay > maxDelay { + delay = maxDelay + } + + select { + case <-closed: + return + case <-time.After(delay): + // continue to next attempt + } + } else { + // not retriable + listener.OnError(err) + return + } + } + } + } + }(listener, e.closedChan) + + return e.stop, nil +} + +// stop closes the token manager and releases any resources. +func (e *entraidTokenManager) stop() (err error) { + e.lock.Lock() + defer e.lock.Unlock() + defer func() { + // recover from panic and return the error + if r := recover(); r != nil { + err = fmt.Errorf("failed to stop token manager: %s", r) + } + }() + + if e.closedChan == nil || e.listener == nil { + return ErrTokenManagerAlreadyStopped + } + + e.ctxCancel() + e.listener = nil + close(e.closedChan) + + return nil +} + +// durationToRenewal calculates the duration to the next token renewal. +// It returns the duration to the next token renewal based on the expiration refresh ratio and the lower bound duration. +// If the token is nil, it returns 0. +// If the time till expiration is less than the lower bound duration, it returns 0 to renew the token now. +func (e *entraidTokenManager) durationToRenewal(t *token.Token) time.Duration { + if t == nil { + return 0 + } + expirationRefreshTime := t.ReceivedAt().Add(time.Duration(float64(t.TTL()) * float64(time.Second) * e.expirationRefreshRatio)) + timeTillExpiration := time.Until(t.ExpirationOn()) + now := time.Now().UTC() + + if expirationRefreshTime.Before(now) { + return 0 + } + + // if the timeTillExpiration is less than the lower bound (or 0), return 0 to renew the token NOW + if timeTillExpiration <= e.lowerBoundDuration || timeTillExpiration <= 0 { + return 0 + } + + // Calculate the time to renew the token based on the expiration refresh ratio + duration := time.Until(expirationRefreshTime) + + // if the duration will take us past the lower bound, return the duration to lower bound + if timeTillExpiration-e.lowerBoundDuration < duration { + return timeTillExpiration - e.lowerBoundDuration + } + + // return the calculated duration + return duration +} diff --git a/manager/manager_test.go b/manager/manager_test.go index 85eafa0..8a49969 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -61,7 +61,7 @@ var testTokenValid = token.New( "test", time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + int64(time.Hour.Seconds()), ) func newTestJWTToken(expiresOn time.Time) string { diff --git a/manager/token_manager.go b/manager/token_manager.go index 83647f4..f6c8620 100644 --- a/manager/token_manager.go +++ b/manager/token_manager.go @@ -3,10 +3,8 @@ package manager import ( "context" "fmt" - "sync" "time" - "github.com/redis-developer/go-redis-entraid/internal" "github.com/redis-developer/go-redis-entraid/shared" "github.com/redis-developer/go-redis-entraid/token" ) @@ -81,8 +79,8 @@ type TokenManager interface { // GetToken returns the token for authentication. // It takes a boolean value forceRefresh as an argument. GetToken(forceRefresh bool) (*token.Token, error) - // Start starts the token manager and returns a channel that will receive updates. - Start(listener TokenListener) (*token.Token, StopFunc, error) + // Start starts the token manager and returns a stopper function to stop the token manager + Start(listener TokenListener) (StopFunc, error) } // StopFunc is a function that stops the token manager. @@ -129,266 +127,3 @@ func NewTokenManager(idp shared.IdentityProvider, options TokenManagerOptions) ( requestTimeout: options.RequestTimeout, }, nil } - -// entraidTokenManager is a struct that implements the TokenManager interface. -type entraidTokenManager struct { - // idp is the identity provider used to obtain the token. - idp shared.IdentityProvider - - // token is the authentication token for the user which should be kept in memory if valid. - token *token.Token - - // tokenRWLock is a read-write lock used to protect the token from concurrent access. - tokenRWLock sync.RWMutex - - // identityProviderResponseParser is the parser used to parse the response from the identity provider. - // It`s ParseResponse method will be called to parse the response and return the token. - identityProviderResponseParser shared.IdentityProviderResponseParser - - // retryOptions is a struct that contains the options for retrying the token request. - // It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier. - // The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier. - // The values can be overridden by the user. - retryOptions RetryOptions - - // listener is the single listener for the token manager. - // It is used to receive updates from the token manager. - // The token manager will call the listener's OnNext method with the updated token. - // If an error occurs, the token manager will call the listener's OnError method with the error. - // if listener is set, Start will fail - listener TokenListener - - // lock locks the listener to prevent concurrent access. - lock sync.Mutex - - // expirationRefreshRatio is the ratio of the token expiration time to refresh the token. - // It is used to determine when to refresh the token. - // The value should be between 0 and 1. - // For example, if the expiration time is 1 hour and the ratio is 0.75, - // the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed) - expirationRefreshRatio float64 - - // lowerBoundDuration is the lower bound for the refresh time in time.Duration. - lowerBoundDuration time.Duration - - // closedChan is a channel that is closedChan when the token manager is closedChan. - // It is used to signal the token manager to stop requesting tokens. - closedChan chan struct{} - - // context is the context used to request the token from the identity provider. - ctx context.Context - - // ctxCancel is the cancel function for the context. - ctxCancel context.CancelFunc - - // requestTimeout is the timeout for the request to the identity provider. - requestTimeout time.Duration -} - -func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) { - e.tokenRWLock.RLock() - // check if the token is nil and if it is not expired - - if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { - t := e.token - e.tokenRWLock.RUnlock() - return t, nil - } - e.tokenRWLock.RUnlock() - - // start the context early, - // since at heavy concurrent load - // locks may take some time to acquire - ctx, ctxCancel := context.WithTimeout(e.ctx, e.requestTimeout) - defer ctxCancel() - - // Upgrade to write lock for token update - e.tokenRWLock.Lock() - defer e.tokenRWLock.Unlock() - - // Double-check pattern to avoid unnecessary token refresh - if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) { - return e.token, nil - } - - // Request a new token from the identity provider - idpResult, err := e.idp.RequestToken(ctx) - if err != nil { - return nil, fmt.Errorf("failed to request token from idp: %w", err) - } - - t, err := e.identityProviderResponseParser.ParseResponse(idpResult) - if err != nil { - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - if t == nil { - return nil, fmt.Errorf("failed to get token: token is nil") - } - - // Store the token - e.token = t - // Return the token - no need to copy since it's immutable - return t, nil -} - -// Start starts the token manager and returns cancelFunc to stop the token manager. -// It takes a TokenListener as an argument, which is used to receive updates. -// The token manager will call the listener's OnNext method with the updated token. -// If an error occurs, the token manager will call the listener's OnError method with the error. -// -// Note: The initial token is delivered synchronously. -// The TokenListener will receive the token immediately, before the token manager goroutine starts. -func (e *entraidTokenManager) Start(listener TokenListener) (*token.Token, StopFunc, error) { - e.lock.Lock() - defer e.lock.Unlock() - if e.listener != nil { - return nil, nil, ErrTokenManagerAlreadyStarted - } - - if e.closedChan != nil && !internal.IsClosed(e.closedChan) { - // there is a hanging goroutine that is waiting for the closedChan to be closed - // if the closedChan is not nil and not closed, close it - close(e.closedChan) - } - - ctx, ctxCancel := context.WithCancel(context.Background()) - e.ctx = ctx - e.ctxCancel = ctxCancel - - t, err := e.GetToken(false) - // If a token was found in the cache, check if: - // - it is expired (based on the lower bound) - // - it is about to expire (based on the expiration refresh ratio) - // if so, get a new token - expirationRefreshTime := t.ReceivedAt().Add(time.Duration(float64(t.TTL()) * float64(time.Second) * e.expirationRefreshRatio)) - expirationWithoutLowerBound := t.ExpirationOn().Add(-1 * e.lowerBoundDuration) - now := time.Now() - if t != nil && (expirationWithoutLowerBound.Before(now) || expirationRefreshTime.Before(now)) { - t, err = e.GetToken(true) - } - if err != nil { - return nil, nil, fmt.Errorf("failed to start token manager: %w", err) - } - - e.closedChan = make(chan struct{}) - e.listener = listener - - go func(listener TokenListener, closed <-chan struct{}) { - maxDelay := e.retryOptions.MaxDelay - initialDelay := e.retryOptions.InitialDelay - - for { - timeToRenewal := e.durationToRenewal() - select { - case <-closed: - return - case <-time.After(timeToRenewal): - if timeToRenewal == 0 { - // Token was requested immediately, guard against infinite loop - select { - case <-closed: - return - case <-time.After(initialDelay): - // continue to attempt - } - } - - // Token is about to expire, refresh it - delay := initialDelay - for i := 0; i < e.retryOptions.MaxAttempts; i++ { - t, err := e.GetToken(true) - if err == nil { - listener.OnNext(t) - break - } - - // check if err is retriable - if e.retryOptions.IsRetryable(err) { - if i == e.retryOptions.MaxAttempts-1 { - // last attempt, call OnError - listener.OnError(fmt.Errorf("max attempts reached: %w", err)) - return - } - - // Exponential backoff - if delay < maxDelay { - delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier) - } - if delay > maxDelay { - delay = maxDelay - } - - select { - case <-closed: - return - case <-time.After(delay): - // continue to next attempt - } - } else { - // not retriable - listener.OnError(err) - return - } - } - } - } - }(listener, e.closedChan) - - return t, e.stop, nil -} - -// stop closes the token manager and releases any resources. -func (e *entraidTokenManager) stop() (err error) { - e.lock.Lock() - defer e.lock.Unlock() - defer func() { - // recover from panic and return the error - if r := recover(); r != nil { - err = fmt.Errorf("failed to stop token manager: %s", r) - } - }() - - if e.closedChan == nil || e.listener == nil { - return ErrTokenManagerAlreadyStopped - } - - e.ctxCancel() - e.listener = nil - close(e.closedChan) - - return nil -} - -// durationToRenewal calculates the duration to the next token renewal. -// It returns the duration to the next token renewal based on the expiration refresh ratio and the lower bound duration. -// If the token is nil, it returns 0. -// If the time till expiration is less than the lower bound duration, it returns 0 to renew the token now. -func (e *entraidTokenManager) durationToRenewal() time.Duration { - e.tokenRWLock.RLock() - if e.token == nil { - e.tokenRWLock.RUnlock() - return 0 - } - - timeTillExpiration := time.Until(e.token.ExpirationOn()) - e.tokenRWLock.RUnlock() - - // if the timeTillExpiration is less than the lower bound (or 0), return 0 to renew the token NOW - if timeTillExpiration <= e.lowerBoundDuration || timeTillExpiration <= 0 { - return 0 - } - - // Calculate the time to renew the token based on the expiration refresh ratio - // Since timeTillExpiration is guarded by the lower bound, we can safely multiply it by the ratio - // and assume the duration is a positive number - duration := time.Duration(float64(timeTillExpiration) * e.expirationRefreshRatio) - - // if the duration will take us past the lower bound, return the duration to lower bound - if timeTillExpiration-e.lowerBoundDuration < duration { - return timeTillExpiration - e.lowerBoundDuration - } - - // return the calculated duration - return duration -} diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index e1d5d27..f634197 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -182,7 +182,7 @@ func TestTokenManager_Close(t *testing.T) { var stopper StopFunc assert.NotPanics(t, func() { - _, stopper, err = tokenManager.Start(listener) + stopper, err = tokenManager.Start(listener) assert.NotNil(t, stopper) assert.NoError(t, err) }) @@ -222,7 +222,7 @@ func TestTokenManager_Close(t *testing.T) { listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -258,7 +258,7 @@ func TestTokenManager_Close(t *testing.T) { listener.On("OnNext", testTokenValid).Return() assert.NotPanics(t, func() { - _, stopper, err := tokenManager.Start(listener) + stopper, err := tokenManager.Start(listener) assert.NotNil(t, stopper) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -329,7 +329,7 @@ func TestTokenManager_Start(t *testing.T) { go func() { defer wg.Done() time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond))) - _, _, err := tokenManager.Start(listener) + _, err := tokenManager.Start(listener) if err == nil { hasStarted += 1 return @@ -344,7 +344,7 @@ func TestTokenManager_Start(t *testing.T) { assert.NotNil(t, tm.listener) assert.Equal(t, 1, hasStarted) assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStarted)) - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.Nil(t, cancel) assert.Error(t, err) assert.NotNil(t, tm.listener) @@ -389,7 +389,7 @@ func TestTokenManager_Start(t *testing.T) { } else { l := &mockTokenListener{Id: num} l.On("OnNext", testTokenValid).Return() - _, _, err = tokenManager.Start(l) + _, err = tokenManager.Start(l) } if err != nil { if err != ErrTokenManagerAlreadyStopped && err != ErrTokenManagerAlreadyStarted { @@ -412,7 +412,7 @@ func TestTokenManager_Start(t *testing.T) { log.Printf("FAILING WITH lastExecution[STOPPED]: %d", lastExecution) } assert.NotNil(t, tm.listener) - _, stopper, err := tokenManager.Start(listener) + stopper, err := tokenManager.Start(listener) assert.Nil(t, stopper) assert.Error(t, err) // Stop the token manager with internal stop, since stopper should be nil @@ -588,8 +588,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) listener.On("OnNext", testTokenValid).Return() - initialToken, cancel, err := tokenManager.Start(listener) - assert.NotNil(t, initialToken) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -599,46 +598,6 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { assert.NotNil(t, token1) }) - t.Run("GetToken with cached token", func(t *testing.T) { - t.Parallel() - idp := &mockIdentityProvider{} - listener := &mockTokenListener{} - mParser := &mockIdentityProviderResponseParser{} - tokenManager, err := NewTokenManager(idp, - TokenManagerOptions{ - IdentityProviderResponseParser: mParser, - }, - ) - assert.NoError(t, err) - assert.NotNil(t, tokenManager) - tm, ok := tokenManager.(*entraidTokenManager) - assert.True(t, ok) - assert.Nil(t, tm.listener) - - rawResponse := &authResult{ - ResultType: shared.ResponseTypeRawToken, - RawTokenVal: "test", - } - - idp.On("RequestToken", mock.Anything).Return(rawResponse, nil) - mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) - listener.On("OnNext", testTokenValid).Return() - - initialToken, cancel, err := tokenManager.Start(listener) - assert.NotNil(t, initialToken) - assert.NotNil(t, cancel) - assert.NoError(t, err) - assert.NotNil(t, tm.listener) - - token1, err := tokenManager.GetToken(false) - assert.NoError(t, err) - assert.NotNil(t, token1) - - token2, err := tokenManager.GetToken(false) - assert.NoError(t, err) - assert.Equal(t, token1, token2) - }) - t.Run("GetToken with parse error", func(t *testing.T) { t.Parallel() idp := &mockIdentityProvider{} @@ -664,7 +623,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) { mParser.On("ParseResponse", rawResponse).Return(nil, fmt.Errorf("parse error")) listener.On("OnError", mock.Anything).Return() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.Error(t, err) assert.Nil(t, cancel) assert.Nil(t, tm.listener) @@ -774,7 +733,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) { tm, ok := tokenManager.(*entraidTokenManager) assert.True(t, ok) - result := tm.durationToRenewal() + result := tm.durationToRenewal(nil) // returns 0 for nil token assert.Equal(t, time.Duration(0), result) @@ -791,7 +750,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) { assert.NotNil(t, tm.token) // return zero, should happen now since it expires before the lower bound - result = tm.durationToRenewal() + result = tm.durationToRenewal(tm.token) assert.Equal(t, time.Duration(0), result) }) @@ -809,7 +768,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) { assert.NotNil(t, tm.token) // return time to lower bound, if the returned time will be after the lower bound - result = tm.durationToRenewal() + result = tm.durationToRenewal(tm.token) assert.InEpsilon(t, time.Until(tm.token.ExpirationOn().Add(-1*tm.lowerBoundDuration)), result, float64(time.Second)) }) @@ -853,14 +812,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { ) mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once() - listener.On("OnNext", token1).Return().Once() - _, stopper, err := tokenManager.Start(listener) + stopper, err := tokenManager.Start(listener) assert.NotNil(t, stopper) assert.NoError(t, err) assert.NotNil(t, tm.listener) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) @@ -922,12 +880,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.Equal(t, time.Duration(0), toRenewal) assert.True(t, expiresIn > toRenewal) @@ -940,7 +898,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.InDelta(t, stop.Sub(start), tm.retryOptions.InitialDelay, float64(200*time.Millisecond)) idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnNext", 2) + listener.AssertNumberOfCalls(t, "OnNext", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -977,14 +935,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() - - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.Equal(t, time.Duration(0), toRenewal) assert.True(t, expiresIn > toRenewal) @@ -997,7 +953,6 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { // called only once since the token manager was closed prior to initial delay passing idp.AssertNumberOfCalls(t, "RequestToken", 1) - listener.AssertNumberOfCalls(t, "OnNext", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1032,12 +987,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) + assert.NotNil(t, tm.token) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) @@ -1076,13 +1032,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0) assert.NotNil(t, err) }).Return().Maybe() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1091,13 +1046,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { returnErr := newMockError(true) idp.On("RequestToken", mock.Anything).Return(nil, returnErr) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) <-time.After(toRenewal + 100*time.Millisecond) idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnNext", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1130,13 +1084,12 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0).(error) assert.NotNil(t, err) }).Return() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1145,14 +1098,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { returnErr := newMockError(false) idp.On("RequestToken", mock.Anything).Return(nil, returnErr) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) <-time.After(toRenewal + 100*time.Millisecond) idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnNext", 1) listener.AssertNumberOfCalls(t, "OnError", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) @@ -1215,11 +1167,11 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { close(maxAttemptsReached) }).Return() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) @@ -1289,7 +1241,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { close(maxAttemptsReached) }).Return().Maybe() - _, cancel, err := tokenManager.Start(listener) + cancel, err := tokenManager.Start(listener) assert.NotNil(t, cancel) assert.NoError(t, err) assert.NotNil(t, tm.listener) @@ -1298,7 +1250,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { returnErr := newMockError(true) idp.On("RequestToken", mock.Anything).Return(nil, returnErr) - toRenewal := tm.durationToRenewal() + toRenewal := tm.durationToRenewal(tm.token) assert.NotEqual(t, time.Duration(0), toRenewal) assert.NotEqual(t, expiresIn, toRenewal) assert.True(t, expiresIn > toRenewal) @@ -1379,7 +1331,7 @@ func BenchmarkTokenManager_Start(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, _ = tokenManager.Start(listener) + _, _ = tokenManager.Start(listener) } } @@ -1405,7 +1357,7 @@ func BenchmarkTokenManager_Close(b *testing.B) { mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil) listener.On("OnNext", testTokenValid).Return() - _, stopper, err := tokenManager.Start(listener) + stopper, err := tokenManager.Start(listener) if err != nil { b.Fatal(err) } @@ -1444,7 +1396,7 @@ func BenchmarkTokenManager_durationToRenewal(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - tm.durationToRenewal() + tm.durationToRenewal(tm.token) } } @@ -1518,7 +1470,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { case 0: // Start the token manager with a new listener // t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j) - _, closeFunc, err := tm.Start(listener) + closeFunc, err := tm.Start(listener) if err != nil { if err != ErrTokenManagerAlreadyStarted { @@ -1696,7 +1648,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) { }, } - _, closeFunc, err := tm.Start(finalListener) + closeFunc, err := tm.Start(finalListener) if err != nil && err != ErrTokenManagerAlreadyStarted { t.Fatalf("Failed to start token manager after concurrent operations: %v", err) } diff --git a/providers_test.go b/providers_test.go index 88d88b3..604ddac 100644 --- a/providers_test.go +++ b/providers_test.go @@ -290,7 +290,7 @@ func TestCredentialsProviderInterface(t *testing.T) { rawTokenString, time.Now().Add(time.Hour), time.Now(), - int64(time.Hour), + int64(time.Hour.Seconds()), ) // Set the token manager factory in the options diff --git a/token/token.go b/token/token.go index fafce60..e696f77 100644 --- a/token/token.go +++ b/token/token.go @@ -56,6 +56,9 @@ func (t *Token) RawToken() string { // ReceivedAt returns the time when the token was received. func (t *Token) ReceivedAt() time.Time { + if t.receivedAt.IsZero() { + return time.Now() + } return t.receivedAt } From 7e3317a4d0f0c79990a4591c4179f186c1c93751 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 13 May 2025 21:21:09 +0300 Subject: [PATCH 43/44] fix(tests): fix manager tests --- manager/token_manager_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/manager/token_manager_test.go b/manager/token_manager_test.go index f634197..6a0944e 100644 --- a/manager/token_manager_test.go +++ b/manager/token_manager_test.go @@ -1156,7 +1156,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { On("OnNext", mock.AnythingOfType("*token.Token")). Run(func(_ mock.Arguments) { start = time.Now() - }).Return() + }).Return().Maybe() maxAttemptsReached := make(chan struct{}) listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0).(error) @@ -1195,8 +1195,6 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { assert.InEpsilon(t, elapsed, allDelaysShouldBe, float64(10*time.Millisecond)) idp.AssertNumberOfCalls(t, "RequestToken", tm.retryOptions.MaxAttempts+1) - listener.AssertNumberOfCalls(t, "OnNext", 1) - listener.AssertNumberOfCalls(t, "OnError", 1) mock.AssertExpectationsForObjects(t, idp, listener) }) t.Run("Start and Listen and close during retries", func(t *testing.T) { @@ -1232,7 +1230,6 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { idpResponse.AuthResultVal = res }).Return(idpResponse, nil) - listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return() maxAttemptsReached := make(chan struct{}) listener.On("OnError", mock.Anything).Run(func(args mock.Arguments) { err := args.Get(0).(error) @@ -1268,7 +1265,6 @@ func TestEntraidTokenManager_Streaming(t *testing.T) { // maxAttempts + the initial one idp.AssertNumberOfCalls(t, "RequestToken", 2) - listener.AssertNumberOfCalls(t, "OnError", 0) mock.AssertExpectationsForObjects(t, idp, listener) }) } From 5a37c4b81b65b98504ad87c11f26370b71881a0c Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 13 May 2025 21:35:43 +0300 Subject: [PATCH 44/44] fix(tests): 100% coverage for token package --- token/token_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/token/token_test.go b/token/token_test.go index 72a54ad..c845dea 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -123,6 +123,10 @@ func TestTokenReceivedAt(t *testing.T) { assert.NotSame(t, token, tcopiedToken) // Check if the copied token is a new instance assert.NotNil(t, tcopiedToken) + + emptyRecievedAt := &Token{} + assert.True(t, emptyRecievedAt.ReceivedAt().After(time.Now().Add(-1*time.Hour))) + assert.True(t, emptyRecievedAt.ReceivedAt().Before(time.Now().Add(1*time.Hour))) } func BenchmarkNew(b *testing.B) {