Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions kms/platform/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto"
"crypto/x509"
"errors"
"fmt"
"net/url"
"strings"

Expand All @@ -28,16 +29,35 @@ const (

type kmsURI struct {
uri *uri.URI
backend apiv1.Type
name string
hw bool
extraValues url.Values
}

func isDefaultKey(k string) bool {
return k == nameKey ||
k == hwKey ||
k == backendKey
return k == nameKey || k == hwKey
}

func getBackend(opts apiv1.Options) (apiv1.Type, error) {
if opts.URI == "" {
return opts.Type, nil
}

typ := opts.Type
if opts.URI != "" {
u, err := uri.ParseWithScheme(Scheme, opts.URI)
if err != nil {
return apiv1.DefaultKMS, err
}
if backend := u.Get(backendKey); backend != "" {
typ = apiv1.Type(strings.ToLower(backend))
if opts.Type != apiv1.DefaultKMS && opts.Type != typ {
return apiv1.DefaultKMS, fmt.Errorf("options type %q and URI backend %q do not match", opts, typ)
}
}
}

return typ, nil
}

func parseURI(rawuri string) (*kmsURI, error) {
Expand All @@ -55,7 +75,6 @@ func parseURI(rawuri string) (*kmsURI, error) {

return &kmsURI{
uri: u,
backend: apiv1.Type(strings.ToLower(u.Get(backendKey))),
name: u.Get(nameKey),
hw: u.GetBool(hwKey),
extraValues: extraValues,
Expand Down
6 changes: 3 additions & 3 deletions kms/platform/kms_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) {
return newMacKMS(ctx, opts)
}

u, err := parseURI(opts.URI)
backend, err := getBackend(opts)
if err != nil {
return nil, err
}

switch u.backend {
switch backend {
case apiv1.TPMKMS:
return newTPMKMS(ctx, opts)
case apiv1.SoftKMS:
return newSoftKMS(ctx, opts)
case apiv1.DefaultKMS, apiv1.MacKMS:
return newMacKMS(ctx, opts)
default:
return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend)
return nil, fmt.Errorf("failed parsing options: unsupported backend %q", backend)
}
}

Expand Down
6 changes: 3 additions & 3 deletions kms/platform/kms_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) {
return newTPMKMS(ctx, opts)
}

u, err := parseURI(opts.URI)
backend, err := getBackend(opts)
if err != nil {
return nil, err
}

switch u.backend {
switch backend {
case apiv1.SoftKMS:
return newSoftKMS(ctx, opts)
case apiv1.DefaultKMS, apiv1.TPMKMS:
return newTPMKMS(ctx, opts)
default:
return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend)
return nil, fmt.Errorf("failed parsing options: unsupported backend %q", backend)
}
}
94 changes: 93 additions & 1 deletion kms/platform/kms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/kms/apiv1"
"go.step.sm/crypto/kms/uri"
Expand Down Expand Up @@ -1219,3 +1218,96 @@ func TestKMS_SearchKeys(t *testing.T) {
})
}
}

func Test_getBackend(t *testing.T) {
type args struct {
opts apiv1.Options
}
tests := []struct {
name string
args args
want apiv1.Type
assertion assert.ErrorAssertionFunc
}{
{"ok", args{apiv1.Options{}}, apiv1.DefaultKMS, assert.NoError},
{"ok from type", args{apiv1.Options{Type: apiv1.TPMKMS}}, apiv1.TPMKMS, assert.NoError},
{"ok from uri", args{apiv1.Options{URI: "kms:backend=softkms"}}, apiv1.SoftKMS, assert.NoError},
{"ok from both", args{apiv1.Options{Type: apiv1.CAPIKMS, URI: "kms:backend=capi"}}, apiv1.CAPIKMS, assert.NoError},
{"fail uri", args{apiv1.Options{URI: "softkms:backend=softkms"}}, apiv1.DefaultKMS, assert.Error},
{"fail mismatch", args{apiv1.Options{Type: apiv1.TPMKMS, URI: "kms:backend=softkms"}}, apiv1.DefaultKMS, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getBackend(tt.args.opts)
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func Test_parseURI(t *testing.T) {
mustURI := func(scheme, opaque, rawquery string, values url.Values) *uri.URI {
u := uri.New(scheme, values)
u.Opaque = opaque
u.RawQuery = rawquery
return u
}

type args struct {
rawuri string
}
tests := []struct {
name string
args args
want *kmsURI
assertion assert.ErrorAssertionFunc
}{
{"ok", args{"kms:"}, &kmsURI{
uri: uri.New(Scheme, url.Values{}),
extraValues: url.Values{},
}, assert.NoError},
{"ok with name", args{"kms:name=foo"}, &kmsURI{
uri: mustURI(Scheme, "name=foo", "", url.Values{"name": []string{"foo"}}),
name: "foo",
extraValues: url.Values{},
}, assert.NoError},
{"ok with hw", args{"kms:name=foo;hw=true"}, &kmsURI{
uri: mustURI(Scheme, "name=foo;hw=true", "", url.Values{
"name": []string{"foo"},
"hw": []string{"true"},
}),
name: "foo",
hw: true,
extraValues: url.Values{},
}, assert.NoError},
{"ok with hw on query", args{"kms:name=foo?hw=true"}, &kmsURI{
uri: mustURI(Scheme, "name=foo", "hw=true", url.Values{"name": []string{"foo"}}),
name: "foo",
hw: true,
extraValues: url.Values{},
}, assert.NoError},
{"ok with extra values", args{"kms:name=foo;hw=true;foo=bar;backend=softkms?bar=zar&foo=qux"}, &kmsURI{
uri: mustURI(Scheme, "name=foo;hw=true;foo=bar;backend=softkms", "bar=zar&foo=qux", url.Values{
"name": []string{"foo"},
"hw": []string{"true"},
"foo": []string{"bar"},
"backend": []string{"softkms"},
}),
name: "foo",
hw: true,
extraValues: url.Values{
"backend": []string{"softkms"},
"foo": []string{"bar", "qux"},
"bar": []string{"zar"},
},
}, assert.NoError},
{"fail parse", args{"tpmkms:name=foo"}, nil, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseURI(tt.args.rawuri)
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}
30 changes: 17 additions & 13 deletions kms/platform/kms_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,24 @@ import (
const tpmProvider = "Microsoft Platform Crypto Provider"

func newKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) {
if opts.URI == "" {
opts.URI = withEnableCNG(nil)
return newTPMKMS(ctx, opts)
}

u, err := parseURI(opts.URI)
backend, err := getBackend(opts)
if err != nil {
return nil, err
}

switch u.backend {
switch backend {
case apiv1.CAPIKMS:
return newCAPIKMS(ctx, opts)
case apiv1.SoftKMS:
return newSoftKMS(ctx, opts)
case apiv1.DefaultKMS, apiv1.TPMKMS:
opts.URI = withEnableCNG(u.uri)
// Add enable-cng=true if necessary
if opts.URI, err = withEnableCNG(opts.URI); err != nil {
return nil, err
}
return newTPMKMS(ctx, opts)
default:
return nil, fmt.Errorf("failed parsing %q: unsupported backend %q", opts.URI, u.backend)
return nil, fmt.Errorf("failed parsing options: unsupported backend %q", backend)
}
}

Expand All @@ -61,14 +59,20 @@ func newCAPIKMS(ctx context.Context, opts apiv1.Options) (*KMS, error) {
}, nil
}

func withEnableCNG(u *uri.URI) string {
if u == nil {
return "kms:enable-cng=true"
func withEnableCNG(rawuri string) (string, error) {
if rawuri == "" {
return "kms:enable-cng=true", nil
}

u, err := uri.ParseWithScheme(Scheme, rawuri)
if err != nil {
return "", err
}

if !u.Has("enable-cng") {
u.Set("enable-cng", "true")
}
return u.String()
return u.String(), nil
}

func transformToCAPIKMS(rawuri string) (string, error) {
Expand Down