Skip to content

Commit a8b2c82

Browse files
committed
feat(mssql): add WithScripts and GetSQLCmdPath functions
1 parent abdce5d commit a8b2c82

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

modules/mssql/mssql.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package mssql
33
import (
44
"context"
55
"fmt"
6+
"path/filepath"
7+
"strconv"
68
"strings"
79

810
"github.com/testcontainers/testcontainers-go"
@@ -41,6 +43,47 @@ func WithPassword(password string) testcontainers.CustomizeRequestOption {
4143
}
4244
}
4345

46+
// WithScripts adds SQL scripts to be executed after the container is ready.
47+
// The scripts are executed in the order they are provided using sqlcmd tool.
48+
func WithScripts(scripts ...string) testcontainers.CustomizeRequestOption {
49+
return func(req *testcontainers.GenericContainerRequest) error {
50+
hooks := make([]testcontainers.ContainerHook, 0, len(scripts))
51+
for _, script := range scripts {
52+
hook := func(ctx context.Context, c testcontainers.Container) error {
53+
password := defaultPassword
54+
if req.Env["MSSQL_SA_PASSWORD"] != "" {
55+
password = req.Env["MSSQL_SA_PASSWORD"]
56+
}
57+
58+
targetPath := "/tmp/" + filepath.Base(script)
59+
if err := c.CopyFileToContainer(ctx, script, targetPath, 0o644); err != nil {
60+
return err
61+
}
62+
63+
cmd := testcontainers.NewRawCommand([]string{
64+
GetSQLCmdPath(req.Image),
65+
"-S", "localhost",
66+
"-U", defaultUsername,
67+
"-P", password,
68+
"-No",
69+
"-i", targetPath,
70+
})
71+
if _, _, err := c.Exec(ctx, cmd.AsCommand(), cmd.Options()...); err != nil {
72+
return fmt.Errorf("script %q: %w", script, err)
73+
}
74+
return nil
75+
}
76+
hooks = append(hooks, hook)
77+
}
78+
79+
req.LifecycleHooks = append(req.LifecycleHooks, testcontainers.ContainerLifecycleHooks{
80+
PostReadies: hooks,
81+
})
82+
83+
return nil
84+
}
85+
}
86+
4487
// Deprecated: use Run instead
4588
// RunContainer creates an instance of the MSSQLServer container type
4689
func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*MSSQLServerContainer, error) {
@@ -99,3 +142,25 @@ func (c *MSSQLServerContainer) ConnectionString(ctx context.Context, args ...str
99142

100143
return connStr, nil
101144
}
145+
146+
// GetSQLCmdPath helper function to return the
147+
// sqlcmd path based on the image version
148+
func GetSQLCmdPath(image string) string {
149+
const sqlCmd, sqlCmdOld = "/opt/mssql-tools18/bin/sqlcmd", "/opt/mssql-tools/bin/sqlcmd"
150+
151+
if strings.Contains(image, "2019-") || strings.Contains(image, "2017-") {
152+
return sqlCmdOld
153+
}
154+
155+
parts := strings.Split(strings.ToLower(image), "-")
156+
for _, part := range parts {
157+
if strings.HasPrefix(part, "cu") {
158+
cuNumber := strings.TrimPrefix(part, "cu")
159+
if num, err := strconv.Atoi(cuNumber); err == nil && num >= 14 {
160+
return sqlCmd
161+
}
162+
}
163+
}
164+
165+
return sqlCmd
166+
}

0 commit comments

Comments
 (0)