Skip to content

Commit 1f7636d

Browse files
authored
(fix) Remove db_type parameter Postgres detector during making connection (#3989)
* remove db_type from configuration parameter * fix integration tests * use single context for all tests instead of having separate for each
1 parent c293bb2 commit 1f7636d

File tree

2 files changed

+54
-47
lines changed

2 files changed

+54
-47
lines changed

pkg/detectors/postgres/postgres.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
129129
// parameters themselves.
130130
if timeout, ok := getDeadlineInSeconds(ctx); ok && timeout > 0 {
131131
params[pgConnectTimeout] = strconv.Itoa(timeout)
132-
} else if timeout <= 0 {
132+
} else if ok && timeout <= 0 {
133133
// Deadline in the context has already exceeded.
134134
break
135135
}
@@ -223,6 +223,15 @@ func verifyPostgres(params map[string]string) (bool, error) {
223223
}()
224224
}
225225

226+
// db_type is not a valid configuration parameter, so we remove it before connecting.
227+
dbType := params[pgDbType]
228+
delete(params, pgDbType)
229+
230+
// we re-add it before returning to preserve in ExtraData
231+
defer func() {
232+
params[pgDbType] = dbType
233+
}()
234+
226235
var connStr string
227236
for key, value := range params {
228237
connStr += fmt.Sprintf("%s='%s'", key, value)

pkg/detectors/postgres/postgres_integration_test.go

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ const (
3434
)
3535

3636
func TestPostgres_FromChunk(t *testing.T) {
37+
ctx := context.Background()
3738
if err := startPostgres(); err != nil {
3839
if exitErr, ok := err.(*exec.ExitError); ok {
3940
t.Fatalf("could not start local postgres: %v w/stderr:\n%s", err, string(exitErr.Stderr))
@@ -58,6 +59,10 @@ func TestPostgres_FromChunk(t *testing.T) {
5859
ctx context.Context
5960
data []byte
6061
verify bool
62+
63+
// For tests that require a timeout context, in which case the above ctx will be ignored and new ctx at the
64+
// time of test execution will be created
65+
requiresTimeoutContext bool
6166
}
6267
tests := []struct {
6368
name string
@@ -70,7 +75,7 @@ func TestPostgres_FromChunk(t *testing.T) {
7075
name: "not found",
7176
s: Scanner{},
7277
args: args{
73-
ctx: context.Background(),
78+
ctx: ctx,
7479
data: []byte("You cannot find the secret within"),
7580
verify: true,
7681
},
@@ -81,7 +86,7 @@ func TestPostgres_FromChunk(t *testing.T) {
8186
name: "found connection URI with ssl mode unset, verified",
8287
s: Scanner{detectLoopback: true},
8388
args: args{
84-
ctx: context.Background(),
89+
ctx: ctx,
8590
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2`, postgresUser, postgresPass, postgresHost, postgresPort)),
8691
verify: true,
8792
},
@@ -100,7 +105,7 @@ func TestPostgres_FromChunk(t *testing.T) {
100105
name: "found connection URI with ssl mode 'prefer', verified",
101106
s: Scanner{detectLoopback: true},
102107
args: args{
103-
ctx: context.Background(),
108+
ctx: ctx,
104109
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?sslmode=prefer`, postgresUser, postgresPass, postgresHost, postgresPort)),
105110
verify: true,
106111
},
@@ -119,7 +124,7 @@ func TestPostgres_FromChunk(t *testing.T) {
119124
name: "found connection URI with ssl mode 'allow', verified",
120125
s: Scanner{detectLoopback: true},
121126
args: args{
122-
ctx: context.Background(),
127+
ctx: ctx,
123128
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?sslmode=allow`, postgresUser, postgresPass, postgresHost, postgresPort)),
124129
verify: true,
125130
},
@@ -138,7 +143,7 @@ func TestPostgres_FromChunk(t *testing.T) {
138143
name: "found connection URI with requiressl=0, verified",
139144
s: Scanner{detectLoopback: true},
140145
args: args{
141-
ctx: context.Background(),
146+
ctx: ctx,
142147
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?requiressl=0`, postgresUser, postgresPass, postgresHost, postgresPort)),
143148
verify: true,
144149
},
@@ -157,7 +162,7 @@ func TestPostgres_FromChunk(t *testing.T) {
157162
name: "found connection URI without database, verified",
158163
s: Scanner{detectLoopback: true},
159164
args: args{
160-
ctx: context.Background(),
165+
ctx: ctx,
161166
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/`, postgresUser, postgresPass, postgresHost, postgresPort)),
162167
verify: true,
163168
},
@@ -176,7 +181,7 @@ func TestPostgres_FromChunk(t *testing.T) {
176181
name: "found connection URI, unverified",
177182
s: Scanner{detectLoopback: true},
178183
args: args{
179-
ctx: context.Background(),
184+
ctx: ctx,
180185
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2`, postgresUser, inactivePass, postgresHost, postgresPort)),
181186
verify: true,
182187
},
@@ -195,7 +200,7 @@ func TestPostgres_FromChunk(t *testing.T) {
195200
name: "ignored localhost",
196201
s: Scanner{},
197202
args: args{
198-
ctx: context.Background(),
203+
ctx: ctx,
199204
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2`, postgresUser, postgresPass, "localhost", postgresPort)),
200205
verify: true,
201206
},
@@ -206,7 +211,7 @@ func TestPostgres_FromChunk(t *testing.T) {
206211
name: "ignored 127.0.0.1",
207212
s: Scanner{},
208213
args: args{
209-
ctx: context.Background(),
214+
ctx: ctx,
210215
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2`, postgresUser, postgresPass, "127.0.0.1", postgresPort)),
211216
verify: true,
212217
},
@@ -216,15 +221,12 @@ func TestPostgres_FromChunk(t *testing.T) {
216221
{
217222
name: "found connection URI, unverified due to error - inactive host",
218223
s: Scanner{},
219-
args: func() args {
220-
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
221-
defer cancel()
222-
return args{
223-
ctx: ctx,
224-
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2`, postgresUser, postgresPass, inactiveHost, postgresPort)),
225-
verify: true,
226-
}
227-
}(),
224+
args: args{
225+
ctx: ctx,
226+
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2`, postgresUser, postgresPass, inactiveHost, postgresPort)),
227+
verify: true,
228+
requiresTimeoutContext: true,
229+
},
228230
want: func() []detectors.Result {
229231
r := detectors.Result{
230232
DetectorType: detectorspb.DetectorType_Postgres,
@@ -241,15 +243,12 @@ func TestPostgres_FromChunk(t *testing.T) {
241243
{
242244
name: "found connection URI, unverified due to error - wrong port",
243245
s: Scanner{detectLoopback: true},
244-
args: func() args {
245-
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
246-
defer cancel()
247-
return args{
248-
ctx: ctx,
249-
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s/postgres2`, postgresUser, postgresPass, postgresHost)),
250-
verify: true,
251-
}
252-
}(),
246+
args: args{
247+
ctx: ctx,
248+
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s/postgres2`, postgresUser, postgresPass, postgresHost)),
249+
verify: true,
250+
requiresTimeoutContext: true,
251+
},
253252
want: func() []detectors.Result {
254253
r := detectors.Result{
255254
DetectorType: detectorspb.DetectorType_Postgres,
@@ -266,15 +265,12 @@ func TestPostgres_FromChunk(t *testing.T) {
266265
{
267266
name: "found connection URI, unverified due to error - ssl not supported (using sslmode)",
268267
s: Scanner{detectLoopback: true},
269-
args: func() args {
270-
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
271-
defer cancel()
272-
return args{
273-
ctx: ctx,
274-
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?sslmode=require`, postgresUser, postgresPass, postgresHost, postgresPort)),
275-
verify: true,
276-
}
277-
}(),
268+
args: args{
269+
ctx: ctx,
270+
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?sslmode=require`, postgresUser, postgresPass, postgresHost, postgresPort)),
271+
verify: true,
272+
requiresTimeoutContext: true,
273+
},
278274
want: func() []detectors.Result {
279275
r := detectors.Result{
280276
DetectorType: detectorspb.DetectorType_Postgres,
@@ -291,15 +287,11 @@ func TestPostgres_FromChunk(t *testing.T) {
291287
{
292288
name: "found connection URI, unverified due to error - ssl not supported (using requiressl)",
293289
s: Scanner{detectLoopback: true},
294-
args: func() args {
295-
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
296-
defer cancel()
297-
return args{
298-
ctx: ctx,
299-
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?requiressl=1`, postgresUser, postgresPass, postgresHost, postgresPort)),
300-
verify: true,
301-
}
302-
}(),
290+
args: args{
291+
ctx: ctx,
292+
data: []byte(fmt.Sprintf(`postgresql://%s:%s@%s:%s/postgres2?requiressl=1`, postgresUser, postgresPass, postgresHost, postgresPort)),
293+
verify: true,
294+
},
303295
want: func() []detectors.Result {
304296
r := detectors.Result{
305297
DetectorType: detectorspb.DetectorType_Postgres,
@@ -316,7 +308,13 @@ func TestPostgres_FromChunk(t *testing.T) {
316308
}
317309
for _, tt := range tests {
318310
t.Run(tt.name, func(t *testing.T) {
319-
got, err := tt.s.FromData(tt.args.ctx, tt.args.verify, tt.args.data)
311+
ctx := tt.args.ctx
312+
var cancel context.CancelFunc
313+
if tt.args.requiresTimeoutContext {
314+
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
315+
defer cancel()
316+
}
317+
got, err := tt.s.FromData(ctx, tt.args.verify, tt.args.data)
320318
if (err != nil) != tt.wantErr {
321319
t.Errorf("postgres.FromData() error = %v, wantErr %v", err, tt.wantErr)
322320
return

0 commit comments

Comments
 (0)