@@ -29,18 +29,17 @@ func WithConfig(r io.Reader) testcontainers.CustomizeRequestOption {
2929 ContainerFilePath : "/etc/scylla/scylla.yaml" ,
3030 FileMode : 0o644 ,
3131 }
32- req .Files = append (req .Files , cf )
33-
34- return nil
32+ return testcontainers .WithFiles (cf )(req )
3533 }
3634}
3735
3836// WithShardAwareness enable shard-awareness in the ScyllaDB container so you can use the `19042` port.
3937func WithShardAwareness () testcontainers.CustomizeRequestOption {
4038 return func (req * testcontainers.GenericContainerRequest ) error {
41- req .ExposedPorts = append (req .ExposedPorts , shardAwarePort )
42- req .WaitingFor = wait .ForAll (req .WaitingFor , wait .ForListeningPort (shardAwarePort ))
43- return nil
39+ if err := testcontainers .WithExposedPorts (shardAwarePort )(req ); err != nil {
40+ return err
41+ }
42+ return testcontainers .WithWaitStrategy (wait .ForListeningPort (shardAwarePort ))(req )
4443 }
4544}
4645
@@ -52,14 +51,16 @@ func WithAlternator() testcontainers.CustomizeRequestOption {
5251 portFlagValue := strings .ReplaceAll (alternatorPort , "/tcp" , "" )
5352
5453 return func (req * testcontainers.GenericContainerRequest ) error {
55- req .ExposedPorts = append (req .ExposedPorts , alternatorPort )
56- req .WaitingFor = wait .ForAll (req .WaitingFor , wait .ForListeningPort (alternatorPort ))
57- setCommandFlag (req , map [string ]string {
54+ if err := testcontainers .WithExposedPorts (alternatorPort )(req ); err != nil {
55+ return err
56+ }
57+ if err := testcontainers .WithWaitStrategy (wait .ForListeningPort (alternatorPort ))(req ); err != nil {
58+ return err
59+ }
60+ return setCommandFlags (req , map [string ]string {
5861 "--alternator-port" : portFlagValue ,
5962 "--alternator-write-isolation" : "always" ,
6063 })
61-
62- return nil
6364 }
6465}
6566
@@ -86,8 +87,7 @@ func WithCustomCommands(flags ...string) testcontainers.CustomizeRequestOption {
8687 }
8788 }
8889
89- setCommandFlag (req , flagsMap )
90- return nil
90+ return setCommandFlags (req , flagsMap )
9191 }
9292}
9393
@@ -108,16 +108,15 @@ func (c Container) AlternatorConnectionHost(ctx context.Context) (string, error)
108108
109109// Run starts a ScyllaDB container with the specified image and options
110110func Run (ctx context.Context , img string , opts ... testcontainers.ContainerCustomizer ) (* Container , error ) {
111- req := testcontainers.ContainerRequest {
112- Image : img ,
113- ExposedPorts : []string {port },
114- Cmd : []string {
111+ moduleOpts := []testcontainers.ContainerCustomizer {
112+ testcontainers .WithExposedPorts (port ),
113+ testcontainers .WithCmd (
115114 "--developer-mode=1" ,
116115 "--overprovisioned=1" ,
117116 "--smp=1" ,
118117 "--memory=512M" ,
119- } ,
120- WaitingFor : wait . ForAll (
118+ ) ,
119+ testcontainers . WithWaitStrategy (
121120 wait .ForListeningPort (port ),
122121 wait .ForExec ([]string {"cqlsh" , "-e" , "SELECT bootstrapped FROM system.local" }).WithResponseMatcher (func (body io.Reader ) bool {
123122 data , _ := io .ReadAll (body )
@@ -126,63 +125,54 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
126125 ),
127126 }
128127
129- genericContainerReq := testcontainers.GenericContainerRequest {
130- ContainerRequest : req ,
131- Started : true ,
132- }
133-
134- for _ , opt := range opts {
135- if err := opt .Customize (& genericContainerReq ); err != nil {
136- return nil , fmt .Errorf ("customize: %w" , err )
137- }
138- }
128+ moduleOpts = append (moduleOpts , opts ... )
139129
140- container , err := testcontainers .GenericContainer (ctx , genericContainerReq )
130+ ctr , err := testcontainers .Run (ctx , img , moduleOpts ... )
141131 var c * Container
142- if container != nil {
143- c = & Container {Container : container }
132+ if ctr != nil {
133+ c = & Container {Container : ctr }
144134 }
145135
146136 if err != nil {
147- return c , fmt .Errorf ("generic container : %w" , err )
137+ return c , fmt .Errorf ("run scylladb : %w" , err )
148138 }
149139
150140 return c , nil
151141}
152142
153- // setCommandFlag sets the flags in the command line.
154- // It takes the array of commands from the GenericContainerRequest and a map of flags,
143+ // setCommandFlags sets the flags in the command line.
144+ // It takes the container request and a map of flags,
155145// and checks if the flag is present in the command line, overriding the value if it is.
156- // If the flag is not present, it's added to the command line.
157- func setCommandFlag (req * testcontainers.GenericContainerRequest , flags map [string ]string ) {
146+ // If the flag is not present, it's added to the end of the command line.
147+ func setCommandFlags (req * testcontainers.GenericContainerRequest , flagsMap map [string ]string ) error {
158148 cmds := []string {}
159149
160150 for _ , cmd := range req .Cmd {
161151 before , _ , hasEquals := strings .Cut (cmd , "=" )
162- val , ok := flags [before ]
152+ val , ok := flagsMap [before ]
163153 if ok {
164154 if hasEquals {
165155 cmds = append (cmds , before + "=" + val )
166156 } else {
167157 cmds = append (cmds , before )
168158 }
169- // The flag is present in the command line, so it's removed from the flags map
159+ // The flag is present in the command line, so it's removed from the flagsMap
170160 // to avoid adding it to the end of the command line.
171- delete (flags , before )
161+ delete (flagsMap , before )
172162 } else {
173163 cmds = append (cmds , cmd )
174164 }
175165 }
176166
177167 // The extra flags not present in the command line are added to the end of the command line,
178168 // and this could be in any order.
179- for key , val := range flags {
169+ for key , val := range flagsMap {
180170 if val == "" {
181171 cmds = append (cmds , key )
182172 } else {
183173 cmds = append (cmds , key + "=" + val )
184174 }
185175 }
186176
187- req . Cmd = cmds
177+ return testcontainers . WithCmd ( cmds ... )( req )
188178}
0 commit comments