Skip to content

Commit 3148eba

Browse files
feat: add db postgres settings to config (#2787)
* feat: add db postgres settings to config * no comment * fix: add updater logic * chore: fix golang lint * chore: apply review comments * chore: refactor and apply PR comments * Apply suggestions from code review Co-authored-by: Han Qiao <[email protected]> --------- Co-authored-by: Han Qiao <[email protected]>
1 parent 0377d47 commit 3148eba

File tree

5 files changed

+376
-38
lines changed

5 files changed

+376
-38
lines changed

pkg/cast/cast.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ func IntToUint(value int) uint {
2020
return uint(value)
2121
}
2222

23+
func UintToIntPtr(value *uint) *int {
24+
if value == nil {
25+
return nil
26+
}
27+
return Ptr(UintToInt(*value))
28+
}
29+
30+
func IntToUintPtr(value *int) *uint {
31+
if value == nil {
32+
return nil
33+
}
34+
return Ptr(IntToUint(*value))
35+
}
36+
2337
func Ptr[T any](v T) *T {
2438
return &v
2539
}

pkg/config/config.go

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@ const (
5757
LogflareBigQuery LogflareBackend = "bigquery"
5858
)
5959

60-
type PoolMode string
61-
62-
const (
63-
TransactionMode PoolMode = "transaction"
64-
SessionMode PoolMode = "session"
65-
)
66-
6760
type AddressFamily string
6861

6962
const (
@@ -146,36 +139,6 @@ type (
146139
Remotes map[string]baseConfig `toml:"-"`
147140
}
148141

149-
db struct {
150-
Image string `toml:"-"`
151-
Port uint16 `toml:"port"`
152-
ShadowPort uint16 `toml:"shadow_port"`
153-
MajorVersion uint `toml:"major_version"`
154-
Password string `toml:"-"`
155-
RootKey string `toml:"-" mapstructure:"root_key"`
156-
Pooler pooler `toml:"pooler"`
157-
Seed seed `toml:"seed"`
158-
}
159-
160-
seed struct {
161-
Enabled bool `toml:"enabled"`
162-
GlobPatterns []string `toml:"sql_paths"`
163-
SqlPaths []string `toml:"-"`
164-
}
165-
166-
pooler struct {
167-
Enabled bool `toml:"enabled"`
168-
Image string `toml:"-"`
169-
Port uint16 `toml:"port"`
170-
PoolMode PoolMode `toml:"pool_mode"`
171-
DefaultPoolSize uint `toml:"default_pool_size"`
172-
MaxClientConn uint `toml:"max_client_conn"`
173-
ConnectionString string `toml:"-"`
174-
TenantId string `toml:"-"`
175-
EncryptionKey string `toml:"-"`
176-
SecretKeyBase string `toml:"-"`
177-
}
178-
179142
realtime struct {
180143
Enabled bool `toml:"enabled"`
181144
Image string `toml:"-"`
@@ -775,6 +738,12 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
775738
}
776739
}
777740
// Validate db config
741+
if c.Db.Settings.SessionReplicationRole != nil {
742+
allowedRoles := []SessionReplicationRole{SessionReplicationRoleOrigin, SessionReplicationRoleReplica, SessionReplicationRoleLocal}
743+
if !sliceContains(allowedRoles, *c.Db.Settings.SessionReplicationRole) {
744+
return errors.Errorf("Invalid config for db.session_replication_role: %s. Must be one of: %v", *c.Db.Settings.SessionReplicationRole, allowedRoles)
745+
}
746+
}
778747
if c.Db.Port == 0 {
779748
return errors.New("Missing required field in config: db.port")
780749
}

pkg/config/db.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package config
2+
3+
import (
4+
"github.com/google/go-cmp/cmp"
5+
v1API "github.com/supabase/cli/pkg/api"
6+
"github.com/supabase/cli/pkg/cast"
7+
"github.com/supabase/cli/pkg/diff"
8+
)
9+
10+
type PoolMode string
11+
12+
const (
13+
TransactionMode PoolMode = "transaction"
14+
SessionMode PoolMode = "session"
15+
)
16+
17+
type SessionReplicationRole string
18+
19+
const (
20+
SessionReplicationRoleOrigin SessionReplicationRole = "origin"
21+
SessionReplicationRoleReplica SessionReplicationRole = "replica"
22+
SessionReplicationRoleLocal SessionReplicationRole = "local"
23+
)
24+
25+
type (
26+
settings struct {
27+
EffectiveCacheSize *string `toml:"effective_cache_size"`
28+
LogicalDecodingWorkMem *string `toml:"logical_decoding_work_mem"`
29+
MaintenanceWorkMem *string `toml:"maintenance_work_mem"`
30+
MaxConnections *uint `toml:"max_connections"`
31+
MaxLocksPerTransaction *uint `toml:"max_locks_per_transaction"`
32+
MaxParallelMaintenanceWorkers *uint `toml:"max_parallel_maintenance_workers"`
33+
MaxParallelWorkers *uint `toml:"max_parallel_workers"`
34+
MaxParallelWorkersPerGather *uint `toml:"max_parallel_workers_per_gather"`
35+
MaxReplicationSlots *uint `toml:"max_replication_slots"`
36+
MaxSlotWalKeepSize *string `toml:"max_slot_wal_keep_size"`
37+
MaxStandbyArchiveDelay *string `toml:"max_standby_archive_delay"`
38+
MaxStandbyStreamingDelay *string `toml:"max_standby_streaming_delay"`
39+
MaxWalSize *string `toml:"max_wal_size"`
40+
MaxWalSenders *uint `toml:"max_wal_senders"`
41+
MaxWorkerProcesses *uint `toml:"max_worker_processes"`
42+
SessionReplicationRole *SessionReplicationRole `toml:"session_replication_role"`
43+
SharedBuffers *string `toml:"shared_buffers"`
44+
StatementTimeout *string `toml:"statement_timeout"`
45+
WalKeepSize *string `toml:"wal_keep_size"`
46+
WalSenderTimeout *string `toml:"wal_sender_timeout"`
47+
WorkMem *string `toml:"work_mem"`
48+
}
49+
50+
db struct {
51+
Image string `toml:"-"`
52+
Port uint16 `toml:"port"`
53+
ShadowPort uint16 `toml:"shadow_port"`
54+
MajorVersion uint `toml:"major_version"`
55+
Password string `toml:"-"`
56+
RootKey string `toml:"-" mapstructure:"root_key"`
57+
Pooler pooler `toml:"pooler"`
58+
Seed seed `toml:"seed"`
59+
Settings settings `toml:"settings"`
60+
}
61+
62+
seed struct {
63+
Enabled bool `toml:"enabled"`
64+
GlobPatterns []string `toml:"sql_paths"`
65+
SqlPaths []string `toml:"-"`
66+
}
67+
68+
pooler struct {
69+
Enabled bool `toml:"enabled"`
70+
Image string `toml:"-"`
71+
Port uint16 `toml:"port"`
72+
PoolMode PoolMode `toml:"pool_mode"`
73+
DefaultPoolSize uint `toml:"default_pool_size"`
74+
MaxClientConn uint `toml:"max_client_conn"`
75+
ConnectionString string `toml:"-"`
76+
TenantId string `toml:"-"`
77+
EncryptionKey string `toml:"-"`
78+
SecretKeyBase string `toml:"-"`
79+
}
80+
)
81+
82+
// Compare two db config, if changes requires restart return true, return false otherwise
83+
func (a settings) requireDbRestart(b settings) bool {
84+
return !cmp.Equal(a.MaxConnections, b.MaxConnections) ||
85+
!cmp.Equal(a.MaxWorkerProcesses, b.MaxWorkerProcesses) ||
86+
!cmp.Equal(a.MaxParallelWorkers, b.MaxParallelWorkers) ||
87+
!cmp.Equal(a.MaxWalSenders, b.MaxWalSenders) ||
88+
!cmp.Equal(a.MaxReplicationSlots, b.MaxReplicationSlots) ||
89+
!cmp.Equal(a.SharedBuffers, b.SharedBuffers)
90+
}
91+
92+
func (a *settings) ToUpdatePostgresConfigBody() v1API.UpdatePostgresConfigBody {
93+
body := v1API.UpdatePostgresConfigBody{}
94+
95+
// Parameters that require restart
96+
body.MaxConnections = cast.UintToIntPtr(a.MaxConnections)
97+
body.MaxWorkerProcesses = cast.UintToIntPtr(a.MaxWorkerProcesses)
98+
body.MaxParallelWorkers = cast.UintToIntPtr(a.MaxParallelWorkers)
99+
body.MaxWalSenders = cast.UintToIntPtr(a.MaxWalSenders)
100+
body.MaxReplicationSlots = cast.UintToIntPtr(a.MaxReplicationSlots)
101+
body.SharedBuffers = a.SharedBuffers
102+
103+
// Parameters that can be changed without restart
104+
body.EffectiveCacheSize = a.EffectiveCacheSize
105+
body.LogicalDecodingWorkMem = a.LogicalDecodingWorkMem
106+
body.MaintenanceWorkMem = a.MaintenanceWorkMem
107+
body.MaxLocksPerTransaction = cast.UintToIntPtr(a.MaxLocksPerTransaction)
108+
body.MaxParallelMaintenanceWorkers = cast.UintToIntPtr(a.MaxParallelMaintenanceWorkers)
109+
body.MaxParallelWorkersPerGather = cast.UintToIntPtr(a.MaxParallelWorkersPerGather)
110+
body.MaxSlotWalKeepSize = a.MaxSlotWalKeepSize
111+
body.MaxStandbyArchiveDelay = a.MaxStandbyArchiveDelay
112+
body.MaxStandbyStreamingDelay = a.MaxStandbyStreamingDelay
113+
body.MaxWalSize = a.MaxWalSize
114+
body.SessionReplicationRole = (*v1API.UpdatePostgresConfigBodySessionReplicationRole)(a.SessionReplicationRole)
115+
body.StatementTimeout = a.StatementTimeout
116+
body.WalKeepSize = a.WalKeepSize
117+
body.WalSenderTimeout = a.WalSenderTimeout
118+
body.WorkMem = a.WorkMem
119+
return body
120+
}
121+
122+
func (a *settings) fromRemoteConfig(remoteConfig v1API.PostgresConfigResponse) settings {
123+
result := *a
124+
125+
result.EffectiveCacheSize = remoteConfig.EffectiveCacheSize
126+
result.LogicalDecodingWorkMem = remoteConfig.LogicalDecodingWorkMem
127+
result.MaintenanceWorkMem = remoteConfig.MaintenanceWorkMem
128+
result.MaxConnections = cast.IntToUintPtr(remoteConfig.MaxConnections)
129+
result.MaxLocksPerTransaction = cast.IntToUintPtr(remoteConfig.MaxLocksPerTransaction)
130+
result.MaxParallelMaintenanceWorkers = cast.IntToUintPtr(remoteConfig.MaxParallelMaintenanceWorkers)
131+
result.MaxParallelWorkers = cast.IntToUintPtr(remoteConfig.MaxParallelWorkers)
132+
result.MaxParallelWorkersPerGather = cast.IntToUintPtr(remoteConfig.MaxParallelWorkersPerGather)
133+
result.MaxReplicationSlots = cast.IntToUintPtr(remoteConfig.MaxReplicationSlots)
134+
result.MaxSlotWalKeepSize = remoteConfig.MaxSlotWalKeepSize
135+
result.MaxStandbyArchiveDelay = remoteConfig.MaxStandbyArchiveDelay
136+
result.MaxStandbyStreamingDelay = remoteConfig.MaxStandbyStreamingDelay
137+
result.MaxWalSenders = cast.IntToUintPtr(remoteConfig.MaxWalSenders)
138+
result.MaxWalSize = remoteConfig.MaxWalSize
139+
result.MaxWorkerProcesses = cast.IntToUintPtr(remoteConfig.MaxWorkerProcesses)
140+
result.SessionReplicationRole = (*SessionReplicationRole)(remoteConfig.SessionReplicationRole)
141+
result.SharedBuffers = remoteConfig.SharedBuffers
142+
result.StatementTimeout = remoteConfig.StatementTimeout
143+
result.WalKeepSize = remoteConfig.WalKeepSize
144+
result.WalSenderTimeout = remoteConfig.WalSenderTimeout
145+
result.WorkMem = remoteConfig.WorkMem
146+
return result
147+
}
148+
149+
func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]byte, error) {
150+
// Convert the config values into easily comparable remoteConfig values
151+
currentValue, err := ToTomlBytes(a)
152+
if err != nil {
153+
return nil, err
154+
}
155+
remoteCompare, err := ToTomlBytes(a.fromRemoteConfig(remoteConfig))
156+
if err != nil {
157+
return nil, err
158+
}
159+
return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil
160+
}

0 commit comments

Comments
 (0)