|
| 1 | +package internal |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + |
| 6 | + "github.com/Azure/azure-sdk-for-go/sdk/azcore" |
| 7 | + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" |
| 8 | +) |
| 9 | + |
| 10 | +// IDPResp represents a response from an Identity Provider (IDP) |
| 11 | +// It can contain either an AuthResult, AccessToken, or a raw token string |
| 12 | +type IDPResp struct { |
| 13 | + // resultType indicates which type of response this is |
| 14 | + resultType string |
| 15 | + authResultVal *public.AuthResult |
| 16 | + accessTokenVal *azcore.AccessToken |
| 17 | + rawTokenVal string |
| 18 | +} |
| 19 | + |
| 20 | +// NewIDPResp creates a new IDPResp with the given values |
| 21 | +// It validates the input and ensures the response type matches the provided value |
| 22 | +func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) { |
| 23 | + if result == nil { |
| 24 | + return nil, fmt.Errorf("result cannot be nil") |
| 25 | + } |
| 26 | + |
| 27 | + r := &IDPResp{resultType: resultType} |
| 28 | + |
| 29 | + switch resultType { |
| 30 | + case "AuthResult": |
| 31 | + switch v := result.(type) { |
| 32 | + case *public.AuthResult: |
| 33 | + r.authResultVal = v |
| 34 | + case public.AuthResult: |
| 35 | + r.authResultVal = &v |
| 36 | + default: |
| 37 | + return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T", result) |
| 38 | + } |
| 39 | + case "AccessToken": |
| 40 | + switch v := result.(type) { |
| 41 | + case *azcore.AccessToken: |
| 42 | + r.accessTokenVal = v |
| 43 | + r.rawTokenVal = v.Token |
| 44 | + case azcore.AccessToken: |
| 45 | + r.accessTokenVal = &v |
| 46 | + r.rawTokenVal = v.Token |
| 47 | + default: |
| 48 | + return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T", result) |
| 49 | + } |
| 50 | + case "RawToken": |
| 51 | + switch v := result.(type) { |
| 52 | + case string: |
| 53 | + r.rawTokenVal = v |
| 54 | + case *string: |
| 55 | + if v == nil { |
| 56 | + return nil, fmt.Errorf("raw token cannot be nil") |
| 57 | + } |
| 58 | + r.rawTokenVal = *v |
| 59 | + default: |
| 60 | + return nil, fmt.Errorf("invalid raw token type: expected string or *string, got %T", result) |
| 61 | + } |
| 62 | + default: |
| 63 | + return nil, fmt.Errorf("unsupported identity provider response type: %s", resultType) |
| 64 | + } |
| 65 | + |
| 66 | + return r, nil |
| 67 | +} |
| 68 | + |
| 69 | +// Type returns the type of response this IDPResp represents |
| 70 | +func (a *IDPResp) Type() string { |
| 71 | + return a.resultType |
| 72 | +} |
| 73 | + |
| 74 | +// AuthResult returns the AuthResult if present, or an empty AuthResult if not set |
| 75 | +// Use HasAuthResult() to check if the value is actually set |
| 76 | +func (a *IDPResp) AuthResult() public.AuthResult { |
| 77 | + if a.authResultVal == nil { |
| 78 | + return public.AuthResult{} |
| 79 | + } |
| 80 | + return *a.authResultVal |
| 81 | +} |
| 82 | + |
| 83 | +// HasAuthResult returns true if an AuthResult is set |
| 84 | +func (a *IDPResp) HasAuthResult() bool { |
| 85 | + return a.authResultVal != nil |
| 86 | +} |
| 87 | + |
| 88 | +// AccessToken returns the AccessToken if present, or an empty AccessToken if not set |
| 89 | +// Use HasAccessToken() to check if the value is actually set |
| 90 | +func (a *IDPResp) AccessToken() azcore.AccessToken { |
| 91 | + if a.accessTokenVal == nil { |
| 92 | + return azcore.AccessToken{} |
| 93 | + } |
| 94 | + return *a.accessTokenVal |
| 95 | +} |
| 96 | + |
| 97 | +// HasAccessToken returns true if an AccessToken is set |
| 98 | +func (a *IDPResp) HasAccessToken() bool { |
| 99 | + return a.accessTokenVal != nil |
| 100 | +} |
| 101 | + |
| 102 | +// RawToken returns the raw token string |
| 103 | +func (a *IDPResp) RawToken() string { |
| 104 | + return a.rawTokenVal |
| 105 | +} |
| 106 | + |
| 107 | +// HasRawToken returns true if a raw token is set |
| 108 | +func (a *IDPResp) HasRawToken() bool { |
| 109 | + return a.rawTokenVal != "" |
| 110 | +} |
0 commit comments