Skip to content

Commit 3b38b01

Browse files
Merge pull request #68 from step-security/int
Revert changes on failure
2 parents 52f1c07 + f78b9cd commit 3b38b01

File tree

12 files changed

+253
-238
lines changed

12 files changed

+253
-238
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,12 @@ jobs:
1010
contents: read
1111
runs-on: ubuntu-latest
1212
steps:
13-
- uses: step-security/harden-runner@917f7d59f22e82a5ddcaef409923426fd7aa6327
14-
with:
15-
allowed-endpoints:
16-
beta.api.stepsecurity.io:443
17-
codecov.io:443
18-
github.com:443
19-
proxy.golang.org:443
20-
storage.googleapis.com:443
2113
- name: Checkout
2214
uses: actions/checkout@629c2de402a417ea7690ca6ce3f33229e27606a5
2315
- name: Set up Go
2416
uses: actions/setup-go@37335c7bb261b353407cff977110895fa0b4f7d8
2517
with:
2618
go-version: 1.17
2719
- name: Run coverage
28-
run: sudo go test -coverprofile=coverage.txt -covermode=atomic
20+
run: sudo CI=true go test -race -coverprofile=coverage.txt -covermode=atomic
2921
- uses: codecov/codecov-action@f32b3a3741e1053eb607407145bc9619351dc93b

agent.go

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package main
33
import (
44
"context"
55
"fmt"
6-
"io"
76
"net/http"
87
"os"
98
"time"
@@ -47,7 +46,7 @@ type IPTables interface {
4746
// TODO: move all inputs into a struct
4847
func Run(ctx context.Context, configFilePath string, hostDNSServer DNSServer,
4948
dockerDNSServer DNSServer, iptables *Firewall, nflog AgentNflogger,
50-
cmd Command, resolvdConfigPath, dockerDaemonConfigPath string, stdout io.Writer) error {
49+
cmd Command, resolvdConfigPath, dockerDaemonConfigPath, tempDir string) error {
5150

5251
// Passed to each go routine, if anyone fails, the program fails
5352
errc := make(chan error)
@@ -79,6 +78,32 @@ func Run(ctx context.Context, configFilePath string, hostDNSServer DNSServer,
7978
go startDNSServer(dnsProxy, hostDNSServer, errc)
8079
go startDNSServer(dnsProxy, dockerDNSServer, errc) // this is for the docker bridge
8180

81+
if cmd == nil {
82+
procMon := &ProcessMonitor{CorrelationId: config.CorrelationId, Repo: config.Repo, ApiClient: apiclient, WorkingDirectory: config.WorkingDirectory}
83+
go procMon.MonitorProcesses(errc)
84+
writeLog("started p monitor")
85+
}
86+
87+
dnsConfig := DnsConfig{}
88+
89+
// Change DNS config on host, causes processes to use agent's DNS proxy
90+
if err := dnsConfig.SetDNSServer(cmd, resolvdConfigPath, tempDir); err != nil {
91+
writeLog(fmt.Sprintf("Error setting DNS server %v", err))
92+
RevertChanges(iptables, nflog, cmd, resolvdConfigPath, dockerDaemonConfigPath, dnsConfig)
93+
return err
94+
}
95+
96+
writeLog("updated resolved")
97+
98+
// Change DNS for docker, causes process in containers to use agent's DNS proxy
99+
if err := dnsConfig.SetDockerDNSServer(cmd, dockerDaemonConfigPath, tempDir); err != nil {
100+
writeLog(fmt.Sprintf("Error setting DNS server for docker %v", err))
101+
RevertChanges(iptables, nflog, cmd, resolvdConfigPath, dockerDaemonConfigPath, dnsConfig)
102+
return err
103+
}
104+
105+
writeLog("set docker config")
106+
82107
if len(config.Endpoints) == 0 {
83108
netMonitor := NetworkMonitor{
84109
CorrelationId: config.CorrelationId,
@@ -93,37 +118,15 @@ func Run(ctx context.Context, configFilePath string, hostDNSServer DNSServer,
93118
writeLog("before audit rules")
94119

95120
// Add logging to firewall, including NFLOG rules
96-
if err := addAuditRules(iptables); err != nil {
121+
if err := AddAuditRules(iptables); err != nil {
97122
writeLog(fmt.Sprintf("Error adding firewall rules %v", err))
123+
RevertChanges(iptables, nflog, cmd, resolvdConfigPath, dockerDaemonConfigPath, dnsConfig)
98124
return err
99125
}
100126

101127
writeLog("added audit rules")
102128
}
103129

104-
// TODO: If something did not work, revert settings
105-
if cmd == nil {
106-
procMon := &ProcessMonitor{CorrelationId: config.CorrelationId, Repo: config.Repo, ApiClient: apiclient, WorkingDirectory: config.WorkingDirectory}
107-
go procMon.MonitorProcesses(errc)
108-
writeLog("started p monitor")
109-
}
110-
111-
// Change DNS config on host, causes processes to use agent's DNS proxy
112-
if err := setDNSServer(cmd, resolvdConfigPath); err != nil {
113-
writeLog(fmt.Sprintf("Error setting DNS server %v", err))
114-
return err
115-
}
116-
117-
writeLog("updated resolved")
118-
119-
// Change DNS for docker, causes process in containers to use agent's DNS proxy
120-
if err := setDockerDNSServer(cmd, dockerDaemonConfigPath); err != nil {
121-
writeLog(fmt.Sprintf("Error setting DNS server for docker %v", err))
122-
return err
123-
}
124-
125-
writeLog("set docker config")
126-
127130
// If allowed endpoints set, resolve them, and add to firewall
128131
if len(config.Endpoints) > 0 {
129132
var ipAddressEndpoints []ipAddressEndpoint
@@ -145,6 +148,7 @@ func Run(ctx context.Context, configFilePath string, hostDNSServer DNSServer,
145148
ipAddress, err := dnsProxy.getIPByDomain(endpoint.domainName)
146149
if err != nil {
147150
writeLog(fmt.Sprintf("Error resolving allowed domain %v", err))
151+
RevertChanges(iptables, nflog, cmd, resolvdConfigPath, dockerDaemonConfigPath, dnsConfig)
148152
return err
149153
}
150154

@@ -154,15 +158,11 @@ func Run(ctx context.Context, configFilePath string, hostDNSServer DNSServer,
154158

155159
if err := addBlockRulesForGitHubHostedRunner(ipAddressEndpoints); err != nil {
156160
writeLog(fmt.Sprintf("Error setting firewall for allowed domains %v", err))
161+
RevertChanges(iptables, nflog, cmd, resolvdConfigPath, dockerDaemonConfigPath, dnsConfig)
157162
return err
158163
}
159164
}
160165

161-
// Ask API to monitor the run
162-
go apiclient.monitorRun(config.Repo, config.RunId)
163-
164-
writeLog("called monitor run")
165-
166166
writeLog("done")
167167

168168
// Write the status file
@@ -173,13 +173,31 @@ func Run(ctx context.Context, configFilePath string, hostDNSServer DNSServer,
173173
case <-ctx.Done():
174174
return nil
175175
case e := <-errc:
176-
writeLog(e.Error())
176+
writeLog(fmt.Sprintf("Error in Initialization %v", e))
177+
RevertChanges(iptables, nflog, cmd, resolvdConfigPath, dockerDaemonConfigPath, dnsConfig)
177178
return e
178179

179180
}
180181
}
181182
}
182183

184+
func RevertChanges(iptables *Firewall, nflog AgentNflogger,
185+
cmd Command, resolvdConfigPath, dockerDaemonConfigPath string, dnsConfig DnsConfig) {
186+
err := RevertFirewallChanges(iptables)
187+
if err != nil {
188+
writeLog(fmt.Sprintf("Error in RevertChanges %v", err))
189+
}
190+
err = dnsConfig.RevertDNSServer(cmd, resolvdConfigPath)
191+
if err != nil {
192+
writeLog(fmt.Sprintf("Error in reverting DNS server changes %v", err))
193+
}
194+
err = dnsConfig.RevertDockerDNSServer(cmd, dockerDaemonConfigPath)
195+
if err != nil {
196+
writeLog(fmt.Sprintf("Error in reverting docker DNS server changes %v", err))
197+
}
198+
writeLog("Reverted changes")
199+
}
200+
183201
func writeLog(message string) {
184202
f, _ := os.OpenFile("/home/agent/agent.log",
185203
os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)

agent_test.go

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package main
33
import (
44
"context"
55
"fmt"
6+
"os"
7+
"path"
68
"testing"
79
"time"
810

911
"github.com/florianl/go-nflog/v2"
10-
"github.com/jarcoal/httpmock"
1112
)
1213

1314
type mockDNSServer struct {
@@ -82,52 +83,8 @@ func (m *MockCommandWithError) Run() error {
8283
return fmt.Errorf("failed to run command")
8384
}
8485

85-
func TestRun(t *testing.T) {
86-
87-
ctx := context.Background()
88-
ctx, cancel := context.WithCancel(ctx)
89-
time.AfterFunc(2*time.Second, cancel)
90-
91-
httpmock.Activate()
92-
defer httpmock.DeactivateAndReset()
93-
94-
httpmock.RegisterResponder("POST", fmt.Sprintf("%s/owner/repo/actions/runs/1287185438/monitor", agentApiBaseUrl),
95-
httpmock.NewStringResponder(200, ""))
96-
97-
err := Run(ctx, "./testfiles/agent.json",
98-
&mockDNSServer{}, &mockDNSServer{}, &Firewall{&MockIPTables{}},
99-
&MockAgentNflogger{}, &MockCommand{}, createTempFileWithContents(""), createTempFileWithContents("{}"), nil)
100-
101-
if err != nil {
102-
fmt.Printf("err: %v\n", err)
103-
t.Fail()
104-
}
105-
}
106-
107-
func TestRunWithDNSFailure(t *testing.T) {
108-
109-
ctx := context.Background()
110-
ctx, cancel := context.WithCancel(ctx)
111-
time.AfterFunc(5*time.Second, cancel) // this should not be used, it should error out earlier
112-
113-
httpmock.Activate()
114-
defer httpmock.DeactivateAndReset()
115-
116-
httpmock.RegisterResponder("POST", fmt.Sprintf("%s/owner/repo/actions/runs/1287185438/monitor", agentApiBaseUrl),
117-
httpmock.NewStringResponder(200, ""))
118-
119-
err := Run(ctx, "./testfiles/agent.json",
120-
&mockDNSServer{}, &mockDNSServerWithError{}, &Firewall{&MockIPTables{}},
121-
&MockAgentNflogger{}, &MockCommand{}, createTempFileWithContents(""), createTempFileWithContents("{}"), nil)
122-
123-
// if 2 seconds pass
124-
if err == nil {
125-
t.Fail()
126-
}
127-
128-
}
129-
130-
func TestRunWithCMDFailure(t *testing.T) {
86+
/*
87+
func TestRunWithNflogError(t *testing.T) {
13188
13289
ctx := context.Background()
13390
ctx, cancel := context.WithCancel(ctx)
@@ -141,36 +98,86 @@ func TestRunWithCMDFailure(t *testing.T) {
14198
14299
err := Run(ctx, "./testfiles/agent.json",
143100
&mockDNSServer{}, &mockDNSServer{}, &Firewall{&MockIPTables{}},
144-
&MockAgentNflogger{}, &MockCommandWithError{}, createTempFileWithContents(""), createTempFileWithContents("{}"), nil)
101+
&MockAgentNfloggerWithErr{}, &MockCommand{}, createTempFileWithContents(""), createTempFileWithContents("{}"), nil)
145102
146103
// if 2 seconds pass
147104
if err == nil {
148105
t.Fail()
149106
}
150107
151108
}
109+
*/
152110

153-
/*
154-
func TestRunWithNflogError(t *testing.T) {
111+
func deleteTempFile(path string) {
112+
os.Remove(path)
113+
}
155114

115+
func getContext(seconds int) context.Context {
156116
ctx := context.Background()
157117
ctx, cancel := context.WithCancel(ctx)
158-
time.AfterFunc(5*time.Second, cancel) // this should not be used, it should error out earlier
159-
160-
httpmock.Activate()
161-
defer httpmock.DeactivateAndReset()
162-
163-
httpmock.RegisterResponder("POST", fmt.Sprintf("%s/owner/repo/actions/runs/1287185438/monitor", agentApiBaseUrl),
164-
httpmock.NewStringResponder(200, ""))
118+
time.AfterFunc(2*time.Second, cancel)
165119

166-
err := Run(ctx, "./testfiles/agent.json",
167-
&mockDNSServer{}, &mockDNSServer{}, &Firewall{&MockIPTables{}},
168-
&MockAgentNfloggerWithErr{}, &MockCommand{}, createTempFileWithContents(""), createTempFileWithContents("{}"), nil)
120+
return ctx
121+
}
169122

170-
// if 2 seconds pass
171-
if err == nil {
172-
t.Fail()
123+
func TestRun(t *testing.T) {
124+
type args struct {
125+
ctxCancelDuration int
126+
configFilePath string
127+
hostDNSServer DNSServer
128+
dockerDNSServer DNSServer
129+
iptables *Firewall
130+
nflog AgentNflogger
131+
cmd Command
132+
resolvdConfigPath string
133+
dockerDaemonConfigPath string
134+
ciTestOnly bool
173135
}
174136

137+
tests := []struct {
138+
name string
139+
args args
140+
wantErr bool
141+
}{
142+
{name: "success", args: args{ctxCancelDuration: 2, configFilePath: "./testfiles/agent.json", hostDNSServer: &mockDNSServer{}, dockerDNSServer: &mockDNSServer{},
143+
iptables: &Firewall{&MockIPTables{}}, nflog: &MockAgentNflogger{}, cmd: &MockCommand{}, resolvdConfigPath: createTempFileWithContents(""),
144+
dockerDaemonConfigPath: createTempFileWithContents("{}")}, wantErr: false},
145+
{name: "success monitor process", args: args{ctxCancelDuration: 2, configFilePath: "./testfiles/agent.json", hostDNSServer: &mockDNSServer{}, dockerDNSServer: &mockDNSServer{},
146+
iptables: &Firewall{&MockIPTables{}}, nflog: &MockAgentNflogger{}, cmd: nil, resolvdConfigPath: createTempFileWithContents(""),
147+
dockerDaemonConfigPath: createTempFileWithContents("{}"), ciTestOnly: true}, wantErr: false},
148+
{name: "success allowed endpoints", args: args{ctxCancelDuration: 2, configFilePath: "./testfiles/agent-allowed-endpoints.json",
149+
hostDNSServer: &mockDNSServer{}, dockerDNSServer: &mockDNSServer{},
150+
iptables: nil, nflog: &MockAgentNflogger{}, cmd: &MockCommand{}, resolvdConfigPath: createTempFileWithContents(""),
151+
dockerDaemonConfigPath: createTempFileWithContents("{}"), ciTestOnly: true}, wantErr: false},
152+
{name: "dns failure", args: args{ctxCancelDuration: 5, configFilePath: "./testfiles/agent.json", hostDNSServer: &mockDNSServer{}, dockerDNSServer: &mockDNSServerWithError{},
153+
iptables: &Firewall{&MockIPTables{}}, nflog: &MockAgentNflogger{}, cmd: &MockCommand{}, resolvdConfigPath: createTempFileWithContents(""),
154+
dockerDaemonConfigPath: createTempFileWithContents("{}")}, wantErr: true},
155+
{name: "cmd failure", args: args{ctxCancelDuration: 5, configFilePath: "./testfiles/agent.json", hostDNSServer: &mockDNSServer{}, dockerDNSServer: &mockDNSServer{},
156+
iptables: &Firewall{&MockIPTables{}}, nflog: &MockAgentNflogger{}, cmd: &MockCommandWithError{}, resolvdConfigPath: createTempFileWithContents(""),
157+
dockerDaemonConfigPath: createTempFileWithContents("{}")}, wantErr: true},
158+
{name: "nflog failure", args: args{ctxCancelDuration: 5, configFilePath: "./testfiles/agent.json", hostDNSServer: &mockDNSServer{}, dockerDNSServer: &mockDNSServer{},
159+
iptables: &Firewall{&MockIPTables{}}, nflog: &MockAgentNfloggerWithErr{}, cmd: &MockCommand{}, resolvdConfigPath: createTempFileWithContents(""),
160+
dockerDaemonConfigPath: createTempFileWithContents("{}")}, wantErr: true},
161+
}
162+
_, ciTest := os.LookupEnv("CI")
163+
fmt.Printf("ci-test: %t\n", ciTest)
164+
for _, tt := range tests {
165+
if !tt.args.ciTestOnly || ciTest {
166+
t.Run(tt.name, func(t *testing.T) {
167+
tempDir := os.TempDir()
168+
if err := Run(getContext(tt.args.ctxCancelDuration), tt.args.configFilePath, tt.args.hostDNSServer, tt.args.dockerDNSServer,
169+
tt.args.iptables, tt.args.nflog, tt.args.cmd, tt.args.resolvdConfigPath, tt.args.dockerDaemonConfigPath, tempDir); (err != nil) != tt.wantErr {
170+
t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr)
171+
}
172+
173+
deleteTempFile(path.Join(tempDir, "resolved.conf"))
174+
deleteTempFile(path.Join(tempDir, "daemon.json"))
175+
176+
if tt.args.ciTestOnly {
177+
fmt.Printf("Reverting firewall changes\n")
178+
RevertFirewallChanges(nil)
179+
}
180+
})
181+
}
182+
}
175183
}
176-
*/

apiclient.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,6 @@ type ApiClient struct {
4141

4242
const agentApiBaseUrl = "https://apiurl/v1"
4343

44-
func (apiclient *ApiClient) monitorRun(repo, runid string) error {
45-
url := fmt.Sprintf("%s/github/%s/actions/runs/%s/monitor", apiclient.APIURL, repo, runid)
46-
47-
return apiclient.sendApiRequest("POST", url, nil)
48-
}
49-
5044
func (apiclient *ApiClient) sendDNSRecord(correlationId, repo, domainName, ipAddress string) error {
5145

5246
dnsRecord := &DNSRecord{}

0 commit comments

Comments
 (0)