Skip to content

Commit b48d47f

Browse files
committed
pkg/settings: fix SettingMap.GetOrDefault
1 parent 9585274 commit b48d47f

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

pkg/settings/cresettings/settings.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ func reinit() {
4040
if err != nil {
4141
log.Fatalf("failed to initialize settings: %v", err)
4242
}
43+
} else {
44+
DefaultGetter = nil
4345
}
4446
}
4547

pkg/settings/cresettings/settings_test.go

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ func TestDefaultGetter_SettingMap(t *testing.T) {
199199
}
200200
}`)
201201
reinit() // set default vars
202+
203+
// ensure merged values; defaults must remain
204+
require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"])
205+
// confirm
202206
got, err = limit.GetOrDefault(ctx, DefaultGetter)
203207
require.NoError(t, err)
204208
require.False(t, got)
@@ -233,13 +237,62 @@ func TestDefaultGetter_SettingMap(t *testing.T) {
233237
require.True(t, got)
234238
}
235239

236-
func TestChainAllows(t *testing.T) {
237-
gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t)}, Default.PerWorkflow.ChainAllowed)
240+
func TestDefaultEnvVars(t *testing.T) {
241+
// confirm defaults
242+
require.Equal(t, "", Default.PerWorkflow.ChainAllowed.Values["1234"])
243+
require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"])
244+
245+
t.Cleanup(reinit) // restore after
246+
247+
// update defaults
248+
t.Setenv(envNameSettingsDefault, `{
249+
"PerWorkflow": {
250+
"ChainAllowed": {
251+
"Values": {
252+
"1234": "true"
253+
}
254+
}
255+
}
256+
}`)
257+
reinit() // set default vars
258+
259+
// confirm through Default
260+
require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["1234"])
261+
// without affecting others (they must merge)
262+
require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"])
263+
264+
// confirm through DefaultGetter
265+
gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: DefaultGetter}, Default.PerWorkflow.ChainAllowed)
238266
require.NoError(t, err)
239267

240-
ctx := contexts.WithCRE(t.Context(), contexts.CRE{Owner: "owner-id", Workflow: "foo"})
268+
ctx := contexts.WithCRE(t.Context(), contexts.CRE{Org: "foo", Owner: "owner-id", Workflow: "foo"})
269+
// defaults and global override allowed
270+
assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246)))
271+
assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802)))
272+
assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234)))
273+
274+
// update overrides
275+
t.Setenv(envNameSettingsDefault, "{}")
276+
t.Setenv(envNameSettings, `{
277+
"global": {
278+
"PerWorkflow": {
279+
"ChainAllowed": {
280+
"Values": {
281+
"1234": "true"
282+
}
283+
}
284+
}
285+
}
286+
}`)
287+
288+
reinit() // set default vars
289+
290+
// confirm through DefaultGetter
291+
gl, err = limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: DefaultGetter}, Default.PerWorkflow.ChainAllowed)
292+
require.NoError(t, err)
241293

294+
// defaults and global override allowed
242295
assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246)))
243296
assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802)))
244-
assert.ErrorIs(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234)), limits.ErrorNotAllowed{})
297+
assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234)))
245298
}

pkg/settings/map.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,20 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er
4545
if err != nil {
4646
return s.Default.DefaultValue, fmt.Errorf("failed to get value from context: %w", err)
4747
}
48-
if g == nil {
48+
//TODO confirm via test
49+
valueOrDefault := func() (T, error) {
4950
if str, ok := s.Values[strconv.FormatUint(k, 10)]; ok {
5051
value, err = s.Default.Parse(str)
5152
if err != nil {
5253
return s.Default.DefaultValue, err
5354
}
54-
return
55+
return value, nil
5556
}
5657
return s.Default.DefaultValue, nil
5758
}
59+
if g == nil {
60+
return valueOrDefault()
61+
}
5862

5963
valueKey := s.Default.Key + ".Values." + strconv.FormatUint(k, 10)
6064
defaultKey := s.Default.Key + ".Default"
@@ -66,20 +70,20 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er
6670
} else if str != "" {
6771
value, err = s.Default.Parse(str)
6872
if err != nil {
69-
return s.Default.DefaultValue, err
73+
return valueOrDefault()
7074
}
7175
return
7276
}
7377

7478
// Default override
7579
str, err = g.GetScoped(ctx, s.Default.Scope, defaultKey)
7680
if err != nil || str == "" {
77-
return s.Default.DefaultValue, err
81+
return valueOrDefault()
7882
}
7983

8084
value, err = s.Default.Parse(str)
8185
if err != nil {
82-
return s.Default.DefaultValue, err
86+
return valueOrDefault()
8387
}
8488
return
8589
}

0 commit comments

Comments
 (0)