@@ -3,6 +3,8 @@ package mssql
33import (
44 "context"
55 "fmt"
6+ "path/filepath"
7+ "strconv"
68 "strings"
79
810 "github.com/testcontainers/testcontainers-go"
@@ -41,6 +43,47 @@ func WithPassword(password string) testcontainers.CustomizeRequestOption {
4143 }
4244}
4345
46+ // WithScripts adds SQL scripts to be executed after the container is ready.
47+ // The scripts are executed in the order they are provided using sqlcmd tool.
48+ func WithScripts (scripts ... string ) testcontainers.CustomizeRequestOption {
49+ return func (req * testcontainers.GenericContainerRequest ) error {
50+ hooks := make ([]testcontainers.ContainerHook , 0 , len (scripts ))
51+ for _ , script := range scripts {
52+ hook := func (ctx context.Context , c testcontainers.Container ) error {
53+ password := defaultPassword
54+ if req .Env ["MSSQL_SA_PASSWORD" ] != "" {
55+ password = req .Env ["MSSQL_SA_PASSWORD" ]
56+ }
57+
58+ targetPath := "/tmp/" + filepath .Base (script )
59+ if err := c .CopyFileToContainer (ctx , script , targetPath , 0o644 ); err != nil {
60+ return err
61+ }
62+
63+ cmd := testcontainers .NewRawCommand ([]string {
64+ GetSQLCmdPath (req .Image ),
65+ "-S" , "localhost" ,
66+ "-U" , defaultUsername ,
67+ "-P" , password ,
68+ "-No" ,
69+ "-i" , targetPath ,
70+ })
71+ if _ , _ , err := c .Exec (ctx , cmd .AsCommand (), cmd .Options ()... ); err != nil {
72+ return fmt .Errorf ("script %q: %w" , script , err )
73+ }
74+ return nil
75+ }
76+ hooks = append (hooks , hook )
77+ }
78+
79+ req .LifecycleHooks = append (req .LifecycleHooks , testcontainers.ContainerLifecycleHooks {
80+ PostReadies : hooks ,
81+ })
82+
83+ return nil
84+ }
85+ }
86+
4487// Deprecated: use Run instead
4588// RunContainer creates an instance of the MSSQLServer container type
4689func RunContainer (ctx context.Context , opts ... testcontainers.ContainerCustomizer ) (* MSSQLServerContainer , error ) {
@@ -99,3 +142,25 @@ func (c *MSSQLServerContainer) ConnectionString(ctx context.Context, args ...str
99142
100143 return connStr , nil
101144}
145+
146+ // GetSQLCmdPath helper function to return the
147+ // sqlcmd path based on the image version
148+ func GetSQLCmdPath (image string ) string {
149+ const sqlCmd , sqlCmdOld = "/opt/mssql-tools18/bin/sqlcmd" , "/opt/mssql-tools/bin/sqlcmd"
150+
151+ if strings .Contains (image , "2019-" ) || strings .Contains (image , "2017-" ) {
152+ return sqlCmdOld
153+ }
154+
155+ parts := strings .Split (strings .ToLower (image ), "-" )
156+ for _ , part := range parts {
157+ if strings .HasPrefix (part , "cu" ) {
158+ cuNumber := strings .TrimPrefix (part , "cu" )
159+ if num , err := strconv .Atoi (cuNumber ); err == nil && num >= 14 {
160+ return sqlCmd
161+ }
162+ }
163+ }
164+
165+ return sqlCmd
166+ }
0 commit comments