Skip to content

Commit 8647e75

Browse files
committed
shared authenticator lookups
1 parent 009c731 commit 8647e75

File tree

3 files changed

+172
-11
lines changed

3 files changed

+172
-11
lines changed

staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ go_test(
1818
"//staging/src/k8s.io/apimachinery/pkg/util/uuid:go_default_library",
1919
"//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
2020
"//staging/src/k8s.io/apiserver/pkg/authentication/user:go_default_library",
21+
"//vendor/github.com/google/go-cmp/cmp:go_default_library",
2122
"//vendor/github.com/google/uuid:go_default_library",
2223
],
2324
)
@@ -32,9 +33,12 @@ go_library(
3233
importmap = "k8s.io/kubernetes/vendor/k8s.io/apiserver/pkg/authentication/token/cache",
3334
importpath = "k8s.io/apiserver/pkg/authentication/token/cache",
3435
deps = [
36+
"//staging/src/k8s.io/apimachinery/pkg/api/errors:go_default_library",
3537
"//staging/src/k8s.io/apimachinery/pkg/util/cache:go_default_library",
3638
"//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library",
3739
"//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
40+
"//vendor/golang.org/x/sync/singleflight:go_default_library",
41+
"//vendor/k8s.io/klog:go_default_library",
3842
],
3943
)
4044

staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,26 @@ import (
2222
"crypto/rand"
2323
"crypto/sha256"
2424
"encoding/binary"
25+
"errors"
2526
"hash"
2627
"io"
28+
"runtime"
2729
"sync"
2830
"time"
2931
"unsafe"
3032

33+
"golang.org/x/sync/singleflight"
34+
35+
apierrors "k8s.io/apimachinery/pkg/api/errors"
3136
utilclock "k8s.io/apimachinery/pkg/util/clock"
3237
"k8s.io/apiserver/pkg/authentication/authenticator"
38+
"k8s.io/klog"
3339
)
3440

41+
var errAuthnCrash = apierrors.NewInternalError(errors.New("authentication failed unexpectedly"))
42+
43+
const sharedLookupTimeout = 30 * time.Second
44+
3545
// cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method
3646
type cacheRecord struct {
3747
resp *authenticator.Response
@@ -47,6 +57,7 @@ type cachedTokenAuthenticator struct {
4757
failureTTL time.Duration
4858

4959
cache cache
60+
group singleflight.Group
5061

5162
// hashPool is a per authenticator pool of hash.Hash (to avoid allocations from building the Hash)
5263
// HMAC with SHA-256 and a random key is used to prevent precomputation and length extension attacks
@@ -98,26 +109,71 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL,
98109

99110
// AuthenticateToken implements authenticator.Token
100111
func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
101-
auds, _ := authenticator.AudiencesFrom(ctx)
112+
auds, audsOk := authenticator.AudiencesFrom(ctx)
102113

103114
key := keyFunc(a.hashPool, auds, token)
104115
if record, ok := a.cache.get(key); ok {
105116
return record.resp, record.ok, record.err
106117
}
107118

108-
resp, ok, err := a.authenticator.AuthenticateToken(ctx, token)
109-
if !a.cacheErrs && err != nil {
110-
return resp, ok, err
119+
type lookup struct {
120+
resp *authenticator.Response
121+
ok bool
111122
}
112123

113-
switch {
114-
case ok && a.successTTL > 0:
115-
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
116-
case !ok && a.failureTTL > 0:
117-
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
124+
c := a.group.DoChan(key, func() (val interface{}, err error) {
125+
// We're leaving the request handling stack so we need to handle crashes
126+
// ourselves. Log a stack trace and return a 500 if something panics.
127+
defer func() {
128+
if r := recover(); r != nil {
129+
err = errAuthnCrash
130+
// Same as stdlib http server code. Manually allocate stack
131+
// trace buffer size to prevent excessively large logs
132+
const size = 64 << 10
133+
buf := make([]byte, size)
134+
buf = buf[:runtime.Stack(buf, false)]
135+
klog.Errorf("%v\n%s", r, buf)
136+
}
137+
}()
138+
139+
// Check again for a cached record. We may have raced with a fetch.
140+
if record, ok := a.cache.get(key); ok {
141+
return lookup{record.resp, record.ok}, record.err
142+
}
143+
144+
// Detach the context because the lookup may be shared by multiple callers,
145+
// however propagate the audience.
146+
ctx, cancel := context.WithTimeout(context.Background(), sharedLookupTimeout)
147+
defer cancel()
148+
149+
if audsOk {
150+
ctx = authenticator.WithAudiences(ctx, auds)
151+
}
152+
153+
resp, ok, err := a.authenticator.AuthenticateToken(ctx, token)
154+
if !a.cacheErrs && err != nil {
155+
return nil, err
156+
}
157+
158+
switch {
159+
case ok && a.successTTL > 0:
160+
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
161+
case !ok && a.failureTTL > 0:
162+
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
163+
}
164+
return lookup{resp, ok}, err
165+
})
166+
167+
select {
168+
case result := <-c:
169+
if result.Err != nil {
170+
return nil, false, result.Err
171+
}
172+
lookup := result.Val.(lookup)
173+
return lookup.resp, lookup.ok, nil
174+
case <-ctx.Done():
175+
return nil, false, ctx.Err()
118176
}
119-
120-
return resp, ok, err
121177
}
122178

123179
// keyFunc generates a string key by hashing the inputs.

staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"testing"
3131
"time"
3232

33+
"github.com/google/go-cmp/cmp"
3334
utilclock "k8s.io/apimachinery/pkg/util/clock"
3435
"k8s.io/apimachinery/pkg/util/uuid"
3536
"k8s.io/apiserver/pkg/authentication/authenticator"
@@ -173,6 +174,106 @@ func BenchmarkKeyFunc(b *testing.B) {
173174
})
174175
}
175176

177+
func TestSharedLookup(t *testing.T) {
178+
var chewie = &authenticator.Response{User: &user.DefaultInfo{Name: "chewbacca"}}
179+
180+
t.Run("actually shared", func(t *testing.T) {
181+
var lookups uint32
182+
c := make(chan struct{})
183+
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
184+
<-c
185+
atomic.AddUint32(&lookups, 1)
186+
return chewie, true, nil
187+
}), true, time.Minute, 0)
188+
189+
var wg sync.WaitGroup
190+
for i := 0; i < 10; i++ {
191+
wg.Add(1)
192+
go func() {
193+
defer wg.Done()
194+
a.AuthenticateToken(context.Background(), "")
195+
}()
196+
}
197+
198+
// no good way to make sure that all the callers are queued so we sleep.
199+
time.Sleep(1 * time.Second)
200+
close(c)
201+
wg.Wait()
202+
203+
if lookups > 3 {
204+
t.Fatalf("unexpected number of lookups: got=%d, wanted less than 3", lookups)
205+
}
206+
})
207+
208+
t.Run("first caller bails, second caller gets result", func(t *testing.T) {
209+
c := make(chan struct{})
210+
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
211+
<-c
212+
return chewie, true, nil
213+
}), true, time.Minute, 0)
214+
215+
var wg sync.WaitGroup
216+
wg.Add(2)
217+
218+
ctx1, cancel1 := context.WithCancel(context.Background())
219+
go func() {
220+
defer wg.Done()
221+
a.AuthenticateToken(ctx1, "")
222+
}()
223+
224+
ctx2 := context.Background()
225+
226+
var (
227+
resp *authenticator.Response
228+
ok bool
229+
err error
230+
)
231+
go func() {
232+
defer wg.Done()
233+
resp, ok, err = a.AuthenticateToken(ctx2, "")
234+
}()
235+
236+
time.Sleep(1 * time.Second)
237+
cancel1()
238+
close(c)
239+
wg.Wait()
240+
241+
if want := chewie; !cmp.Equal(resp, want) {
242+
t.Errorf("Unexpected diff: %v", cmp.Diff(resp, want))
243+
}
244+
if !ok {
245+
t.Errorf("Expected ok response")
246+
}
247+
if err != nil {
248+
t.Errorf("Unexpected error: %v", err)
249+
}
250+
})
251+
252+
t.Run("lookup panics", func(t *testing.T) {
253+
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
254+
panic("uh oh")
255+
}), true, time.Minute, 0)
256+
257+
_, _, err := a.AuthenticateToken(context.Background(), "")
258+
if err != errAuthnCrash {
259+
t.Errorf("expected error: %v", err)
260+
}
261+
})
262+
263+
t.Run("audiences are forwarded", func(t *testing.T) {
264+
ctx := authenticator.WithAudiences(context.Background(), []string{"a"})
265+
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
266+
auds, _ := authenticator.AudiencesFrom(ctx)
267+
if got, want := auds, []string{"a"}; cmp.Equal(got, want) {
268+
t.Fatalf("unexpeced audiences: %v", cmp.Diff(got, want))
269+
}
270+
return nil, false, nil
271+
}), true, time.Minute, 0)
272+
273+
a.AuthenticateToken(ctx, "")
274+
})
275+
}
276+
176277
func BenchmarkCachedTokenAuthenticator(b *testing.B) {
177278
tokenCount := []int{100, 500, 2500, 12500, 62500}
178279
threadCount := []int{1, 16, 256}

0 commit comments

Comments
 (0)