Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.nimbusds.jwt.proc.ExpiredJWTException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Strings;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.RequestDataContext;
Expand All @@ -25,7 +26,6 @@
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.web.server.ResponseStatusException;
import org.zowe.apiml.constants.ApimlConstants;
import org.zowe.apiml.gateway.caching.LoadBalancerCache;
import org.zowe.apiml.gateway.caching.LoadBalancerCache.LoadBalancerCacheRecord;
import reactor.core.publisher.Flux;
Expand All @@ -35,25 +35,20 @@
import java.text.ParseException;
import java.time.Clock;
import java.time.LocalDateTime;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.*;
import java.util.stream.Stream;

import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static org.zowe.apiml.constants.ApimlConstants.X_INSTANCEID;
import static reactor.core.publisher.Flux.just;
import static reactor.core.publisher.Mono.empty;

/**
* A sticky session load balancer that ensures requests from the same user are routed to the same service instance.
*/
@Slf4j
public class DeterministicLoadBalancer extends SameInstancePreferenceServiceInstanceListSupplier {

public static final String HEADER_PREFIX = "Bearer ";
private static final String HEADER_NONE_SIGNATURE = Base64.getEncoder().encodeToString("{\"typ\":\"JWT\",\"alg\":\"none\"}".getBytes(StandardCharsets.UTF_8));

private final LoadBalancerCache cache;
Expand Down Expand Up @@ -85,22 +80,44 @@ public Flux<List<ServiceInstance>> get(Request request) {
if (serviceId == null) {
return Flux.empty();
}
AtomicReference<String> principal = new AtomicReference<>();

var requestContext = request.getContext();
var instanceId = getInstanceId(requestContext);
if (instanceId != null) {
// if instanceId is set in headers use it
try {
return delegate.get(request)
.map(serviceInstances -> checkInstanceIdHeader(instanceId, serviceInstances));
} catch (ResponseStatusException ex) {
return Flux.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Service instance not found for the provided instance ID"));
}
}

var userId = getSub(requestContext);
if (userId == null) {
// if no userId is available return all
log.debug("No authentication present on request, not filtering the service: {}", serviceId);
return delegate.get(request);
}

return delegate.get(request)
.flatMap(serviceInstances -> getSub(request.getContext())
.switchIfEmpty(Mono.just(""))
.flatMap(user -> {
if (user == null || user.isEmpty()) {
log.debug("No authentication present on request, not filtering the service: {}", serviceId);
return empty();
} else {
principal.set(user);
return cache.retrieve(user, serviceId).onErrorResume(t -> Mono.empty());
}
})
.switchIfEmpty(Mono.just(LoadBalancerCacheRecord.NONE))
.flatMapMany(cacheRecord -> filterInstances(principal.get(), serviceId, cacheRecord, serviceInstances, request.getContext()))
)
.flatMap(serviceInstances -> {
if (serviceInstances.isEmpty()) {
// no instances available - just return
return Flux.just(serviceInstances);
}

boolean stickySession = lbTypeIsAuthentication(serviceInstances.iterator().next());
if (!stickySession) {
// service does not support sticky session by userId, just return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we log some message in debug also for this condition?

return Flux.just(serviceInstances);
}

return cache.retrieve(userId, serviceId)
.onErrorResume(t -> Mono.empty())
.flatMapMany(cacheRecord -> filterInstances(userId, serviceId, cacheRecord, serviceInstances))
.switchIfEmpty(Flux.just(serviceInstances));
})
.doOnError(e -> log.debug("Error in determining service instances", e));
}

Expand All @@ -115,30 +132,23 @@ private boolean isTooOld(LocalDateTime cachedDate) {
return now.isAfter(cachedDate);
}

private Mono<String> getSub(Object requestContext) {
private String getSub(Object requestContext) {
if (requestContext instanceof RequestDataContext ctx) {
var token = Optional.ofNullable(getTokenFromCookie(ctx))
.orElseGet(() -> getTokenFromHeader(ctx));
return Mono.just(extractSubFromToken(token));
return extractSubFromToken(token);
}
return Mono.just("");
return null;
}

private String getTokenFromCookie(RequestDataContext ctx) {
var tokens = ctx.getClientRequest().getCookies().get("apimlAuthenticationToken");
return tokens == null || tokens.isEmpty() ? null : tokens.get(0);
return ctx.getClientRequest().getCookies().getFirst("apimlAuthenticationToken");
}

private String getTokenFromHeader(RequestDataContext ctx) {
var authHeaderValues = ctx.getClientRequest().getHeaders().get(HttpHeaders.AUTHORIZATION);
var token = authHeaderValues == null || authHeaderValues.isEmpty() ? null : authHeaderValues.get(0);
if (token != null && token.startsWith(ApimlConstants.BEARER_AUTHENTICATION_PREFIX)) {
token = token.replaceFirst(ApimlConstants.BEARER_AUTHENTICATION_PREFIX, "").trim();
if (token.isEmpty()) {
return null;
}

return token;
var authHeaderValue = ctx.getClientRequest().getHeaders().getFirst(HttpHeaders.AUTHORIZATION);
if (Strings.CS.startsWith(authHeaderValue, HEADER_PREFIX)) {
return authHeaderValue.substring(HEADER_PREFIX.length());
}
return null;
}
Expand All @@ -157,27 +167,17 @@ private Flux<List<ServiceInstance>> filterInstances(
String user,
String serviceId,
LoadBalancerCacheRecord cacheRecord,
List<ServiceInstance> serviceInstances,
Object requestContext) {

Flux<List<ServiceInstance>> result;
if (shouldIgnore(serviceInstances, user)) {
var instanceId = getInstanceId(requestContext);
try {
return just(checkInstanceIdHeader(instanceId, serviceInstances));
} catch (ResponseStatusException ex) {
return Flux.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Service instance not found for the provided instance ID"));
List<ServiceInstance> serviceInstances
) {
if (isNotBlank(cacheRecord.getInstanceId())) {
if (isTooOld(cacheRecord.getCreationTime())) {
return cache.delete(user, serviceId)
.thenMany(chooseOne(user, serviceInstances));
}
return chooseOne(cacheRecord.getInstanceId(), user, serviceInstances);
}
if (isNotBlank(cacheRecord.getInstanceId()) && isTooOld(cacheRecord.getCreationTime())) {
result = cache.delete(user, serviceId)
.thenMany(chooseOne(user, serviceInstances));
} else if (isNotBlank(cacheRecord.getInstanceId())) {
result = chooseOne(cacheRecord.getInstanceId(), user, serviceInstances);
} else {
result = chooseOne(user, serviceInstances);
}
return result;

return chooseOne(user, serviceInstances);
}

/**
Expand Down Expand Up @@ -253,10 +253,6 @@ private Flux<List<ServiceInstance>> chooseOne(String user, List<ServiceInstance>
return chooseOne(null, user, serviceInstances);
}

boolean shouldIgnore(List<ServiceInstance> instances, String user) {
return StringUtils.isEmpty(user) || instances.isEmpty() || !lbTypeIsAuthentication(instances.get(0));
}

private boolean lbTypeIsAuthentication(ServiceInstance instance) {
Map<String, String> metadata = instance.getMetadata();
if (metadata != null) {
Expand Down Expand Up @@ -305,6 +301,7 @@ private String extractSubFromToken(String token) {
return claims.getSubject();
}
}
return "";
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
@TestInstance(Lifecycle.PER_CLASS)
Expand Down Expand Up @@ -162,7 +158,6 @@ void setUp() {
@Test
void whenServiceDoesNotHaveMetadata_thenUseDefaultList() {
when(instance1.getMetadata()).thenReturn(null);
when(lbCache.retrieve("USER", "service")).thenReturn(Mono.just(LoadBalancerCacheRecord.NONE));

StepVerifier.create(loadBalancer.get(request))
.assertNext(chosenInstances -> {
Expand All @@ -179,8 +174,6 @@ void whenServiceDoesNotUseSticky_thenUseDefaultList() {
metadata.put("apiml.lb.type", "somethingelse");
when(instance1.getMetadata()).thenReturn(metadata);

when(lbCache.retrieve("USER", "service")).thenReturn(Mono.just(LoadBalancerCacheRecord.NONE));

StepVerifier.create(loadBalancer.get(request))
.assertNext(chosenInstances -> {
assertNotNull(chosenInstances);
Expand Down Expand Up @@ -247,6 +240,9 @@ void whenCacheEntryExpired_thenUpdatePreference() {
assertNotNull(chosenInstances);
assertEquals(1, chosenInstances.size());
assertEquals("instance1", chosenInstances.get(0).getInstanceId());
verify(lbCache).retrieve("USER", "service");
verify(lbCache).delete("USER", "service");
verify(lbCache).store(eq("USER"), eq("service"), any());
})
.expectComplete()
.verify();
Expand All @@ -266,6 +262,8 @@ void whenNoPreferece_thenCreateOne() {
assertNotNull(chosenInstances);
assertEquals(1, chosenInstances.size());
assertEquals("instance1", chosenInstances.get(0).getInstanceId());

verify(lbCache).retrieve("USER", "service");
})
.expectComplete()
.verify();
Expand Down Expand Up @@ -368,15 +366,11 @@ class GivenInstanceIdHeaderIsPresent {
@BeforeEach
void setUp() {
var context = new RequestDataContext(requestData);
MultiValueMap<String, String> cookie = new LinkedMultiValueMap<>();
cookie.add("apimlAuthenticationToken", "invalidToken");

when(request.getContext()).thenReturn(context);
when(requestData.getCookies()).thenReturn(cookie);
}

@Test
void whenInstanceIdExists_thenChoseeIt() {
void whenInstanceIdExists_thenChooseIt() {
var headers = new HttpHeaders();
headers.add("X-InstanceId", "instance2");
when(requestData.getHeaders()).thenReturn(headers);
Expand All @@ -386,6 +380,21 @@ void whenInstanceIdExists_thenChoseeIt() {
assertNotNull(chosenInstances);
assertEquals(1, chosenInstances.size());
assertEquals("instance2", chosenInstances.get(0).getInstanceId());
verify(lbCache, never()).retrieve(any(), any());
})
.expectComplete()
.verify();
}

@Test
void whenNoToken_thenDoNotCallCache() {
when(requestData.getHeaders()).thenReturn(new HttpHeaders());
when(requestData.getCookies()).thenReturn(new LinkedMultiValueMap<>());

StepVerifier.create(loadBalancer.get(request))
.assertNext(chosenInstances -> {
assertNotNull(chosenInstances);
verify(lbCache, never()).retrieve(any(), any());
})
.expectComplete()
.verify();
Expand Down
Loading