Skip to content

Commit e669ff9

Browse files
committed
All tests in aws_tests work
1 parent f2c0c82 commit e669ff9

File tree

3 files changed

+101
-84
lines changed

3 files changed

+101
-84
lines changed

deps/rabbitmq_aws/src/rabbitmq_aws.erl

Lines changed: 98 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
close_connection/1,
2727
direct_request/6,
2828
endpoint/4,
29-
sign_headers/9
29+
sign_headers/10
3030
]).
3131

3232
%% Export all for unit tests
@@ -37,83 +37,7 @@
3737
-include("rabbitmq_aws.hrl").
3838
-include_lib("kernel/include/logger.hrl").
3939

40-
-type connection_handle() :: {gun:conn_ref(), string()}.
41-
%%====================================================================
42-
%% ETS-based state management
43-
%%====================================================================
44-
45-
-spec get_credentials() ->
46-
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
47-
get_credentials() ->
48-
get_credentials(10).
49-
50-
-spec get_credentials(Retries :: non_neg_integer()) ->
51-
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
52-
get_credentials(Retries) ->
53-
case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of
54-
[{current, Creds}] ->
55-
case expired_credentials(Creds#aws_credentials.expiration) of
56-
false ->
57-
Region = get_region(),
58-
{ok, Creds#aws_credentials.access_key, Creds#aws_credentials.secret_key,
59-
Creds#aws_credentials.security_token, Region};
60-
true ->
61-
refresh_credentials_with_lock(Retries)
62-
end;
63-
[] ->
64-
refresh_credentials_with_lock(Retries)
65-
end.
66-
67-
-spec refresh_credentials_with_lock(Retries :: non_neg_integer()) ->
68-
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
69-
refresh_credentials_with_lock(0) ->
70-
{error, lock_timeout};
71-
refresh_credentials_with_lock(Retries) ->
72-
LockId = {aws_credentials_refresh, node()},
73-
case global:set_lock(LockId, [node()], 0) of
74-
true ->
75-
try
76-
% Double-check if someone else already refreshed
77-
case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of
78-
[{current, Creds}] ->
79-
case expired_credentials(Creds#aws_credentials.expiration) of
80-
false ->
81-
Region = get_region(),
82-
{ok, Creds#aws_credentials.access_key,
83-
Creds#aws_credentials.secret_key,
84-
Creds#aws_credentials.security_token, Region};
85-
true ->
86-
do_refresh_credentials()
87-
end;
88-
[] ->
89-
do_refresh_credentials()
90-
end
91-
after
92-
global:del_lock(LockId, [node()])
93-
end;
94-
false ->
95-
% Someone else is refreshing, wait and retry
96-
timer:sleep(100),
97-
get_credentials(Retries - 1)
98-
end.
99-
100-
-spec do_refresh_credentials() ->
101-
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
102-
do_refresh_credentials() ->
103-
Region = get_region(),
104-
case rabbitmq_aws_config:credentials() of
105-
{ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} ->
106-
Creds = #aws_credentials{
107-
access_key = AccessKey,
108-
secret_key = SecretAccessKey,
109-
security_token = SecurityToken,
110-
expiration = Expiration
111-
},
112-
ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}),
113-
{ok, AccessKey, SecretAccessKey, SecurityToken, Region};
114-
{error, Reason} ->
115-
{error, Reason}
116-
end.
40+
-type connection_handle() :: {pid(), string()}.
11741

11842
-spec get_region() -> region().
11943
get_region() ->
@@ -258,7 +182,16 @@ direct_request({GunPid, Service}, Method, Path, Body, Headers, Options) ->
258182
URI = create_uri(Host, Path),
259183
BodyHash = proplists:get_value(payload_hash, Options),
260184
SignedHeaders = sign_headers(
261-
AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash
185+
AccessKey,
186+
SecretKey,
187+
SecurityToken,
188+
Region,
189+
Service,
190+
Method,
191+
URI,
192+
Headers,
193+
Body,
194+
BodyHash
262195
),
263196
direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options);
264197
{error, Reason} ->
@@ -277,7 +210,9 @@ direct_request({GunPid, Service}, Method, Path, Body, Headers, Options) ->
277210
Body :: body(),
278211
BodyHash :: iodata()
279212
) -> headers().
280-
sign_headers(AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash) ->
213+
sign_headers(
214+
AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body, BodyHash
215+
) ->
281216
rabbitmq_aws_sign:headers(
282217
#request{
283218
access_key = AccessKey,
@@ -388,7 +323,16 @@ perform_request_direct(Service, Method, Headers, Path, Body, Options, Host) ->
388323
{ok, AccessKey, SecretKey, SecurityToken, Region} ->
389324
URI = endpoint(Region, Host, Service, Path),
390325
SignedHeaders = sign_headers(
391-
AccessKey, SecretKey, SecurityToken, Region, Service, Method, URI, Headers, Body
326+
AccessKey,
327+
SecretKey,
328+
SecurityToken,
329+
Region,
330+
Service,
331+
Method,
332+
URI,
333+
Headers,
334+
Body,
335+
undefined
392336
),
393337
gun_request(Method, URI, SignedHeaders, Body, Options);
394338
{error, Reason} ->
@@ -422,6 +366,79 @@ endpoint_tld("cn-northwest-1") ->
422366
endpoint_tld(_Other) ->
423367
"amazonaws.com".
424368

369+
-spec get_credentials() ->
370+
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
371+
get_credentials() ->
372+
get_credentials(10).
373+
374+
-spec get_credentials(Retries :: non_neg_integer()) ->
375+
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
376+
get_credentials(Retries) ->
377+
case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of
378+
[{current, Creds}] ->
379+
case expired_credentials(Creds#aws_credentials.expiration) of
380+
false ->
381+
Region = get_region(),
382+
{ok, Creds#aws_credentials.access_key, Creds#aws_credentials.secret_key,
383+
Creds#aws_credentials.security_token, Region};
384+
true ->
385+
refresh_credentials_with_lock(Retries)
386+
end;
387+
[] ->
388+
refresh_credentials_with_lock(Retries)
389+
end.
390+
391+
-spec refresh_credentials_with_lock(Retries :: non_neg_integer()) ->
392+
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
393+
refresh_credentials_with_lock(0) ->
394+
{error, lock_timeout};
395+
refresh_credentials_with_lock(Retries) ->
396+
LockId = {aws_credentials_refresh, node()},
397+
case global:set_lock(LockId, [node()], 0) of
398+
true ->
399+
try
400+
% Double-check if someone else already refreshed
401+
case ets:lookup(?AWS_CREDENTIALS_TABLE, current) of
402+
[{current, Creds}] ->
403+
case expired_credentials(Creds#aws_credentials.expiration) of
404+
false ->
405+
Region = get_region(),
406+
{ok, Creds#aws_credentials.access_key,
407+
Creds#aws_credentials.secret_key,
408+
Creds#aws_credentials.security_token, Region};
409+
true ->
410+
do_refresh_credentials()
411+
end;
412+
[] ->
413+
do_refresh_credentials()
414+
end
415+
after
416+
global:del_lock(LockId, [node()])
417+
end;
418+
false ->
419+
% Someone else is refreshing, wait and retry
420+
timer:sleep(100),
421+
get_credentials(Retries - 1)
422+
end.
423+
424+
-spec do_refresh_credentials() ->
425+
{ok, access_key(), secret_access_key(), security_token(), region()} | {error, term()}.
426+
do_refresh_credentials() ->
427+
Region = get_region(),
428+
case rabbitmq_aws_config:credentials() of
429+
{ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} ->
430+
Creds = #aws_credentials{
431+
access_key = AccessKey,
432+
secret_key = SecretAccessKey,
433+
security_token = SecurityToken,
434+
expiration = Expiration
435+
},
436+
ets:insert(?AWS_CREDENTIALS_TABLE, {current, Creds}),
437+
{ok, AccessKey, SecretAccessKey, SecurityToken, Region};
438+
{error, Reason} ->
439+
{error, Reason}
440+
end.
441+
425442
-spec format_response(Response :: httpc_result()) -> result().
426443
%% @doc Format the httpc response result, returning the request result data
427444
%% structure. The response body will attempt to be decoded by invoking the
@@ -676,7 +693,6 @@ status_text(416) -> "Range Not Satisfiable";
676693
status_text(500) -> "Internal Server Error";
677694
status_text(Code) -> integer_to_list(Code).
678695

679-
680696
-spec direct_gun_request(
681697
GunPid :: pid(),
682698
Method :: method(),

deps/rabbitmq_aws/test/rabbitmq_aws_sign_tests.erl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ request_hash_test_() ->
221221
{"Host", "iam.amazonaws.com"},
222222
{"Date", "20150830T123600Z"}
223223
],
224-
Payload = "",
224+
Payload = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
225225
Expectation = "49b454e0f20fe17f437eaa570846fc5d687efc1752c8b5a1eeee5597a7eb92a5",
226226
?assertEqual(
227227
Expectation,

deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ sign_headers_test_() ->
360360
Method,
361361
URI,
362362
Headers,
363-
Body
363+
Body,
364+
undefined
364365
)
365366
),
366367
meck:validate(calendar)

0 commit comments

Comments
 (0)