diff --git a/cmd/gen.go b/cmd/gen.go index 8893988d4..b98dae7fc 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -3,6 +3,7 @@ package cmd import ( "os" "os/signal" + "time" env "github.com/Netflix/go-env" "github.com/go-errors/errors" @@ -56,6 +57,7 @@ var ( }, Value: types.LangTypescript, } + queryTimeout time.Duration postgrestV9Compat bool swiftAccessControl = utils.EnumFlag{ Allowed: []string{ @@ -88,7 +90,7 @@ var ( return err } } - return types.Run(ctx, flags.ProjectRef, flags.DbConfig, lang.Value, schema, postgrestV9Compat, swiftAccessControl.Value, afero.NewOsFs()) + return types.Run(ctx, flags.ProjectRef, flags.DbConfig, lang.Value, schema, postgrestV9Compat, swiftAccessControl.Value, queryTimeout, afero.NewOsFs()) }, Example: ` supabase gen types --local supabase gen types --linked --lang=go @@ -126,8 +128,13 @@ func init() { genTypesCmd.MarkFlagsMutuallyExclusive("local", "linked", "project-id", "db-url") typeFlags.Var(&lang, "lang", "Output language of the generated types.") typeFlags.StringSliceVarP(&schema, "schema", "s", []string{}, "Comma separated list of schema to include.") + // Direct connection only flags typeFlags.Var(&swiftAccessControl, "swift-access-control", "Access control for Swift generated types.") - typeFlags.BoolVar(&postgrestV9Compat, "postgrest-v9-compat", false, "Generate types compatible with PostgREST v9 and below. Only use together with --db-url.") + genTypesCmd.MarkFlagsMutuallyExclusive("linked", "project-id", "swift-access-control") + typeFlags.BoolVar(&postgrestV9Compat, "postgrest-v9-compat", false, "Generate types compatible with PostgREST v9 and below.") + genTypesCmd.MarkFlagsMutuallyExclusive("linked", "project-id", "postgrest-v9-compat") + typeFlags.DurationVar(&queryTimeout, "query-timeout", time.Second*15, "Maximum timeout allowed for the database query.") + genTypesCmd.MarkFlagsMutuallyExclusive("linked", "project-id", "query-timeout") genCmd.AddCommand(genTypesCmd) keyFlags := genKeysCmd.Flags() keyFlags.StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.") diff --git a/internal/db/diff/migra.go b/internal/db/diff/migra.go index 63e0566fd..34b7d579e 100644 --- a/internal/db/diff/migra.go +++ b/internal/db/diff/migra.go @@ -23,11 +23,6 @@ var ( //go:embed templates/migra.ts diffSchemaTypeScript string - //go:embed templates/staging-ca-2021.crt - caStaging string - //go:embed templates/prod-ca-2021.crt - caProd string - managedSchemas = []string{ // Local development "_analytics", @@ -107,12 +102,10 @@ func loadSchema(ctx context.Context, dbURL string, options ...func(*pgx.ConnConf func DiffSchemaMigra(ctx context.Context, source, target string, schema []string, options ...func(*pgx.ConnConfig)) (string, error) { env := []string{"SOURCE=" + source, "TARGET=" + target} - // node-postgres does not support sslmode=prefer - if require, err := types.IsRequireSSL(ctx, target, options...); err != nil { + if ca, err := types.GetRootCA(ctx, target, options...); err != nil { return "", err - } else if require { - rootCA := caStaging + caProd - env = append(env, "SSL_CA="+rootCA) + } else if len(ca) > 0 { + env = append(env, "SSL_CA="+ca) } if len(schema) > 0 { env = append(env, "INCLUDED_SCHEMAS="+strings.Join(schema, ",")) diff --git a/internal/db/diff/templates/prod-ca-2021.crt b/internal/gen/types/templates/prod-ca-2021.crt similarity index 100% rename from internal/db/diff/templates/prod-ca-2021.crt rename to internal/gen/types/templates/prod-ca-2021.crt diff --git a/internal/db/diff/templates/staging-ca-2021.crt b/internal/gen/types/templates/staging-ca-2021.crt similarity index 100% rename from internal/db/diff/templates/staging-ca-2021.crt rename to internal/gen/types/templates/staging-ca-2021.crt diff --git a/internal/gen/types/types.go b/internal/gen/types/types.go index 5aeeb1bd6..aa22c04fa 100644 --- a/internal/gen/types/types.go +++ b/internal/gen/types/types.go @@ -2,9 +2,11 @@ package types import ( "context" + _ "embed" "fmt" "os" "strings" + "time" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/network" @@ -28,7 +30,7 @@ const ( SwiftInternalAccessControl = "internal" ) -func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang string, schemas []string, postgrestV9Compat bool, swiftAccessControl string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang string, schemas []string, postgrestV9Compat bool, swiftAccessControl string, queryTimeout time.Duration, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { originalURL := utils.ToPostgresURL(dbConfig) // Add default schemas if --schema flag is not specified if len(schemas) == 0 { @@ -77,26 +79,27 @@ func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang str } fmt.Fprintln(os.Stderr, "Connecting to", dbConfig.Host, dbConfig.Port) - escaped := utils.ToPostgresURL(dbConfig) - if require, err := IsRequireSSL(ctx, originalURL, options...); err != nil { + env := []string{ + "PG_META_DB_URL=" + utils.ToPostgresURL(dbConfig), + fmt.Sprintf("PG_CONN_TIMEOUT_SECS=%.0f", queryTimeout.Seconds()), + fmt.Sprintf("PG_QUERY_TIMEOUT_SECS=%.0f", queryTimeout.Seconds()), + "PG_META_GENERATE_TYPES=" + lang, + "PG_META_GENERATE_TYPES_INCLUDED_SCHEMAS=" + included, + "PG_META_GENERATE_TYPES_SWIFT_ACCESS_CONTROL=" + swiftAccessControl, + fmt.Sprintf("PG_META_GENERATE_TYPES_DETECT_ONE_TO_ONE_RELATIONSHIPS=%v", !postgrestV9Compat), + } + if ca, err := GetRootCA(ctx, originalURL, options...); err != nil { return err - } else if require { - // node-postgres does not support sslmode=prefer - escaped += "&sslmode=require" + } else if len(ca) > 0 { + env = append(env, "PG_META_DB_SSL_ROOT_CERT="+ca) } return utils.DockerRunOnceWithConfig( ctx, container.Config{ Image: utils.Config.Studio.PgmetaImage, - Env: []string{ - "PG_META_DB_URL=" + escaped, - "PG_META_GENERATE_TYPES=" + lang, - "PG_META_GENERATE_TYPES_INCLUDED_SCHEMAS=" + included, - "PG_META_GENERATE_TYPES_SWIFT_ACCESS_CONTROL=" + swiftAccessControl, - fmt.Sprintf("PG_META_GENERATE_TYPES_DETECT_ONE_TO_ONE_RELATIONSHIPS=%v", !postgrestV9Compat), - }, - Cmd: []string{"node", "dist/server/server.js"}, + Env: env, + Cmd: []string{"node", "dist/server/server.js"}, }, hostConfig, network.NetworkingConfig{}, @@ -106,7 +109,22 @@ func Run(ctx context.Context, projectId string, dbConfig pgconn.Config, lang str ) } -func IsRequireSSL(ctx context.Context, dbUrl string, options ...func(*pgx.ConnConfig)) (bool, error) { +var ( + //go:embed templates/staging-ca-2021.crt + caStaging string + //go:embed templates/prod-ca-2021.crt + caProd string +) + +func GetRootCA(ctx context.Context, dbURL string, options ...func(*pgx.ConnConfig)) (string, error) { + // node-postgres does not support sslmode=prefer + if require, err := isRequireSSL(ctx, dbURL, options...); !require { + return "", err + } + return caStaging + caProd, nil +} + +func isRequireSSL(ctx context.Context, dbUrl string, options ...func(*pgx.ConnConfig)) (bool, error) { conn, err := utils.ConnectByUrl(ctx, dbUrl+"&sslmode=require", options...) if err != nil { if strings.HasSuffix(err.Error(), "(server refused TLS connection)") { diff --git a/internal/gen/types/types_test.go b/internal/gen/types/types_test.go index bf7ee4067..e7134b182 100644 --- a/internal/gen/types/types_test.go +++ b/internal/gen/types/types_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "testing" + "time" "github.com/docker/docker/api/types/container" "github.com/h2non/gock" @@ -48,7 +49,7 @@ func TestGenLocalCommand(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) // Run test - assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys, conn.Intercept)) + assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", time.Second, fsys, conn.Intercept)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -63,7 +64,7 @@ func TestGenLocalCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/containers/" + utils.DbId). Reply(http.StatusServiceUnavailable) // Run test - assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys)) + assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", time.Second, fsys)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -83,7 +84,7 @@ func TestGenLocalCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/images"). Reply(http.StatusServiceUnavailable) // Run test - assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", fsys)) + assert.Error(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{}, true, "", time.Second, fsys)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -106,7 +107,7 @@ func TestGenLocalCommand(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) // Run test - assert.NoError(t, Run(context.Background(), "", dbConfig, LangSwift, []string{}, true, SwiftInternalAccessControl, fsys, conn.Intercept)) + assert.NoError(t, Run(context.Background(), "", dbConfig, LangSwift, []string{}, true, SwiftInternalAccessControl, time.Second, fsys, conn.Intercept)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -129,7 +130,7 @@ func TestGenLinkedCommand(t *testing.T) { Reply(200). JSON(api.TypescriptResponse{Types: ""}) // Run test - assert.NoError(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys)) + assert.NoError(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", time.Second, fsys)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) }) @@ -144,7 +145,7 @@ func TestGenLinkedCommand(t *testing.T) { Get("/v1/projects/" + projectId + "/types/typescript"). ReplyError(errNetwork) // Run test - err := Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys) + err := Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", time.Second, fsys) // Validate api assert.ErrorIs(t, err, errNetwork) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -159,7 +160,7 @@ func TestGenLinkedCommand(t *testing.T) { Get("/v1/projects/" + projectId + "/types/typescript"). Reply(http.StatusServiceUnavailable) // Run test - assert.Error(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", fsys)) + assert.Error(t, Run(context.Background(), projectId, pgconn.Config{}, LangTypescript, []string{}, true, "", time.Second, fsys)) }) } @@ -184,7 +185,7 @@ func TestGenRemoteCommand(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) // Run test - assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, true, "", afero.NewMemMapFs(), conn.Intercept)) + assert.NoError(t, Run(context.Background(), "", dbConfig, LangTypescript, []string{"public"}, true, "", time.Second, afero.NewMemMapFs(), conn.Intercept)) // Validate api assert.Empty(t, apitest.ListUnmatchedRequests()) })