Skip to content

Commit 1b58c87

Browse files
authored
feat: support env loading for all string fields (#3019)
* feat: support env loading for all string fields * chore: update unit tests * chore: assert env loaded for backwards compatibility
1 parent 805dcd0 commit 1b58c87

File tree

4 files changed

+113
-69
lines changed

4 files changed

+113
-69
lines changed

internal/db/branch/switch_/switch__test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func TestSwitchCommand(t *testing.T) {
7979
// Run test
8080
err := Run(context.Background(), "target", fsys)
8181
// Check error
82-
assert.ErrorContains(t, err, "toml: line 0: unexpected EOF; expected key separator '='")
82+
assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there")
8383
})
8484

8585
t.Run("throws error on missing database", func(t *testing.T) {

internal/start/start_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func TestStartCommand(t *testing.T) {
3737
// Run test
3838
err := Run(context.Background(), fsys, []string{}, false)
3939
// Check error
40-
assert.ErrorContains(t, err, "toml: line 0: unexpected EOF; expected key separator '='")
40+
assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there")
4141
})
4242

4343
t.Run("throws error on missing docker", func(t *testing.T) {

internal/status/status_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func TestStatusCommand(t *testing.T) {
5959
// Run test
6060
err := Run(context.Background(), CustomName{}, utils.OutputPretty, fsys)
6161
// Check error
62-
assert.ErrorContains(t, err, "toml: line 0: unexpected EOF; expected key separator '='")
62+
assert.ErrorContains(t, err, "toml: expected = after a key, but the document ends there")
6363
})
6464

6565
t.Run("throws error on missing docker", func(t *testing.T) {

pkg/config/config.go

Lines changed: 110 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"os"
1717
"path"
1818
"path/filepath"
19+
"reflect"
1920
"regexp"
2021
"sort"
2122
"strconv"
@@ -378,9 +379,64 @@ func (c *config) Eject(w io.Writer) error {
378379
return nil
379380
}
380381

382+
// Loads custom config file to struct fields tagged with toml.
383+
func (c *config) loadFromFile(filename string, fsys fs.FS) error {
384+
v := viper.New()
385+
v.SetConfigType("toml")
386+
// Load default values
387+
var buf bytes.Buffer
388+
if err := initConfigTemplate.Option("missingkey=zero").Execute(&buf, c); err != nil {
389+
return errors.Errorf("failed to initialise template config: %w", err)
390+
} else if err := c.loadFromReader(v, &buf); err != nil {
391+
return err
392+
}
393+
// Load custom config
394+
if ext := filepath.Ext(filename); len(ext) > 0 {
395+
v.SetConfigType(ext[1:])
396+
}
397+
f, err := fsys.Open(filename)
398+
if err != nil {
399+
return errors.Errorf("failed to read file config: %w", err)
400+
}
401+
defer f.Close()
402+
return c.loadFromReader(v, f)
403+
}
404+
405+
func (c *config) loadFromReader(v *viper.Viper, r io.Reader) error {
406+
if err := v.MergeConfig(r); err != nil {
407+
return errors.Errorf("failed to merge config: %w", err)
408+
}
409+
// Manually parse [functions.*] to empty struct for backwards compatibility
410+
for key, value := range v.GetStringMap("functions") {
411+
if m, ok := value.(map[string]any); ok && len(m) == 0 {
412+
v.Set("functions."+key, function{})
413+
}
414+
}
415+
if err := v.UnmarshalExact(c, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
416+
mapstructure.StringToTimeDurationHookFunc(),
417+
mapstructure.StringToIPHookFunc(),
418+
mapstructure.StringToSliceHookFunc(","),
419+
mapstructure.TextUnmarshallerHookFunc(),
420+
LoadEnvHook,
421+
// TODO: include decrypt secret hook
422+
)), func(dc *mapstructure.DecoderConfig) {
423+
dc.TagName = "toml"
424+
dc.Squash = true
425+
}); err != nil {
426+
return errors.Errorf("failed to parse config: %w", err)
427+
}
428+
return nil
429+
}
430+
431+
// Loads envs prefixed with supabase_ to struct fields tagged with mapstructure.
381432
func (c *config) loadFromEnv() error {
382-
// Allow overriding base config object with automatic env
383-
// Ref: https://github.com/spf13/viper/issues/761
433+
v := viper.New()
434+
v.SetEnvPrefix("SUPABASE")
435+
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
436+
v.AutomaticEnv()
437+
// Viper does not parse env vars automatically. Instead of calling viper.BindEnv
438+
// per key, we decode all keys from an existing struct, and merge them to viper.
439+
// Ref: https://github.com/spf13/viper/issues/761#issuecomment-859306364
384440
envKeysMap := map[string]interface{}{}
385441
if dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
386442
Result: &envKeysMap,
@@ -389,47 +445,32 @@ func (c *config) loadFromEnv() error {
389445
return errors.Errorf("failed to create decoder: %w", err)
390446
} else if err := dec.Decode(c.baseConfig); err != nil {
391447
return errors.Errorf("failed to decode env: %w", err)
392-
}
393-
v := viper.New()
394-
v.SetEnvPrefix("SUPABASE")
395-
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
396-
v.AutomaticEnv()
397-
if err := v.MergeConfigMap(envKeysMap); err != nil {
398-
return errors.Errorf("failed to merge config: %w", err)
399-
} else if err := v.Unmarshal(c); err != nil {
400-
return errors.Errorf("failed to parse env to config: %w", err)
448+
} else if err := v.MergeConfigMap(envKeysMap); err != nil {
449+
return errors.Errorf("failed to merge env config: %w", err)
450+
}
451+
// Writes viper state back to config struct, with automatic env substitution
452+
if err := v.UnmarshalExact(c, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
453+
mapstructure.StringToTimeDurationHookFunc(),
454+
mapstructure.StringToIPHookFunc(),
455+
mapstructure.StringToSliceHookFunc(","),
456+
mapstructure.TextUnmarshallerHookFunc(),
457+
// TODO: include decrypt secret hook
458+
))); err != nil {
459+
return errors.Errorf("failed to parse env override: %w", err)
401460
}
402461
return nil
403462
}
404463

405464
func (c *config) Load(path string, fsys fs.FS) error {
406465
builder := NewPathBuilder(path)
407-
// Load default values
408-
var buf bytes.Buffer
409-
if err := initConfigTemplate.Option("missingkey=zero").Execute(&buf, c); err != nil {
410-
return errors.Errorf("failed to initialise config template: %w", err)
411-
}
412-
dec := toml.NewDecoder(&buf)
413-
if _, err := dec.Decode(c); err != nil {
414-
return errors.Errorf("failed to decode config template: %w", err)
415-
}
416-
if metadata, err := toml.DecodeFS(fsys, builder.ConfigPath, c); err != nil {
417-
cwd, osErr := os.Getwd()
418-
if osErr != nil {
419-
cwd = "current directory"
420-
}
421-
return errors.Errorf("cannot read config in %s: %w", cwd, err)
422-
} else if undecoded := metadata.Undecoded(); len(undecoded) > 0 {
423-
for _, key := range undecoded {
424-
if key[0] != "remotes" {
425-
fmt.Fprintf(os.Stderr, "Unknown config field: [%s]\n", key)
426-
}
427-
}
428-
}
429466
// Load secrets from .env file
430467
if err := loadDefaultEnv(); err != nil {
431468
return err
432-
} else if err := c.loadFromEnv(); err != nil {
469+
}
470+
if err := c.loadFromFile(builder.ConfigPath, fsys); err != nil {
471+
return err
472+
}
473+
if err := c.loadFromEnv(); err != nil {
433474
return err
434475
}
435476
// Generate JWT tokens
@@ -619,17 +660,16 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
619660
case 15:
620661
if len(c.Experimental.OrioleDBVersion) > 0 {
621662
c.Db.Image = "supabase/postgres:orioledb-" + c.Experimental.OrioleDBVersion
622-
var err error
623-
if c.Experimental.S3Host, err = maybeLoadEnv(c.Experimental.S3Host); err != nil {
663+
if err := assertEnvLoaded(c.Experimental.S3Host); err != nil {
624664
return err
625665
}
626-
if c.Experimental.S3Region, err = maybeLoadEnv(c.Experimental.S3Region); err != nil {
666+
if err := assertEnvLoaded(c.Experimental.S3Region); err != nil {
627667
return err
628668
}
629-
if c.Experimental.S3AccessKey, err = maybeLoadEnv(c.Experimental.S3AccessKey); err != nil {
669+
if err := assertEnvLoaded(c.Experimental.S3AccessKey); err != nil {
630670
return err
631671
}
632-
if c.Experimental.S3SecretKey, err = maybeLoadEnv(c.Experimental.S3SecretKey); err != nil {
672+
if err := assertEnvLoaded(c.Experimental.S3SecretKey); err != nil {
633673
return err
634674
}
635675
}
@@ -666,7 +706,6 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
666706
} else if parsed.Host == "" || parsed.Host == c.Hostname {
667707
c.Studio.ApiUrl = c.Api.ExternalUrl
668708
}
669-
c.Studio.OpenaiApiKey, _ = maybeLoadEnv(c.Studio.OpenaiApiKey)
670709
}
671710
// Validate smtp config
672711
if c.Inbucket.Enabled {
@@ -679,12 +718,11 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
679718
if c.Auth.SiteUrl == "" {
680719
return errors.New("Missing required field in config: auth.site_url")
681720
}
682-
var err error
683-
if c.Auth.SiteUrl, err = maybeLoadEnv(c.Auth.SiteUrl); err != nil {
721+
if err := assertEnvLoaded(c.Auth.SiteUrl); err != nil {
684722
return err
685723
}
686724
for i, url := range c.Auth.AdditionalRedirectUrls {
687-
if c.Auth.AdditionalRedirectUrls[i], err = maybeLoadEnv(url); err != nil {
725+
if err := assertEnvLoaded(url); err != nil {
688726
return errors.Errorf("Invalid config for auth.additional_redirect_urls[%d]: %v", i, err)
689727
}
690728
}
@@ -749,18 +787,24 @@ func (c *baseConfig) Validate(fsys fs.FS) error {
749787
return nil
750788
}
751789

752-
func maybeLoadEnv(s string) (string, error) {
753-
matches := envPattern.FindStringSubmatch(s)
754-
if len(matches) == 0 {
755-
return s, nil
790+
func assertEnvLoaded(s string) error {
791+
if matches := envPattern.FindStringSubmatch(s); len(matches) > 1 {
792+
return errors.Errorf(`Error evaluating "%s": environment variable %s is unset.`, s, matches[1])
756793
}
794+
return nil
795+
}
757796

758-
envName := matches[1]
759-
if value := os.Getenv(envName); value != "" {
760-
return value, nil
797+
func LoadEnvHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
798+
if f != reflect.String || t != reflect.String {
799+
return data, nil
761800
}
762-
763-
return "", errors.Errorf(`Error evaluating "%s": environment variable %s is unset.`, s, envName)
801+
value := data.(string)
802+
if matches := envPattern.FindStringSubmatch(value); len(matches) > 1 {
803+
if v, exists := os.LookupEnv(matches[1]); exists {
804+
value = v
805+
}
806+
}
807+
return value, nil
764808
}
765809

766810
func truncateText(text string, maxLen int) string {
@@ -874,7 +918,7 @@ func (e *email) validate(fsys fs.FS) (err error) {
874918
if len(e.Smtp.AdminEmail) == 0 {
875919
return errors.New("Missing required field in config: auth.email.smtp.admin_email")
876920
}
877-
if e.Smtp.Pass, err = maybeLoadEnv(e.Smtp.Pass); err != nil {
921+
if err := assertEnvLoaded(e.Smtp.Pass); err != nil {
878922
return err
879923
}
880924
}
@@ -893,7 +937,7 @@ func (s *sms) validate() (err error) {
893937
if len(s.Twilio.AuthToken) == 0 {
894938
return errors.New("Missing required field in config: auth.sms.twilio.auth_token")
895939
}
896-
if s.Twilio.AuthToken, err = maybeLoadEnv(s.Twilio.AuthToken); err != nil {
940+
if err := assertEnvLoaded(s.Twilio.AuthToken); err != nil {
897941
return err
898942
}
899943
case s.TwilioVerify.Enabled:
@@ -906,7 +950,7 @@ func (s *sms) validate() (err error) {
906950
if len(s.TwilioVerify.AuthToken) == 0 {
907951
return errors.New("Missing required field in config: auth.sms.twilio_verify.auth_token")
908952
}
909-
if s.TwilioVerify.AuthToken, err = maybeLoadEnv(s.TwilioVerify.AuthToken); err != nil {
953+
if err := assertEnvLoaded(s.TwilioVerify.AuthToken); err != nil {
910954
return err
911955
}
912956
case s.Messagebird.Enabled:
@@ -916,7 +960,7 @@ func (s *sms) validate() (err error) {
916960
if len(s.Messagebird.AccessKey) == 0 {
917961
return errors.New("Missing required field in config: auth.sms.messagebird.access_key")
918962
}
919-
if s.Messagebird.AccessKey, err = maybeLoadEnv(s.Messagebird.AccessKey); err != nil {
963+
if err := assertEnvLoaded(s.Messagebird.AccessKey); err != nil {
920964
return err
921965
}
922966
case s.Textlocal.Enabled:
@@ -926,7 +970,7 @@ func (s *sms) validate() (err error) {
926970
if len(s.Textlocal.ApiKey) == 0 {
927971
return errors.New("Missing required field in config: auth.sms.textlocal.api_key")
928972
}
929-
if s.Textlocal.ApiKey, err = maybeLoadEnv(s.Textlocal.ApiKey); err != nil {
973+
if err := assertEnvLoaded(s.Textlocal.ApiKey); err != nil {
930974
return err
931975
}
932976
case s.Vonage.Enabled:
@@ -939,10 +983,10 @@ func (s *sms) validate() (err error) {
939983
if len(s.Vonage.ApiSecret) == 0 {
940984
return errors.New("Missing required field in config: auth.sms.vonage.api_secret")
941985
}
942-
if s.Vonage.ApiKey, err = maybeLoadEnv(s.Vonage.ApiKey); err != nil {
986+
if err := assertEnvLoaded(s.Vonage.ApiKey); err != nil {
943987
return err
944988
}
945-
if s.Vonage.ApiSecret, err = maybeLoadEnv(s.Vonage.ApiSecret); err != nil {
989+
if err := assertEnvLoaded(s.Vonage.ApiSecret); err != nil {
946990
return err
947991
}
948992
case s.EnableSignup:
@@ -969,16 +1013,16 @@ func (e external) validate() (err error) {
9691013
if !sliceContains([]string{"apple", "google"}, ext) && provider.Secret == "" {
9701014
return errors.Errorf("Missing required field in config: auth.external.%s.secret", ext)
9711015
}
972-
if provider.ClientId, err = maybeLoadEnv(provider.ClientId); err != nil {
1016+
if err := assertEnvLoaded(provider.ClientId); err != nil {
9731017
return err
9741018
}
975-
if provider.Secret, err = maybeLoadEnv(provider.Secret); err != nil {
1019+
if err := assertEnvLoaded(provider.Secret); err != nil {
9761020
return err
9771021
}
978-
if provider.RedirectUri, err = maybeLoadEnv(provider.RedirectUri); err != nil {
1022+
if err := assertEnvLoaded(provider.RedirectUri); err != nil {
9791023
return err
9801024
}
981-
if provider.Url, err = maybeLoadEnv(provider.Url); err != nil {
1025+
if err := assertEnvLoaded(provider.Url); err != nil {
9821026
return err
9831027
}
9841028
e[ext] = provider
@@ -1033,7 +1077,7 @@ func (h *hookConfig) validate(hookType string) (err error) {
10331077
case "http", "https":
10341078
if len(h.Secrets) == 0 {
10351079
return errors.Errorf("Missing required field in config: auth.hook.%s.secrets", hookType)
1036-
} else if h.Secrets, err = maybeLoadEnv(h.Secrets); err != nil {
1080+
} else if err := assertEnvLoaded(h.Secrets); err != nil {
10371081
return err
10381082
}
10391083
for _, secret := range strings.Split(h.Secrets, "|") {
@@ -1119,13 +1163,13 @@ func (c *tpaCognito) issuerURL() string {
11191163
func (c *tpaCognito) validate() (err error) {
11201164
if c.UserPoolID == "" {
11211165
return errors.New("Invalid config: auth.third_party.cognito is enabled but without a user_pool_id.")
1122-
} else if c.UserPoolID, err = maybeLoadEnv(c.UserPoolID); err != nil {
1166+
} else if err := assertEnvLoaded(c.UserPoolID); err != nil {
11231167
return err
11241168
}
11251169

11261170
if c.UserPoolRegion == "" {
11271171
return errors.New("Invalid config: auth.third_party.cognito is enabled but without a user_pool_region.")
1128-
} else if c.UserPoolRegion, err = maybeLoadEnv(c.UserPoolRegion); err != nil {
1172+
} else if err := assertEnvLoaded(c.UserPoolRegion); err != nil {
11291173
return err
11301174
}
11311175

0 commit comments

Comments
 (0)