diff --git a/modules/scylladb/scylladb.go b/modules/scylladb/scylladb.go index bf7c604014..42493807c7 100644 --- a/modules/scylladb/scylladb.go +++ b/modules/scylladb/scylladb.go @@ -29,18 +29,17 @@ func WithConfig(r io.Reader) testcontainers.CustomizeRequestOption { ContainerFilePath: "/etc/scylla/scylla.yaml", FileMode: 0o644, } - req.Files = append(req.Files, cf) - - return nil + return testcontainers.WithFiles(cf)(req) } } // WithShardAwareness enable shard-awareness in the ScyllaDB container so you can use the `19042` port. func WithShardAwareness() testcontainers.CustomizeRequestOption { return func(req *testcontainers.GenericContainerRequest) error { - req.ExposedPorts = append(req.ExposedPorts, shardAwarePort) - req.WaitingFor = wait.ForAll(req.WaitingFor, wait.ForListeningPort(shardAwarePort)) - return nil + if err := testcontainers.WithExposedPorts(shardAwarePort)(req); err != nil { + return err + } + return testcontainers.WithWaitStrategy(wait.ForListeningPort(shardAwarePort))(req) } } @@ -52,14 +51,16 @@ func WithAlternator() testcontainers.CustomizeRequestOption { portFlagValue := strings.ReplaceAll(alternatorPort, "/tcp", "") return func(req *testcontainers.GenericContainerRequest) error { - req.ExposedPorts = append(req.ExposedPorts, alternatorPort) - req.WaitingFor = wait.ForAll(req.WaitingFor, wait.ForListeningPort(alternatorPort)) - setCommandFlag(req, map[string]string{ + if err := testcontainers.WithExposedPorts(alternatorPort)(req); err != nil { + return err + } + if err := testcontainers.WithWaitStrategy(wait.ForListeningPort(alternatorPort))(req); err != nil { + return err + } + return setCommandFlags(req, map[string]string{ "--alternator-port": portFlagValue, "--alternator-write-isolation": "always", }) - - return nil } } @@ -86,8 +87,7 @@ func WithCustomCommands(flags ...string) testcontainers.CustomizeRequestOption { } } - setCommandFlag(req, flagsMap) - return nil + return setCommandFlags(req, flagsMap) } } @@ -108,16 +108,15 @@ func (c Container) AlternatorConnectionHost(ctx context.Context) (string, error) // Run starts a ScyllaDB container with the specified image and options func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*Container, error) { - req := testcontainers.ContainerRequest{ - Image: img, - ExposedPorts: []string{port}, - Cmd: []string{ + moduleOpts := []testcontainers.ContainerCustomizer{ + testcontainers.WithExposedPorts(port), + testcontainers.WithCmd( "--developer-mode=1", "--overprovisioned=1", "--smp=1", "--memory=512M", - }, - WaitingFor: wait.ForAll( + ), + testcontainers.WithWaitStrategy( wait.ForListeningPort(port), wait.ForExec([]string{"cqlsh", "-e", "SELECT bootstrapped FROM system.local"}).WithResponseMatcher(func(body io.Reader) bool { data, _ := io.ReadAll(body) @@ -126,49 +125,40 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom ), } - genericContainerReq := testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, - } - - for _, opt := range opts { - if err := opt.Customize(&genericContainerReq); err != nil { - return nil, fmt.Errorf("customize: %w", err) - } - } + moduleOpts = append(moduleOpts, opts...) - container, err := testcontainers.GenericContainer(ctx, genericContainerReq) + ctr, err := testcontainers.Run(ctx, img, moduleOpts...) var c *Container - if container != nil { - c = &Container{Container: container} + if ctr != nil { + c = &Container{Container: ctr} } if err != nil { - return c, fmt.Errorf("generic container: %w", err) + return c, fmt.Errorf("run scylladb: %w", err) } return c, nil } -// setCommandFlag sets the flags in the command line. -// It takes the array of commands from the GenericContainerRequest and a map of flags, +// setCommandFlags sets the flags in the command line. +// It takes the container request and a map of flags, // and checks if the flag is present in the command line, overriding the value if it is. -// If the flag is not present, it's added to the command line. -func setCommandFlag(req *testcontainers.GenericContainerRequest, flags map[string]string) { +// If the flag is not present, it's added to the end of the command line. +func setCommandFlags(req *testcontainers.GenericContainerRequest, flagsMap map[string]string) error { cmds := []string{} for _, cmd := range req.Cmd { before, _, hasEquals := strings.Cut(cmd, "=") - val, ok := flags[before] + val, ok := flagsMap[before] if ok { if hasEquals { cmds = append(cmds, before+"="+val) } else { cmds = append(cmds, before) } - // The flag is present in the command line, so it's removed from the flags map + // The flag is present in the command line, so it's removed from the flagsMap // to avoid adding it to the end of the command line. - delete(flags, before) + delete(flagsMap, before) } else { cmds = append(cmds, cmd) } @@ -176,7 +166,7 @@ func setCommandFlag(req *testcontainers.GenericContainerRequest, flags map[strin // The extra flags not present in the command line are added to the end of the command line, // and this could be in any order. - for key, val := range flags { + for key, val := range flagsMap { if val == "" { cmds = append(cmds, key) } else { @@ -184,5 +174,5 @@ func setCommandFlag(req *testcontainers.GenericContainerRequest, flags map[strin } } - req.Cmd = cmds + return testcontainers.WithCmd(cmds...)(req) } diff --git a/modules/scylladb/scylladb_test.go b/modules/scylladb/scylladb_test.go index 9ec7565359..a1db016ff4 100644 --- a/modules/scylladb/scylladb_test.go +++ b/modules/scylladb/scylladb_test.go @@ -197,7 +197,7 @@ func requireCreateTable(t *testing.T, client *dynamodb.Client) { func TestWithCustomCommands(t *testing.T) { t.Run("invalid-flag", func(t *testing.T) { - req := testcontainers.GenericContainerRequest{ + req := &testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ Cmd: []string{"--memory=1G", "--smp=2"}, }, @@ -206,7 +206,7 @@ func TestWithCustomCommands(t *testing.T) { // Same commands as in the Cmd, overriding the values. opt := scylladb.WithCustomCommands("--memory=2G", "--smp=4", "invalid-flag") - err := opt.Customize(&req) + err := opt.Customize(req) require.Error(t, err) require.Contains(t, err.Error(), "invalid flag") @@ -217,7 +217,7 @@ func TestWithCustomCommands(t *testing.T) { }) t.Run("equals-override", func(t *testing.T) { - req := testcontainers.GenericContainerRequest{ + req := &testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ Cmd: []string{"--memory=1G", "--smp=2"}, }, @@ -226,7 +226,7 @@ func TestWithCustomCommands(t *testing.T) { // Same commands as in the Cmd, overriding the values. opt := scylladb.WithCustomCommands("--memory=2G", "--smp=4") - err := opt.Customize(&req) + err := opt.Customize(req) require.NoError(t, err) require.Len(t, req.Cmd, 2) @@ -235,7 +235,7 @@ func TestWithCustomCommands(t *testing.T) { }) t.Run("equals-override/no-equals", func(t *testing.T) { - req := testcontainers.GenericContainerRequest{ + req := &testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ Cmd: []string{"--memory=1G", "--flag1=true", "--flag2"}, }, @@ -245,7 +245,7 @@ func TestWithCustomCommands(t *testing.T) { // of several types: with and without equals. opt := scylladb.WithCustomCommands("--memory=2G", "--smp=4", "--flag1=false", "--flag2", "--flag3") - err := opt.Customize(&req) + err := opt.Customize(req) require.NoError(t, err) require.Len(t, req.Cmd, 5) @@ -259,7 +259,7 @@ func TestWithCustomCommands(t *testing.T) { }) t.Run("equals-override/different-order", func(t *testing.T) { - req := testcontainers.GenericContainerRequest{ + req := &testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ Image: "scylladb/scylla:6.2", Cmd: []string{"--memory=1G", "--smp=2"}, @@ -268,7 +268,7 @@ func TestWithCustomCommands(t *testing.T) { opt := scylladb.WithCustomCommands("--smp=4", "--memory=2G") - err := opt.Customize(&req) + err := opt.Customize(req) require.NoError(t, err) require.Len(t, req.Cmd, 2)