@@ -17,6 +17,7 @@ limitations under the License.
17
17
package netexec
18
18
19
19
import (
20
+ "context"
20
21
"encoding/json"
21
22
"fmt"
22
23
"io"
@@ -69,8 +70,14 @@ var CmdNetexec = &cobra.Command{
69
70
Acceptable values: "http", "udp", "sctp".
70
71
- "tries": The number of times the request will be performed. Default value: "1".
71
72
- "/echo": Returns the given "msg" ("/echo?msg=echoed_msg")
72
- - "/exit": Closes the server with the given code ("/exit?code=some-code"). The "code"
73
- is expected to be an integer [0-127] or empty; if it is not, it will return an error message.
73
+ - "/exit": Closes the server with the given code and graceful shutdown. The endpoint's parameters
74
+ are:
75
+ - "code": The exit code for the process. Default value: 0. Allows an integer [0-127].
76
+ - "timeout": The amount of time to wait for connections to close before shutting down.
77
+ Acceptable values are golang durations. If 0 the process will exit immediately without
78
+ shutdown.
79
+ - "wait": The amount of time to wait before starting shutdown. Acceptable values are
80
+ golang durations. If 0 the process will start shutdown immediately.
74
81
- "/healthz": Returns "200 OK" if the server is ready, "412 Status Precondition Failed"
75
82
otherwise. The server is considered not ready if the UDP server did not start yet or
76
83
it exited.
@@ -127,25 +134,27 @@ func (a *atomicBool) get() bool {
127
134
}
128
135
129
136
func main (cmd * cobra.Command , args []string ) {
137
+ exitCh := make (chan shutdownRequest )
138
+ addRoutes (exitCh )
139
+
130
140
go startUDPServer (udpPort )
131
141
if sctpPort != - 1 {
132
142
go startSCTPServer (sctpPort )
133
143
}
134
144
135
- addRoutes ()
145
+ server := & http. Server { Addr : fmt . Sprintf ( ":%d" , httpPort )}
136
146
if len (certFile ) > 0 {
137
- // only start HTTPS server if a cert is provided
138
- startHTTPSServer (httpPort , certFile , privKeyFile )
147
+ startServer (server , exitCh , func () error { return server .ListenAndServeTLS (certFile , privKeyFile ) })
139
148
} else {
140
- startHTTPServer ( httpPort )
149
+ startServer ( server , exitCh , server . ListenAndServe )
141
150
}
142
151
}
143
152
144
- func addRoutes () {
153
+ func addRoutes (exitCh chan shutdownRequest ) {
145
154
http .HandleFunc ("/" , rootHandler )
146
155
http .HandleFunc ("/clientip" , clientIPHandler )
147
156
http .HandleFunc ("/echo" , echoHandler )
148
- http .HandleFunc ("/exit" , exitHandler )
157
+ http .HandleFunc ("/exit" , func ( w http. ResponseWriter , req * http. Request ) { exitHandler ( w , req , exitCh ) } )
149
158
http .HandleFunc ("/hostname" , hostnameHandler )
150
159
http .HandleFunc ("/shell" , shellHandler )
151
160
http .HandleFunc ("/upload" , uploadHandler )
@@ -156,12 +165,23 @@ func addRoutes() {
156
165
http .HandleFunc ("/shutdown" , shutdownHandler )
157
166
}
158
167
159
- func startHTTPSServer (httpsPort int , certFile , privKeyFile string ) {
160
- log .Fatal (http .ListenAndServeTLS (fmt .Sprintf (":%d" , httpPort ), certFile , privKeyFile , nil ))
161
- }
168
+ func startServer (server * http.Server , exitCh chan shutdownRequest , fn func () error ) {
169
+ go func () {
170
+ re := <- exitCh
171
+ ctx , cancelFn := context .WithTimeout (context .Background (), re .timeout )
172
+ defer cancelFn ()
173
+ err := server .Shutdown (ctx )
174
+ log .Printf ("Graceful shutdown completed with: %v" , err )
175
+ os .Exit (re .code )
176
+ }()
162
177
163
- func startHTTPServer (httpPort int ) {
164
- log .Fatal (http .ListenAndServe (fmt .Sprintf (":%d" , httpPort ), nil ))
178
+ if err := fn (); err != nil {
179
+ if err == http .ErrServerClosed {
180
+ // wait until the goroutine calls os.Exit()
181
+ select {}
182
+ }
183
+ log .Fatal (err )
184
+ }
165
185
}
166
186
167
187
func rootHandler (w http.ResponseWriter , r * http.Request ) {
@@ -179,13 +199,37 @@ func clientIPHandler(w http.ResponseWriter, r *http.Request) {
179
199
fmt .Fprintf (w , r .RemoteAddr )
180
200
}
181
201
182
- func exitHandler (w http.ResponseWriter , r * http.Request ) {
183
- log .Printf ("GET /exit?code=%s" , r .FormValue ("code" ))
184
- code , err := strconv .Atoi (r .FormValue ("code" ))
185
- if err == nil || r .FormValue ("code" ) == "" {
202
+ type shutdownRequest struct {
203
+ code int
204
+ timeout time.Duration
205
+ }
206
+
207
+ func exitHandler (w http.ResponseWriter , r * http.Request , exitCh chan <- shutdownRequest ) {
208
+ waitString := r .FormValue ("wait" )
209
+ timeoutString := r .FormValue ("timeout" )
210
+ codeString := r .FormValue ("code" )
211
+ log .Printf ("GET /exit?code=%s&timeout=%s&wait=%s" , codeString , timeoutString , waitString )
212
+ timeout , err := time .ParseDuration (timeoutString )
213
+ if err != nil && timeoutString != "" {
214
+ fmt .Fprintf (w , "argument 'timeout' must be a valid golang duration or empty, got %q\n " , timeoutString )
215
+ return
216
+ }
217
+ wait , err := time .ParseDuration (waitString )
218
+ if err != nil && waitString != "" {
219
+ fmt .Fprintf (w , "argument 'wait' must be a valid golang duration or empty, got %q\n " , waitString )
220
+ return
221
+ }
222
+ code , err := strconv .Atoi (codeString )
223
+ if err != nil && codeString != "" {
224
+ fmt .Fprintf (w , "argument 'code' must be an integer [0-127] or empty, got %q\n " , codeString )
225
+ return
226
+ }
227
+ log .Printf ("Will begin shutdown in %s, allowing %s for connections to close, then will exit with %d" , wait , timeout , code )
228
+ time .Sleep (wait )
229
+ if timeout == 0 {
186
230
os .Exit (code )
187
231
}
188
- fmt . Fprintf ( w , "argument ' code' must be an integer [0-127] or empty, got %q" , r . FormValue ( "code" ))
232
+ exitCh <- shutdownRequest { code : code , timeout : timeout }
189
233
}
190
234
191
235
func hostnameHandler (w http.ResponseWriter , r * http.Request ) {
0 commit comments