@@ -3,6 +3,8 @@ package mssql
33import (
44 "context"
55 "fmt"
6+ "path/filepath"
7+ "strconv"
68 "strings"
79
810 "github.com/testcontainers/testcontainers-go"
@@ -41,6 +43,45 @@ func WithPassword(password string) testcontainers.CustomizeRequestOption {
4143 }
4244}
4345
46+ // WithScripts adds SQL scripts to be executed after the container is ready.
47+ // The scripts are executed in the order they are provided using sqlcmd tool.
48+ func WithScripts (scripts ... string ) testcontainers.CustomizeRequestOption {
49+ return func (req * testcontainers.GenericContainerRequest ) error {
50+ hooks := make ([]testcontainers.ContainerHook , 0 , len (scripts ))
51+ for _ , script := range scripts {
52+ hook := func (ctx context.Context , c testcontainers.Container ) error {
53+ password := defaultPassword
54+ if req .Env ["MSSQL_SA_PASSWORD" ] != "" {
55+ password = req .Env ["MSSQL_SA_PASSWORD" ]
56+ }
57+
58+ targetPath := "/tmp/" + filepath .Base (script )
59+ if err := c .CopyFileToContainer (ctx , script , targetPath , 0o644 ); err != nil {
60+ return err
61+ }
62+
63+ cmd := testcontainers .NewRawCommand ([]string {
64+ GetSQLCmdPath (req .Image ),
65+ "-S" , "localhost" ,
66+ "-U" , defaultUsername ,
67+ "-P" , password ,
68+ "-No" ,
69+ "-i" , targetPath ,
70+ })
71+ _ , _ , err := c .Exec (ctx , cmd .AsCommand (), cmd .Options ()... )
72+ return err
73+ }
74+ hooks = append (hooks , hook )
75+ }
76+
77+ req .LifecycleHooks = append (req .LifecycleHooks , testcontainers.ContainerLifecycleHooks {
78+ PostReadies : hooks ,
79+ })
80+
81+ return nil
82+ }
83+ }
84+
4485// Deprecated: use Run instead
4586// RunContainer creates an instance of the MSSQLServer container type
4687func RunContainer (ctx context.Context , opts ... testcontainers.ContainerCustomizer ) (* MSSQLServerContainer , error ) {
@@ -99,3 +140,25 @@ func (c *MSSQLServerContainer) ConnectionString(ctx context.Context, args ...str
99140
100141 return connStr , nil
101142}
143+
144+ // GetSQLCmdPath helper function to return the
145+ // sqlcmd path based on the image version
146+ func GetSQLCmdPath (image string ) string {
147+ const sqlCmd , sqlCmdOld = "/opt/mssql-tools18/bin/sqlcmd" , "/opt/mssql-tools/bin/sqlcmd"
148+
149+ if strings .Contains (image , "2019-" ) || strings .Contains (image , "2017-" ) {
150+ return sqlCmdOld
151+ }
152+
153+ parts := strings .Split (strings .ToLower (image ), "-" )
154+ for _ , part := range parts {
155+ if strings .HasPrefix (part , "cu" ) {
156+ cuNumber := strings .TrimPrefix (part , "cu" )
157+ if num , err := strconv .Atoi (cuNumber ); err == nil && num >= 14 {
158+ return sqlCmd
159+ }
160+ }
161+ }
162+
163+ return sqlCmd
164+ }
0 commit comments