Skip to content

Commit 0fde566

Browse files
gkeesh7gauravkuber
andauthored
Adding tests to cmd.go and refactoring the heartbeat method (AI generated) (#508)
* Added unit tests and refactored cmd.go for kraken agent * Removed PR template * Refactor heartbeat ticker for testing * Split option tests * Inject heartbeat ticker from Run * Calling stopheartbeat immediately after hearbeat with a defer --------- Co-authored-by: gauravk <gauravk@uber.com>
1 parent a63bb66 commit 0fde566

File tree

2 files changed

+210
-10
lines changed

2 files changed

+210
-10
lines changed

agent/cmd/cmd.go

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"flag"
1818
"fmt"
1919
"net/http"
20+
"sync"
2021
"time"
2122

2223
"github.com/uber/kraken/agent/agentserver"
@@ -231,32 +232,73 @@ func Run(flags *Flags, opts ...Option) {
231232
config.AgentServer, stats, cads, sched, tagClient, announceClient, containerRuntimeFactory)
232233
addr := fmt.Sprintf(":%d", flags.AgentServerPort)
233234
log.Infof("Starting agent server on %s", addr)
235+
heartbeatTicker := &timeTicker{inner: time.NewTicker(10 * time.Second)}
236+
heartbeatDone := make(chan struct{})
237+
var heartbeatStop sync.Once
238+
stopHeartbeat := func() {
239+
heartbeatStop.Do(func() {
240+
close(heartbeatDone)
241+
heartbeatTicker.Stop()
242+
})
243+
}
244+
245+
go heartbeat(stats, heartbeatTicker, heartbeatDone)
246+
defer stopHeartbeat()
234247
go func() {
235-
log.Fatal(http.ListenAndServe(addr, agentServer.Handler()))
248+
if err := http.ListenAndServe(addr, agentServer.Handler()); err != nil {
249+
stopHeartbeat()
250+
log.Fatal(err)
251+
}
236252
}()
237253

238254
log.Info("Starting registry...")
239255
go func() {
240-
log.Fatal(registry.ListenAndServe())
256+
if err := registry.ListenAndServe(); err != nil {
257+
stopHeartbeat()
258+
log.Fatal(err)
259+
}
241260
}()
242261

243-
go heartbeat(stats)
244-
245-
log.Fatal(nginx.Run(config.Nginx, map[string]interface{}{
262+
if err := nginx.Run(config.Nginx, map[string]interface{}{
246263
"allowed_cidrs": config.AllowedCidrs,
247264
"port": flags.AgentRegistryPort,
248265
"registry_server": nginx.GetServer(
249266
config.Registry.Docker.HTTP.Net, config.Registry.Docker.HTTP.Addr),
250267
"agent_server": fmt.Sprintf("127.0.0.1:%d", flags.AgentServerPort),
251268
"registry_backup": config.RegistryBackup},
252-
nginx.WithTLS(config.TLS)))
269+
nginx.WithTLS(config.TLS)); err != nil {
270+
stopHeartbeat()
271+
log.Fatal(err)
272+
}
273+
}
274+
275+
// heartbeatTicker provides the minimal ticker contract required by heartbeat.
276+
type heartbeatTicker interface {
277+
Chan() <-chan time.Time
278+
Stop()
279+
}
280+
281+
type timeTicker struct {
282+
inner *time.Ticker
283+
}
284+
285+
func (t *timeTicker) Chan() <-chan time.Time {
286+
return t.inner.C
287+
}
288+
289+
func (t *timeTicker) Stop() {
290+
t.inner.Stop()
253291
}
254292

255293
// heartbeat periodically emits a counter metric which allows us to monitor the
256-
// number of active agents.
257-
func heartbeat(stats tally.Scope) {
294+
// number of active agents, using the provided ticker and done channel to control its lifecycle.
295+
func heartbeat(stats tally.Scope, ticker heartbeatTicker, done <-chan struct{}) {
258296
for {
259-
stats.Counter("heartbeat").Inc(1)
260-
time.Sleep(10 * time.Second)
297+
select {
298+
case <-ticker.Chan():
299+
stats.Counter("heartbeat").Inc(1)
300+
case <-done:
301+
return
302+
}
261303
}
262304
}

agent/cmd/cmd_test.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package cmd
2+
3+
import (
4+
"flag"
5+
"os"
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/andres-erbsen/clock"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"github.com/uber-go/tally"
14+
"go.uber.org/zap"
15+
)
16+
17+
func TestParseFlags(t *testing.T) {
18+
// Save original args and flagset
19+
oldArgs := os.Args
20+
oldCommandLine := flag.CommandLine
21+
defer func() {
22+
os.Args = oldArgs
23+
flag.CommandLine = oldCommandLine
24+
}()
25+
26+
// Reset flags
27+
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
28+
29+
// Set up test args
30+
os.Args = []string{
31+
"cmd",
32+
"-peer-ip=1.2.3.4",
33+
"-peer-port=1000",
34+
"-agent-server-port=2000",
35+
"-agent-registry-port=3000",
36+
"-config=config.yaml",
37+
"-zone=test-zone",
38+
"-cluster=test-cluster",
39+
"-secrets=secrets.yaml",
40+
}
41+
42+
flags := ParseFlags()
43+
44+
assert.Equal(t, "1.2.3.4", flags.PeerIP)
45+
assert.Equal(t, 1000, flags.PeerPort)
46+
assert.Equal(t, 2000, flags.AgentServerPort)
47+
assert.Equal(t, 3000, flags.AgentRegistryPort)
48+
assert.Equal(t, "config.yaml", flags.ConfigFile)
49+
assert.Equal(t, "test-zone", flags.Zone)
50+
assert.Equal(t, "test-cluster", flags.KrakenCluster)
51+
assert.Equal(t, "secrets.yaml", flags.SecretsFile)
52+
}
53+
54+
func TestWithConfigOption(t *testing.T) {
55+
var o options
56+
c := Config{RegistryBackup: "test"}
57+
WithConfig(c)(&o)
58+
assert.Equal(t, "test", o.config.RegistryBackup)
59+
}
60+
61+
func TestWithMetricsOption(t *testing.T) {
62+
var o options
63+
s := tally.NoopScope
64+
WithMetrics(s)(&o)
65+
assert.Equal(t, s, o.metrics)
66+
}
67+
68+
func TestWithLoggerOption(t *testing.T) {
69+
var o options
70+
l := zap.NewNop()
71+
WithLogger(l)(&o)
72+
assert.Equal(t, l, o.logger)
73+
}
74+
75+
func TestWithEffectOption(t *testing.T) {
76+
var o options
77+
called := false
78+
f := func() { called = true }
79+
WithEffect(f)(&o)
80+
assert.NotNil(t, o.effect)
81+
o.effect()
82+
assert.True(t, called)
83+
}
84+
85+
func TestRunValidation(t *testing.T) {
86+
tests := []struct {
87+
desc string
88+
flags Flags
89+
panic string
90+
}{
91+
{
92+
desc: "missing peer port",
93+
flags: Flags{AgentServerPort: 1, AgentRegistryPort: 1},
94+
panic: "must specify non-zero peer port",
95+
},
96+
{
97+
desc: "missing agent server port",
98+
flags: Flags{PeerPort: 1, AgentRegistryPort: 1},
99+
panic: "must specify non-zero agent server port",
100+
},
101+
{
102+
desc: "missing agent registry port",
103+
flags: Flags{PeerPort: 1, AgentServerPort: 1},
104+
panic: "must specify non-zero agent registry port",
105+
},
106+
}
107+
108+
for _, test := range tests {
109+
t.Run(test.desc, func(t *testing.T) {
110+
assert.PanicsWithValue(t, test.panic, func() {
111+
Run(&test.flags)
112+
})
113+
})
114+
}
115+
}
116+
func TestHeartbeatWithTicker(t *testing.T) {
117+
scope := tally.NewTestScope("", nil)
118+
mockClock := clock.NewMock()
119+
mockTicker := mockClock.Ticker(100 * time.Millisecond)
120+
done := make(chan struct{})
121+
122+
var wg sync.WaitGroup
123+
wg.Add(1)
124+
go func() {
125+
defer wg.Done()
126+
heartbeat(scope, clockTicker{ticker: mockTicker}, done)
127+
}()
128+
129+
for i := 0; i < 3; i++ {
130+
mockClock.Add(100 * time.Millisecond)
131+
}
132+
133+
require.Eventually(t, func() bool {
134+
snapshot := scope.Snapshot()
135+
for _, counter := range snapshot.Counters() {
136+
if counter.Name() == "heartbeat" && counter.Value() >= 3 {
137+
return true
138+
}
139+
}
140+
return false
141+
}, time.Second, 10*time.Millisecond)
142+
143+
close(done)
144+
mockTicker.Stop()
145+
wg.Wait()
146+
}
147+
148+
type clockTicker struct {
149+
ticker *clock.Ticker
150+
}
151+
152+
func (t clockTicker) Chan() <-chan time.Time {
153+
return t.ticker.C
154+
}
155+
156+
func (t clockTicker) Stop() {
157+
t.ticker.Stop()
158+
}

0 commit comments

Comments
 (0)