Skip to content

Commit 3054a5f

Browse files
committed
Handle LoadVintage error when contexts are not in use
1 parent 7ae858b commit 3054a5f

File tree

2 files changed

+180
-13
lines changed

2 files changed

+180
-13
lines changed

step/context.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,17 @@ func (cs *CtxState) initCurrent() error {
168168

169169
func (cs *CtxState) load() error {
170170
if cs.Enabled() && cs.GetCurrent() != nil {
171-
return cs.GetCurrent().Load()
171+
if err := cs.GetCurrent().Load(); err != nil {
172+
return fmt.Errorf("failed loading current context configuration: %w", err)
173+
}
174+
175+
return nil
172176
}
173-
cs.LoadVintage("")
177+
178+
if err := cs.LoadVintage(""); err != nil {
179+
return fmt.Errorf("failed loading context configuration: %w", err)
180+
}
181+
174182
return nil
175183
}
176184

step/context_test.go

Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package step
22

33
import (
4-
"reflect"
4+
"encoding/json"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"path/filepath"
59
"testing"
610

7-
"github.com/pkg/errors"
8-
"github.com/smallstep/assert"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
913
)
1014

1115
func TestContextValidate(t *testing.T) {
@@ -21,13 +25,13 @@ func TestContextValidate(t *testing.T) {
2125
}
2226
for _, tc := range tests {
2327
t.Run(tc.name, func(t *testing.T) {
24-
if err := tc.context.Validate(); err != nil {
25-
if assert.NotNil(t, tc.err) {
26-
assert.HasPrefix(t, err.Error(), tc.err.Error())
27-
}
28-
} else {
29-
assert.Nil(t, tc.err)
28+
err := tc.context.Validate()
29+
if tc.err != nil {
30+
assert.Contains(t, err.Error(), tc.err.Error())
31+
return
3032
}
33+
34+
assert.NoError(t, err)
3135
})
3236
}
3337
}
@@ -65,8 +69,163 @@ func TestCtxState_ListAlphabetical(t *testing.T) {
6569
cs := &CtxState{
6670
contexts: tt.fields.contexts,
6771
}
68-
if got := cs.ListAlphabetical(); !reflect.DeepEqual(got, tt.want) {
69-
t.Errorf("CtxState.ListAlphabetical() = %v, want %v", got, tt.want)
72+
73+
got := cs.ListAlphabetical()
74+
assert.Equal(t, tt.want, got)
75+
})
76+
}
77+
}
78+
79+
type config struct {
80+
CA string `json:"ca-url"`
81+
Fingerprint string `json:"fingerprint"`
82+
Root string `json:"root"`
83+
Redirect string `json:"redirect-url"`
84+
}
85+
86+
func TestCtxState_load(t *testing.T) {
87+
contextsDir := t.TempDir()
88+
os.Setenv(HomeEnv, contextsDir)
89+
90+
ctx1ConfigDirectory := filepath.Join(contextsDir, ".step", "authorities", "ctx1", "config")
91+
err := os.MkdirAll(ctx1ConfigDirectory, 0o777)
92+
require.NoError(t, err)
93+
b, err := json.Marshal(config{
94+
CA: "https://127.0.0.1:8443",
95+
Fingerprint: "ctx1-fingerprint",
96+
Root: "/path/to/root.crt",
97+
Redirect: "redirect",
98+
})
99+
require.NoError(t, err)
100+
err = os.WriteFile(filepath.Join(ctx1ConfigDirectory, "defaults.json"), b, 0o644)
101+
require.NoError(t, err)
102+
103+
ctx2ConfigDirectory := filepath.Join(contextsDir, ".step", "authorities", "ctx2", "config")
104+
err = os.MkdirAll(ctx2ConfigDirectory, 0o777)
105+
require.NoError(t, err)
106+
err = os.WriteFile(filepath.Join(ctx2ConfigDirectory, "defaults.json"), []byte{0x42}, 0o644)
107+
require.NoError(t, err)
108+
109+
vintageConfigDirectory := filepath.Join(contextsDir, ".step", "config")
110+
os.MkdirAll(vintageConfigDirectory, 0o777)
111+
b, err = json.Marshal(config{
112+
CA: "https://127.0.0.1:8443",
113+
Fingerprint: "vintage-fingerprint",
114+
Root: "/path/to/root.crt",
115+
Redirect: "redirect",
116+
})
117+
require.NoError(t, err)
118+
err = os.WriteFile(filepath.Join(vintageConfigDirectory, "defaults.json"), b, 0o644)
119+
require.NoError(t, err)
120+
121+
ctx1 := &Context{
122+
Authority: "ctx1",
123+
Name: "ctx1",
124+
}
125+
ctx2 := &Context{
126+
Authority: "ctx2",
127+
Name: "ctx2",
128+
}
129+
130+
contexts := ContextMap{
131+
"ctx1": ctx1,
132+
"ctx2": ctx2,
133+
}
134+
135+
failVintageStepPath := filepath.Join(t.TempDir(), ".step")
136+
fmt.Println("fail vintage step path", failVintageStepPath)
137+
failVintageConfigDirectory := filepath.Join(failVintageStepPath, "config")
138+
err = os.MkdirAll(failVintageConfigDirectory, 0o777)
139+
require.NoError(t, err)
140+
err = os.WriteFile(filepath.Join(failVintageConfigDirectory, "defaults.json"), []byte{0x42}, 0o644)
141+
require.NoError(t, err)
142+
143+
type fields struct {
144+
current *Context
145+
contexts ContextMap
146+
}
147+
tests := []struct {
148+
name string
149+
fields fields
150+
stepPath string
151+
want map[string]any
152+
errPrefix string
153+
}{
154+
{
155+
name: "ok/ctx1",
156+
fields: fields{
157+
current: ctx1,
158+
contexts: contexts,
159+
},
160+
want: map[string]any{
161+
"ca-url": "https://127.0.0.1:8443",
162+
"fingerprint": "ctx1-fingerprint",
163+
"redirect-url": "redirect",
164+
"root": "/path/to/root.crt",
165+
},
166+
},
167+
{
168+
name: "ok/vintage",
169+
fields: fields{
170+
contexts: contexts,
171+
},
172+
want: map[string]any{
173+
"ca-url": "https://127.0.0.1:8443",
174+
"fingerprint": "vintage-fingerprint",
175+
"redirect-url": "redirect",
176+
"root": "/path/to/root.crt",
177+
},
178+
},
179+
{
180+
name: "fail/ctx2",
181+
fields: fields{
182+
current: ctx2,
183+
contexts: contexts,
184+
},
185+
errPrefix: "failed loading current context configuration:",
186+
},
187+
{
188+
name: "fail/vintage",
189+
fields: fields{
190+
contexts: contexts,
191+
},
192+
stepPath: failVintageStepPath,
193+
errPrefix: "failed loading context configuration:",
194+
},
195+
}
196+
for _, tt := range tests {
197+
t.Run(tt.name, func(t *testing.T) {
198+
if tt.stepPath != "" {
199+
// alter the state in a non-standard way, because it's
200+
// cached once.
201+
currentStepPath := cache.stepBasePath
202+
cache.stepBasePath = tt.stepPath
203+
defer func() {
204+
cache.stepBasePath = currentStepPath
205+
}()
206+
}
207+
208+
cs := &CtxState{
209+
current: tt.fields.current,
210+
contexts: tt.fields.contexts,
211+
}
212+
213+
err := cs.load()
214+
if tt.errPrefix != "" {
215+
if assert.Error(t, err) {
216+
assert.Contains(t, err.Error(), tt.errPrefix)
217+
}
218+
return
219+
}
220+
221+
assert.NoError(t, err)
222+
223+
if current := cs.GetCurrent(); current != nil {
224+
assert.Empty(t, cs.config)
225+
assert.Equal(t, tt.want, current.config)
226+
} else {
227+
assert.Nil(t, cs.current)
228+
assert.Equal(t, tt.want, cs.config)
70229
}
71230
})
72231
}

0 commit comments

Comments
 (0)