Skip to content

Commit c244332

Browse files
koenpuntrstoyanchev
authored andcommitted
Raise exception if Principal is required but not present
See gh-790
1 parent ab39246 commit c244332

File tree

4 files changed

+123
-12
lines changed

4 files changed

+123
-12
lines changed

spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/AuthenticationPrincipalArgumentResolver.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ private static AuthenticationPrincipal findMethodAnnotation(MethodParameter para
9696

9797
@Override
9898
public Object resolveArgument(MethodParameter parameter, DataFetchingEnvironment environment) throws Exception {
99-
return getCurrentAuthentication()
99+
return getCurrentAuthentication(parameter.isOptional())
100100
.flatMap(auth -> Mono.justOrEmpty(resolvePrincipal(parameter, auth.getPrincipal())))
101101
.transform((argument) -> isParameterMonoAssignable(parameter) ? Mono.just(argument) : argument);
102102
}
@@ -106,9 +106,16 @@ private static boolean isParameterMonoAssignable(MethodParameter parameter) {
106106
return (Publisher.class.equals(type) || Mono.class.equals(type));
107107
}
108108

109-
private Mono<Authentication> getCurrentAuthentication() {
110-
return Mono.justOrEmpty(SecurityContextHolder.getContext().getAuthentication())
111-
.switchIfEmpty(ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication));
109+
@SuppressWarnings("unchecked")
110+
private Mono<Authentication> getCurrentAuthentication(boolean optional) {
111+
Object principal = PrincipalMethodArgumentResolver.doResolve(optional);
112+
if (principal instanceof Authentication) {
113+
return Mono.just((Authentication) principal);
114+
}
115+
else if (principal instanceof Mono) {
116+
return (Mono<Authentication>) principal;
117+
}
118+
return Mono.error(new IllegalStateException("Unexpected return value: " + principal));
112119
}
113120

114121
@Nullable

spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/BatchLoaderHandlerMethod.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ else if ("kotlin.coroutines.Continuation".equals(parameterType.getName())) {
145145
return null;
146146
}
147147
else if (springSecurityPresent && Principal.class.isAssignableFrom(parameter.getParameterType())) {
148-
return PrincipalMethodArgumentResolver.doResolve();
148+
return PrincipalMethodArgumentResolver.doResolve(parameter.isOptional());
149149
}
150150
else {
151151
throw new IllegalStateException(formatArgumentError(parameter, "Unexpected argument type."));

spring-graphql/src/main/java/org/springframework/graphql/data/method/annotation/support/PrincipalMethodArgumentResolver.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
package org.springframework.graphql.data.method.annotation.support;
1717

1818
import java.security.Principal;
19+
import java.util.function.Function;
1920

2021
import graphql.schema.DataFetchingEnvironment;
2122

2223
import org.springframework.core.MethodParameter;
2324
import org.springframework.graphql.data.method.HandlerMethodArgumentResolver;
25+
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
2426
import org.springframework.security.core.Authentication;
27+
import org.springframework.security.core.AuthenticationException;
2528
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
2629
import org.springframework.security.core.context.SecurityContext;
2730
import org.springframework.security.core.context.SecurityContextHolder;
31+
import reactor.core.publisher.Mono;
2832

2933
/**
3034
* Resolver to obtain {@link Principal} from Spring Security context via
@@ -50,13 +54,29 @@ public boolean supportsParameter(MethodParameter parameter) {
5054

5155
@Override
5256
public Object resolveArgument(MethodParameter parameter, DataFetchingEnvironment environment) {
53-
return doResolve();
57+
return doResolve(parameter.isOptional());
5458
}
5559

56-
static Object doResolve() {
60+
static Object doResolve(boolean optional) {
5761
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
58-
return (authentication != null ? authentication :
59-
ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication));
62+
63+
if (authentication != null) {
64+
return authentication;
65+
}
66+
67+
return ReactiveSecurityContextHolder.getContext()
68+
.switchIfEmpty(optional ? Mono.empty() : Mono.error(new AuthenticationCredentialsNotFoundException("SecurityContext not available")))
69+
.handle((context, sink) -> {
70+
Authentication auth = context.getAuthentication();
71+
72+
if (auth != null) {
73+
sink.next(auth);
74+
} else if (!optional) {
75+
sink.error(new AuthenticationCredentialsNotFoundException("An Authentication object was not found in the SecurityContext"));
76+
} else {
77+
sink.complete();
78+
}
79+
});
6080
}
6181

6282
}

spring-graphql/src/test/java/org/springframework/graphql/data/method/annotation/support/SchemaMappingPrincipalMethodArgumentResolverTests.java

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
import java.time.Duration;
2121
import java.util.function.Function;
2222

23+
import graphql.GraphqlErrorBuilder;
2324
import io.micrometer.context.ContextSnapshot;
2425
import org.junit.jupiter.api.Nested;
2526
import org.junit.jupiter.api.Test;
2627
import org.junit.jupiter.params.ParameterizedTest;
2728
import org.junit.jupiter.params.provider.ValueSource;
29+
import org.springframework.graphql.execution.DataFetcherExceptionResolver;
30+
import org.springframework.graphql.execution.ErrorType;
2831
import reactor.core.publisher.Flux;
2932
import reactor.core.publisher.Mono;
3033
import reactor.test.StepVerifier;
@@ -63,6 +66,9 @@ public class SchemaMappingPrincipalMethodArgumentResolverTests {
6366
private final Function<Context, Context> reactiveContextWriter = context ->
6467
ReactiveSecurityContextHolder.withAuthentication(this.authentication);
6568

69+
private final Function<Context, Context> reactiveContextWriterWithoutAuthentication = context ->
70+
ReactiveSecurityContextHolder.withSecurityContext(Mono.just(SecurityContextHolder.createEmptyContext()));
71+
6672
private final Function<Context, Context> threadLocalContextWriter = context ->
6773
ContextSnapshot.captureAll().updateContext(context);
6874

@@ -100,6 +106,68 @@ void resolveFromThreadLocalContext(String field) {
100106
}
101107
}
102108

109+
@Test
110+
void nullablePrincipalDoesntRequireSecurityContext() {
111+
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
112+
"type Query { greetingMonoNullable: String }", "{ greetingMonoNullable }",
113+
context -> context);
114+
115+
ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);
116+
117+
assertThat(responseHelper.errorCount()).isEqualTo(0);
118+
}
119+
120+
@Test
121+
void nonNullPrincipalRequiresSecurityContext() {
122+
DataFetcherExceptionResolver exceptionResolver =
123+
DataFetcherExceptionResolver.forSingleError((ex, env) -> GraphqlErrorBuilder.newError(env)
124+
.message("Resolved error: " + ex.getMessage())
125+
.errorType(ErrorType.UNAUTHORIZED)
126+
.build());
127+
128+
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
129+
"type Query { greetingMono: String }", "{ greetingMono }",
130+
context -> context,
131+
exceptionResolver);
132+
133+
ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);
134+
135+
assertThat(responseHelper.errorCount()).isEqualTo(1);
136+
assertThat(responseHelper.error(0).errorType()).isEqualTo("UNAUTHORIZED");
137+
assertThat(responseHelper.error(0).message()).isEqualTo("Resolved error: SecurityContext not available");
138+
}
139+
140+
@Test
141+
void nonNullPrincipalRequiresAuthentication() {
142+
DataFetcherExceptionResolver exceptionResolver =
143+
DataFetcherExceptionResolver.forSingleError((ex, env) -> GraphqlErrorBuilder.newError(env)
144+
.message("Resolved error: " + ex.getMessage())
145+
.errorType(ErrorType.UNAUTHORIZED)
146+
.build());
147+
148+
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
149+
"type Query { greetingMono: String }", "{ greetingMono }",
150+
reactiveContextWriterWithoutAuthentication,
151+
exceptionResolver);
152+
153+
ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);
154+
155+
assertThat(responseHelper.errorCount()).isEqualTo(1);
156+
assertThat(responseHelper.error(0).errorType()).isEqualTo("UNAUTHORIZED");
157+
assertThat(responseHelper.error(0).message()).isEqualTo("Resolved error: An Authentication object was not found in the SecurityContext");
158+
}
159+
160+
@Test
161+
void nullablePrincipalDoesntRequireAuthentication() {
162+
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
163+
"type Query { greetingMonoNullable: String }", "{ greetingMonoNullable }",
164+
reactiveContextWriterWithoutAuthentication);
165+
166+
ResponseHelper responseHelper = ResponseHelper.forResponse(responseMono);
167+
168+
assertThat(responseHelper.errorCount()).isEqualTo(0);
169+
}
170+
103171
private void testQuery(String field, Function<Context, Context> contextWriter) {
104172
Mono<ExecutionGraphQlResponse> responseMono = executeAsync(
105173
"type Query { " + field + ": String }", "{ " + field + " }", contextWriter);
@@ -150,14 +218,24 @@ private void testSubscription(Function<Context, Context> contextModifier) {
150218

151219
private Mono<ExecutionGraphQlResponse> executeAsync(
152220
String schema, String document, Function<Context, Context> contextWriter) {
221+
return executeAsync(schema, document, contextWriter, null);
222+
}
223+
224+
private Mono<ExecutionGraphQlResponse> executeAsync(
225+
String schema, String document, Function<Context, Context> contextWriter, @Nullable DataFetcherExceptionResolver exceptionResolver) {
153226

154227
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
155228
context.registerBean(GreetingController.class, () -> greetingController);
156229
context.refresh();
157230

158-
TestExecutionGraphQlService graphQlService = GraphQlSetup.schemaContent(schema)
159-
.runtimeWiringForAnnotatedControllers(context)
160-
.toGraphQlService();
231+
GraphQlSetup graphQlSetup = GraphQlSetup.schemaContent(schema)
232+
.runtimeWiringForAnnotatedControllers(context);
233+
234+
if (exceptionResolver != null) {
235+
graphQlSetup.exceptionResolver(exceptionResolver);
236+
}
237+
238+
TestExecutionGraphQlService graphQlService = graphQlSetup.toGraphQlService();
161239

162240
return Mono.delay(Duration.ofMillis(10))
163241
.flatMap(aLong -> graphQlService.execute(document))
@@ -197,6 +275,12 @@ Mono<String> greetingMono(Principal principal) {
197275
return Mono.just("Hello");
198276
}
199277

278+
@QueryMapping
279+
Mono<String> greetingMonoNullable(@Nullable Principal principal) {
280+
this.principal = principal;
281+
return Mono.just("Hello");
282+
}
283+
200284
@SubscriptionMapping
201285
Flux<String> greetingSubscription(Principal principal) {
202286
this.principal = principal;

0 commit comments

Comments
 (0)