4747import java .net .URL ;
4848import java .nio .charset .StandardCharsets ;
4949import java .text .ParseException ;
50+ import java .util .ArrayList ;
5051import java .util .HashSet ;
5152import java .util .List ;
5253import java .util .Set ;
@@ -67,6 +68,25 @@ public class SSOCookieFederationFilter extends AbstractJWTFilter {
6768 public static final String X_FORWARDED_PORT = "X-Forwarded-Port" ;
6869 public static final String X_FORWARDED_PROTO = "X-Forwarded-Proto" ;
6970
71+ /* Overwrite original from header */
72+ /* Feature flag to turn the original url from header for SSO ON */
73+ public static final String SHOULD_USE_ORIGINAL_URL_FROM_HEADER = "sso.use.original.url.from.header" ;
74+ public static final String X_ORIGINAL_URL = "X-Original-URL" ;
75+ /* Users can choose to use custom header names */
76+ public static final String X_ORIGINAL_URL_HEADER_NAME = "sso.original.url.from.header.name" ;
77+ private static final boolean DEFAULT_SHOULD_USE_ORIGINAL_URL_FROM_HEADER = false ;
78+ /* Should we check for domain in configured whitelist? */
79+ public static final String VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN = "sso.original.url.from.header.verify.domain" ;
80+ /*
81+ * This is ONLY needed when you want tighter access,
82+ * we already have `knoxsso.redirect.whitelist.regex` property
83+ * that checks for redirect URL. If you add domains to whitelist here
84+ * make sure they are added there as well.
85+ */
86+ private static final boolean DEFAULT_VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN = false ;
87+ /* Param that specifies the whitelist for original url header domains, domains are comma seperated list */
88+ public static final String VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN_WHITELIST = "sso.original.url.from.header.domain.whitelist" ;
89+
7090 private static final String ORIGINAL_URL_QUERY_PARAM = "originalUrl=" ;
7191 public static final String DEFAULT_SSO_COOKIE_NAME = "hadoop-jwt" ;
7292
@@ -76,7 +96,12 @@ public class SSOCookieFederationFilter extends AbstractJWTFilter {
7696 private String cookieName ;
7797 private String authenticationProviderUrl ;
7898 private String gatewayPath ;
79- private Set <String > unAuthenticatedPaths = new HashSet <>(20 );
99+ private final Set <String > unAuthenticatedPaths = new HashSet <>(20 );
100+
101+ private boolean shouldUseOriginalUrlFromHeader = DEFAULT_SHOULD_USE_ORIGINAL_URL_FROM_HEADER ;
102+ private boolean verifyOriginalUrlFromHeaderDomain = DEFAULT_VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN ;
103+ private final List <String > verifyOriginalUrlFromHeaderDomainWhitelist = new ArrayList <>();
104+ private String originalUrlHeaderName ;
80105
81106 @ Override
82107 public void init ( FilterConfig filterConfig ) throws ServletException {
@@ -121,6 +146,46 @@ public void init( FilterConfig filterConfig ) throws ServletException {
121146 LOGGER .configuredIdleTimeout (idleTimeoutSeconds , topologyName );
122147 }
123148
149+ /* Support to overwrite originalUrl by providing an option to pick it up from the request header value */
150+ final String shouldUseOriginalUrlFromHeaderFilterParam = filterConfig .getInitParameter (SHOULD_USE_ORIGINAL_URL_FROM_HEADER );
151+ if (shouldUseOriginalUrlFromHeaderFilterParam != null ) {
152+ shouldUseOriginalUrlFromHeader = Boolean .parseBoolean (shouldUseOriginalUrlFromHeaderFilterParam );
153+ } else {
154+ shouldUseOriginalUrlFromHeader = DEFAULT_SHOULD_USE_ORIGINAL_URL_FROM_HEADER ;
155+ }
156+
157+ /*
158+ * If the feature to use update orignalurl for SSO to use headers is on populate
159+ * required fields, else don't bother
160+ */
161+ if (shouldUseOriginalUrlFromHeader ) {
162+ originalUrlHeaderName = filterConfig .getInitParameter (X_ORIGINAL_URL_HEADER_NAME );
163+ if (originalUrlHeaderName == null ) {
164+ originalUrlHeaderName = X_ORIGINAL_URL ;
165+ }
166+
167+ final String verifyOriginalUrlFromHeaderDomainFilterParam = filterConfig .getInitParameter (VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN );
168+ if (verifyOriginalUrlFromHeaderDomainFilterParam != null ) {
169+ verifyOriginalUrlFromHeaderDomain = Boolean .parseBoolean (verifyOriginalUrlFromHeaderDomainFilterParam );
170+ } else {
171+ verifyOriginalUrlFromHeaderDomain = DEFAULT_VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN ;
172+ }
173+
174+ /* populate the whitelisted domains */
175+ final String verifyOriginalUrlDomainWhitelistParam = filterConfig .getInitParameter (VERIFY_ORIGINAL_URL_FROM_HEADER_DOMAIN_WHITELIST );
176+ if (verifyOriginalUrlFromHeaderDomain && verifyOriginalUrlDomainWhitelistParam != null ) {
177+ final String [] domains = verifyOriginalUrlDomainWhitelistParam .split ("," );
178+ for (final String domain : domains ) {
179+ final String trimmedDomain = domain .trim ();
180+ if (!trimmedDomain .isEmpty ()) {
181+ verifyOriginalUrlFromHeaderDomainWhitelist .add (trimmedDomain );
182+ }
183+ }
184+ }
185+ }
186+
187+
188+
124189 configureExpectedParameters (filterConfig );
125190 }
126191
@@ -265,9 +330,36 @@ protected String constructLoginURL(HttpServletRequest request) {
265330 if (providerURL .contains ("?" )) {
266331 delimiter = "&" ;
267332 }
268- return providerURL + delimiter
269- + ORIGINAL_URL_QUERY_PARAM
270- + request .getRequestURL ().append (getOriginalQueryString (request ));
333+
334+ if (shouldUseOriginalUrlFromHeader && (request .getHeader (originalUrlHeaderName ) != null ) && !request .getHeader (originalUrlHeaderName ).trim ().isEmpty ()) {
335+ final String originalUrlFromHeader = request .getHeader (originalUrlHeaderName );
336+ LOGGER .usingOriginalUrlFromHeader (originalUrlFromHeader );
337+ /* verify if the original request domain and the domain in the header matches */
338+ if (verifyOriginalUrlFromHeaderDomain ) {
339+ try {
340+ final URL originalUrl = new URL (originalUrlFromHeader );
341+ final String originalDomain = originalUrl .getHost ();
342+ if (!verifyOriginalUrlFromHeaderDomainWhitelist .contains (originalDomain )) {
343+ LOGGER .invalidOriginalUrlDomain (originalDomain );
344+ throw new IllegalArgumentException ("Original URL domain '" + originalDomain +
345+ "' is not in the allowed whitelist" );
346+ }
347+ } catch (final MalformedURLException e ) {
348+ LOGGER .malformedOriginalUrlDomain (originalUrlFromHeader );
349+ throw new IllegalArgumentException ("Invalid original URL format: " + originalUrlFromHeader , e );
350+ }
351+ }
352+
353+ LOGGER .originalHeaderURLForwarding (originalUrlFromHeader , originalUrlHeaderName );
354+ return providerURL + delimiter
355+ + ORIGINAL_URL_QUERY_PARAM
356+ + originalUrlFromHeader ;
357+ } else {
358+ return providerURL + delimiter
359+ + ORIGINAL_URL_QUERY_PARAM
360+ + request .getRequestURL ().append (getOriginalQueryString (request ));
361+ }
362+
271363 }
272364
273365 public String deriveDefaultAuthenticationProviderUrl (HttpServletRequest request ) {
0 commit comments