diff --git a/CHANGES b/CHANGES new file mode 100644 index 0000000..be3a12e --- /dev/null +++ b/CHANGES @@ -0,0 +1,80 @@ +Changes we want +=============== +Curently ssm-env recognises variables beginning with ssm://, containing an SSM parameter +store key after this prefix, and substitutes them with their stored value by making calls +to the AWS SSM API. It then executes a child process with this new environment. + +We want it to also recognise KMS-encrypted parameters and have them substituted into +their plaintext value in the child process environment. These variables are of the format + + !kms '' + +where stands in for the base64-encoded KMS-encrypted secure test. + +Example +======= +Given the following example environment + + VAR1=VALUE1 + VAR2=ssm:///omnes/caeli + VAR3=!kms 'ABCDEFG==' + +currently SSM would detect that VAR2 has the appropriate prefix, query the SSM Parameter +Database for the /omnes/caeli key and change VAR2 into the plain value it returns, while +leaving VAR1 and VAR3 unchanged since they don't have the ssm:// prefix. Supposing the +stored value of said key is 'plainvalue' (without quotes), the environment of the child +process will be + + VAR1=VALUE1 + VAR2=plainvalue + VAR3=!kms 'ABCDEFG==' + +After the changes, for that sample enviroment we want ssm-env to detect that VAR3 has +the appropriate format and query the AWS KMS service to decrypt the kms-encrypted value, +getting the following modified environment for its child process (assuming that ABCDEFG== +is the base64 kms-encrypted value, under certain key, of 'muchacha' (without quotes): + + VAR1=VALUE + VAR2=plainvalue + VAR3=muchacha + +Note that the behaviour for neither VAR1 nor VAR2 is changed from the current +implementation. + +Considerations +============== +* Express the new KMS-encrypted value format as a regular expression. +* Attempt to mimic the current implementation as much as possible, substituting SSM API + calls for the necessary KMS API calls but keeping the larger structure and spirit + faithful to the project. +* Add tests akin to the current ones, in which we mock the KMS service. + +Existing implementation +======================= +The wanted extra functionality is currently provided by a shell function: + +kms_env() { + env -0 | while IFS== read -r -d '' name value + do + if echo "$value" | grep -q '^!kms ' + then + cipher="$(echo "$value" | sed -e 's/!kms //' | sed -e "s/'//g")" + # Don't quote cipher since it could already be quoted, and it's expected + # to be a base64-encrypted value + plain="$(aws kms decrypt --ciphertext-blob fileb://<(echo $cipher | base64 -d) --output text --query Plaintext | base64 -d)" \ + || return 1 + echo "export $name=$plain" + fi + done +} + +$(kms_env) || error_exit 'Error decrypting environment with kms_env' + +where the last line represents the replacing of the environment for the next +commands. Note that the behaviour is not exactly the same -- the idea of the +kms_env() function is to be run in a shell script previous to the commands +that will inherit the new environment, while ssm-env is expected to be run +and passing the subsequent commands as parameters, which it will then exec. + +However, it's useful as a reference for the format and the means in which +the KMS-encrypted values are received and expanded. diff --git a/README.md b/README.md index 5af17b3..a78760c 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,11 @@ You can most likely find the downloaded binary in `~/go/bin/ssm-env` ssm-env [-template STRING] [-with-decryption] [-no-fail] COMMAND ``` +### Parameter Formats + +- **SSM Parameter Store**: `ssm:///parameter-path` (fetches the value from AWS SSM Parameter Store) +- **KMS Encrypted**: `!kms 'base64EncodedEncryptedValue'` (decrypts base64-encoded value using AWS KMS) + ## Details Given the following environment: @@ -23,14 +28,16 @@ Given the following environment: ``` RAILS_ENV=production COOKIE_SECRET=ssm://prod.app.cookie-secret +API_KEY=!kms 'base64EncodedEncryptedValue==' ``` -You can run the application using `ssm-env` to automatically populate the `COOKIE_SECRET` env var from SSM: +You can run the application using `ssm-env` to automatically populate the `COOKIE_SECRET` env var from SSM and decrypt the `API_KEY` using AWS KMS: ```console $ ssm-env env RAILS_ENV=production COOKIE_SECRET=super-secret +API_KEY=decrypted-value ``` You can also configure how the parameter name is determined for an environment variable, by using the `-template` flag: diff --git a/main.go b/main.go index 7f64517..5bfe6f3 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "encoding/base64" "flag" "fmt" "os" @@ -13,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/ssm" ) @@ -24,6 +26,9 @@ const ( // defaultBatchSize is the default number of parameters to fetch at once. // The SSM API limits this to a maximum of 10 at the time of writing. defaultBatchSize = 10 + + // KMS prefix for variables that contain KMS-encrypted values + KMSPrefix = "!kms " ) // TemplateFuncs are helper functions provided to the template. @@ -77,6 +82,7 @@ func main() { batchSize: defaultBatchSize, t: t, ssm: &lazySSMClient{}, + kms: &lazyKMSClient{}, os: os, } must(e.expandEnviron(*decrypt, *nofail)) @@ -93,7 +99,7 @@ type lazySSMClient struct { func (c *lazySSMClient) GetParameters(input *ssm.GetParametersInput) (*ssm.GetParametersOutput, error) { // Initialize the SSM client (and AWS session) if it hasn't been already. if c.ssm == nil { - sess, err := c.awsSession() + sess, err := awsSession() if err != nil { return nil, err } @@ -102,7 +108,27 @@ func (c *lazySSMClient) GetParameters(input *ssm.GetParametersInput) (*ssm.GetPa return c.ssm.GetParameters(input) } -func (c *lazySSMClient) awsSession() (*session.Session, error) { +// lazyKMSClient wraps the AWS SDK KMS client such that the AWS session and +// KMS client are not actually initialized until Decrypt is called for +// the first time. +type lazyKMSClient struct { + kms kmsClient +} + +func (c *lazyKMSClient) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, error) { + // Initialize the KMS client (and AWS session) if it hasn't been already. + if c.kms == nil { + sess, err := awsSession() + if err != nil { + return nil, err + } + c.kms = kms.New(sess) + } + return c.kms.Decrypt(input) +} + +// awsSession creates and configures an AWS session with region detection +func awsSession() (*session.Session, error) { sess, err := session.NewSession(&aws.Config{ CredentialsChainVerboseErrors: aws.Bool(true), }) @@ -132,6 +158,10 @@ type ssmClient interface { GetParameters(*ssm.GetParametersInput) (*ssm.GetParametersOutput, error) } +type kmsClient interface { + Decrypt(*kms.DecryptInput) (*kms.DecryptOutput, error) +} + type environ interface { Environ() []string Setenv(key, vale string) @@ -155,6 +185,7 @@ type ssmVar struct { type expander struct { t *template.Template ssm ssmClient + kms kmsClient os environ batchSize int } @@ -172,14 +203,33 @@ func (e *expander) parameter(k, v string) (*string, error) { return nil, nil } +type kmsVar struct { + envvar string + encoded string +} + func (e *expander) expandEnviron(decrypt bool, nofail bool) error { // Environment variables that point to some SSM parameters. var ssmVars []ssmVar - + + // Environment variables that are KMS encrypted. + var kmsVars []kmsVar + uniqNames := make(map[string]bool) for _, envvar := range e.os.Environ() { k, v := splitVar(envvar) + // Check if this is a KMS encrypted value + if strings.HasPrefix(v, KMSPrefix) { + // Extract the base64 value by removing the prefix and any quotes + encodedPart := strings.TrimPrefix(v, KMSPrefix) + // Remove leading and trailing quotes if present + encodedPart = strings.Trim(encodedPart, "'\" ") + + kmsVars = append(kmsVars, kmsVar{k, encodedPart}) + continue + } + parameter, err := e.parameter(k, v) if err != nil { // TODO: Should this _also_ not error if nofail is passed? @@ -197,34 +247,47 @@ func (e *expander) expandEnviron(decrypt bool, nofail bool) error { } } - if len(uniqNames) == 0 { - // Nothing to do, no SSM parameters. - return nil - } + // Process SSM parameters + if len(uniqNames) > 0 { + names := make([]string, len(uniqNames)) + i := 0 + for k := range uniqNames { + names[i] = k + i++ + } - names := make([]string, len(uniqNames)) - i := 0 - for k := range uniqNames { - names[i] = k - i++ - } + for i := 0; i < len(names); i += e.batchSize { + j := i + e.batchSize + if j > len(names) { + j = len(names) + } - for i := 0; i < len(names); i += e.batchSize { - j := i + e.batchSize - if j > len(names) { - j = len(names) - } + values, err := e.getParameters(names[i:j], decrypt, nofail) + if err != nil { + return err + } - values, err := e.getParameters(names[i:j], decrypt, nofail) - if err != nil { - return err + for _, v := range ssmVars { + val, ok := values[v.parameter] + if ok { + e.os.Setenv(v.envvar, val) + } + } } + } - for _, v := range ssmVars { - val, ok := values[v.parameter] - if ok { - e.os.Setenv(v.envvar, val) + // Process KMS encrypted values + if len(kmsVars) > 0 { + for _, kv := range kmsVars { + decryptedValue, err := e.decryptKmsValue(kv.encoded, nofail) + if err != nil { + if nofail { + fmt.Fprintf(os.Stderr, "ssm-env: failed to decrypt KMS value: %v\n", err) + continue + } + return fmt.Errorf("failed to decrypt KMS value: %v", err) } + e.os.Setenv(kv.envvar, decryptedValue) } } @@ -287,6 +350,32 @@ func (e *invalidParametersError) Error() string { return fmt.Sprintf("invalid parameters: %v", e.InvalidParameters) } +// decryptKmsValue decrypts a base64-encoded KMS-encrypted value. +func (e *expander) decryptKmsValue(encodedValue string, nofail bool) (string, error) { + // Add padding to base64 if needed + // Base64 encoding requires the string length to be a multiple of 4 + padding := len(encodedValue) % 4 + if padding != 0 { + encodedValue = encodedValue + strings.Repeat("=", 4-padding) + } + + decodedBytes, err := base64.StdEncoding.DecodeString(encodedValue) + if err != nil { + return "", fmt.Errorf("failed to decode base64 value: %v", err) + } + + input := &kms.DecryptInput{ + CiphertextBlob: decodedBytes, + } + + result, err := e.kms.Decrypt(input) + if err != nil { + return "", err + } + + return string(result.Plaintext), nil +} + func splitVar(v string) (key, val string) { parts := strings.Split(v, "=") return parts[0], parts[1] diff --git a/main_test.go b/main_test.go index f64b635..1867711 100644 --- a/main_test.go +++ b/main_test.go @@ -1,12 +1,15 @@ package main import ( + "encoding/base64" "fmt" "sort" + "strings" "testing" "text/template" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/ssm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -15,10 +18,12 @@ import ( func TestExpandEnviron_NoSSMParameters(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -38,10 +43,12 @@ func TestExpandEnviron_NoSSMParameters(t *testing.T) { func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -73,10 +80,12 @@ func TestExpandEnviron_SimpleSSMParameter(t *testing.T) { func TestExpandEnviron_VersionedSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -108,10 +117,12 @@ func TestExpandEnviron_VersionedSSMParameter(t *testing.T) { func TestExpandEnviron_CustomTemplate(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(`{{ if eq .Name "SUPER_SECRET" }}/secret{{end}}`)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -143,10 +154,12 @@ func TestExpandEnviron_CustomTemplate(t *testing.T) { func TestExpandEnviron_DuplicateSSMParameter(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -180,10 +193,12 @@ func TestExpandEnviron_DuplicateSSMParameter(t *testing.T) { func TestExpandEnviron_MalformedParametersFail(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -198,10 +213,12 @@ func TestExpandEnviron_MalformedParametersFail(t *testing.T) { func TestExpandEnviron_MalformedParametersNofail(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -231,10 +248,12 @@ func TestExpandEnviron_MalformedParametersNofail(t *testing.T) { func TestExpandEnviron_InvalidParameters(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -258,10 +277,12 @@ func TestExpandEnviron_InvalidParameters(t *testing.T) { func TestExpandEnviron_InvalidParametersNoFail(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: defaultBatchSize, } @@ -291,10 +312,12 @@ func TestExpandEnviron_InvalidParametersNoFail(t *testing.T) { func TestExpandEnviron_BatchParameters(t *testing.T) { os := newFakeEnviron() c := new(mockSSM) + k := new(mockKMS) e := expander{ t: template.Must(parseTemplate(DefaultTemplate)), os: os, ssm: c, + kms: k, batchSize: 1, } @@ -334,6 +357,340 @@ func TestExpandEnviron_BatchParameters(t *testing.T) { c.AssertExpectations(t) } +// TestExtractKmsValue tests the KMS value extraction logic. +// It verifies: +// - Values starting with `!kms ` are correctly identified as KMS-encrypted +// - The base64-encoded part is properly extracted, removing the prefix and any quotes +// - Various formats of quoted/unquoted values are handled correctly +// - Non-KMS values are correctly identified as not matching the pattern +func TestExtractKmsValue(t *testing.T) { + // Test the KMS value extraction logic + + testCases := []struct { + input string + shouldMatch bool + expectedBase64 string + }{ + {"!kms QUJDREVGRw==", true, "QUJDREVGRw=="}, + {"!kms abc", true, "abc"}, + {"!kms abc===", true, "abc==="}, + {"!kms ", true, ""}, + {"!kmsa", false, ""}, + {"ssm:///path", false, ""}, + } + + for _, tc := range testCases { + if tc.shouldMatch { + // If it should match, we expect the string to start with the KMS prefix + if strings.HasPrefix(tc.input, KMSPrefix) { + // Extract the base64 part + encodedPart := strings.TrimPrefix(tc.input, KMSPrefix) + // Remove quotes + encodedPart = strings.Trim(encodedPart, "'\" ") + assert.Equal(t, tc.expectedBase64, encodedPart, + "Wrong base64 extraction for '%s'", tc.input) + } else { + assert.Fail(t, "Expected '%s' to start with KMS prefix", tc.input) + } + } else { + // If it should not match, we expect the string NOT to start with the KMS prefix + assert.False(t, strings.HasPrefix(tc.input, KMSPrefix), + "Expected '%s' not to start with KMS prefix", tc.input) + } + } +} + +// TestDecryptKmsValue tests the decryptKmsValue function directly. +// It: +// - Verifies base64 decoding works correctly +// - Mocks the KMS client to simulate decryption +// - Confirms the decryption process returns the expected value +// - Tests the function in isolation from the rest of the environment expansion logic +func TestDecryptKmsValue(t *testing.T) { + // Test the decryptKmsValue function directly + kmsClient := new(mockKMS) + e := &expander{ + kms: kmsClient, + } + + // Verify the base64 decoding works as expected + value := "QUJDREVGR0g=" + decoded, err := base64.StdEncoding.DecodeString(value) + assert.NoError(t, err) + assert.Equal(t, "ABCDEFGH", string(decoded)) + + // Setup KMS mock + kmsClient.On("Decrypt", mock.MatchedBy(func(input *kms.DecryptInput) bool { + return string(input.CiphertextBlob) == "ABCDEFGH" + })).Return(&kms.DecryptOutput{ + Plaintext: []byte("decrypted-secret"), + }, nil) + + // Call the function directly + result, err := e.decryptKmsValue("QUJDREVGR0g=", false) + assert.NoError(t, err) + assert.Equal(t, "decrypted-secret", result) + + // Verify expectations + kmsClient.AssertExpectations(t) +} + +// A simpler implementation of test environment that works specifically for KMS values +type testEnviron struct { + env map[string]string +} + +func newTestEnviron() testEnviron { + return testEnviron{ + env: map[string]string{ + "SHELL": "/bin/bash", + "TERM": "screen-256color", + }, + } +} + +func (e testEnviron) Environ() []string { + var env []string + for k, v := range e.env { + envStr := fmt.Sprintf("%s=%s", k, v) + env = append(env, envStr) + } + return env +} + +func (e testEnviron) Setenv(key, val string) { + e.env[key] = val +} + +// TestExpandEnviron_KMSParameter mirrors the TestExpandEnviron_SimpleSSMParameter test but for KMS variables. +// It: +// - Sets up an environment with a KMS-encrypted value +// - Mocks the KMS client to return a decrypted value +// - Verifies the environment variable gets properly replaced with the decrypted value +func TestExpandEnviron_KMSParameter(t *testing.T) { + // Create a testEnviron instance that works better for our test + os := newTestEnviron() + + // Set the KMS env var with proper base64 + // ABCDEFGH -> QUJDREVGR0g= (but avoid quotes which cause problems) + kmsValue := "!kms QUJDREVGR0g=" + os.env["KMS_SECRET"] = kmsValue + + // Create mock clients + kmsClient := new(mockKMS) + ssmClient := new(mockSSM) + + // Create expander + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: os, + ssm: ssmClient, + kms: kmsClient, + batchSize: defaultBatchSize, + } + + // Setup KMS mock + kmsClient.On("Decrypt", mock.MatchedBy(func(input *kms.DecryptInput) bool { + return string(input.CiphertextBlob) == "ABCDEFGH" + })).Return(&kms.DecryptOutput{ + Plaintext: []byte("decrypted-secret"), + }, nil) + + // Call expandEnviron + err := e.expandEnviron(false, false) + assert.NoError(t, err) + + // Check the result + assert.Equal(t, "decrypted-secret", os.env["KMS_SECRET"]) + + // Verify expectations + kmsClient.AssertExpectations(t) +} + +// TestExpandEnviron_KMSAndSSMParameters verifies both KMS and SSM parameters can be processed together. +// It: +// - Creates an environment with both SSM and KMS variables +// - Mocks both clients to return appropriate values +// - Confirms both types of variables are correctly replaced +func TestExpandEnviron_KMSAndSSMParameters(t *testing.T) { + // Create a testEnviron instance that works better for our test + os := newTestEnviron() + + // Set both SSM parameter and KMS encrypted value + os.env["SSM_SECRET"] = "ssm:///secret" + // The base64 value "QUJDREVGR0g=" decodes to "ABCDEFGH" + os.env["KMS_SECRET"] = "!kms QUJDREVGR0g=" + + // Create mock clients + ssmClient := new(mockSSM) + kmsClient := new(mockKMS) + + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: os, + ssm: ssmClient, + kms: kmsClient, + batchSize: defaultBatchSize, + } + + // Setup mocks + ssmClient.On("GetParameters", &ssm.GetParametersInput{ + Names: []*string{aws.String("/secret")}, + WithDecryption: aws.Bool(false), + }).Return(&ssm.GetParametersOutput{ + Parameters: []*ssm.Parameter{ + {Name: aws.String("/secret"), Value: aws.String("ssm-value")}, + }, + }, nil) + + kmsClient.On("Decrypt", mock.MatchedBy(func(input *kms.DecryptInput) bool { + return string(input.CiphertextBlob) == "ABCDEFGH" + })).Return(&kms.DecryptOutput{ + Plaintext: []byte("kms-value"), + }, nil) + + decrypt := false + nofail := false + err := e.expandEnviron(decrypt, nofail) + assert.NoError(t, err) + + // Check the values directly + assert.Equal(t, "kms-value", os.env["KMS_SECRET"]) + assert.Equal(t, "ssm-value", os.env["SSM_SECRET"]) + + ssmClient.AssertExpectations(t) + kmsClient.AssertExpectations(t) +} + +// TestExpandEnviron_InvalidKMSParameter mirrors TestExpandEnviron_MalformedParametersFail but for KMS values. +// It: +// - Tests behavior when an invalid base64 string is provided as a KMS value +// - Verifies that without the -no-fail flag, the process errors out with a base64 decoding error +func TestExpandEnviron_InvalidKMSParameter(t *testing.T) { + // Create a testEnviron instance that works better for our test + os := newTestEnviron() + + // Set KMS encrypted value with invalid base64 + os.env["KMS_SECRET"] = "!kms INVALID-BASE64!" + + // Create mock clients + ssmClient := new(mockSSM) + kmsClient := new(mockKMS) + + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: os, + ssm: ssmClient, + kms: kmsClient, + batchSize: defaultBatchSize, + } + + // No mocks needed as it should fail at base64 decoding stage + + decrypt := false + nofail := false + err := e.expandEnviron(decrypt, nofail) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode base64 value") +} + +// TestExpandEnviron_InvalidKMSParameterNoFail mirrors TestExpandEnviron_MalformedParametersNofail for KMS values. +// It: +// - Tests the same case as TestExpandEnviron_InvalidKMSParameter but with -no-fail set to true +// - Verifies that the environment variable remains unchanged when there's an error +// - Confirms the process continues without failing, as directed by the flag +func TestExpandEnviron_InvalidKMSParameterNoFail(t *testing.T) { + // Create a testEnviron instance that works better for our test + os := newTestEnviron() + + // Set KMS encrypted value with invalid base64 + os.env["KMS_SECRET"] = "!kms INVALID-BASE64!" + + // Create mock clients + ssmClient := new(mockSSM) + kmsClient := new(mockKMS) + + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: os, + ssm: ssmClient, + kms: kmsClient, + batchSize: defaultBatchSize, + } + + // With nofail=true, it shouldn't error + decrypt := false + nofail := true + err := e.expandEnviron(decrypt, nofail) + assert.NoError(t, err) + + // And the environment variable should remain unchanged + assert.Equal(t, "!kms INVALID-BASE64!", os.env["KMS_SECRET"]) +} + +// TestExpandEnviron_KMSDecryptFails tests error handling when KMS decryption fails. +// It: +// - Sets up a valid base64-encoded value but makes the KMS client return an error +// - Tests both with and without the -no-fail flag to verify correct behavior +// - Confirms error messages are properly reported +// - Verifies that with -no-fail, the original value is preserved and execution continues +func TestExpandEnviron_KMSDecryptFails(t *testing.T) { + // Create a testEnviron instance that works better for our test + os := newTestEnviron() + + // The base64 value "QUJDREVGR0g=" decodes to "ABCDEFGH" + os.env["KMS_SECRET"] = "!kms QUJDREVGR0g=" + + // Create mock clients + ssmClient := new(mockSSM) + kmsClient := new(mockKMS) + + e := expander{ + t: template.Must(parseTemplate(DefaultTemplate)), + os: os, + ssm: ssmClient, + kms: kmsClient, + batchSize: defaultBatchSize, + } + + // Setup mock for KMS Decrypt operation to fail + kmsClient.On("Decrypt", mock.MatchedBy(func(input *kms.DecryptInput) bool { + return string(input.CiphertextBlob) == "ABCDEFGH" + })).Return(&kms.DecryptOutput{}, fmt.Errorf("KMS error: access denied")) + + // With nofail=false, it should error + decrypt := false + nofail := false + err := e.expandEnviron(decrypt, nofail) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decrypt KMS value") + assert.Contains(t, err.Error(), "KMS error: access denied") + + // Re-setup the mock for the second test since it was already used + kmsClient = new(mockKMS) + e.kms = kmsClient + kmsClient.On("Decrypt", mock.MatchedBy(func(input *kms.DecryptInput) bool { + return string(input.CiphertextBlob) == "ABCDEFGH" + })).Return(&kms.DecryptOutput{}, fmt.Errorf("KMS error: access denied")) + + // With nofail=true, it shouldn't error + nofail = true + err = e.expandEnviron(decrypt, nofail) + assert.NoError(t, err) + + // And the environment variable should remain unchanged + found := false + for _, env := range os.Environ() { + if strings.HasPrefix(env, "KMS_SECRET=") { + found = true + assert.Equal(t, "KMS_SECRET=!kms QUJDREVGR0g=", env) + } + } + assert.True(t, found, "KMS_SECRET environment variable should still exist") + + kmsClient.AssertExpectations(t) +} + type fakeEnviron map[string]string func newFakeEnviron() fakeEnviron { @@ -346,6 +703,7 @@ func newFakeEnviron() fakeEnviron { func (e fakeEnviron) Environ() []string { var env sort.StringSlice for k, v := range e { + // Force raw string to preserve all characters including single quotes and = signs env = append(env, fmt.Sprintf("%s=%s", k, v)) } env.Sort() @@ -364,3 +722,12 @@ func (m *mockSSM) GetParameters(input *ssm.GetParametersInput) (*ssm.GetParamete args := m.Called(input) return args.Get(0).(*ssm.GetParametersOutput), args.Error(1) } + +type mockKMS struct { + mock.Mock +} + +func (m *mockKMS) Decrypt(input *kms.DecryptInput) (*kms.DecryptOutput, error) { + args := m.Called(input) + return args.Get(0).(*kms.DecryptOutput), args.Error(1) +}