Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cmd/metadata-server/tokencache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ type KnownToken struct {
expires time.Time
}

type inflightLock struct {
handle *shared.TicketLock
lastUsed time.Time
}

// TokenCache is a cache for previously requested tokens
type TokenCache struct {
lock *sync.Mutex
Expand All @@ -25,6 +30,8 @@ type TokenCache struct {
hitMetric prometheus.Counter
missMetric prometheus.Counter
setMetric prometheus.Counter

inflight map[TokenUID]*inflightLock
}

// NewTokenCache creates a new token cache with a garbage collection interval.
Expand Down Expand Up @@ -68,6 +75,7 @@ func NewTokenCache(gcInterval, minLifetime time.Duration) *TokenCache {
hitMetric: hitMetric,
missMetric: missMetric,
setMetric: setMetric,
inflight: make(map[TokenUID]*inflightLock),
}

if gcInterval > 0 {
Expand All @@ -86,6 +94,30 @@ func NewTokenCache(gcInterval, minLifetime time.Duration) *TokenCache {
return cache
}

// GetTokenLock returns a ticket lock for the given token identifier.
// The lock can be used to prevent multiple parallel requests for the same token.
func (t *TokenCache) GetTokenLock(tokenIdentifier TokenLookup) *shared.TicketLock {
t.lock.Lock()
defer t.lock.Unlock()

id := tokenIdentifier.ToTokenUID()
lock, ok := t.inflight[id]

if ok {
if time.Since(lock.lastUsed) <= t.minTokenLifetime {
lock.lastUsed = time.Now()
return lock.handle
}
}

lock = &inflightLock{
handle: shared.NewTicketLock(5 * time.Millisecond),
lastUsed: time.Now(),
}
t.inflight[id] = lock
return lock.handle
}

// StopGC stops the garbage collection timer.
func (t *TokenCache) StopGC() {
if t.gcTimer != nil {
Expand All @@ -110,6 +142,20 @@ func (t *TokenCache) GC() {
for _, id := range staleTokens {
delete(t.data, id)
}

// Cleanup inflight locks.
// As this is a map, we can delete keys while iterating.
for id, lock := range t.inflight {
if time.Since(lock.lastUsed) > t.minTokenLifetime {
// Locks that are held for longer than minTokenLifetime are a sign
// of a bug, like not releasing the lock. Fetching a token should
// always be _much_ shorter than minTokenLifetime.
if lock.handle.IsLocked() {
log.Warn().Msg("Timed-out inflight lock is still held by a thread, this should not happen")
}
delete(t.inflight, id)
}
}
}

// Get reurns the known token for the given service account or nil
Expand Down
88 changes: 57 additions & 31 deletions cmd/metadata-server/tokenhandlers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"identity-metadata-server/internal/shared"
"net/http"
"strings"
Expand Down Expand Up @@ -48,44 +49,69 @@ func HandleGetAccessToken(c *gin.Context) {
}

tokenID := NewLookupWithScopeAndAudience(TokenTypeAccess, srcIdentity, scopes, additionalAudiences)
cachedToken := knownTokens.Get(tokenID)

if cachedToken == nil {
// The documentation is a bit patchy here, so we don't know if we can
// actually override the token lifetime through a request.
// TODO: Reverse-engineering is required here. We need to find a
// call that sets the token lifetime and see which parameter is
// being used.
tokenLifeTime := AccessTokenLifetime

trt, err := tokenProvider.GetTokenRequestToken(c.Request.Context(), srcIdentity, tokenLifeTime, scopes, additionalAudiences)
if trt == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
returnToken := func(cachedToken *KnownToken) {
// The format is explained in the documentation.
// https://cloud.google.com/compute/docs/access/authenticate-workloads#applications
// The response format is identical to the one used by the STS endpoint,
// which might be by design, but is not documented.
response := shared.TokenExchangeResponse{
AccessToken: cachedToken.token,
ExpiresIn: int(time.Until(cachedToken.expires).Seconds()),
TokenType: "Bearer",
}

// Get the token for the given parameters.
accessToken, err := tokenProvider.GetAccessToken(c.Request.Context(), *trt, tokenLifeTime, scopes, gsa)
if accessToken == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
}
c.Header("Metadata-Flavor", "Google")
c.JSON(http.StatusOK, response)
}

cachedToken = knownTokens.StoreUntil(tokenID, accessToken.AccessToken, accessToken.ExpireTime)
// Try to get the token from the cache
if cachedToken := knownTokens.Get(tokenID); cachedToken != nil {
returnToken(cachedToken)
return
}

// The format is explained in the documentation.
// https://cloud.google.com/compute/docs/access/authenticate-workloads#applications
// The response format is identical to the one used by the STS endpoint,
// which might be by design, but is not documented.
response := shared.TokenExchangeResponse{
AccessToken: cachedToken.token,
ExpiresIn: int(time.Until(cachedToken.expires).Seconds()),
TokenType: "Bearer",
// Cache miss. Acquire a lock to block inflight requests for the same tokenID.
inflightLock := knownTokens.GetTokenLock(tokenID)
if inflightLock.LockWithContext(c.Request.Context()) == 0 {
c.Header("Retry-After", "5")
shared.HttpError(c, http.StatusTooManyRequests, errors.New("timed out while waiting for another token fetch to finish"))
return
}
defer inflightLock.Unlock()

c.Header("Metadata-Flavor", "Google")
c.JSON(http.StatusOK, response)
// Try to get the token from the cache again. This time we might have a token
// in the cache, as another request might have fetched the token while we were
// waiting for the lock.
if cachedToken := knownTokens.Get(tokenID); cachedToken != nil {
returnToken(cachedToken)
return
}

// True cache miss. Fetch the token from the token provider.

// The documentation is a bit patchy here, so we don't know if we can
// actually override the token lifetime through a request.
// TODO: Reverse-engineering is required here. We need to find a
// call that sets the token lifetime and see which parameter is
// being used.
tokenLifeTime := AccessTokenLifetime

trt, err := tokenProvider.GetTokenRequestToken(c.Request.Context(), srcIdentity, tokenLifeTime, scopes, additionalAudiences)
if trt == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
}

// Get the token for the given parameters.
accessToken, err := tokenProvider.GetAccessToken(c.Request.Context(), *trt, tokenLifeTime, scopes, gsa)
if accessToken == nil {
shared.HttpError(c, http.StatusInternalServerError, err)
return
}

// Store the token in the cache and return it
cachedToken := knownTokens.StoreUntil(tokenID, accessToken.AccessToken, accessToken.ExpireTime)
returnToken(cachedToken)
}

// HandleGetIdentityToken handles an identity token request.
Expand Down
41 changes: 41 additions & 0 deletions internal/shared/intheap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package shared

// HeapUint64 implements a min-heap of uint64 values on top of a slice.
// Use this type with the container/heap package.
type HeapUint64 []uint64

// Len returns the number of elements in the heap.
func (h HeapUint64) Len() int {
return len(h)
}

// Less returns true if the element at index i is less than the element at index j.
func (h HeapUint64) Less(i, j int) bool { return h[i] < h[j] }

// Swap swaps the elements at index i and j.
func (h HeapUint64) Swap(i, j int) { h[i], h[j] = h[j], h[i] }

// Push adds a new element to the heap.
// Use heap.Push(h, x) instead of this function.
func (h *HeapUint64) Push(x any) {
*h = append(*h, x.(uint64))
}

// Peek returns the smallest element from the heap without removing it.
// It returns false if the heap is empty.
func (h *HeapUint64) Peek() (uint64, bool) {
if len(*h) == 0 {
return 0, false
}
return (*h)[0], true
}

// Pop removes and returns the smallest element from the heap.
// Use heap.Pop(h) instead of this function.
func (h *HeapUint64) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
111 changes: 111 additions & 0 deletions internal/shared/ticketlock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package shared

import (
"container/heap"
"context"
"sync"
"sync/atomic"
"time"
)

type TicketLock struct {
nextTicket uint64
activeTicket uint64
pauseDuration time.Duration
canceledTickets *HeapUint64
ticketGuard *sync.Mutex
}

// NewTicketLock creates a new ticket lock with the given granularity.
// The granularity is the time to wait between each lock acquisition check.
// The granularity should be small enough to not block the main thread for too
// long, but large enough to not waste too much time.
// A granularity of 5-10 milliseconds is a good starting point.
func NewTicketLock(granularity time.Duration) *TicketLock {
return &TicketLock{
nextTicket: 1,
activeTicket: 1,
pauseDuration: granularity,
canceledTickets: &HeapUint64{},
ticketGuard: &sync.Mutex{},
}
}

// IsLocked returns true if the lock is currently held by a thread.
// Please note that this status can change right after the call to IsLocked().
// I.e. this is not a reliable way to check if the lock is currently held by a
// thread. It is only meant for debugging purposes.
func (l *TicketLock) IsLocked() bool {
return atomic.LoadUint64(&l.activeTicket) != atomic.LoadUint64(&l.nextTicket)
}

// Lock tries to acquire a lock in a FIFO way.
func (l *TicketLock) Lock() uint64 {
return l.LockWithContext(context.Background())
}

// LockWithContext tries to acquire a lock in a FIFO way.
// It returns 0 when the lock failed to be acquired due to a context
// cancellation or a timeout.
// If the lock was acquired, it returns the ticket number of the lock.
func (l *TicketLock) LockWithContext(ctx context.Context) uint64 {
ticket := atomic.AddUint64(&l.nextTicket, 1) - 1
var pause *time.Ticker

for {
if atomic.LoadUint64(&l.activeTicket) == ticket {
return ticket
}

// Do a lazy initialization of the ticker to avoid creating a ticker if
// it is not needed.
if pause == nil {
pause = time.NewTicker(l.pauseDuration)
defer pause.Stop()
}

select {
// We use a ticker to yield the CPU during waiting and to be able to
// check on the context while pausing.
case <-pause.C:
continue

case <-ctx.Done():
// We need to keep track of canceled tickets as tickets are linearly
// ordered. If we don't do this, we cannot properly unlock the lock
// in the correct order.
l.ticketGuard.Lock()
defer l.ticketGuard.Unlock()
heap.Push(l.canceledTickets, ticket)
return 0
}
}
}

// Unlock releases the lock.
func (l *TicketLock) Unlock() {
l.ticketGuard.Lock()
defer l.ticketGuard.Unlock()

for {
ticket := atomic.AddUint64(&l.activeTicket, 1)
nextCanceledTicket, hasCanceledTickets := l.canceledTickets.Peek()

switch {
// No canceled tickets, we can return
case !hasCanceledTickets:
return

// The last canceled ticket is the same as the current ticket.
// We need to try again with the next ticket (which might also be
// canceled).
case nextCanceledTicket == ticket:
heap.Pop(l.canceledTickets)

// There are canceled tickets, but the current ticket is smaller than
// the first canceled ticket.
default:
return
}
}
}
49 changes: 49 additions & 0 deletions internal/shared/ticketlock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package shared

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestTicketLock(t *testing.T) {
assert := assert.New(t)

lock := NewTicketLock(time.Millisecond)

ticket1 := lock.Lock()
assert.NotZero(ticket1, "Lock should return a non-zero ticket")
assert.Equal(uint64(1), ticket1, "Lock should return the first ticket")
assert.True(lock.IsLocked(), "Lock should return a non-zero ticket")

// Test timeout
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

ticket2 := lock.LockWithContext(ctx)
assert.Zero(ticket2, "LockWithContext should return a zero ticket as it timed out before the lock was acquired")

// Test release
lock.Unlock()
assert.False(lock.IsLocked(), "Lock should return a zero ticket after the lock was released")

// Test if release properly increments the active ticket
ticket3 := lock.Lock()
assert.NotZero(ticket3, "Lock should return a non-zero ticket after the previous lock was released")
assert.NotEqual(ticket1, ticket3, "Lock should return a different ticket after the previous lock was released")
assert.Equal(uint64(3), ticket3, "Lock should return the third ticket, as the second lock was aborted")

// Test if release properly increments the active ticket with consecutive discards
ticket4 := lock.LockWithContext(ctx)
assert.Zero(ticket4, "LockWithContext should return a zero ticket as it timed out before the lock was acquired")
ticket5 := lock.LockWithContext(ctx)
assert.Zero(ticket5, "LockWithContext should return a zero ticket as it timed out before the lock was acquired")

lock.Unlock()
ticket6 := lock.Lock()
assert.NotZero(ticket6, "Lock should return a non-zero ticket after the previous lock was released")
assert.NotEqual(ticket3, ticket6, "Lock should return a different ticket after the previous lock was released")
assert.Equal(uint64(6), ticket6, "Lock should return the sixth ticket, as the fourth and fifth locks were aborted")
}
Loading