Skip to content

Commit 8f58775

Browse files
committed
Add initial version of the handshake command
This commit adds the command `step certificate handshake`. This command performs a handshake and displays details about it.
1 parent bfab777 commit 8f58775

File tree

3 files changed

+293
-0
lines changed

3 files changed

+293
-0
lines changed

command/certificate/certificate.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ $ step certificate uninstall root-ca.crt
8686
createCommand(),
8787
formatCommand(),
8888
inspectCommand(),
89+
handshakeCommand(),
8990
fingerprintCommand(),
9091
lintCommand(),
9192
needsRenewalCommand(),

command/certificate/handshake.go

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
package certificate
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"encoding/pem"
7+
"fmt"
8+
"net"
9+
"reflect"
10+
11+
"github.com/smallstep/cli-utils/errs"
12+
"github.com/smallstep/cli/flags"
13+
"github.com/smallstep/cli/internal/cryptoutil"
14+
"github.com/urfave/cli"
15+
"go.step.sm/crypto/pemutil"
16+
"go.step.sm/crypto/x509util"
17+
)
18+
19+
func handshakeCommand() cli.Command {
20+
return cli.Command{
21+
Name: "handshake",
22+
Action: cli.ActionFunc(handshakeAction),
23+
Usage: `print handshake details`,
24+
UsageText: `**step certificate handshake** <url>`,
25+
Description: `**step certificate handshake** displays detailed handshake information for a TLS connection.`,
26+
Flags: []cli.Flag{
27+
flags.ServerName,
28+
cli.StringFlag{
29+
Name: "tls",
30+
Usage: `Defines the TLS <version> in the handshake. By default it will use TLS 1.3 or TLS 1.2.
31+
: The supported versions are **1.3**, **1.2**, **1.1**, and **1.0**.`,
32+
},
33+
cli.StringFlag{
34+
Name: "cert",
35+
Usage: `The path to the <file> containing the client certificate to use.`,
36+
},
37+
cli.StringFlag{
38+
Name: "key",
39+
Usage: `The path to the <file> or KMS <uri> containing the certificate key to use.`,
40+
},
41+
cli.StringFlag{
42+
Name: "roots",
43+
Usage: `Root certificate(s) that will be used to verify the
44+
authenticity of the remote server.
45+
46+
: <roots> is a case-sensitive string and may be one of:
47+
48+
**file**
49+
: Relative or full path to a file. All certificates in the file will be used for path validation.
50+
51+
**list of files**
52+
: Comma-separated list of relative or full file paths. Every PEM encoded certificate from each file will be used for path validation.
53+
54+
**directory**
55+
: Relative or full path to a directory. Every PEM encoded certificate from each file in the directory will be used for path validation.`,
56+
},
57+
58+
cli.StringFlag{
59+
Name: "password-file",
60+
Usage: "The path to the <file> containing the password to decrypt the private key.",
61+
},
62+
cli.BoolFlag{
63+
Name: "chain",
64+
Usage: "Print only the chain of verified certificates.",
65+
},
66+
cli.BoolFlag{
67+
Name: "peer",
68+
Usage: `Print only the peer certificates sent by the server.`,
69+
},
70+
cli.BoolFlag{
71+
Name: "insecure",
72+
Usage: `Use an insecure client to retrieve a remote peer certificate. Useful for
73+
debugging invalid certificates remotely.`,
74+
},
75+
},
76+
}
77+
}
78+
79+
func handshakeAction(ctx *cli.Context) error {
80+
if err := errs.NumberOfArguments(ctx, 1); err != nil {
81+
return err
82+
}
83+
84+
var (
85+
addr = ctx.Args().First()
86+
tlsVersion = ctx.String("tls")
87+
roots = ctx.String("roots")
88+
serverName = ctx.String("servername")
89+
certFile = ctx.String("cert")
90+
keyFile = ctx.String("key")
91+
passwordFile = ctx.String("password-file")
92+
printChains = ctx.Bool("chain")
93+
printPeer = ctx.Bool("peer")
94+
insecure = ctx.Bool("insecure")
95+
rootCAs *x509.CertPool
96+
err error
97+
)
98+
99+
switch {
100+
case certFile != "" && keyFile == "":
101+
return errs.RequiredWithFlag(ctx, "cert", "key")
102+
case keyFile != "" && certFile == "":
103+
return errs.RequiredWithFlag(ctx, "key", "cert")
104+
}
105+
106+
// Parse address
107+
if u, ok, err := trimURL(addr); err != nil {
108+
return err
109+
} else if ok {
110+
addr = u
111+
}
112+
if _, _, err := net.SplitHostPort(addr); err != nil {
113+
addr = net.JoinHostPort(addr, "443")
114+
}
115+
116+
// Load certificate and if
117+
var certificates []tls.Certificate
118+
if certFile != "" && keyFile != "" {
119+
opts := []pemutil.Options{}
120+
if passwordFile != "" {
121+
opts = append(opts, pemutil.WithPasswordFile(passwordFile))
122+
}
123+
crt, err := cryptoutil.LoadTLSCertificate(certFile, keyFile, opts...)
124+
if err != nil {
125+
return err
126+
}
127+
certificates = []tls.Certificate{crt}
128+
}
129+
130+
// Get the list of roots used to validate the certificate.
131+
if roots != "" {
132+
rootCAs, err = x509util.ReadCertPool(roots)
133+
if err != nil {
134+
return fmt.Errorf("error loading root certificate pool from %q: %w", roots, err)
135+
}
136+
} else {
137+
rootCAs, err = x509.SystemCertPool()
138+
if err != nil {
139+
return fmt.Errorf("error loading the system cert pool: %w", err)
140+
}
141+
}
142+
143+
// Get the tls version to use. Defaults to TLS 1.2+
144+
minVersion, maxVersion, err := getTLSVersions(tlsVersion)
145+
if err != nil {
146+
return err
147+
}
148+
149+
tlsConfig := &tls.Config{
150+
MinVersion: minVersion,
151+
MaxVersion: maxVersion,
152+
RootCAs: rootCAs,
153+
InsecureSkipVerify: insecure,
154+
ServerName: serverName,
155+
Certificates: certificates,
156+
}
157+
158+
cs, err := tlsDialWithFallback(addr, tlsConfig)
159+
if err != nil {
160+
return err
161+
}
162+
163+
// Print only the list of verified chains
164+
if printChains {
165+
for _, chain := range cs.VerifiedChains {
166+
for _, crt := range chain {
167+
fmt.Print(string(pem.EncodeToMemory(&pem.Block{
168+
Type: "CERTIFICATE",
169+
Bytes: crt.Raw,
170+
})))
171+
}
172+
}
173+
return nil
174+
}
175+
176+
// Print only the peer certificates
177+
if printPeer {
178+
for _, crt := range cs.PeerCertificates {
179+
fmt.Print(string(pem.EncodeToMemory(&pem.Block{
180+
Type: "CERTIFICATE", Bytes: crt.Raw,
181+
})))
182+
}
183+
return nil
184+
}
185+
186+
// Check if the certificates is verified
187+
var intermediates *x509.CertPool
188+
if len(cs.PeerCertificates) > 1 {
189+
intermediates = x509.NewCertPool()
190+
for _, crt := range cs.PeerCertificates[1:] {
191+
intermediates.AddCert(crt)
192+
}
193+
}
194+
_, verifyErr := cs.PeerCertificates[0].Verify(x509.VerifyOptions{
195+
Roots: rootCAs,
196+
Intermediates: intermediates,
197+
DNSName: serverName,
198+
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
199+
})
200+
201+
connStateValue := reflect.ValueOf(cs)
202+
curveIDField := connStateValue.FieldByName("testingOnlyCurveID")
203+
204+
fmt.Printf("Server Name: %s\n", cs.ServerName)
205+
fmt.Printf("Version: %s\n", tls.VersionName(cs.Version))
206+
fmt.Printf("Cipher Suite: %s\n", tls.CipherSuiteName(cs.CipherSuite))
207+
fmt.Printf("KEM: %s\n", curveIDName(curveIDField.Uint()))
208+
fmt.Printf("Insecure: %v\n", tlsConfig.InsecureSkipVerify)
209+
fmt.Printf("Verified: %v\n", verifyErr == nil)
210+
211+
return nil
212+
}
213+
214+
func curveIDName(curveID uint64) string {
215+
switch tls.CurveID(curveID) {
216+
case tls.CurveP256:
217+
return "P-256"
218+
case tls.CurveP384:
219+
return "P-384"
220+
case tls.CurveP521:
221+
return "P-521"
222+
case tls.X25519:
223+
return "X25519"
224+
case tls.X25519MLKEM768:
225+
return "X25519MLKEM768"
226+
default:
227+
return "Unknown"
228+
}
229+
}
230+
231+
func getTLSVersions(s string) (uint16, uint16, error) {
232+
switch s {
233+
case "":
234+
return tls.VersionTLS12, 0, nil
235+
case "1.3":
236+
return tls.VersionTLS13, tls.VersionTLS13, nil
237+
case "1.2":
238+
return tls.VersionTLS12, tls.VersionTLS12, nil
239+
case "1.1":
240+
return tls.VersionTLS11, tls.VersionTLS11, nil
241+
case "1.0":
242+
return tls.VersionTLS10, tls.VersionTLS10, nil
243+
default:
244+
return 0, 0, fmt.Errorf("unsupported TLS version %q", s)
245+
}
246+
}
247+
248+
func tlsDialWithFallback(addr string, tlsConfig *tls.Config) (tls.ConnectionState, error) {
249+
conn, err := tls.Dial("tcp", addr, tlsConfig)
250+
if err != nil {
251+
if tlsConfig.InsecureSkipVerify {
252+
return tls.ConnectionState{}, fmt.Errorf("error connecting to %q: %w", addr, err)
253+
}
254+
tlsConfig.InsecureSkipVerify = true
255+
return tlsDialWithFallback(addr, tlsConfig)
256+
}
257+
defer conn.Close()
258+
conn.Handshake()
259+
return conn.ConnectionState(), nil
260+
}

internal/cryptoutil/cryptoutil.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/ed25519"
77
"crypto/elliptic"
88
"crypto/rsa"
9+
"crypto/tls"
910
"crypto/x509"
1011
"encoding/base64"
1112
"errors"
@@ -115,6 +116,37 @@ func LoadCertificate(kmsURI, certPath string) ([]*x509.Certificate, error) {
115116
return cert, nil
116117
}
117118

119+
// LoadTLSCertificate returns a [tls.Certificate] from a certificate fine and a
120+
// key in a file or in a KMS.
121+
func LoadTLSCertificate(certFile, keyName string, opts ...pemutil.Options) (tls.Certificate, error) {
122+
bundle, err := pemutil.ReadCertificateBundle(certFile)
123+
if err != nil {
124+
return tls.Certificate{}, err
125+
}
126+
127+
var signer crypto.Signer
128+
if IsKMS(keyName) {
129+
if signer, err = CreateSigner(keyName, keyName, opts...); err != nil {
130+
return tls.Certificate{}, err
131+
}
132+
} else {
133+
if signer, err = CreateSigner("", keyName, opts...); err != nil {
134+
return tls.Certificate{}, err
135+
}
136+
}
137+
138+
cert := make([][]byte, len(bundle))
139+
for i, crt := range bundle {
140+
cert[i] = crt.Raw
141+
}
142+
143+
return tls.Certificate{
144+
Certificate: cert,
145+
PrivateKey: signer,
146+
Leaf: bundle[0],
147+
}, nil
148+
}
149+
118150
// LoadJSONWebKey returns a jose.JSONWebKey from a KMS or a file.
119151
func LoadJSONWebKey(kmsURI, name string, opts ...jose.Option) (*jose.JSONWebKey, error) {
120152
if kmsURI == "" {

0 commit comments

Comments
 (0)