Skip to content

Commit 47ab1c2

Browse files
authored
Merge pull request #61 from steadybit/windows-support
feat: handle signals on windows
2 parents 985c2bb + bd39bb9 commit 47ab1c2

File tree

8 files changed

+207
-57
lines changed

8 files changed

+207
-57
lines changed

extcmd/extcmd_test.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,39 @@ package extcmd
66
import (
77
"github.com/stretchr/testify/assert"
88
"os/exec"
9+
"runtime"
910
"testing"
1011
)
1112

1213
func TestNewCmdState(t *testing.T) {
1314
_, err := GetCmdState("I am unknown")
14-
assert.NotNil(t, err)
15+
assert.Error(t, err)
16+
17+
var cmd *exec.Cmd
18+
if runtime.GOOS == "windows" {
19+
cmd = exec.Command("cmd", "/C", "echo hello world")
20+
} else {
21+
cmd = exec.Command("echo", "hello", "world")
22+
}
1523

16-
cmd := exec.Command("echo", "hello", "world")
1724
cs := NewCmdState(cmd)
1825

1926
persistedState, err := GetCmdState(cs.Id)
20-
assert.Nil(t, err)
27+
assert.NoError(t, err)
2128
assert.NotNil(t, persistedState)
2229

2330
err = cmd.Start()
24-
assert.Nil(t, err)
31+
assert.NoError(t, err)
2532
err = cmd.Wait()
26-
assert.Nil(t, err)
33+
assert.NoError(t, err)
2734

2835
messages := cs.GetMessages(true)
2936
assert.Len(t, messages, 1)
3037
assert.Equal(t, "info", *messages[0].Level)
31-
assert.Equal(t, "hello world\n", messages[0].Message)
38+
assert.Contains(t, messages[0].Message, "hello world")
3239

3340
RemoveCmdState(cs.Id)
3441
persistedState, err = GetCmdState(cs.Id)
35-
assert.NotNil(t, err)
42+
assert.Error(t, err)
3643
assert.Nil(t, persistedState)
3744
}

exthttp/listener_test.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@ import (
1212
"crypto/tls"
1313
"crypto/x509"
1414
"fmt"
15-
"github.com/madflojo/testcerts"
16-
"github.com/phayes/freeport"
17-
"github.com/stretchr/testify/assert"
18-
"github.com/stretchr/testify/require"
1915
"net"
2016
"net/http"
2117
"net/http/httptest"
2218
"os"
2319
"path/filepath"
20+
"runtime"
2421
"testing"
22+
23+
"github.com/madflojo/testcerts"
24+
"github.com/phayes/freeport"
25+
"github.com/stretchr/testify/assert"
26+
"github.com/stretchr/testify/require"
2527
)
2628

2729
func TestValidateSpecificationSuccessHttp(t *testing.T) {
@@ -126,7 +128,16 @@ func TestStartHttpsServerMustFailWhenCertificateCannotBeFound(t *testing.T) {
126128
t.Setenv("STEADYBIT_EXTENSION_TLS_SERVER_CERT", filepath.Join(t.TempDir(), "unknown.pem"))
127129

128130
err = listen(ListenOpts{Port: port})
129-
assert.ErrorContains(t, err, "no such file or directory")
131+
132+
var expected string
133+
134+
if runtime.GOOS == "windows" {
135+
expected = "cannot find the file specified"
136+
} else {
137+
expected = "no such file or directory"
138+
}
139+
140+
assert.ErrorContains(t, err, expected)
130141
}
131142

132143
func TestStartHttpsServerMustFailWhenKeyCannotBeFound(t *testing.T) {
@@ -142,7 +153,15 @@ func TestStartHttpsServerMustFailWhenKeyCannotBeFound(t *testing.T) {
142153
t.Setenv("STEADYBIT_EXTENSION_TLS_SERVER_CERT", cert)
143154
err = listen(ListenOpts{Port: port})
144155

145-
assert.ErrorContains(t, err, "no such file or directory")
156+
var expected string
157+
158+
if runtime.GOOS == "windows" {
159+
expected = "cannot find the file specified"
160+
} else {
161+
expected = "no such file or directory"
162+
}
163+
164+
assert.ErrorContains(t, err, expected)
146165
}
147166

148167
func TestStartHttpsServerWithMutualTlsMustRefuseConnectionsWithoutMutualTls(t *testing.T) {

extlogging/extlogging.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
package extlogging
77

88
import (
9-
"github.com/rs/zerolog"
10-
"github.com/rs/zerolog/log"
119
"os"
1210
"strings"
11+
"time"
12+
13+
"github.com/rs/zerolog"
14+
"github.com/rs/zerolog/log"
1315
)
1416

1517
const RFC3339Micro = "2006-01-02T15:04:05.999Z07:00"
@@ -21,7 +23,10 @@ func InitZeroLog() {
2123

2224
var logger zerolog.Logger
2325
if strings.ToLower(os.Getenv("STEADYBIT_LOG_FORMAT")) != "json" {
24-
logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: getNoColor(), TimeFormat: RFC3339Micro})
26+
logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, NoColor: getNoColor(), TimeFormat: RFC3339Micro, FormatTimestamp: func(i interface{}) string {
27+
timestamp, _ := time.Parse(time.RFC3339, i.(string))
28+
return timestamp.Format(RFC3339Micro)
29+
}})
2530
} else {
2631
logger = zerolog.New(os.Stderr)
2732
}

extlogging/extlogging_test.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
package extlogging
22

33
import (
4-
"github.com/rs/zerolog"
5-
"github.com/rs/zerolog/log"
6-
"github.com/stretchr/testify/assert"
74
"io"
85
"os"
96
"sync"
107
"testing"
118
"time"
9+
10+
"github.com/rs/zerolog"
11+
"github.com/rs/zerolog/log"
12+
"github.com/stretchr/testify/assert"
1213
)
1314

1415
var captureLock sync.Mutex
@@ -30,11 +31,12 @@ func TestInitZeroLog_Format(t *testing.T) {
3031
t.Run(tt.format, func(t *testing.T) {
3132
t.Setenv("STEADYBIT_LOG_FORMAT", tt.format)
3233

33-
msg := captureStdErr(func() {
34+
msg, err := captureStdErr(func() {
3435
InitZeroLog()
3536
log.Info().Msg("test")
3637
})
3738

39+
assert.Nil(t, err)
3840
assert.Equal(t, tt.wantedOutput, msg)
3941
})
4042
}
@@ -60,21 +62,32 @@ func TestInitZeroLog_Level(t *testing.T) {
6062
}
6163
}
6264

63-
func captureStdErr(f func()) string {
65+
func captureStdErr(f func()) (string, error) {
6466
captureLock.Lock()
6567
defer captureLock.Unlock()
6668

6769
rescueStderr := os.Stderr
68-
r, w, _ := os.Pipe()
70+
rescueStdout := os.Stdout
71+
r, w, err := os.Pipe()
72+
73+
if err != nil {
74+
log.Error().Msgf("unable to create os pipe: %s", err)
75+
return "", err
76+
}
77+
6978
os.Stderr = w
7079
os.Stdout = w
7180

7281
f()
73-
_ = w.Sync()
7482

75-
defer func() { os.Stderr = rescueStderr }()
83+
defer func() {
84+
os.Stderr = rescueStderr
85+
os.Stdout = rescueStdout
86+
}()
7687

7788
_ = r.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
89+
90+
w.Close()
7891
captured, _ := io.ReadAll(r)
79-
return string(captured)
92+
return string(captured), nil
8093
}

extsignals/signalhandler.go

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
package extsignals
22

33
import (
4-
"fmt"
5-
"github.com/rs/zerolog/log"
6-
"golang.org/x/sys/unix"
4+
"context"
75
"os"
86
"os/signal"
97
"sort"
108
"sync"
119
"syscall"
10+
11+
"github.com/rs/zerolog/log"
1212
)
1313

1414
var (
@@ -40,44 +40,64 @@ func AddSignalHandler(signalHandler SignalHandler) {
4040
handlers.Store(signalHandler.Name, signalHandler)
4141
}
4242

43+
func ClearSignalHandlers() {
44+
handlers.Range(func(key, value interface{}) bool {
45+
handlers.Delete(key)
46+
return true
47+
})
48+
}
49+
4350
// RemoveSignalHandlersByName removes signal handlers by name. This is mainly used for testing.
4451
func RemoveSignalHandlersByName(names ...string) {
4552
for _, name := range names {
4653
handlers.Delete(name)
4754
}
4855
}
4956

57+
func createSignalChannel(context context.Context) {
58+
signalChannel := make(chan os.Signal, 1)
59+
Notify(signalChannel)
60+
go func(signals <-chan os.Signal) {
61+
for {
62+
select {
63+
case <-context.Done():
64+
signal.Stop(signalChannel)
65+
return
66+
case s := <-signals:
67+
handlerList := make([]SignalHandler, 0)
68+
handlers.Range(func(key, value interface{}) bool {
69+
handlerList = append(handlerList, value.(SignalHandler))
70+
return true
71+
})
72+
sort.Sort(ByOrder(handlerList))
73+
signalName := GetSignalName(s.(syscall.Signal))
74+
for _, handler := range handlerList {
75+
log.Debug().Str("signal", signalName).Str("handler", handler.Name).Int("order", handler.Order).Msg("received signal - call handler")
76+
handler.Handler(s)
77+
}
78+
}
79+
}
80+
}(signalChannel)
81+
}
82+
5083
func ActivateSignalHandlers() {
84+
ActivateSignalHandlerWithContext(context.Background())
85+
}
86+
87+
func ActivateSignalHandlerWithContext(context context.Context) {
5188
AddSignalHandler(SignalHandler{
5289
Handler: func(signal os.Signal) {
5390
switch signal {
5491
case syscall.SIGINT:
5592
os.Exit(128 + int(signal.(syscall.Signal)))
5693

5794
case syscall.SIGTERM:
58-
fmt.Printf("Terminated: %d\n", int(signal.(syscall.Signal)))
5995
os.Exit(128 + int(signal.(syscall.Signal)))
6096
}
6197
},
6298
Order: OrderTermination,
6399
Name: "Termination",
64100
})
65101

66-
signalChannel := make(chan os.Signal, 1)
67-
signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGUSR1)
68-
go func(signals <-chan os.Signal) {
69-
for s := range signals {
70-
handlerList := make([]SignalHandler, 0)
71-
handlers.Range(func(key, value interface{}) bool {
72-
handlerList = append(handlerList, value.(SignalHandler))
73-
return true
74-
})
75-
sort.Sort(ByOrder(handlerList))
76-
signalName := unix.SignalName(s.(syscall.Signal))
77-
for _, handler := range handlerList {
78-
log.Debug().Str("signal", signalName).Str("handler", handler.Name).Int("order", handler.Order).Msg("received signal - call handler")
79-
handler.Handler(s)
80-
}
81-
}
82-
}(signalChannel)
102+
createSignalChannel(context)
83103
}

extsignals/signalhandler_test.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
package extsignals
22

3+
// These tests don't consistently work in IntelliJ,
4+
// use the native go test runner instead.
5+
36
import (
4-
"github.com/rs/zerolog/log"
5-
"github.com/stretchr/testify/require"
67
"os"
78
"sync/atomic"
8-
"syscall"
99
"testing"
1010
"time"
11+
12+
"github.com/rs/zerolog/log"
13+
"github.com/stretchr/testify/require"
1114
)
1215

1316
func TestSignalHandlers(t *testing.T) {
14-
//cleanup previous test
15-
RemoveSignalHandlersByName("Termination", "Handler1", "Handler2")
16-
1717
handler1Run := atomic.Bool{}
1818
handler2Run := atomic.Bool{}
1919
handlerList := atomic.Value{}
2020

21+
ClearSignalHandlers()
22+
defer ClearSignalHandlers()
2123
ActivateSignalHandlers()
2224
RemoveSignalHandlersByName("Termination")
2325
AddSignalHandler(SignalHandler{
@@ -39,7 +41,7 @@ func TestSignalHandlers(t *testing.T) {
3941
Name: "Handler2",
4042
})
4143

42-
err := syscall.Kill(os.Getpid(), syscall.SIGUSR1)
44+
err := Kill(os.Getpid())
4345
require.NoError(t, err)
4446

4547
// Wait for the signal to be processed
@@ -51,12 +53,11 @@ func TestSignalHandlers(t *testing.T) {
5153
}
5254

5355
func TestRemoveSignalHandlersByName(t *testing.T) {
54-
//cleanup previous test
55-
RemoveSignalHandlersByName("Termination", "Handler1", "Handler2")
56-
5756
handler1Run := atomic.Bool{}
5857
handler2Run := atomic.Bool{}
5958

59+
ClearSignalHandlers()
60+
defer ClearSignalHandlers()
6061
ActivateSignalHandlers()
6162
AddSignalHandler(SignalHandler{
6263
Handler: func(signal os.Signal) {
@@ -76,7 +77,7 @@ func TestRemoveSignalHandlersByName(t *testing.T) {
7677
})
7778

7879
RemoveSignalHandlersByName("Termination", "Handler1")
79-
err := syscall.Kill(os.Getpid(), syscall.SIGUSR1)
80+
err := Kill(os.Getpid())
8081
require.NoError(t, err)
8182

8283
// Wait for the signal to be processed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//go:build !windows
2+
3+
package extsignals
4+
5+
import (
6+
"os"
7+
"os/signal"
8+
"syscall"
9+
10+
"golang.org/x/sys/unix"
11+
)
12+
13+
func Notify(c chan<- os.Signal, sig ...os.Signal) {
14+
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGUSR1)
15+
}
16+
17+
func GetSignalName(s syscall.Signal) string {
18+
return unix.SignalName(s)
19+
}
20+
21+
func Kill(pid int) (e error) {
22+
return syscall.Kill(os.Getpid(), syscall.SIGUSR1)
23+
}

0 commit comments

Comments
 (0)