Skip to content

Commit 54acb6f

Browse files
committed
feat(mssql): add WithScripts and GetSQLCmdPath functions
1 parent abdce5d commit 54acb6f

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

modules/mssql/mssql.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package mssql
33
import (
44
"context"
55
"fmt"
6+
"path/filepath"
7+
"strconv"
68
"strings"
79

810
"github.com/testcontainers/testcontainers-go"
@@ -41,6 +43,45 @@ func WithPassword(password string) testcontainers.CustomizeRequestOption {
4143
}
4244
}
4345

46+
// WithScripts adds SQL scripts to be executed after the container is ready.
47+
// The scripts are executed in the order they are provided using sqlcmd tool.
48+
func WithScripts(scripts ...string) testcontainers.CustomizeRequestOption {
49+
return func(req *testcontainers.GenericContainerRequest) error {
50+
hooks := make([]testcontainers.ContainerHook, 0, len(scripts))
51+
for _, script := range scripts {
52+
hook := func(ctx context.Context, c testcontainers.Container) error {
53+
password := defaultPassword
54+
if req.Env["MSSQL_SA_PASSWORD"] != "" {
55+
password = req.Env["MSSQL_SA_PASSWORD"]
56+
}
57+
58+
targetPath := "/tmp/" + filepath.Base(script)
59+
if err := c.CopyFileToContainer(ctx, script, targetPath, 0o644); err != nil {
60+
return err
61+
}
62+
63+
cmd := testcontainers.NewRawCommand([]string{
64+
GetSQLCmdPath(req.Image),
65+
"-S", "localhost",
66+
"-U", defaultUsername,
67+
"-P", password,
68+
"-No",
69+
"-i", targetPath,
70+
})
71+
_, _, err := c.Exec(ctx, cmd.AsCommand(), cmd.Options()...)
72+
return err
73+
}
74+
hooks = append(hooks, hook)
75+
}
76+
77+
req.LifecycleHooks = append(req.LifecycleHooks, testcontainers.ContainerLifecycleHooks{
78+
PostReadies: hooks,
79+
})
80+
81+
return nil
82+
}
83+
}
84+
4485
// Deprecated: use Run instead
4586
// RunContainer creates an instance of the MSSQLServer container type
4687
func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*MSSQLServerContainer, error) {
@@ -99,3 +140,25 @@ func (c *MSSQLServerContainer) ConnectionString(ctx context.Context, args ...str
99140

100141
return connStr, nil
101142
}
143+
144+
// GetSQLCmdPath helper function to return the
145+
// sqlcmd path based on the image version
146+
func GetSQLCmdPath(image string) string {
147+
const sqlCmd, sqlCmdOld = "/opt/mssql-tools18/bin/sqlcmd", "/opt/mssql-tools/bin/sqlcmd"
148+
149+
if strings.Contains(image, "2019-") || strings.Contains(image, "2017-") {
150+
return sqlCmdOld
151+
}
152+
153+
parts := strings.Split(strings.ToLower(image), "-")
154+
for _, part := range parts {
155+
if strings.HasPrefix(part, "cu") {
156+
cuNumber := strings.TrimPrefix(part, "cu")
157+
if num, err := strconv.Atoi(cuNumber); err == nil && num >= 14 {
158+
return sqlCmd
159+
}
160+
}
161+
}
162+
163+
return sqlCmd
164+
}

0 commit comments

Comments
 (0)