diff --git a/cmd/gen.go b/cmd/gen.go index 48a800f17..8893988d4 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -13,6 +13,7 @@ import ( "github.com/supabase/cli/internal/gen/types" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" + "github.com/supabase/cli/pkg/config" ) var ( @@ -97,7 +98,7 @@ var ( algorithm = utils.EnumFlag{ Allowed: signingkeys.GetSupportedAlgorithms(), - Value: string(signingkeys.AlgES256), + Value: string(config.AlgES256), } appendKeys bool diff --git a/internal/gen/signingkeys/signingkeys.go b/internal/gen/signingkeys/signingkeys.go index 7394a99c4..10bdd646d 100644 --- a/internal/gen/signingkeys/signingkeys.go +++ b/internal/gen/signingkeys/signingkeys.go @@ -20,58 +20,29 @@ import ( "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/cast" + "github.com/supabase/cli/pkg/config" ) -type Algorithm string - -const ( - AlgRS256 Algorithm = "RS256" - AlgES256 Algorithm = "ES256" -) - -type JWK struct { - KeyType string `json:"kty"` - KeyID string `json:"kid,omitempty"` - Use string `json:"use,omitempty"` - KeyOps []string `json:"key_ops,omitempty"` - Algorithm string `json:"alg,omitempty"` - Extractable *bool `json:"ext,omitempty"` - // RSA specific fields - Modulus string `json:"n,omitempty"` - Exponent string `json:"e,omitempty"` - // RSA private key fields - PrivateExponent string `json:"d,omitempty"` - FirstPrimeFactor string `json:"p,omitempty"` - SecondPrimeFactor string `json:"q,omitempty"` - FirstFactorCRTExponent string `json:"dp,omitempty"` - SecondFactorCRTExponent string `json:"dq,omitempty"` - FirstCRTCoefficient string `json:"qi,omitempty"` - // EC specific fields - Curve string `json:"crv,omitempty"` - X string `json:"x,omitempty"` - Y string `json:"y,omitempty"` -} - type KeyPair struct { - PublicKey JWK - PrivateKey JWK + PublicKey config.JWK + PrivateKey config.JWK } // GenerateKeyPair generates a new key pair for the specified algorithm -func GenerateKeyPair(alg Algorithm) (*KeyPair, error) { - keyID := uuid.New().String() +func GenerateKeyPair(alg config.Algorithm) (*KeyPair, error) { + keyID := uuid.New() switch alg { - case AlgRS256: + case config.AlgRS256: return generateRSAKeyPair(keyID) - case AlgES256: + case config.AlgES256: return generateECDSAKeyPair(keyID) default: return nil, errors.Errorf("unsupported algorithm: %s", alg) } } -func generateRSAKeyPair(keyID string) (*KeyPair, error) { +func generateRSAKeyPair(keyID uuid.UUID) (*KeyPair, error) { // Generate RSA key pair (2048 bits for RS256) privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -84,7 +55,7 @@ func generateRSAKeyPair(keyID string) (*KeyPair, error) { privateKey.Precompute() // Convert to JWK format - privateJWK := JWK{ + privateJWK := config.JWK{ KeyType: "RSA", KeyID: keyID, Use: "sig", @@ -101,7 +72,7 @@ func generateRSAKeyPair(keyID string) (*KeyPair, error) { FirstCRTCoefficient: base64.RawURLEncoding.EncodeToString(privateKey.Precomputed.Qinv.Bytes()), } - publicJWK := JWK{ + publicJWK := config.JWK{ KeyType: "RSA", KeyID: keyID, Use: "sig", @@ -118,7 +89,7 @@ func generateRSAKeyPair(keyID string) (*KeyPair, error) { }, nil } -func generateECDSAKeyPair(keyID string) (*KeyPair, error) { +func generateECDSAKeyPair(keyID uuid.UUID) (*KeyPair, error) { // Generate ECDSA key pair (P-256 curve for ES256) privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -128,7 +99,7 @@ func generateECDSAKeyPair(keyID string) (*KeyPair, error) { publicKey := &privateKey.PublicKey // Convert to JWK format - privateJWK := JWK{ + privateJWK := config.JWK{ KeyType: "EC", KeyID: keyID, Use: "sig", @@ -141,7 +112,7 @@ func generateECDSAKeyPair(keyID string) (*KeyPair, error) { PrivateExponent: base64.RawURLEncoding.EncodeToString(privateKey.D.Bytes()), } - publicJWK := JWK{ + publicJWK := config.JWK{ KeyType: "EC", KeyID: keyID, Use: "sig", @@ -168,13 +139,13 @@ func Run(ctx context.Context, algorithm string, appendMode bool, fsys afero.Fs) outputPath := utils.Config.Auth.SigningKeysPath // Generate key pair - keyPair, err := GenerateKeyPair(Algorithm(algorithm)) + keyPair, err := GenerateKeyPair(config.Algorithm(algorithm)) if err != nil { return err } out := io.Writer(os.Stdout) - var jwkArray []JWK + var jwkArray []config.JWK if len(outputPath) > 0 { if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(outputPath)); err != nil { return err @@ -245,5 +216,5 @@ signing_keys_path = "./signing_key.json" // GetSupportedAlgorithms returns a list of supported algorithms func GetSupportedAlgorithms() []string { - return []string{string(AlgRS256), string(AlgES256)} + return []string{string(config.AlgRS256), string(config.AlgES256)} } diff --git a/internal/gen/signingkeys/signingkeys_test.go b/internal/gen/signingkeys/signingkeys_test.go index 51333887d..369811c07 100644 --- a/internal/gen/signingkeys/signingkeys_test.go +++ b/internal/gen/signingkeys/signingkeys_test.go @@ -2,22 +2,24 @@ package signingkeys import ( "testing" + + "github.com/supabase/cli/pkg/config" ) func TestGenerateKeyPair(t *testing.T) { tests := []struct { name string - algorithm Algorithm + algorithm config.Algorithm wantErr bool }{ { name: "RSA key generation", - algorithm: AlgRS256, + algorithm: config.AlgRS256, wantErr: false, }, { name: "ECDSA key generation", - algorithm: AlgES256, + algorithm: config.AlgES256, wantErr: false, }, { @@ -55,7 +57,7 @@ func TestGenerateKeyPair(t *testing.T) { // Algorithm-specific checks switch tt.algorithm { - case AlgRS256: + case config.AlgRS256: if keyPair.PublicKey.KeyType != "RSA" { t.Errorf("Expected RSA key type, got %s", keyPair.PublicKey.KeyType) } @@ -69,7 +71,7 @@ func TestGenerateKeyPair(t *testing.T) { if keyPair.PrivateKey.PrivateExponent == "" { t.Error("RSA private key missing private exponent") } - case AlgES256: + case config.AlgES256: if keyPair.PublicKey.KeyType != "EC" { t.Errorf("Expected EC key type, got %s", keyPair.PublicKey.KeyType) } diff --git a/pkg/config/apikeys.go b/pkg/config/apikeys.go new file mode 100644 index 000000000..35ed815d1 --- /dev/null +++ b/pkg/config/apikeys.go @@ -0,0 +1,169 @@ +package config + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "math/big" + "time" + + "github.com/go-errors/errors" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// generateAPIKeys generates JWT tokens using the appropriate signing method +func (a *auth) generateAPIKeys() error { + // Generate anon key if not provided + if len(a.AnonKey.Value) == 0 { + signed, err := a.generateJWT("anon") + if err != nil { + return err + } + a.AnonKey.Value = signed + } + // Generate service_role key if not provided + if len(a.ServiceRoleKey.Value) == 0 { + signed, err := a.generateJWT("service_role") + if err != nil { + return err + } + a.ServiceRoleKey.Value = signed + } + return nil +} + +func (a auth) generateJWT(role string) (string, error) { + claims := CustomClaims{Issuer: "supabase-demo", Role: role} + if len(a.SigningKeys) > 0 { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365 * 10)) // 10 years + return generateAsymmetricJWT(a.SigningKeys[0], claims) + } + // Fallback to generating symmetric keys + if len(a.JwtSecret.Value) < 16 { + return "", errors.Errorf("Invalid config for auth.jwt_secret. Must be at least 16 characters") + } + signed, err := claims.NewToken().SignedString([]byte(a.JwtSecret.Value)) + if err != nil { + return "", errors.Errorf("failed to generate JWT: %w", err) + } + return signed, nil +} + +// generateAsymmetricJWT generates a JWT token signed with the provided JWK private key +func generateAsymmetricJWT(jwk JWK, claims CustomClaims) (string, error) { + privateKey, err := jwkToPrivateKey(jwk) + if err != nil { + return "", errors.Errorf("failed to convert JWK to private key: %w", err) + } + + // Determine signing method based on algorithm + var token *jwt.Token + switch jwk.Algorithm { + case AlgRS256: + token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + case AlgES256: + token = jwt.NewWithClaims(jwt.SigningMethodES256, claims) + default: + return "", errors.Errorf("unsupported algorithm: %s", jwk.Algorithm) + } + + if jwk.KeyID != uuid.Nil { + token.Header["kid"] = jwk.KeyID.String() + } + + tokenString, err := token.SignedString(privateKey) + if err != nil { + return "", errors.Errorf("failed to sign JWT: %w", err) + } + + return tokenString, nil +} + +// jwkToPrivateKey converts a JWK to a crypto.PrivateKey +func jwkToPrivateKey(jwk JWK) (crypto.PrivateKey, error) { + switch jwk.KeyType { + case "RSA": + return jwkToRSAPrivateKey(jwk) + case "EC": + return jwkToECDSAPrivateKey(jwk) + default: + return nil, errors.Errorf("unsupported key type: %s", jwk.KeyType) + } +} + +// jwkToRSAPrivateKey converts a JWK to an RSA private key +func jwkToRSAPrivateKey(jwk JWK) (*rsa.PrivateKey, error) { + nBytes, err := base64.RawURLEncoding.DecodeString(jwk.Modulus) + if err != nil { + return nil, errors.Errorf("failed to decode modulus: %w", err) + } + n := new(big.Int).SetBytes(nBytes) + + eBytes, err := base64.RawURLEncoding.DecodeString(jwk.Exponent) + if err != nil { + return nil, errors.Errorf("failed to decode exponent: %w", err) + } + e := int(new(big.Int).SetBytes(eBytes).Int64()) + + dBytes, err := base64.RawURLEncoding.DecodeString(jwk.PrivateExponent) + if err != nil { + return nil, errors.Errorf("failed to decode private exponent: %w", err) + } + d := new(big.Int).SetBytes(dBytes) + + pBytes, err := base64.RawURLEncoding.DecodeString(jwk.FirstPrimeFactor) + if err != nil { + return nil, errors.Errorf("failed to decode first prime factor: %w", err) + } + p := new(big.Int).SetBytes(pBytes) + + qBytes, err := base64.RawURLEncoding.DecodeString(jwk.SecondPrimeFactor) + if err != nil { + return nil, errors.Errorf("failed to decode second prime factor: %w", err) + } + q := new(big.Int).SetBytes(qBytes) + + return &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{N: n, E: e}, + D: d, + Primes: []*big.Int{p, q}, + }, nil +} + +// jwkToECDSAPrivateKey converts a JWK to an ECDSA private key +func jwkToECDSAPrivateKey(jwk JWK) (*ecdsa.PrivateKey, error) { + // Only support P-256 curve for ES256 + if jwk.Curve != "P-256" { + return nil, errors.Errorf("unsupported curve: %s", jwk.Curve) + } + + xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, errors.Errorf("failed to decode x coordinate: %w", err) + } + x := new(big.Int).SetBytes(xBytes) + + yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y) + if err != nil { + return nil, errors.Errorf("failed to decode y coordinate: %w", err) + } + y := new(big.Int).SetBytes(yBytes) + + dBytes, err := base64.RawURLEncoding.DecodeString(jwk.PrivateExponent) + if err != nil { + return nil, errors.Errorf("failed to decode private key: %w", err) + } + d := new(big.Int).SetBytes(dBytes) + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: x, + Y: y, + }, + D: d, + }, nil +} diff --git a/pkg/config/auth.go b/pkg/config/auth.go index 2b0e155d7..ed0345e8a 100644 --- a/pkg/config/auth.go +++ b/pkg/config/auth.go @@ -6,6 +6,7 @@ import ( "time" "github.com/go-errors/errors" + "github.com/google/uuid" "github.com/oapi-codegen/nullable" openapi_types "github.com/oapi-codegen/runtime/types" v1API "github.com/supabase/cli/pkg/api" @@ -69,6 +70,44 @@ func (p *CaptchaProvider) UnmarshalText(text []byte) error { return nil } +type Algorithm string + +const ( + AlgRS256 Algorithm = "RS256" + AlgES256 Algorithm = "ES256" +) + +func (p *Algorithm) UnmarshalText(text []byte) error { + allowed := []Algorithm{AlgRS256, AlgES256} + if *p = Algorithm(text); !sliceContains(allowed, *p) { + return errors.Errorf("must be one of %v", allowed) + } + return nil +} + +type JWK struct { + KeyType string `json:"kty"` + KeyID uuid.UUID `json:"kid,omitempty"` + Use string `json:"use,omitempty"` + KeyOps []string `json:"key_ops,omitempty"` + Algorithm Algorithm `json:"alg,omitempty"` + Extractable *bool `json:"ext,omitempty"` + // RSA specific fields + Modulus string `json:"n,omitempty"` + Exponent string `json:"e,omitempty"` + // RSA private key fields + PrivateExponent string `json:"d,omitempty"` + FirstPrimeFactor string `json:"p,omitempty"` + SecondPrimeFactor string `json:"q,omitempty"` + FirstFactorCRTExponent string `json:"dp,omitempty"` + SecondFactorCRTExponent string `json:"dq,omitempty"` + FirstCRTCoefficient string `json:"qi,omitempty"` + // EC specific fields + Curve string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` +} + type ( auth struct { Enabled bool `toml:"enabled"` @@ -85,6 +124,7 @@ type ( MinimumPasswordLength uint `toml:"minimum_password_length"` PasswordRequirements PasswordRequirements `toml:"password_requirements"` SigningKeysPath string `toml:"signing_keys_path"` + SigningKeys []JWK `toml:"-"` RateLimit rateLimit `toml:"rate_limit"` Captcha *captcha `toml:"captcha"` diff --git a/pkg/config/config.go b/pkg/config/config.go index d96cc3952..ddaace033 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -585,26 +585,6 @@ func (c *config) Load(path string, fsys fs.FS) error { if err := c.loadFromFile(builder.ConfigPath, fsys); err != nil { return err } - // Generate JWT tokens - if len(c.Auth.JwtSecret.Value) < 16 { - return errors.Errorf("Invalid config for auth.jwt_secret. Must be at least 16 characters") - } - if len(c.Auth.AnonKey.Value) == 0 { - anonToken := CustomClaims{Role: "anon"}.NewToken() - if signed, err := anonToken.SignedString([]byte(c.Auth.JwtSecret.Value)); err != nil { - return errors.Errorf("failed to generate anon key: %w", err) - } else { - c.Auth.AnonKey.Value = signed - } - } - if len(c.Auth.ServiceRoleKey.Value) == 0 { - anonToken := CustomClaims{Role: "service_role"}.NewToken() - if signed, err := anonToken.SignedString([]byte(c.Auth.JwtSecret.Value)); err != nil { - return errors.Errorf("failed to generate service_role key: %w", err) - } else { - c.Auth.ServiceRoleKey.Value = signed - } - } // TODO: move linked pooler connection string elsewhere if connString, err := fs.ReadFile(fsys, builder.PoolerUrlPath); err == nil && len(connString) > 0 { c.Db.Pooler.ConnectionString = string(connString) @@ -851,6 +831,18 @@ func (c *config) Validate(fsys fs.FS) error { return err } } + if len(c.Auth.SigningKeysPath) > 0 { + if f, err := fsys.Open(c.Auth.SigningKeysPath); errors.Is(err, os.ErrNotExist) { + // Ignore missing signing key path on CI + } else if err != nil { + return errors.Errorf("failed to read signing keys: %w", err) + } else if c.Auth.SigningKeys, err = fetcher.ParseJSON[[]JWK](f); err != nil { + return errors.Errorf("failed to decode signing keys: %w", err) + } + } + if err := c.Auth.generateAPIKeys(); err != nil { + return err + } if err := c.Auth.Hook.validate(); err != nil { return err } diff --git a/pkg/config/updater.go b/pkg/config/updater.go index 96b73efd2..4e2159488 100644 --- a/pkg/config/updater.go +++ b/pkg/config/updater.go @@ -6,6 +6,7 @@ import ( "os" "github.com/go-errors/errors" + "github.com/google/uuid" v1API "github.com/supabase/cli/pkg/api" ) @@ -163,6 +164,77 @@ func (u *ConfigUpdater) UpdateAuthConfig(ctx context.Context, projectRef string, return nil } +func (u *ConfigUpdater) UpdateSigningKeys(ctx context.Context, projectRef string, signingKeys []JWK, filter ...func(string) bool) error { + if len(signingKeys) == 0 { + return nil + } + resp, err := u.client.V1GetProjectSigningKeysWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to fetch signing keys: %w", err) + } else if resp.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body)) + } + exists := map[uuid.UUID]struct{}{} + for _, k := range resp.JSON200.Keys { + if k.PublicJwk != nil { + exists[k.Id] = struct{}{} + } + } + var toInsert []JWK + for _, k := range signingKeys { + if _, ok := exists[k.KeyID]; !ok { + toInsert = append(toInsert, k) + } + } + if len(toInsert) == 0 { + fmt.Fprintln(os.Stderr, "Remote JWT signing keys are up to date.") + return nil + } + fmt.Fprintln(os.Stderr, "JWT signing keys to insert:") + for _, k := range toInsert { + fmt.Fprintln(os.Stderr, " -", k.KeyID) + } + for _, keep := range filter { + if !keep("signing keys") { + return nil + } + } + for _, k := range toInsert { + body := v1API.CreateSigningKeyBody{ + Algorithm: v1API.CreateSigningKeyBodyAlgorithm(k.Algorithm), + PrivateJwk: &v1API.CreateSigningKeyBody_PrivateJwk{}, + } + switch k.Algorithm { + case AlgRS256: + body.PrivateJwk.FromCreateSigningKeyBodyPrivateJwk0(v1API.CreateSigningKeyBodyPrivateJwk0{ + D: k.PrivateExponent, + Dp: k.FirstFactorCRTExponent, + Dq: k.SecondFactorCRTExponent, + E: v1API.CreateSigningKeyBodyPrivateJwk0E(k.Exponent), + Kty: v1API.CreateSigningKeyBodyPrivateJwk0Kty(k.KeyType), + N: k.Modulus, + P: k.FirstPrimeFactor, + Q: k.SecondPrimeFactor, + Qi: k.FirstCRTCoefficient, + }) + case AlgES256: + body.PrivateJwk.FromCreateSigningKeyBodyPrivateJwk1(v1API.CreateSigningKeyBodyPrivateJwk1{ + Crv: v1API.CreateSigningKeyBodyPrivateJwk1Crv(k.Curve), + D: k.PrivateExponent, + Kty: v1API.CreateSigningKeyBodyPrivateJwk1Kty(k.KeyType), + X: k.X, + Y: k.Y, + }) + } + if resp, err := u.client.V1CreateProjectSigningKeyWithResponse(ctx, projectRef, body); err != nil { + return errors.Errorf("failed to add signing key: %w", err) + } else if status := resp.StatusCode(); status < 200 || status >= 300 { + return errors.Errorf("unexpected status %d: %s", status, string(resp.Body)) + } + } + return nil +} + func (u *ConfigUpdater) UpdateStorageConfig(ctx context.Context, projectRef string, c storage, filter ...func(string) bool) error { if !c.Enabled { return nil