Skip to content
2 changes: 2 additions & 0 deletions src/ray/rpc/authentication/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ ray_cc_library(
":authentication_token",
":k8s_constants",
"//src/ray/util:logging",
"@com_google_absl//absl/strings",
"@nlohmann_json",
],
)

Expand Down
89 changes: 89 additions & 0 deletions src/ray/rpc/authentication/authentication_token_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/strings/escaping.h"
#include "absl/strings/str_split.h"
#include "nlohmann/json.hpp"
#include "ray/rpc/authentication/authentication_mode.h"
#include "ray/rpc/authentication/k8s_constants.h"
#include "ray/util/logging.h"
Expand All @@ -44,6 +48,50 @@ constexpr const char *kNoTokenErrorMessage =
"or store the token in any file and set RAY_AUTH_TOKEN_PATH to point to it, "
"or set the RAY_AUTH_TOKEN environment variable.";

constexpr int kRaySATokenDefaultTTLSeconds = 600;
constexpr int kRaySATokenExpirationBufferSeconds = 300;

std::optional<std::chrono::system_clock::time_point>
AuthenticationTokenLoader::GetTokenExpiration(const std::string &token) {
std::vector<std::string> parts = absl::StrSplit(token, '.');
if (parts.size() != 3) {
RAY_LOG(WARNING) << "Invalid JWT token format.";
return std::nullopt;
}

std::string payload_b64 = parts[1];
// Convert Base64URL to Base64
for (char &c : payload_b64) {
if (c == '-') {
c = '+';
} else if (c == '_') {
c = '/';
}
}
// Add padding if necessary
while (payload_b64.size() % 4 != 0) {
payload_b64 += '=';
}

std::string payload;
if (!absl::Base64Unescape(payload_b64, &payload)) {
RAY_LOG(WARNING) << "Unable to base64 decode JWT token.";
return std::nullopt;
}

try {
auto json = nlohmann::json::parse(payload);
if (json.contains("exp") && json["exp"].is_number()) {
int64_t exp = json["exp"].get<int64_t>();
return std::chrono::system_clock::from_time_t(exp);
}
} catch (...) {
return std::nullopt;
}

return std::nullopt;
}

AuthenticationTokenLoader &AuthenticationTokenLoader::instance() {
static AuthenticationTokenLoader instance;
return instance;
Expand All @@ -53,6 +101,16 @@ std::shared_ptr<const AuthenticationToken> AuthenticationTokenLoader::GetToken(
bool ignore_auth_mode) {
absl::MutexLock lock(&token_mutex_);

// If k8s token auth is enabled, revoke cached token as Kubelet
// will expire and auto rotate new service account tokens every hour by default.
// Use 5 minutes as a default as users can configure the expiration time.
if (IsK8sTokenAuthEnabled()) {
if (cached_token_ &&
std::chrono::system_clock::now() >= cached_token_expiration_time_) {
cached_token_ = nullptr;
}
}

// If already loaded, return cached value
if (cached_token_) {
return cached_token_;
Expand All @@ -76,6 +134,17 @@ std::shared_ptr<const AuthenticationToken> AuthenticationTokenLoader::GetToken(
// Cache and return the loaded token
if (has_token) {
cached_token_ = std::make_shared<const AuthenticationToken>(std::move(*result.token));
if (IsK8sTokenAuthEnabled()) {
auto exp = GetTokenExpiration(cached_token_->GetRawValue());
if (exp) {
cached_token_expiration_time_ =
*exp - std::chrono::seconds(kRaySATokenExpirationBufferSeconds);
} else {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Token thrashing when remaining lifetime under buffer

High Severity

When a K8s token has less than 5 minutes remaining before expiration, cached_token_expiration_time_ is set to a past time (token's exp minus 300 seconds). This causes every subsequent GetToken() call to immediately invalidate the cached token and reload it from the filesystem. In busy systems making frequent RPC calls, this results in excessive filesystem reads until the token is actually rotated by Kubernetes, defeating the caching mechanism and potentially causing performance degradation.

Fix in Cursor Fix in Web

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be a problem as Kubernetes enforces a minimum expiration time of 10 minutes

cached_token_expiration_time_ =
std::chrono::system_clock::now() +
std::chrono::seconds(kRaySATokenDefaultTTLSeconds);
}
}
}
return cached_token_;
}
Expand All @@ -84,6 +153,16 @@ TokenLoadResult AuthenticationTokenLoader::TryLoadToken(bool ignore_auth_mode) {
absl::MutexLock lock(&token_mutex_);
TokenLoadResult result;

// If k8s token auth is enabled, revoke cached token as Kubelet
// will expire and auto rotate new service account tokens every hour by default.
// Use 5 minutes as a default as users can configure the expiration time.
if (IsK8sTokenAuthEnabled()) {
if (cached_token_ &&
std::chrono::system_clock::now() >= cached_token_expiration_time_) {
cached_token_ = nullptr;
}
}

// If already loaded, return cached value
if (cached_token_) {
result.token = *cached_token_; // Copy from shared_ptr
Expand Down Expand Up @@ -112,6 +191,16 @@ TokenLoadResult AuthenticationTokenLoader::TryLoadToken(bool ignore_auth_mode) {
}
// Cache and return success
cached_token_ = std::make_shared<const AuthenticationToken>(std::move(*result.token));
if (IsK8sTokenAuthEnabled()) {
auto exp = GetTokenExpiration(cached_token_->GetRawValue());
if (exp) {
cached_token_expiration_time_ =
*exp - std::chrono::seconds(kRaySATokenExpirationBufferSeconds);
} else {
cached_token_expiration_time_ = std::chrono::system_clock::now() +
std::chrono::seconds(kRaySATokenDefaultTTLSeconds);
}
}
result.token = *cached_token_; // Copy back for return
return result;
}
Expand Down
9 changes: 9 additions & 0 deletions src/ray/rpc/authentication/authentication_token_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <chrono>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -62,6 +63,7 @@ class AuthenticationTokenLoader {
void ResetCache() {
absl::MutexLock lock(&token_mutex_);
cached_token_ = nullptr;
cached_token_expiration_time_ = std::chrono::system_clock::time_point();
}

AuthenticationTokenLoader(const AuthenticationTokenLoader &) = delete;
Expand All @@ -83,8 +85,15 @@ class AuthenticationTokenLoader {
/// Trim whitespace from the beginning and end of the string.
std::string TrimWhitespace(const std::string &str);

/// Extract expiration time from a JWT token.
/// \param token The JWT token string.
/// \return The expiration time, or std::nullopt if not a valid JWT or no exp claim.
std::optional<std::chrono::system_clock::time_point> GetTokenExpiration(
const std::string &token);

absl::Mutex token_mutex_;
std::shared_ptr<const AuthenticationToken> cached_token_;
std::chrono::system_clock::time_point cached_token_expiration_time_;
};

} // namespace rpc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

#include "ray/rpc/authentication/authentication_token_loader.h"

#include <algorithm>
#include <chrono>
#include <fstream>
#include <string>
#include <thread>

#include "absl/strings/escaping.h"
#include "gtest/gtest.h"
#include "ray/common/ray_config.h"
#include "ray/util/env.h"
Expand Down Expand Up @@ -352,6 +356,49 @@ TEST_F(AuthenticationTokenLoaderTest, TestIgnoreAuthModeGetToken) {
RayConfig::instance().initialize(R"({"AUTH_MODE": "token"})");
}

TEST_F(AuthenticationTokenLoaderTest, TestJWTExpiration) {
// Enable K8s token auth
RayConfig::instance().initialize(
R"({"AUTH_MODE": "token", "ENABLE_K8S_TOKEN_AUTH": true})");
AuthenticationTokenLoader::instance().ResetCache();

// Create a JWT with expiration time buffer (300) + 1 seconds
auto now = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
int64_t exp = now + 301;

std::string header =
"eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0"; // {"alg":"none","typ":"JWT"}
std::string payload_json = "{\"exp\":" + std::to_string(exp) + "}";
std::string payload;
absl::Base64Escape(payload_json, &payload);
std::replace(payload.begin(), payload.end(), '+', '-');
std::replace(payload.begin(), payload.end(), '/', '_');
payload.erase(std::remove(payload.begin(), payload.end(), '='), payload.end());

std::string jwt = header + "." + payload + ".signature";

set_env_var("RAY_AUTH_TOKEN", jwt.c_str());

auto &loader = AuthenticationTokenLoader::instance();
auto token1 = loader.GetToken();
ASSERT_TRUE(token1 != nullptr);
EXPECT_EQ(token1->GetRawValue(), jwt);

// Wait for it to expire
std::this_thread::sleep_for(std::chrono::seconds(2));

// Next call should revoke and reload. Change the env var to verify it reloads.
set_env_var("RAY_AUTH_TOKEN", "new-token");

auto token2 = loader.GetToken();
ASSERT_TRUE(token2 != nullptr);
EXPECT_EQ(token2->GetRawValue(), "new-token");

// Re-enable auth for other tests
RayConfig::instance().initialize(
R"({"AUTH_MODE": "token", "ENABLE_K8S_TOKEN_AUTH": false})");
}

} // namespace rpc
} // namespace ray

Expand Down