Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 32 additions & 42 deletions modules/scylladb/scylladb.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,17 @@ func WithConfig(r io.Reader) testcontainers.CustomizeRequestOption {
ContainerFilePath: "/etc/scylla/scylla.yaml",
FileMode: 0o644,
}
req.Files = append(req.Files, cf)

return nil
return testcontainers.WithFiles(cf)(req)
}
}

// WithShardAwareness enable shard-awareness in the ScyllaDB container so you can use the `19042` port.
func WithShardAwareness() testcontainers.CustomizeRequestOption {
return func(req *testcontainers.GenericContainerRequest) error {
req.ExposedPorts = append(req.ExposedPorts, shardAwarePort)
req.WaitingFor = wait.ForAll(req.WaitingFor, wait.ForListeningPort(shardAwarePort))
return nil
if err := testcontainers.WithExposedPorts(shardAwarePort)(req); err != nil {
return err
}
return testcontainers.WithWaitStrategy(wait.ForListeningPort(shardAwarePort))(req)
}
}

Expand All @@ -52,14 +51,16 @@ func WithAlternator() testcontainers.CustomizeRequestOption {
portFlagValue := strings.ReplaceAll(alternatorPort, "/tcp", "")

return func(req *testcontainers.GenericContainerRequest) error {
req.ExposedPorts = append(req.ExposedPorts, alternatorPort)
req.WaitingFor = wait.ForAll(req.WaitingFor, wait.ForListeningPort(alternatorPort))
setCommandFlag(req, map[string]string{
if err := testcontainers.WithExposedPorts(alternatorPort)(req); err != nil {
return err
}
if err := testcontainers.WithWaitStrategy(wait.ForListeningPort(alternatorPort))(req); err != nil {
return err
}
return setCommandFlags(req, map[string]string{
"--alternator-port": portFlagValue,
"--alternator-write-isolation": "always",
})

return nil
}
}

Expand All @@ -86,8 +87,7 @@ func WithCustomCommands(flags ...string) testcontainers.CustomizeRequestOption {
}
}

setCommandFlag(req, flagsMap)
return nil
return setCommandFlags(req, flagsMap)
}
}

Expand All @@ -108,16 +108,15 @@ func (c Container) AlternatorConnectionHost(ctx context.Context) (string, error)

// Run starts a ScyllaDB container with the specified image and options
func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*Container, error) {
req := testcontainers.ContainerRequest{
Image: img,
ExposedPorts: []string{port},
Cmd: []string{
moduleOpts := []testcontainers.ContainerCustomizer{
testcontainers.WithExposedPorts(port),
testcontainers.WithCmd(
"--developer-mode=1",
"--overprovisioned=1",
"--smp=1",
"--memory=512M",
},
WaitingFor: wait.ForAll(
),
testcontainers.WithWaitStrategy(
wait.ForListeningPort(port),
wait.ForExec([]string{"cqlsh", "-e", "SELECT bootstrapped FROM system.local"}).WithResponseMatcher(func(body io.Reader) bool {
data, _ := io.ReadAll(body)
Expand All @@ -126,63 +125,54 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
),
}

genericContainerReq := testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
}

for _, opt := range opts {
if err := opt.Customize(&genericContainerReq); err != nil {
return nil, fmt.Errorf("customize: %w", err)
}
}
moduleOpts = append(moduleOpts, opts...)

container, err := testcontainers.GenericContainer(ctx, genericContainerReq)
ctr, err := testcontainers.Run(ctx, img, moduleOpts...)
var c *Container
if container != nil {
c = &Container{Container: container}
if ctr != nil {
c = &Container{Container: ctr}
}

if err != nil {
return c, fmt.Errorf("generic container: %w", err)
return c, fmt.Errorf("run scylladb: %w", err)
}

return c, nil
}

// setCommandFlag sets the flags in the command line.
// It takes the array of commands from the GenericContainerRequest and a map of flags,
// setCommandFlags sets the flags in the command line.
// It takes the container request and a map of flags,
// and checks if the flag is present in the command line, overriding the value if it is.
// If the flag is not present, it's added to the command line.
func setCommandFlag(req *testcontainers.GenericContainerRequest, flags map[string]string) {
// If the flag is not present, it's added to the end of the command line.
func setCommandFlags(req *testcontainers.GenericContainerRequest, flagsMap map[string]string) error {
cmds := []string{}

for _, cmd := range req.Cmd {
before, _, hasEquals := strings.Cut(cmd, "=")
val, ok := flags[before]
val, ok := flagsMap[before]
if ok {
if hasEquals {
cmds = append(cmds, before+"="+val)
} else {
cmds = append(cmds, before)
}
// The flag is present in the command line, so it's removed from the flags map
// The flag is present in the command line, so it's removed from the flagsMap
// to avoid adding it to the end of the command line.
delete(flags, before)
delete(flagsMap, before)
} else {
cmds = append(cmds, cmd)
}
}

// The extra flags not present in the command line are added to the end of the command line,
// and this could be in any order.
for key, val := range flags {
for key, val := range flagsMap {
if val == "" {
cmds = append(cmds, key)
} else {
cmds = append(cmds, key+"="+val)
}
}

req.Cmd = cmds
return testcontainers.WithCmd(cmds...)(req)
}
16 changes: 8 additions & 8 deletions modules/scylladb/scylladb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func requireCreateTable(t *testing.T, client *dynamodb.Client) {

func TestWithCustomCommands(t *testing.T) {
t.Run("invalid-flag", func(t *testing.T) {
req := testcontainers.GenericContainerRequest{
req := &testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Cmd: []string{"--memory=1G", "--smp=2"},
},
Expand All @@ -206,7 +206,7 @@ func TestWithCustomCommands(t *testing.T) {
// Same commands as in the Cmd, overriding the values.
opt := scylladb.WithCustomCommands("--memory=2G", "--smp=4", "invalid-flag")

err := opt.Customize(&req)
err := opt.Customize(req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid flag")

Expand All @@ -217,7 +217,7 @@ func TestWithCustomCommands(t *testing.T) {
})

t.Run("equals-override", func(t *testing.T) {
req := testcontainers.GenericContainerRequest{
req := &testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Cmd: []string{"--memory=1G", "--smp=2"},
},
Expand All @@ -226,7 +226,7 @@ func TestWithCustomCommands(t *testing.T) {
// Same commands as in the Cmd, overriding the values.
opt := scylladb.WithCustomCommands("--memory=2G", "--smp=4")

err := opt.Customize(&req)
err := opt.Customize(req)
require.NoError(t, err)

require.Len(t, req.Cmd, 2)
Expand All @@ -235,7 +235,7 @@ func TestWithCustomCommands(t *testing.T) {
})

t.Run("equals-override/no-equals", func(t *testing.T) {
req := testcontainers.GenericContainerRequest{
req := &testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Cmd: []string{"--memory=1G", "--flag1=true", "--flag2"},
},
Expand All @@ -245,7 +245,7 @@ func TestWithCustomCommands(t *testing.T) {
// of several types: with and without equals.
opt := scylladb.WithCustomCommands("--memory=2G", "--smp=4", "--flag1=false", "--flag2", "--flag3")

err := opt.Customize(&req)
err := opt.Customize(req)
require.NoError(t, err)

require.Len(t, req.Cmd, 5)
Expand All @@ -259,7 +259,7 @@ func TestWithCustomCommands(t *testing.T) {
})

t.Run("equals-override/different-order", func(t *testing.T) {
req := testcontainers.GenericContainerRequest{
req := &testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: "scylladb/scylla:6.2",
Cmd: []string{"--memory=1G", "--smp=2"},
Expand All @@ -268,7 +268,7 @@ func TestWithCustomCommands(t *testing.T) {

opt := scylladb.WithCustomCommands("--smp=4", "--memory=2G")

err := opt.Customize(&req)
err := opt.Customize(req)
require.NoError(t, err)

require.Len(t, req.Cmd, 2)
Expand Down
Loading