Skip to content

feat: generate jwt tokens from signing key #3969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/supabase/cli/internal/gen/types"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/internal/utils/flags"
"github.com/supabase/cli/pkg/config"
)

var (
Expand Down Expand Up @@ -97,7 +98,7 @@ var (

algorithm = utils.EnumFlag{
Allowed: signingkeys.GetSupportedAlgorithms(),
Value: string(signingkeys.AlgES256),
Value: string(config.AlgES256),
}
appendKeys bool

Expand Down
61 changes: 16 additions & 45 deletions internal/gen/signingkeys/signingkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,58 +20,29 @@ import (
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/internal/utils/flags"
"github.com/supabase/cli/pkg/cast"
"github.com/supabase/cli/pkg/config"
)

type Algorithm string

const (
AlgRS256 Algorithm = "RS256"
AlgES256 Algorithm = "ES256"
)

type JWK struct {
KeyType string `json:"kty"`
KeyID string `json:"kid,omitempty"`
Use string `json:"use,omitempty"`
KeyOps []string `json:"key_ops,omitempty"`
Algorithm string `json:"alg,omitempty"`
Extractable *bool `json:"ext,omitempty"`
// RSA specific fields
Modulus string `json:"n,omitempty"`
Exponent string `json:"e,omitempty"`
// RSA private key fields
PrivateExponent string `json:"d,omitempty"`
FirstPrimeFactor string `json:"p,omitempty"`
SecondPrimeFactor string `json:"q,omitempty"`
FirstFactorCRTExponent string `json:"dp,omitempty"`
SecondFactorCRTExponent string `json:"dq,omitempty"`
FirstCRTCoefficient string `json:"qi,omitempty"`
// EC specific fields
Curve string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
}

type KeyPair struct {
PublicKey JWK
PrivateKey JWK
PublicKey config.JWK
PrivateKey config.JWK
}

// GenerateKeyPair generates a new key pair for the specified algorithm
func GenerateKeyPair(alg Algorithm) (*KeyPair, error) {
keyID := uuid.New().String()
func GenerateKeyPair(alg config.Algorithm) (*KeyPair, error) {
keyID := uuid.New()

switch alg {
case AlgRS256:
case config.AlgRS256:
return generateRSAKeyPair(keyID)
case AlgES256:
case config.AlgES256:
return generateECDSAKeyPair(keyID)
default:
return nil, errors.Errorf("unsupported algorithm: %s", alg)
}
}

func generateRSAKeyPair(keyID string) (*KeyPair, error) {
func generateRSAKeyPair(keyID uuid.UUID) (*KeyPair, error) {
// Generate RSA key pair (2048 bits for RS256)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
Expand All @@ -84,7 +55,7 @@ func generateRSAKeyPair(keyID string) (*KeyPair, error) {
privateKey.Precompute()

// Convert to JWK format
privateJWK := JWK{
privateJWK := config.JWK{
KeyType: "RSA",
KeyID: keyID,
Use: "sig",
Expand All @@ -101,7 +72,7 @@ func generateRSAKeyPair(keyID string) (*KeyPair, error) {
FirstCRTCoefficient: base64.RawURLEncoding.EncodeToString(privateKey.Precomputed.Qinv.Bytes()),
}

publicJWK := JWK{
publicJWK := config.JWK{
KeyType: "RSA",
KeyID: keyID,
Use: "sig",
Expand All @@ -118,7 +89,7 @@ func generateRSAKeyPair(keyID string) (*KeyPair, error) {
}, nil
}

func generateECDSAKeyPair(keyID string) (*KeyPair, error) {
func generateECDSAKeyPair(keyID uuid.UUID) (*KeyPair, error) {
// Generate ECDSA key pair (P-256 curve for ES256)
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
Expand All @@ -128,7 +99,7 @@ func generateECDSAKeyPair(keyID string) (*KeyPair, error) {
publicKey := &privateKey.PublicKey

// Convert to JWK format
privateJWK := JWK{
privateJWK := config.JWK{
KeyType: "EC",
KeyID: keyID,
Use: "sig",
Expand All @@ -141,7 +112,7 @@ func generateECDSAKeyPair(keyID string) (*KeyPair, error) {
PrivateExponent: base64.RawURLEncoding.EncodeToString(privateKey.D.Bytes()),
}

publicJWK := JWK{
publicJWK := config.JWK{
KeyType: "EC",
KeyID: keyID,
Use: "sig",
Expand All @@ -168,13 +139,13 @@ func Run(ctx context.Context, algorithm string, appendMode bool, fsys afero.Fs)
outputPath := utils.Config.Auth.SigningKeysPath

// Generate key pair
keyPair, err := GenerateKeyPair(Algorithm(algorithm))
keyPair, err := GenerateKeyPair(config.Algorithm(algorithm))
if err != nil {
return err
}

out := io.Writer(os.Stdout)
var jwkArray []JWK
var jwkArray []config.JWK
if len(outputPath) > 0 {
if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(outputPath)); err != nil {
return err
Expand Down Expand Up @@ -245,5 +216,5 @@ signing_keys_path = "./signing_key.json"

// GetSupportedAlgorithms returns a list of supported algorithms
func GetSupportedAlgorithms() []string {
return []string{string(AlgRS256), string(AlgES256)}
return []string{string(config.AlgRS256), string(config.AlgES256)}
}
12 changes: 7 additions & 5 deletions internal/gen/signingkeys/signingkeys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@ package signingkeys

import (
"testing"

"github.com/supabase/cli/pkg/config"
)

func TestGenerateKeyPair(t *testing.T) {
tests := []struct {
name string
algorithm Algorithm
algorithm config.Algorithm
wantErr bool
}{
{
name: "RSA key generation",
algorithm: AlgRS256,
algorithm: config.AlgRS256,
wantErr: false,
},
{
name: "ECDSA key generation",
algorithm: AlgES256,
algorithm: config.AlgES256,
wantErr: false,
},
{
Expand Down Expand Up @@ -55,7 +57,7 @@ func TestGenerateKeyPair(t *testing.T) {

// Algorithm-specific checks
switch tt.algorithm {
case AlgRS256:
case config.AlgRS256:
if keyPair.PublicKey.KeyType != "RSA" {
t.Errorf("Expected RSA key type, got %s", keyPair.PublicKey.KeyType)
}
Expand All @@ -69,7 +71,7 @@ func TestGenerateKeyPair(t *testing.T) {
if keyPair.PrivateKey.PrivateExponent == "" {
t.Error("RSA private key missing private exponent")
}
case AlgES256:
case config.AlgES256:
if keyPair.PublicKey.KeyType != "EC" {
t.Errorf("Expected EC key type, got %s", keyPair.PublicKey.KeyType)
}
Expand Down
169 changes: 169 additions & 0 deletions pkg/config/apikeys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package config

import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"math/big"
"time"

"github.com/go-errors/errors"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)

// generateAPIKeys generates JWT tokens using the appropriate signing method
func (a *auth) generateAPIKeys() error {
// Generate anon key if not provided
if len(a.AnonKey.Value) == 0 {
signed, err := a.generateJWT("anon")
if err != nil {
return err
}
a.AnonKey.Value = signed
}
// Generate service_role key if not provided
if len(a.ServiceRoleKey.Value) == 0 {
signed, err := a.generateJWT("service_role")
if err != nil {
return err
}
a.ServiceRoleKey.Value = signed
}
return nil
}

func (a auth) generateJWT(role string) (string, error) {
claims := CustomClaims{Issuer: "supabase-demo", Role: role}
if len(a.SigningKeys) > 0 {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365 * 10)) // 10 years
return generateAsymmetricJWT(a.SigningKeys[0], claims)
}
// Fallback to generating symmetric keys
if len(a.JwtSecret.Value) < 16 {
return "", errors.Errorf("Invalid config for auth.jwt_secret. Must be at least 16 characters")
}
signed, err := claims.NewToken().SignedString([]byte(a.JwtSecret.Value))
if err != nil {
return "", errors.Errorf("failed to generate JWT: %w", err)
}
return signed, nil
}

// generateAsymmetricJWT generates a JWT token signed with the provided JWK private key
func generateAsymmetricJWT(jwk JWK, claims CustomClaims) (string, error) {
privateKey, err := jwkToPrivateKey(jwk)
if err != nil {
return "", errors.Errorf("failed to convert JWK to private key: %w", err)
}

// Determine signing method based on algorithm
var token *jwt.Token
switch jwk.Algorithm {
case AlgRS256:
token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
case AlgES256:
token = jwt.NewWithClaims(jwt.SigningMethodES256, claims)
default:
return "", errors.Errorf("unsupported algorithm: %s", jwk.Algorithm)
}

if jwk.KeyID != uuid.Nil {
token.Header["kid"] = jwk.KeyID.String()
}

tokenString, err := token.SignedString(privateKey)
if err != nil {
return "", errors.Errorf("failed to sign JWT: %w", err)
}

return tokenString, nil
}

// jwkToPrivateKey converts a JWK to a crypto.PrivateKey
func jwkToPrivateKey(jwk JWK) (crypto.PrivateKey, error) {
switch jwk.KeyType {
case "RSA":
return jwkToRSAPrivateKey(jwk)
case "EC":
return jwkToECDSAPrivateKey(jwk)
default:
return nil, errors.Errorf("unsupported key type: %s", jwk.KeyType)
}
}

// jwkToRSAPrivateKey converts a JWK to an RSA private key
func jwkToRSAPrivateKey(jwk JWK) (*rsa.PrivateKey, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.Modulus)
if err != nil {
return nil, errors.Errorf("failed to decode modulus: %w", err)
}
n := new(big.Int).SetBytes(nBytes)

eBytes, err := base64.RawURLEncoding.DecodeString(jwk.Exponent)
if err != nil {
return nil, errors.Errorf("failed to decode exponent: %w", err)
}
e := int(new(big.Int).SetBytes(eBytes).Int64())

dBytes, err := base64.RawURLEncoding.DecodeString(jwk.PrivateExponent)
if err != nil {
return nil, errors.Errorf("failed to decode private exponent: %w", err)
}
d := new(big.Int).SetBytes(dBytes)

pBytes, err := base64.RawURLEncoding.DecodeString(jwk.FirstPrimeFactor)
if err != nil {
return nil, errors.Errorf("failed to decode first prime factor: %w", err)
}
p := new(big.Int).SetBytes(pBytes)

qBytes, err := base64.RawURLEncoding.DecodeString(jwk.SecondPrimeFactor)
if err != nil {
return nil, errors.Errorf("failed to decode second prime factor: %w", err)
}
q := new(big.Int).SetBytes(qBytes)

return &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: e},
D: d,
Primes: []*big.Int{p, q},
}, nil
}

// jwkToECDSAPrivateKey converts a JWK to an ECDSA private key
func jwkToECDSAPrivateKey(jwk JWK) (*ecdsa.PrivateKey, error) {
// Only support P-256 curve for ES256
if jwk.Curve != "P-256" {
return nil, errors.Errorf("unsupported curve: %s", jwk.Curve)
}

xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, errors.Errorf("failed to decode x coordinate: %w", err)
}
x := new(big.Int).SetBytes(xBytes)

yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, errors.Errorf("failed to decode y coordinate: %w", err)
}
y := new(big.Int).SetBytes(yBytes)

dBytes, err := base64.RawURLEncoding.DecodeString(jwk.PrivateExponent)
if err != nil {
return nil, errors.Errorf("failed to decode private key: %w", err)
}
d := new(big.Int).SetBytes(dBytes)

return &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: elliptic.P256(),
X: x,
Y: y,
},
D: d,
}, nil
}
Loading