Skip to content

Commit 24e8dab

Browse files
authored
[RORDEV-1813] fix: better handling of ROR metadata embedded into header value (#1176)
1 parent 464bd3c commit 24e8dab

File tree

162 files changed

+2696
-737
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

162 files changed

+2696
-737
lines changed

core/src/main/scala/tech/beshu/ror/accesscontrol/domain/http.scala

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -69,40 +69,36 @@ object Header {
6969

7070
def apply(nameAndValue: (NonEmptyString, NonEmptyString)): Header = new Header(Name(nameAndValue._1), nameAndValue._2)
7171

72-
def fromRawHeaders(headers: Map[String, List[String]]): Set[Header] = {
73-
val (authorizationHeaders, otherHeaders) =
72+
def fromRawHeaders(headers: Map[String, List[String]]): Either[AuthorizationValueError, Set[Header]] = {
73+
val (authorizationHeaders, nonAuthorizationHeaders) =
7474
headers
7575
.map { case (name, values) => (name, values.toCovariantSet) }
76-
.flatMap { case (name, values) =>
77-
for {
78-
nonEmptyName <- NonEmptyString.unapply(name)
79-
nonEmptyValues <- NonEmptyList.fromList(values.toList.flatMap(NonEmptyString.unapply))
80-
} yield (Header.Name(nonEmptyName), nonEmptyValues)
81-
}
82-
.toSeq
83-
.partition { case (name, _) => name === Header.Name.authorization }
84-
val headersFromAuthorizationHeaderValues = authorizationHeaders
85-
.flatMap { case (_, values) =>
86-
val headersFromAuthorizationHeaderValues = values
87-
.map(fromAuthorizationValue)
88-
.toList
89-
.map(_.map(_.toList))
90-
.sequence
91-
.map(_.flatten)
92-
headersFromAuthorizationHeaderValues match {
93-
case Left(error) => throw new IllegalArgumentException(error.show)
94-
case Right(v) => v
95-
}
76+
.flatMap { case (name, values) => createHeadersFrom(name, values) }
77+
.partition(h => h.name === Header.Name.authorization)
78+
val headersFromAuthorizationHeaderValues =
79+
authorizationHeaders.toList
80+
.map(header => fromAuthorizationValue(header.value))
81+
.sequence
82+
.map(_.flatMap(_.toList))
83+
84+
headersFromAuthorizationHeaderValues
85+
.map { authHeaderBasedExtractedHeaders =>
86+
val restOfHeadersNames = nonAuthorizationHeaders.map(_.name).toCovariantSet
87+
val filteredAuthHeaderBasedExtractedHeaders = authHeaderBasedExtractedHeaders
88+
.filter { header => !restOfHeadersNames.contains(header.name) }
89+
(nonAuthorizationHeaders ++ filteredAuthHeaderBasedExtractedHeaders).toCovariantSet
9690
}
97-
.toCovariantSet
98-
val restOfHeaders = otherHeaders
99-
.flatMap { case (name, values) => values.map(new Header(name, _)).toList }
100-
.toCovariantSet
101-
val restOfHeaderNames = restOfHeaders.map(_.name)
102-
restOfHeaders ++ headersFromAuthorizationHeaderValues.filter { header => !restOfHeaderNames.contains(header.name) }
10391
}
10492

105-
def fromAuthorizationValue(value: NonEmptyString): Either[AuthorizationValueError, NonEmptyList[Header]] = {
93+
private def createHeadersFrom(name: String, values: Iterable[String]) = {
94+
val value = for {
95+
nonEmptyName <- NonEmptyString.unapply(name)
96+
nonEmptyValues <- NonEmptyList.fromList(values.toList.flatMap(NonEmptyString.unapply))
97+
} yield nonEmptyValues.toList.map(value => new Header(Header.Name(nonEmptyName), value))
98+
value.toList.flatten
99+
}
100+
101+
private def fromAuthorizationValue(value: NonEmptyString): Either[AuthorizationValueError, NonEmptyList[Header]] = {
106102
value.value.splitBy("ror_metadata=") match {
107103
case (_, None) =>
108104
Right(NonEmptyList.one(new Header(Name.authorization, value)))
@@ -182,7 +178,7 @@ object Address {
182178
Hostname.fromString(value).map(Address.Name.apply)
183179

184180
private def parseIpAddress(value: String) =
185-
(cutOffZoneIndex _ andThen IpAddress.fromString andThen (_.map(createAddressIp))) (value)
181+
(cutOffZoneIndex _ andThen IpAddress.fromString andThen (_.map(createAddressIp)))(value)
186182

187183
private def createAddressIp(ip: IpAddress) =
188184
Address.Ip(Cidr(ip, 32))
@@ -214,9 +210,9 @@ final case class UriPath private(value: NonEmptyString) {
214210
def isSqlQueryPath: Boolean = value.value.startsWith("/_sql")
215211

216212
def isXpackSqlQueryPath: Boolean = value.value.startsWith("/_xpack/sql")
217-
213+
218214
def isEsqlQueryPath: Boolean = value.value.startsWith("/_query")
219-
215+
220216
def isAliasesPath: Boolean =
221217
value.value.startsWith("/_cat/aliases") ||
222218
value.value.startsWith("/_alias") ||

core/src/main/scala/tech/beshu/ror/implicits.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import tech.beshu.ror.accesscontrol.blocks.definitions.ldap.implementations.Unbo
3333
import tech.beshu.ror.accesscontrol.blocks.definitions.ldap.implementations.UserGroupsSearchFilterConfig.UserGroupsSearchMode.*
3434
import tech.beshu.ror.accesscontrol.blocks.definitions.ldap.{Dn, LdapService}
3535
import tech.beshu.ror.accesscontrol.blocks.metadata.UserMetadata
36-
import tech.beshu.ror.accesscontrol.blocks.rules.Rule
3736
import tech.beshu.ror.accesscontrol.blocks.rules.Rule.{RuleName, RuleResult}
3837
import tech.beshu.ror.accesscontrol.blocks.rules.elasticsearch.{ActionsRule, FieldsRule, FilterRule, ResponseFieldsRule}
3938
import tech.beshu.ror.accesscontrol.blocks.rules.kibana.*
@@ -45,7 +44,6 @@ import tech.beshu.ror.accesscontrol.blocks.variables.startup.StartupResolvableVa
4544
import tech.beshu.ror.accesscontrol.blocks.variables.transformation.domain.*
4645
import tech.beshu.ror.accesscontrol.domain.*
4746
import tech.beshu.ror.accesscontrol.domain.AccessRequirement.{MustBeAbsent, MustBePresent}
48-
import tech.beshu.ror.accesscontrol.domain.Address.Ip
4947
import tech.beshu.ror.accesscontrol.domain.ClusterIndexName.Remote.ClusterName
5048
import tech.beshu.ror.accesscontrol.domain.FieldLevelSecurity.Strategy
5149
import tech.beshu.ror.accesscontrol.domain.GroupIdLike.GroupId
@@ -143,6 +141,7 @@ trait LogsShowInstances
143141
implicit val proxyAuthNameShow: Show[ProxyAuth.Name] = Show.show(_.value)
144142

145143
implicit def requestedIndexShow[T <: ClusterIndexName : Show]: Show[RequestedIndex[T]] = Show(_.name.show)
144+
146145
implicit val clusterIndexNameShow: Show[ClusterIndexName] = Show.show(_.stringify)
147146
implicit val localClusterIndexNameShow: Show[ClusterIndexName.Local] = Show.show(_.stringify)
148147
implicit val remoteClusterIndexNameShow: Show[ClusterIndexName.Remote] = Show.show(_.stringify)
@@ -381,19 +380,25 @@ trait LogsShowInstances
381380
showNamedIterable(name, option.toList)
382381
}
383382

384-
implicit val authorizationValueErrorShow: Show[AuthorizationValueError] = Show.show {
383+
val authorizationValueErrorWithDetailsShow: Show[AuthorizationValueError] = Show.show {
385384
case AuthorizationValueError.EmptyAuthorizationValue => "Empty authorization value"
386385
case AuthorizationValueError.InvalidHeaderFormat(value) => s"Unexpected header format in ror_metadata: [${value.show}]"
387386
case AuthorizationValueError.RorMetadataInvalidFormat(value, message) => s"Invalid format of ror_metadata: [${value.show}], reason: [${message.show}]"
388387
}
389388

389+
val authorizationValueErrorSanitizedShow: Show[AuthorizationValueError] = Show.show {
390+
case AuthorizationValueError.EmptyAuthorizationValue => "Empty authorization value"
391+
case AuthorizationValueError.InvalidHeaderFormat(_) => s"Unexpected header format in ror_metadata"
392+
case AuthorizationValueError.RorMetadataInvalidFormat(_, message) => s"Invalid format of ror_metadata. Reason: [${message.show}]"
393+
}
394+
390395
implicit val unresolvableErrorShow: Show[Unresolvable] = Show.show {
391396
case Unresolvable.CannotExtractValue(msg) => s"Cannot extract variable value. ${msg.show}"
392397
case Unresolvable.CannotInstantiateResolvedValue(msg) => s"Extracted value type doesn't fit. ${msg.show}"
393398
}
394399

395400
implicit def accessShow[T: Show]: Show[AccessRequirement[T]] = Show.show {
396401
case MustBePresent(value) => value.show
397-
case AccessRequirement.MustBeAbsent(value) => s"~${value.show}"
402+
case MustBeAbsent(value) => s"~${value.show}"
398403
}
399404
}

es67x/src/main/scala/tech/beshu/ror/es/RorRestChannel.scala

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,56 @@ import squants.information.{Bytes, Information}
2222
import tech.beshu.ror.accesscontrol.domain.{Address, Header, UriPath}
2323
import tech.beshu.ror.accesscontrol.request.RequestContext.Method
2424
import tech.beshu.ror.accesscontrol.request.RestRequest
25+
import tech.beshu.ror.es.utils.ThreadRepo
2526
import tech.beshu.ror.syntax.*
2627
import tech.beshu.ror.utils.RefinedUtils.nes
2728

2829
import java.net.{InetSocketAddress, SocketAddress}
2930
import scala.jdk.CollectionConverters.*
3031

31-
final class RorRestChannel(underlying: EsRestChannel)
32+
object RorRestChannel {
33+
def from(esRestChannel: EsRestChannel): Either[Header.AuthorizationValueError, RorRestChannel] = {
34+
RorRestRequest
35+
.from(esRestChannel.request())
36+
.map(new RorRestChannel(esRestChannel, _))
37+
}
38+
}
39+
final class RorRestChannel private(underlying: EsRestChannel, val restRequest: RorRestRequest)
3240
extends AbstractRestChannel(underlying.request(), true)
3341
with ResponseFieldsFiltering
3442
with Logging {
3543

36-
val restRequest: RorRestRequest = new RorRestRequest(underlying.request())
37-
3844
override def sendResponse(response: EsRestResponse): Unit = {
45+
ThreadRepo.removeRestChannel(this)
3946
underlying.sendResponse(filterRestResponse(response))
4047
}
4148
}
4249

43-
final class RorRestRequest(underlying: EsRestRequest) extends RestRequest {
50+
object RorRestRequest {
51+
52+
def from(esRestRequest: EsRestRequest): Either[Header.AuthorizationValueError, RorRestRequest] = {
53+
headersFrom(esRestRequest).map(new RorRestRequest(esRestRequest, _))
54+
}
55+
56+
private def headersFrom(esRestRequest: EsRestRequest) = {
57+
Header.fromRawHeaders(
58+
esRestRequest
59+
.getHeaders.asScala
60+
.view.mapValues(_.asScala.toList)
61+
.toMap
62+
)
63+
}
64+
}
65+
final class RorRestRequest private(underlying: EsRestRequest,
66+
headers: Set[Header]) extends RestRequest {
4467

4568
override lazy val method: Method = Method.fromStringUnsafe(underlying.method().name())
4669

4770
override lazy val path: UriPath = UriPath
4871
.from(underlying.path())
4972
.getOrElse(UriPath.from(nes("/")))
5073

51-
override lazy val allHeaders: Set[Header] = Header.fromRawHeaders(
52-
underlying
53-
.getHeaders.asScala
54-
.view.mapValues(_.asScala.toList)
55-
.toMap
56-
)
74+
override lazy val allHeaders: Set[Header] = headers
5775

5876
override lazy val localAddress: Address =
5977
createAddressFrom(_.getLocalAddress)

es67x/src/main/scala/tech/beshu/ror/es/handler/request/context/types/MultiGetEsRequestContext.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import org.elasticsearch.threadpool.ThreadPool
2626
import tech.beshu.ror.accesscontrol.blocks.BlockContext.FilterableMultiRequestBlockContext
2727
import tech.beshu.ror.accesscontrol.blocks.BlockContext.MultiIndexRequestBlockContext.Indices
2828
import tech.beshu.ror.accesscontrol.blocks.metadata.UserMetadata
29-
import tech.beshu.ror.accesscontrol.domain
3029
import tech.beshu.ror.accesscontrol.domain.*
3130
import tech.beshu.ror.accesscontrol.domain.DocumentAccessibility.{Accessible, Inaccessible}
3231
import tech.beshu.ror.accesscontrol.domain.FieldLevelSecurity.RequestFieldsUsage

es67x/src/main/scala/tech/beshu/ror/es/handler/request/context/types/templates/GetTemplatesEsRequestContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class GetTemplatesEsRequestContext(actionRequest: GetIndexTemplatesRequest,
107107

108108
private[templates] object GetTemplatesEsRequestContext extends Logging {
109109

110-
def filter(templates: List[IndexTemplateMetaData],
110+
def filter(templates: Iterable[IndexTemplateMetaData],
111111
usingTemplate: Set[Template] => Set[Template])
112112
(implicit requestContextId: RequestContext.Id): List[IndexTemplateMetaData] = {
113113
val templatesMap = templates

es67x/src/main/scala/tech/beshu/ror/es/utils/ChannelInterceptingRestHandlerDecorator.scala

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,50 @@
1616
*/
1717
package tech.beshu.ror.es.utils
1818

19+
import cats.implicits.*
20+
import org.apache.logging.log4j.scala.Logging
21+
import org.elasticsearch.ElasticsearchException
1922
import org.elasticsearch.client.node.NodeClient
2023
import org.elasticsearch.common.settings.Settings
24+
import org.elasticsearch.rest.*
2125
import org.elasticsearch.rest.action.cat.RestCatAction
22-
import org.elasticsearch.rest.{RestChannel, RestHandler, RestRequest}
2326
import org.joor.Reflect.on
27+
import tech.beshu.ror.accesscontrol.domain.Header.AuthorizationValueError
2428
import tech.beshu.ror.es.RorRestChannel
2529
import tech.beshu.ror.es.actions.wrappers._cat.rest.RorWrappedRestCatAction
2630
import tech.beshu.ror.es.utils.ThreadContextOps.createThreadContextOps
31+
import tech.beshu.ror.implicits.*
2732
import tech.beshu.ror.utils.AccessControllerHelper.doPrivileged
2833

29-
import scala.util.Try
34+
import scala.util.{Failure, Success, Try}
3035

3136
class ChannelInterceptingRestHandlerDecorator private(val underlying: RestHandler,
3237
settings: Settings)
33-
extends RestHandler {
38+
extends RestHandler with Logging {
3439

3540
private val wrapped = doPrivileged {
3641
wrapSomeActions(underlying)
3742
}
3843

3944
override def handleRequest(request: RestRequest, channel: RestChannel, client: NodeClient): Unit = {
40-
val rorRestChannel = new RorRestChannel(channel)
41-
ThreadRepo.setRestChannel(rorRestChannel)
42-
addXpackUserAuthenticationHeaderForInCaseOfSecurityRequest(request, client)
43-
wrapped.handleRequest(request, rorRestChannel, client)
45+
Try {
46+
RorRestChannel.from(channel) match {
47+
case Right(rorRestChannel) =>
48+
ThreadRepo.safeSetRestChannel(rorRestChannel) {
49+
addXpackUserAuthenticationHeaderForInCaseOfSecurityRequest(request, client)
50+
wrapped.handleRequest(request, rorRestChannel, client)
51+
}
52+
case Left(error) =>
53+
logError(error)
54+
implicit val show = authorizationValueErrorSanitizedShow
55+
channel.sendResponse(new BytesRestResponse(channel, RestStatus.BAD_REQUEST, new ElasticsearchException(error.show)))
56+
}
57+
} match {
58+
case Success(_) =>
59+
case Failure(ex) =>
60+
logger.error(s"The incoming request handling error:", ex)
61+
channel.sendResponse(new BytesRestResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, new ElasticsearchException("ROR internal error")))
62+
}
4463
}
4564

4665
override def canTripCircuitBreaker: Boolean = underlying.canTripCircuitBreaker
@@ -70,14 +89,25 @@ class ChannelInterceptingRestHandlerDecorator private(val underlying: RestHandle
7089
}
7190

7291
private def addXpackUserAuthenticationHeaderForInCaseOfSecurityRequest(request: RestRequest,
73-
client: NodeClient): Unit = {
92+
client: NodeClient): Unit = {
7493
if (request.path().contains("/_security") || request.path().contains("/_xpack/security")) {
7594
client
7695
.threadPool().getThreadContext
7796
.addXpackUserAuthenticationHeader(client.getLocalNodeId)
7897
}
7998
}
8099

100+
private def logError(error: AuthorizationValueError): Unit = {
101+
{
102+
implicit val show = authorizationValueErrorSanitizedShow
103+
logger.warn(s"The incoming request was malformed. Cause: ${error.show}")
104+
}
105+
if (logger.delegate.isDebugEnabled()) {
106+
implicit val show = authorizationValueErrorWithDetailsShow
107+
logger.debug(s"Malformed request detailed cause: ${error.show}")
108+
}
109+
}
110+
81111
}
82112

83113
object ChannelInterceptingRestHandlerDecorator {

es67x/src/main/scala/tech/beshu/ror/es/utils/ThreadRepo.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,19 @@ import org.elasticsearch.rest.RestRequest
2020
import tech.beshu.ror.accesscontrol.domain.UriPath
2121
import tech.beshu.ror.es.RorRestChannel
2222

23+
import scala.util.{Failure, Success, Try}
24+
2325
object ThreadRepo {
2426
private val threadLocalChannel = new ThreadLocal[RorRestChannel]
2527

26-
def setRestChannel(restChannel: RorRestChannel): Unit = {
28+
def safeSetRestChannel(restChannel: RorRestChannel)(code: => Unit): Unit = {
2729
threadLocalChannel.set(restChannel)
30+
Try(code) match {
31+
case Success(_) =>
32+
case Failure(ex) =>
33+
removeRestChannel(restChannel)
34+
throw ex
35+
}
2836
}
2937

3038
def removeRestChannel(restChannel: RorRestChannel): Unit = {

es70x/src/main/scala/tech/beshu/ror/es/RorRestChannel.scala

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,56 @@ import squants.information.{Bytes, Information}
2323
import tech.beshu.ror.accesscontrol.domain.{Address, Header, UriPath}
2424
import tech.beshu.ror.accesscontrol.request.RequestContext.Method
2525
import tech.beshu.ror.accesscontrol.request.RestRequest
26+
import tech.beshu.ror.es.utils.ThreadRepo
2627
import tech.beshu.ror.syntax.*
2728
import tech.beshu.ror.utils.RefinedUtils.nes
2829

2930
import java.net.InetSocketAddress
3031
import scala.jdk.CollectionConverters.*
3132

32-
final class RorRestChannel(underlying: EsRestChannel)
33+
object RorRestChannel {
34+
def from(esRestChannel: EsRestChannel): Either[Header.AuthorizationValueError, RorRestChannel] = {
35+
RorRestRequest
36+
.from(esRestChannel.request())
37+
.map(new RorRestChannel(esRestChannel, _))
38+
}
39+
}
40+
final class RorRestChannel private(underlying: EsRestChannel, val restRequest: RorRestRequest)
3341
extends AbstractRestChannel(underlying.request(), true)
3442
with ResponseFieldsFiltering
3543
with Logging {
3644

37-
val restRequest: RorRestRequest = new RorRestRequest(underlying.request())
38-
3945
override def sendResponse(response: EsRestResponse): Unit = {
46+
ThreadRepo.removeRestChannel(this)
4047
underlying.sendResponse(filterRestResponse(response))
4148
}
4249
}
4350

44-
final class RorRestRequest(underlying: EsRestRequest) extends RestRequest {
51+
object RorRestRequest {
52+
53+
def from(esRestRequest: EsRestRequest): Either[Header.AuthorizationValueError, RorRestRequest] = {
54+
headersFrom(esRestRequest).map(new RorRestRequest(esRestRequest, _))
55+
}
56+
57+
private def headersFrom(esRestRequest: EsRestRequest) = {
58+
Header.fromRawHeaders(
59+
esRestRequest
60+
.getHeaders.asScala
61+
.view.mapValues(_.asScala.toList)
62+
.toMap
63+
)
64+
}
65+
}
66+
final class RorRestRequest private(underlying: EsRestRequest,
67+
headers: Set[Header]) extends RestRequest {
4568

4669
override lazy val method: Method = Method.fromStringUnsafe(underlying.method().name())
4770

4871
override lazy val path: UriPath = UriPath
4972
.from(underlying.path())
5073
.getOrElse(UriPath.from(nes("/")))
5174

52-
override lazy val allHeaders: Set[Header] = Header.fromRawHeaders(
53-
underlying
54-
.getHeaders.asScala
55-
.view.mapValues(_.asScala.toList)
56-
.toMap
57-
)
75+
override lazy val allHeaders: Set[Header] = headers
5876

5977
override lazy val localAddress: Address =
6078
createAddressFrom(_.getLocalAddress)

0 commit comments

Comments
 (0)