Skip to content

Commit d4623d9

Browse files
authored
fix(collector): Let pgx library parse TLS parameters (#1390)
* fix(collector): Let pgx library parse TLS parameters This allows the collector to respect the sslmode parameters Fix: #1163 * Add comment * Improve postgres collector test
1 parent bc48568 commit d4623d9

File tree

4 files changed

+115
-25
lines changed

4 files changed

+115
-25
lines changed

pkg/collect/cluster_resources.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,16 @@ func (c *CollectClusterResources) Collect(progressChan chan<- interface{}) (Coll
120120
var namespaceNames []string
121121
if len(c.Collector.Namespaces) > 0 {
122122
namespaces, namespaceErrors := getNamespaces(ctx, client, c.Collector.Namespaces)
123-
klog.V(4).Infof("checking for namespaces access: %s", string(namespaces))
124123
namespaceNames = c.Collector.Namespaces
125124
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s.json", constants.CLUSTER_RESOURCES_NAMESPACES)), bytes.NewBuffer(namespaces))
126125
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s-errors.json", constants.CLUSTER_RESOURCES_NAMESPACES)), marshalErrors(namespaceErrors))
127126
} else if c.Namespace != "" {
128127
namespace, namespaceErrors := getNamespace(ctx, client, c.Namespace)
129-
klog.V(4).Infof("checking for namespace access: %s", string(namespace))
130128
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s.json", constants.CLUSTER_RESOURCES_NAMESPACES)), bytes.NewBuffer(namespace))
131129
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s-errors.json", constants.CLUSTER_RESOURCES_NAMESPACES)), marshalErrors(namespaceErrors))
132130
namespaceNames = append(namespaceNames, c.Namespace)
133131
} else {
134132
namespaces, namespaceList, namespaceErrors := getAllNamespaces(ctx, client)
135-
klog.V(4).Infof("checking for all namespaces access: %s", string(namespaces))
136133
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s.json", constants.CLUSTER_RESOURCES_NAMESPACES)), bytes.NewBuffer(namespaces))
137134
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, fmt.Sprintf("%s-errors.json", constants.CLUSTER_RESOURCES_NAMESPACES)), marshalErrors(namespaceErrors))
138135
if namespaceList != nil {
@@ -146,6 +143,7 @@ func (c *CollectClusterResources) Collect(progressChan chan<- interface{}) (Coll
146143
reviewStatuses, reviewStatusErrors := getSelfSubjectRulesReviews(ctx, client, namespaceNames)
147144

148145
// auth cani
146+
klog.V(2).Infof("checking [%s] namespaces for permissions to collect resources", strings.Join(namespaceNames, ", "))
149147
authCanI := authCanI(reviewStatuses, namespaceNames)
150148
for k, v := range authCanI {
151149
output.SaveResult(c.BundlePath, path.Join(constants.CLUSTER_RESOURCES_DIR, constants.CLUSTER_RESOURCES_AUTH_CANI, k), bytes.NewBuffer(v))
@@ -160,8 +158,12 @@ func (c *CollectClusterResources) Collect(progressChan chan<- interface{}) (Coll
160158
filteredNamespaces = append(filteredNamespaces, ns)
161159
}
162160
}
161+
if len(filteredNamespaces) != len(namespaceNames) {
162+
klog.V(2).Infof("filtered namespaces down to [%s] after evaluating permissions", strings.Join(filteredNamespaces, ", "))
163+
} else {
164+
klog.V(2).Infof("no namespaces filtered out after evaluating permissions")
165+
}
163166
namespaceNames = filteredNamespaces
164-
klog.V(4).Infof("filtered to namespaceNames %s", namespaceNames)
165167
}
166168

167169
// pods

pkg/collect/postgres.go

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"os"
9+
"path/filepath"
810
"regexp"
911

1012
"github.com/jackc/pgx/v5"
1113
"github.com/pkg/errors"
1214
troubleshootv1beta2 "github.com/replicatedhq/troubleshoot/pkg/apis/troubleshoot/v1beta2"
1315
"k8s.io/client-go/kubernetes"
1416
"k8s.io/client-go/rest"
17+
"k8s.io/klog/v2"
1518
)
1619

1720
type CollectPostgres struct {
@@ -37,20 +40,74 @@ func (c *CollectPostgres) createConnectConfig() (*pgx.ConnConfig, error) {
3740
return nil, errors.New("postgres uri cannot be empty")
3841
}
3942

40-
cfg, err := pgx.ParseConfig(c.Collector.URI)
41-
if err != nil {
42-
return nil, errors.Wrap(err, "failed to parse postgres config")
43-
}
44-
4543
if c.Collector.TLS != nil {
46-
tlsCfg, err := createTLSConfig(c.Context, c.Client, c.Collector.TLS)
44+
klog.V(2).Infof("Connecting to postgres with TLS client config")
45+
// Set the libpq TLS environment variables since pgx parses them to
46+
// create the TLS configuration (tls.Config instance) to connect with
47+
// https://www.postgresql.org/docs/current/libpq-envars.html
48+
caCert, clientCert, clientKey, err := getTLSParamTriplet(c.Context, c.Client, c.Collector.TLS)
4749
if err != nil {
4850
return nil, err
4951
}
5052

51-
tlsCfg.ServerName = cfg.Host
52-
cfg.TLSConfig = tlsCfg
53+
// Drop the TLS params to files and set the paths to their
54+
// respective environment variables
55+
// The environment variables are unset after the connection config
56+
// is created. Their respective files are deleted as well.
57+
tmpdir, err := os.MkdirTemp("", "ts-postgres-collector")
58+
if err != nil {
59+
return nil, errors.Wrap(err, "failed to create temp dir to store postgres collector TLS files")
60+
}
61+
defer os.RemoveAll(tmpdir)
62+
63+
if caCert != "" {
64+
caCertPath := filepath.Join(tmpdir, "ca.crt")
65+
err = os.WriteFile(caCertPath, []byte(caCert), 0644)
66+
if err != nil {
67+
return nil, errors.Wrap(err, "failed to write ca cert to file")
68+
}
69+
err = os.Setenv("PGSSLROOTCERT", caCertPath)
70+
if err != nil {
71+
return nil, errors.Wrap(err, "failed to set PGSSLROOTCERT environment variable")
72+
}
73+
klog.V(2).Infof("'PGSSLROOTCERT' environment variable set to %q", caCertPath)
74+
defer os.Unsetenv("PGSSLROOTCERT")
75+
}
76+
77+
if clientCert != "" {
78+
clientCertPath := filepath.Join(tmpdir, "client.crt")
79+
err = os.WriteFile(clientCertPath, []byte(clientCert), 0644)
80+
if err != nil {
81+
return nil, errors.Wrap(err, "failed to write client cert to file")
82+
}
83+
err = os.Setenv("PGSSLCERT", clientCertPath)
84+
if err != nil {
85+
return nil, errors.Wrap(err, "failed to set PGSSLCERT environment variable")
86+
}
87+
klog.V(2).Infof("'PGSSLCERT' environment variable set to %q", clientCertPath)
88+
defer os.Unsetenv("PGSSLCERT")
89+
}
90+
91+
if clientKey != "" {
92+
clientKeyPath := filepath.Join(tmpdir, "client.key")
93+
err = os.WriteFile(clientKeyPath, []byte(clientKey), 0600)
94+
if err != nil {
95+
return nil, errors.Wrap(err, "failed to write client key to file")
96+
}
97+
err = os.Setenv("PGSSLKEY", clientKeyPath)
98+
if err != nil {
99+
return nil, errors.Wrap(err, "failed to set PGSSLKEY environment variable")
100+
}
101+
klog.V(2).Infof("'PGSSLKEY' environment variable set to %q", clientKeyPath)
102+
defer os.Unsetenv("PGSSLKEY")
103+
}
104+
}
105+
106+
cfg, err := pgx.ParseConfig(c.Collector.URI)
107+
if err != nil {
108+
return nil, errors.Wrap(err, "failed to parse postgres config")
53109
}
110+
klog.V(2).Infof("Successfully parsed postgres config")
54111

55112
return cfg, nil
56113
}
@@ -74,8 +131,10 @@ func (c *CollectPostgres) Collect(progressChan chan<- interface{}) (CollectorRes
74131

75132
conn, err := c.connect()
76133
if err != nil {
134+
klog.V(2).Infof("Postgres connection error: %s", err.Error())
77135
databaseConnection.Error = err.Error()
78136
} else {
137+
klog.V(2).Info("Successfully connected to postgres")
79138
defer conn.Close(c.Context)
80139

81140
query := `select version()`

pkg/collect/postgres_test.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package collect
22

33
import (
44
"context"
5+
"crypto/rsa"
6+
"crypto/x509"
7+
"encoding/pem"
58
"testing"
69

710
"github.com/replicatedhq/troubleshoot/internal/testutils"
@@ -100,7 +103,7 @@ func TestCollectPostgres_createConnectConfigTLS(t *testing.T) {
100103
Client: k8sClient,
101104
Context: context.Background(),
102105
Collector: &v1beta2.Database{
103-
URI: "postgresql://user:password@my-pghost:5432/defaultdb?sslmode=require",
106+
URI: "postgresql://user:password@my-pghost:5432/defaultdb?sslmode=verify-full",
104107
TLS: &v1beta2.TLSParams{
105108
CACert: testutils.GetTestFixture(t, "db/ca.pem"),
106109
ClientCert: testutils.GetTestFixture(t, "db/client.pem"),
@@ -113,7 +116,21 @@ func TestCollectPostgres_createConnectConfigTLS(t *testing.T) {
113116
assert.NoError(t, err)
114117
assert.NotNil(t, connCfg)
115118
assert.Equal(t, connCfg.Host, "my-pghost")
116-
assert.NotNil(t, connCfg.TLSConfig.Certificates)
119+
120+
// Check client cert
121+
require.Len(t, connCfg.TLSConfig.Certificates, 1)
122+
require.Len(t, connCfg.TLSConfig.Certificates[0].Certificate, 1)
123+
cert := connCfg.TLSConfig.Certificates[0]
124+
clientCert, err := x509.ParseCertificate(cert.Certificate[0])
125+
require.NoError(t, err)
126+
assert.Equal(t, "CN=client,L=Didcot,ST=Oxfordshire,C=UK", clientCert.Subject.String())
127+
128+
// Check client key
129+
block, _ := pem.Decode([]byte(testutils.GetTestFixture(t, "db/client-key.pem")))
130+
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
131+
require.NoError(t, err)
132+
assert.True(t, key.Equal(cert.PrivateKey.(*rsa.PrivateKey)))
133+
117134
assert.NotNil(t, connCfg.TLSConfig.RootCAs)
118135
assert.False(t, connCfg.TLSConfig.InsecureSkipVerify)
119136
}

pkg/collect/util.go

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,30 @@ func listNodesInSelector(ctx context.Context, client *kubernetes.Clientset, sele
139139

140140
nodes, err := client.CoreV1().Nodes().List(ctx, listOptions)
141141
if err != nil {
142-
return nil, fmt.Errorf("Can't get the list of nodes, got: %w", err)
142+
return nil, fmt.Errorf("can't get the list of nodes, got: %w", err)
143143
}
144144

145145
return nodes.Items, nil
146146
}
147147

148+
func getTLSParamTriplet(
149+
ctx context.Context, client kubernetes.Interface, params *troubleshootv1beta2.TLSParams,
150+
) (string, string, string, error) {
151+
var caCert, clientCert, clientKey string
152+
if params.Secret != nil {
153+
var err error
154+
caCert, clientCert, clientKey, err = getTLSParamsFromSecret(ctx, client, params.Secret)
155+
if err != nil {
156+
return caCert, clientCert, clientKey, err
157+
}
158+
} else {
159+
caCert = params.CACert
160+
clientCert = params.ClientCert
161+
clientKey = params.ClientKey
162+
}
163+
return caCert, clientCert, clientKey, nil
164+
}
165+
148166
func createTLSConfig(ctx context.Context, client kubernetes.Interface, params *troubleshootv1beta2.TLSParams) (*tls.Config, error) {
149167
rootCA, err := x509.SystemCertPool()
150168
if err != nil {
@@ -158,21 +176,15 @@ func createTLSConfig(ctx context.Context, client kubernetes.Interface, params *t
158176
return tlsCfg, nil
159177
}
160178

161-
var caCert, clientCert, clientKey string
162-
if params.Secret != nil {
163-
caCert, clientCert, clientKey, err = getTLSParamsFromSecret(ctx, client, params.Secret)
164-
if err != nil {
165-
return nil, err
166-
}
167-
} else {
168-
caCert = params.CACert
169-
clientCert = params.ClientCert
170-
clientKey = params.ClientKey
179+
caCert, clientCert, clientKey, err := getTLSParamTriplet(ctx, client, params)
180+
if err != nil {
181+
return nil, err
171182
}
172183

173184
if ok := rootCA.AppendCertsFromPEM([]byte(caCert)); !ok {
174185
return nil, fmt.Errorf("failed to append CA cert to root CA bundle")
175186
}
187+
176188
tlsCfg.RootCAs = rootCA
177189

178190
if clientCert == "" && clientKey == "" {

0 commit comments

Comments
 (0)