|
| 1 | +package oauthserver |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + |
| 6 | + "github.com/supabase/auth/internal/models" |
| 7 | +) |
| 8 | + |
| 9 | +// InferClientTypeFromAuthMethod infers client type from token_endpoint_auth_method |
| 10 | +func InferClientTypeFromAuthMethod(authMethod string) string { |
| 11 | + switch authMethod { |
| 12 | + case models.TokenEndpointAuthMethodNone: |
| 13 | + return models.OAuthServerClientTypePublic |
| 14 | + case models.TokenEndpointAuthMethodClientSecretBasic, models.TokenEndpointAuthMethodClientSecretPost: |
| 15 | + return models.OAuthServerClientTypeConfidential |
| 16 | + default: |
| 17 | + return models.OAuthServerClientTypeConfidential // Default to confidential |
| 18 | + } |
| 19 | +} |
| 20 | + |
| 21 | +// GetValidAuthMethodsForClientType returns the valid authentication methods for a client type |
| 22 | +func GetValidAuthMethodsForClientType(clientType string) []string { |
| 23 | + switch clientType { |
| 24 | + case models.OAuthServerClientTypePublic: |
| 25 | + return []string{models.TokenEndpointAuthMethodNone} |
| 26 | + case models.OAuthServerClientTypeConfidential: |
| 27 | + return []string{ |
| 28 | + models.TokenEndpointAuthMethodClientSecretBasic, |
| 29 | + models.TokenEndpointAuthMethodClientSecretPost, |
| 30 | + } |
| 31 | + default: |
| 32 | + return []string{} // Unknown client type |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +// ValidateClientTypeConsistency validates consistency between client_type and token_endpoint_auth_method |
| 37 | +func ValidateClientTypeConsistency(clientType, authMethod string) error { |
| 38 | + if clientType == "" || authMethod == "" { |
| 39 | + return nil // Skip validation if either is not provided |
| 40 | + } |
| 41 | + |
| 42 | + expectedClientType := InferClientTypeFromAuthMethod(authMethod) |
| 43 | + if clientType != expectedClientType { |
| 44 | + return fmt.Errorf("client_type '%s' is inconsistent with token_endpoint_auth_method '%s' (expected client_type '%s')", |
| 45 | + clientType, authMethod, expectedClientType) |
| 46 | + } |
| 47 | + |
| 48 | + return nil |
| 49 | +} |
| 50 | + |
| 51 | +// IsValidAuthMethodForClientType checks if the auth method is valid for the given client type |
| 52 | +func IsValidAuthMethodForClientType(clientType, authMethod string) bool { |
| 53 | + validMethods := GetValidAuthMethodsForClientType(clientType) |
| 54 | + for _, method := range validMethods { |
| 55 | + if method == authMethod { |
| 56 | + return true |
| 57 | + } |
| 58 | + } |
| 59 | + return false |
| 60 | +} |
| 61 | + |
| 62 | +// DetermineClientType determines the final client type using the priority: |
| 63 | +// 1. Explicit client_type |
| 64 | +// 2. Inferred from token_endpoint_auth_method |
| 65 | +// 3. Default to confidential |
| 66 | +func DetermineClientType(explicitClientType, authMethod string) string { |
| 67 | + // Priority 1: Explicit client_type |
| 68 | + if explicitClientType != "" { |
| 69 | + return explicitClientType |
| 70 | + } |
| 71 | + |
| 72 | + // Priority 2: Infer from token_endpoint_auth_method |
| 73 | + if authMethod != "" { |
| 74 | + return InferClientTypeFromAuthMethod(authMethod) |
| 75 | + } |
| 76 | + |
| 77 | + // Priority 3: Default to confidential |
| 78 | + return models.OAuthServerClientTypeConfidential |
| 79 | +} |
| 80 | + |
| 81 | +// ValidateClientAuthentication validates client authentication based on client type |
| 82 | +func ValidateClientAuthentication(client *models.OAuthServerClient, providedSecret string) error { |
| 83 | + if client.IsPublic() { |
| 84 | + // Public clients should not provide client secrets |
| 85 | + if providedSecret != "" { |
| 86 | + return fmt.Errorf("public clients must not provide client_secret") |
| 87 | + } |
| 88 | + return nil |
| 89 | + } |
| 90 | + |
| 91 | + // Confidential clients must provide a valid client secret |
| 92 | + if providedSecret == "" { |
| 93 | + return fmt.Errorf("confidential clients must provide client_secret") |
| 94 | + } |
| 95 | + |
| 96 | + if !ValidateClientSecret(providedSecret, client.ClientSecretHash) { |
| 97 | + return fmt.Errorf("invalid client credentials") |
| 98 | + } |
| 99 | + |
| 100 | + return nil |
| 101 | +} |
| 102 | + |
| 103 | +// GetAllValidAuthMethods returns all supported authentication methods |
| 104 | +func GetAllValidAuthMethods() []string { |
| 105 | + return []string{ |
| 106 | + models.TokenEndpointAuthMethodNone, |
| 107 | + models.TokenEndpointAuthMethodClientSecretBasic, |
| 108 | + models.TokenEndpointAuthMethodClientSecretPost, |
| 109 | + } |
| 110 | +} |
0 commit comments