@@ -2,6 +2,7 @@ package cli
22
33import (
44 "bytes"
5+ "context"
56 "encoding/json"
67 "errors"
78 "fmt"
@@ -66,6 +67,8 @@ the prediction on that.`,
6667}
6768
6869func cmdPredict (cmd * cobra.Command , args []string ) error {
70+ ctx := cmd .Context ()
71+
6972 imageName := ""
7073 volumes := []docker.Volume {}
7174 gpus := gpusFlag
@@ -85,7 +88,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
8588 if buildFast {
8689 imageName = config .DockerImageName (projectDir )
8790 } else {
88- if imageName , err = image .BuildBase (cfg , projectDir , buildUseCudaBaseImage , DetermineUseCogBaseImage (cmd ), buildProgressOutput ); err != nil {
91+ if imageName , err = image .BuildBase (ctx , cfg , projectDir , buildUseCudaBaseImage , DetermineUseCogBaseImage (cmd ), buildProgressOutput ); err != nil {
8992 return err
9093 }
9194
@@ -109,17 +112,17 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
109112 return fmt .Errorf ("Invalid image name '%s'. Did you forget `-i`?" , imageName )
110113 }
111114
112- exists , err := docker .ImageExists (imageName )
115+ exists , err := docker .ImageExists (ctx , imageName )
113116 if err != nil {
114117 return fmt .Errorf ("Failed to determine if %s exists: %w" , imageName , err )
115118 }
116119 if ! exists {
117120 console .Infof ("Pulling image: %s" , imageName )
118- if err := docker .Pull (imageName ); err != nil {
121+ if err := docker .Pull (ctx , imageName ); err != nil {
119122 return fmt .Errorf ("Failed to pull %s: %w" , imageName , err )
120123 }
121124 }
122- conf , err := image .GetConfig (imageName )
125+ conf , err := image .GetConfig (ctx , imageName )
123126 if err != nil {
124127 return err
125128 }
@@ -135,7 +138,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
135138 console .Infof ("Starting Docker image %s and running setup()..." , imageName )
136139 dockerCommand := docker .NewDockerCommand ()
137140
138- predictor , err := predict .NewPredictor (docker.RunOptions {
141+ predictor , err := predict .NewPredictor (ctx , docker.RunOptions {
139142 GPUs : gpus ,
140143 Image : imageName ,
141144 Volumes : volumes ,
@@ -152,20 +155,20 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
152155 <- captureSignal
153156
154157 console .Info ("Stopping container..." )
155- if err := predictor .Stop (); err != nil {
158+ if err := predictor .Stop (ctx ); err != nil {
156159 console .Warnf ("Failed to stop container: %s" , err )
157160 }
158161 }()
159162
160163 timeout := time .Duration (setupTimeout ) * time .Second
161- if err := predictor .Start (os .Stderr , timeout ); err != nil {
164+ if err := predictor .Start (ctx , os .Stderr , timeout ); err != nil {
162165 // Only retry if we're using a GPU but but the user didn't explicitly select a GPU with --gpus
163166 // If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it
164167 if gpus == "all" && errors .Is (err , docker .ErrMissingDeviceDriver ) {
165168 console .Info ("Missing device driver, re-trying without GPU" )
166169
167- _ = predictor .Stop ()
168- predictor , err = predict .NewPredictor (docker.RunOptions {
170+ _ = predictor .Stop (ctx )
171+ predictor , err = predict .NewPredictor (ctx , docker.RunOptions {
169172 Image : imageName ,
170173 Volumes : volumes ,
171174 Env : envFlags ,
@@ -174,7 +177,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
174177 return err
175178 }
176179
177- if err := predictor .Start (os .Stderr , timeout ); err != nil {
180+ if err := predictor .Start (ctx , os .Stderr , timeout ); err != nil {
178181 return err
179182 }
180183 } else {
@@ -185,7 +188,8 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
185188 // FIXME: will not run on signal
186189 defer func () {
187190 console .Debugf ("Stopping container..." )
188- if err := predictor .Stop (); err != nil {
191+ // use background context to ensure stop signal is still sent after root context is canceled
192+ if err := predictor .Stop (context .Background ()); err != nil {
189193 console .Warnf ("Failed to stop container: %s" , err )
190194 }
191195 }()
0 commit comments