Skip to content

Commit 4acb0f0

Browse files
committed
wip
1 parent 84e0478 commit 4acb0f0

File tree

3 files changed

+130
-17
lines changed

3 files changed

+130
-17
lines changed

deps/rabbitmq_aws/include/rabbitmq_aws.hrl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@
6868
security_token :: security_token() | undefined,
6969
region :: region() | undefined,
7070
imdsv2_token:: imdsv2token() | undefined,
71-
error :: atom() | string() | undefined}).
71+
error :: atom() | string() | undefined,
72+
gun_connections = #{} :: #{string() => pid()} % host -> gun_pid mapping
73+
}).
7274
-type state() :: #state{}.
7375

7476
-type scheme() :: atom().

deps/rabbitmq_aws/src/rabbitmq_aws.erl

Lines changed: 125 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,11 @@ init([]) ->
180180
{ok, #state{}}.
181181

182182

183-
terminate(_, _) ->
183+
terminate(_, State) ->
184+
% Close all Gun connections
185+
maps:fold(fun(_Host, ConnPid, _Acc) ->
186+
gun:close(ConnPid)
187+
end, ok, State#state.gun_connections),
184188
ok.
185189

186190

@@ -223,7 +227,8 @@ handle_msg({set_credentials, NewState}, State) ->
223227
secret_access_key = NewState#state.secret_access_key,
224228
security_token = NewState#state.security_token,
225229
expiration = NewState#state.expiration,
226-
error = NewState#state.error}};
230+
error = NewState#state.error,
231+
gun_connections = State#state.gun_connections}};
227232

228233
handle_msg({set_region, Region}, State) ->
229234
{reply, ok, State#state{region = Region}};
@@ -293,7 +298,7 @@ get_content_type(Headers) ->
293298
proplists:get_value("Content-Type", Headers, "text/xml");
294299
Other -> Other
295300
end,
296-
parse_content_type(Value).
301+
parse_content_type(Value).
297302

298303
-spec has_credentials() -> boolean().
299304
has_credentials() ->
@@ -324,7 +329,7 @@ expired_credentials(Expiration) ->
324329
%% - Credentials file
325330
%% - EC2 Instance Metadata Service
326331
%% @end
327-
load_credentials(#state{region = Region}) ->
332+
load_credentials(#state{region = Region, gun_connections = GunConnections}) ->
328333
case rabbitmq_aws_config:credentials() of
329334
{ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} ->
330335
{ok, #state{region = Region,
@@ -333,7 +338,8 @@ load_credentials(#state{region = Region}) ->
333338
secret_access_key = SecretAccessKey,
334339
expiration = Expiration,
335340
security_token = SecurityToken,
336-
imdsv2_token = undefined}};
341+
imdsv2_token = undefined,
342+
gun_connections = GunConnections}};
337343
{error, Reason} ->
338344
?LOG_ERROR("Could not load AWS credentials from environment variables, AWS_CONFIG_FILE, AWS_SHARED_CREDENTIALS_FILE or EC2 metadata endpoint: ~tp. Will depend on config settings to be set~n", [Reason]),
339345
{error, #state{region = Region,
@@ -342,7 +348,8 @@ load_credentials(#state{region = Region}) ->
342348
secret_access_key = undefined,
343349
expiration = undefined,
344350
security_token = undefined,
345-
imdsv2_token = undefined}}
351+
imdsv2_token = undefined,
352+
gun_connections = GunConnections}}
346353
end.
347354

348355

@@ -383,7 +390,7 @@ parse_content_type(ContentType) ->
383390
%% @doc Make the API request and return the formatted response.
384391
%% @end
385392
perform_request(State, Service, Method, Headers, Path, Body, Options, Host) ->
386-
perform_request_has_creds(has_credentials(State), State, Service, Method,
393+
perform_request_has_creds(has_credentials(State), State, Service, Method,
387394
Headers, Path, Body, Options, Host).
388395

389396

@@ -397,7 +404,7 @@ perform_request(State, Service, Method, Headers, Path, Body, Options, Host) ->
397404
%% otherwise return an error result.
398405
%% @end
399406
perform_request_has_creds(true, State, Service, Method, Headers, Path, Body, Options, Host) ->
400-
perform_request_creds_expired(expired_credentials(State#state.expiration), State,
407+
perform_request_creds_expired(expired_credentials(State#state.expiration), State,
401408
Service, Method, Headers, Path, Body, Options, Host);
402409
perform_request_has_creds(false, State, _, _, _, _, _, _, _) ->
403410
perform_request_creds_error(State).
@@ -413,7 +420,7 @@ perform_request_has_creds(false, State, _, _, _, _, _, _, _) ->
413420
%% credentials before performing the request.
414421
%% @end
415422
perform_request_creds_expired(false, State, Service, Method, Headers, Path, Body, Options, Host) ->
416-
perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host);
423+
perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host);
417424
perform_request_creds_expired(true, State, _, _, _, _, _, _, _) ->
418425
perform_request_creds_error(State#state{error = "Credentials expired!"}).
419426

@@ -429,7 +436,7 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options,
429436
URI = endpoint(State, Host, Service, Path),
430437
SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body),
431438
ContentType = proplists:get_value("content-type", SignedHeaders, undefined),
432-
perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options).
439+
perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options).
433440

434441

435442
-spec perform_request_with_creds(State :: state(), Method :: method(), URI :: string(),
@@ -440,13 +447,12 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options,
440447
%% expired, perform the request and return the response.
441448
%% @end
442449
perform_request_with_creds(State, Method, URI, Headers, undefined, "", Options0) ->
443-
Options1 = ensure_timeout(Options0),
444-
Response = httpc:request(Method, {URI, Headers}, Options1, []),
445-
{format_response(Response), State};
450+
{Response, NewState} = gun_request(State, Method, URI, Headers, <<>>, Options0),
451+
{format_response(Response), NewState};
446452
perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Options0) ->
447-
Options1 = ensure_timeout(Options0),
448-
Response = httpc:request(Method, {URI, Headers, ContentType, Body}, Options1, []),
449-
{format_response(Response), State}.
453+
GunHeaders = [{"content-type", ContentType} | Headers],
454+
{Response, NewState} = gun_request(State, Method, URI, GunHeaders, Body, Options0),
455+
{format_response(Response), NewState}.
450456

451457

452458
-spec perform_request_creds_error(State :: state()) ->
@@ -567,3 +573,106 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) ->
567573
timer:sleep(WaitTimeBetweenRetries),
568574
api_get_request_with_retries(Service, Path, Retries - 1, WaitTimeBetweenRetries)
569575
end.
576+
577+
%% Gun HTTP client functions
578+
gun_request(State, Method, URI, Headers, Body, Options) ->
579+
{Host, Port, Path} = parse_uri(URI),
580+
{ConnPid, NewState} = get_or_create_gun_connection(State, Host, Port, Options),
581+
Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT),
582+
try
583+
StreamRef = gun:get(ConnPid, Path, Headers),
584+
case gun:await(ConnPid, StreamRef, Timeout) of
585+
{response, fin, Status, RespHeaders} ->
586+
Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}},
587+
{Response, NewState};
588+
{response, nofin, Status, RespHeaders} ->
589+
{ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout),
590+
Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, binary_to_list(RespBody)}},
591+
{Response, NewState};
592+
{error, Reason} ->
593+
{{error, Reason}, NewState}
594+
end
595+
catch
596+
_:Error ->
597+
% Connection failed, remove from pool and return error
598+
NewConnections = maps:remove(Host, NewState#state.gun_connections),
599+
gun:close(ConnPid),
600+
{{error, Error}, NewState#state{gun_connections = NewConnections}}
601+
end.
602+
603+
get_or_create_gun_connection(State, Host, Port, Options) ->
604+
HostKey = Host ++ ":" ++ integer_to_list(Port),
605+
case maps:get(HostKey, State#state.gun_connections, undefined) of
606+
undefined ->
607+
create_gun_connection(State, Host, Port, HostKey, Options);
608+
ConnPid ->
609+
case is_process_alive(ConnPid) andalso gun:info(ConnPid) =/= undefined of
610+
true ->
611+
{ConnPid, State};
612+
false ->
613+
% Connection is dead, create new one
614+
gun:close(ConnPid),
615+
create_gun_connection(State, Host, Port, HostKey, Options)
616+
end
617+
end.
618+
619+
create_gun_connection(State, Host, Port, HostKey, Options) ->
620+
% Map HTTP version to Gun protocols, always include http as fallback
621+
HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"),
622+
Protocols = case HttpVersion of
623+
"HTTP/2" -> [http2, http];
624+
"HTTP/2.0" -> [http2, http];
625+
"HTTP/1.1" -> [http];
626+
"HTTP/1.0" -> [http];
627+
_ -> [http2, http] % Default: try HTTP/2, fallback to HTTP/1.1
628+
end,
629+
ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000),
630+
Opts = #{
631+
transport => if Port == 443 -> tls; true -> tcp end,
632+
protocols => Protocols,
633+
connect_timeout => ConnectTimeout
634+
},
635+
application:ensure_all_started(gun),
636+
case gun:open(Host, Port, Opts) of
637+
{ok, ConnPid} ->
638+
case gun:await_up(ConnPid, ConnectTimeout) of
639+
{ok, _Protocol} ->
640+
NewConnections = maps:put(HostKey, ConnPid, State#state.gun_connections),
641+
NewState = State#state{gun_connections = NewConnections},
642+
{ConnPid, NewState};
643+
{error, Reason} ->
644+
gun:close(ConnPid),
645+
error({gun_connection_failed, Reason})
646+
end;
647+
{error, Reason} ->
648+
error({gun_open_failed, Reason})
649+
end.
650+
651+
parse_uri(URI) ->
652+
case string:split(URI, "://", leading) of
653+
[_Scheme, Rest] ->
654+
case string:split(Rest, "/", leading) of
655+
[HostPort] ->
656+
{Host, Port} = parse_host_port(HostPort),
657+
{Host, Port, "/"};
658+
[HostPort, Path] ->
659+
{Host, Port} = parse_host_port(HostPort),
660+
{Host, Port, "/" ++ Path}
661+
end
662+
end.
663+
664+
parse_host_port(HostPort) ->
665+
case string:split(HostPort, ":", trailing) of
666+
[Host] ->
667+
{Host, 443}; % Default HTTPS port
668+
[Host, PortStr] ->
669+
{Host, list_to_integer(PortStr)}
670+
end.
671+
672+
status_text(200) -> "OK";
673+
status_text(400) -> "Bad Request";
674+
status_text(401) -> "Unauthorized";
675+
status_text(403) -> "Forbidden";
676+
status_text(404) -> "Not Found";
677+
status_text(500) -> "Internal Server Error";
678+
status_text(Code) -> integer_to_list(Code).

deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
-include_lib("xmerl/include/xmerl.hrl").
1212

1313
-spec parse(Value :: string() | binary()) -> list().
14+
parse(Value) when is_binary(Value) ->
15+
parse(binary_to_list(Value));
1416
parse(Value) ->
1517
{Element, _} = xmerl_scan:string(Value),
1618
parse_node(Element).

0 commit comments

Comments
 (0)