diff --git a/pkg/tempdb/factory.go b/pkg/tempdb/factory.go index 6311ae6..e65706a 100644 --- a/pkg/tempdb/factory.go +++ b/pkg/tempdb/factory.go @@ -62,6 +62,7 @@ type ( metadataTable string logger log.Logger rootDatabase string + dropTimeout time.Duration } OnInstanceFactoryOpt func(*onInstanceFactoryOptions) @@ -102,6 +103,13 @@ func WithRootDatabase(db string) OnInstanceFactoryOpt { } } +// WithDropTimeout sets the timeout used when dropping database +func WithDropTimeout(d time.Duration) OnInstanceFactoryOpt { + return func(opts *onInstanceFactoryOptions) { + opts.dropTimeout = d + } +} + type ( CreateConnPoolForDbFn func(ctx context.Context, dbName string) (*sql.DB, error) @@ -129,6 +137,7 @@ func NewOnInstanceFactory(ctx context.Context, createConnPoolForDb CreateConnPoo dbPrefix: DefaultOnInstanceDbPrefix, metadataSchema: DefaultOnInstanceMetadataSchema, metadataTable: DefaultOnInstanceMetadataTable, + dropTimeout: DefaultStatementTimeout, rootDatabase: "postgres", logger: log.SimpleLogger(), } @@ -259,6 +268,10 @@ func (o *onInstanceFactory) dropTempDatabase(ctx context.Context, dbName string) } defer rootConn.Close() + if _, err := rootConn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d;", o.options.dropTimeout.Milliseconds())); err != nil { + return fmt.Errorf("setting statement timeout: %w", err) + } + _, err = rootConn.ExecContext(ctx, fmt.Sprintf("DROP DATABASE %s;", dbName)) if err != nil { return fmt.Errorf("dropping temporary database: %w", err)