|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | +package aws.smithy.kotlin.runtime.auth.awssigning |
| 6 | + |
| 7 | +import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials |
| 8 | +import aws.smithy.kotlin.runtime.auth.awscredentials.CredentialsProvider |
| 9 | +import aws.smithy.kotlin.runtime.client.endpoints.Endpoint |
| 10 | +import aws.smithy.kotlin.runtime.http.Headers |
| 11 | +import aws.smithy.kotlin.runtime.http.operation.EndpointResolver |
| 12 | +import aws.smithy.kotlin.runtime.http.operation.ResolveEndpointRequest |
| 13 | +import aws.smithy.kotlin.runtime.http.request.HttpRequest |
| 14 | +import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder |
| 15 | +import aws.smithy.kotlin.runtime.http.request.url |
| 16 | +import aws.smithy.kotlin.runtime.net.Url |
| 17 | +import aws.smithy.kotlin.runtime.operation.ExecutionContext |
| 18 | +import aws.smithy.kotlin.runtime.util.Attributes |
| 19 | +import kotlinx.coroutines.ExperimentalCoroutinesApi |
| 20 | +import kotlinx.coroutines.test.runTest |
| 21 | +import kotlin.test.Test |
| 22 | +import kotlin.test.assertEquals |
| 23 | + |
| 24 | +private const val NON_HTTPS_URL = "http://localhost:8080/path/to/resource?foo=bar" |
| 25 | + |
| 26 | +@OptIn(ExperimentalCoroutinesApi::class) |
| 27 | +class PresignerTest { |
| 28 | + // Verify that custom endpoint URL schemes aren't changed. |
| 29 | + // See https://github.com/awslabs/aws-sdk-kotlin/issues/938 |
| 30 | + @Test |
| 31 | + fun testSignedUrlAllowsHttp() = testSigningUrl("http://localhost:8080/path/to/resource?foo=bar") |
| 32 | + |
| 33 | + // Verify that custom endpoint URL schemes aren't changed. |
| 34 | + // See https://github.com/awslabs/aws-sdk-kotlin/issues/938 |
| 35 | + @Test |
| 36 | + fun testSignedUrlAllowsHttps() = testSigningUrl("https://localhost:8088/path/to/resource?bar=foo") |
| 37 | + |
| 38 | + private fun testSigningUrl(url: String) = runTest { |
| 39 | + val expectedUrl = Url.parse(url) |
| 40 | + |
| 41 | + val unsignedRequestBuilder = HttpRequestBuilder() |
| 42 | + val ctx = ExecutionContext() |
| 43 | + val credentialsProvider = TestCredentialsProvider(Credentials("foo", "bar")) |
| 44 | + val endpointResolver = TestEndpointResolver(Endpoint(expectedUrl)) |
| 45 | + val signer = TestSigner(HttpRequest { url(expectedUrl) }) |
| 46 | + val signingConfig: AwsSigningConfig.Builder.() -> Unit = { |
| 47 | + service = "launch-service" |
| 48 | + region = "the-moon" |
| 49 | + } |
| 50 | + |
| 51 | + val presignedRequest = presignRequest( |
| 52 | + unsignedRequestBuilder, |
| 53 | + ctx, |
| 54 | + credentialsProvider, |
| 55 | + endpointResolver, |
| 56 | + signer, |
| 57 | + signingConfig, |
| 58 | + ) |
| 59 | + |
| 60 | + val actualUrl = presignedRequest.url |
| 61 | + |
| 62 | + assertEquals(expectedUrl.scheme, actualUrl.scheme) |
| 63 | + assertEquals(expectedUrl.host, actualUrl.host) |
| 64 | + assertEquals(expectedUrl.port, actualUrl.port) |
| 65 | + assertEquals(expectedUrl.path, actualUrl.path) |
| 66 | + expectedUrl.parameters.forEach { key, value -> |
| 67 | + assertEquals(value, actualUrl.parameters.getAll(key)) |
| 68 | + } |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +private class TestCredentialsProvider(private val credentials: Credentials) : CredentialsProvider { |
| 73 | + override suspend fun resolve(attributes: Attributes): Credentials = credentials |
| 74 | +} |
| 75 | + |
| 76 | +private class TestEndpointResolver(private val resolvedEndpoint: Endpoint) : EndpointResolver { |
| 77 | + override suspend fun resolve(request: ResolveEndpointRequest): Endpoint = resolvedEndpoint |
| 78 | +} |
| 79 | + |
| 80 | +private class TestSigner(private val signedOutput: HttpRequest) : AwsSigner { |
| 81 | + override suspend fun sign(request: HttpRequest, config: AwsSigningConfig): AwsSigningResult<HttpRequest> = |
| 82 | + AwsSigningResult(signedOutput, byteArrayOf()) |
| 83 | + |
| 84 | + override suspend fun signChunk( |
| 85 | + chunkBody: ByteArray, |
| 86 | + prevSignature: ByteArray, |
| 87 | + config: AwsSigningConfig, |
| 88 | + ): AwsSigningResult<Unit> = error("Method should not be called") |
| 89 | + |
| 90 | + override suspend fun signChunkTrailer( |
| 91 | + trailingHeaders: Headers, |
| 92 | + prevSignature: ByteArray, |
| 93 | + config: AwsSigningConfig, |
| 94 | + ): AwsSigningResult<Unit> = error("Method should not be called") |
| 95 | +} |
0 commit comments