Skip to content

Commit c2367bd

Browse files
author
James Munnelly
committed
Extend client-go csr package to invalidate CSRs based on signerName
1 parent c86aec0 commit c2367bd

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

staging/src/k8s.io/client-go/util/certificate/csr/BUILD

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package(default_visibility = ["//visibility:public"])
22

3-
load("@io_bazel_rules_go//go:def.bzl", "go_library")
3+
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
44

55
go_library(
66
name = "go_default_library",
@@ -35,3 +35,13 @@ filegroup(
3535
srcs = [":package-srcs"],
3636
tags = ["automanaged"],
3737
)
38+
39+
go_test(
40+
name = "go_default_test",
41+
srcs = ["csr_test.go"],
42+
embed = [":go_default_library"],
43+
deps = [
44+
"//staging/src/k8s.io/api/certificates/v1beta1:go_default_library",
45+
"//vendor/k8s.io/utils/pointer:go_default_library",
46+
],
47+
)

staging/src/k8s.io/client-go/util/certificate/csr/csr.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ func ensureCompatible(new, orig *certificates.CertificateSigningRequest, private
150150
if !reflect.DeepEqual(newCSR.Subject, origCSR.Subject) {
151151
return fmt.Errorf("csr subjects differ: new: %#v, orig: %#v", newCSR.Subject, origCSR.Subject)
152152
}
153+
if new.Spec.SignerName != nil && orig.Spec.SignerName != nil && *new.Spec.SignerName != *orig.Spec.SignerName {
154+
return fmt.Errorf("csr signerNames differ: new %q, orig: %q", *new.Spec.SignerName, *orig.Spec.SignerName)
155+
}
153156
signer, ok := privateKey.(crypto.Signer)
154157
if !ok {
155158
return fmt.Errorf("privateKey is not a signer")
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
Copyright 2020 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package csr
18+
19+
import (
20+
"crypto"
21+
"crypto/rand"
22+
"crypto/rsa"
23+
"crypto/x509"
24+
"crypto/x509/pkix"
25+
"encoding/pem"
26+
"testing"
27+
28+
certificates "k8s.io/api/certificates/v1beta1"
29+
"k8s.io/utils/pointer"
30+
)
31+
32+
func TestEnsureCompatible(t *testing.T) {
33+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
req := pemWithPrivateKey(privateKey)
38+
39+
tests := map[string]struct {
40+
new, orig *certificates.CertificateSigningRequest
41+
privateKey interface{}
42+
err string
43+
}{
44+
"nil signerName on 'new' matches any signerName on 'orig'": {
45+
new: &certificates.CertificateSigningRequest{
46+
Spec: certificates.CertificateSigningRequestSpec{
47+
Request: req,
48+
},
49+
},
50+
orig: &certificates.CertificateSigningRequest{
51+
Spec: certificates.CertificateSigningRequestSpec{
52+
Request: req,
53+
SignerName: pointer.StringPtr("example.com/test"),
54+
},
55+
},
56+
privateKey: privateKey,
57+
},
58+
"nil signerName on 'orig' matches any signerName on 'new'": {
59+
new: &certificates.CertificateSigningRequest{
60+
Spec: certificates.CertificateSigningRequestSpec{
61+
Request: req,
62+
SignerName: pointer.StringPtr("example.com/test"),
63+
},
64+
},
65+
orig: &certificates.CertificateSigningRequest{
66+
Spec: certificates.CertificateSigningRequestSpec{
67+
Request: req,
68+
},
69+
},
70+
privateKey: privateKey,
71+
},
72+
"signerName on 'orig' matches signerName on 'new'": {
73+
new: &certificates.CertificateSigningRequest{
74+
Spec: certificates.CertificateSigningRequestSpec{
75+
Request: req,
76+
SignerName: pointer.StringPtr("example.com/test"),
77+
},
78+
},
79+
orig: &certificates.CertificateSigningRequest{
80+
Spec: certificates.CertificateSigningRequestSpec{
81+
Request: req,
82+
SignerName: pointer.StringPtr("example.com/test"),
83+
},
84+
},
85+
privateKey: privateKey,
86+
},
87+
"signerName on 'orig' does not match signerName on 'new'": {
88+
new: &certificates.CertificateSigningRequest{
89+
Spec: certificates.CertificateSigningRequestSpec{
90+
Request: req,
91+
SignerName: pointer.StringPtr("example.com/test"),
92+
},
93+
},
94+
orig: &certificates.CertificateSigningRequest{
95+
Spec: certificates.CertificateSigningRequestSpec{
96+
Request: req,
97+
SignerName: pointer.StringPtr("example.com/not-test"),
98+
},
99+
},
100+
privateKey: privateKey,
101+
err: `csr signerNames differ: new "example.com/test", orig: "example.com/not-test"`,
102+
},
103+
}
104+
for name, test := range tests {
105+
t.Run(name, func(t *testing.T) {
106+
err := ensureCompatible(test.new, test.orig, test.privateKey)
107+
if err != nil && test.err == "" {
108+
t.Errorf("expected no error, but got: %v", err)
109+
} else if err != nil && test.err != err.Error() {
110+
t.Errorf("error did not match as expected, got=%v, exp=%s", err, test.err)
111+
}
112+
if err == nil && test.err != "" {
113+
t.Errorf("expected to get an error but got none")
114+
}
115+
})
116+
}
117+
}
118+
119+
func pemWithPrivateKey(pk crypto.PrivateKey) []byte {
120+
template := &x509.CertificateRequest{
121+
Subject: pkix.Name{
122+
CommonName: "something",
123+
Organization: []string{"test"},
124+
},
125+
}
126+
return pemWithTemplate(template, pk)
127+
}
128+
129+
func pemWithTemplate(template *x509.CertificateRequest, key crypto.PrivateKey) []byte {
130+
csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, key)
131+
if err != nil {
132+
panic(err)
133+
}
134+
135+
csrPemBlock := &pem.Block{
136+
Type: "CERTIFICATE REQUEST",
137+
Bytes: csrDER,
138+
}
139+
140+
p := pem.EncodeToMemory(csrPemBlock)
141+
if p == nil {
142+
panic("invalid pem block")
143+
}
144+
145+
return p
146+
}

0 commit comments

Comments
 (0)