Skip to content

Commit 6b3d992

Browse files
committed
fix: on bad payload, block join
1 parent 1d140af commit 6b3d992

File tree

4 files changed

+114
-94
lines changed

4 files changed

+114
-94
lines changed

lib/realtime_web/channels/realtime_channel.ex

Lines changed: 66 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ defmodule RealtimeWeb.RealtimeChannel do
55
use RealtimeWeb, :channel
66
use RealtimeWeb.RealtimeChannel.Logging
77

8-
alias RealtimeWeb.SocketDisconnect
98
alias DBConnection.Backoff
9+
alias Phoenix.Socket
1010

1111
alias Realtime.Crypto
1212
alias Realtime.GenCounter
@@ -22,11 +22,14 @@ defmodule RealtimeWeb.RealtimeChannel do
2222
alias Realtime.Tenants.Connect
2323

2424
alias RealtimeWeb.Channels.Payloads.Join
25+
alias RealtimeWeb.Channels.Payloads.Config
26+
alias RealtimeWeb.Channels.Payloads.PostgresChange
2527
alias RealtimeWeb.ChannelsAuthorization
2628
alias RealtimeWeb.RealtimeChannel.BroadcastHandler
2729
alias RealtimeWeb.RealtimeChannel.MessageDispatcher
2830
alias RealtimeWeb.RealtimeChannel.PresenceHandler
2931
alias RealtimeWeb.RealtimeChannel.Tracker
32+
alias RealtimeWeb.SocketDisconnect
3033

3134
@confirm_token_ms_interval :timer.minutes(5)
3235

@@ -47,20 +50,11 @@ defmodule RealtimeWeb.RealtimeChannel do
4750
Logger.metadata(external_id: tenant_id, project: tenant_id)
4851
Logger.put_process_level(self(), log_level)
4952

50-
socket =
51-
socket
52-
|> assign_access_token(params)
53-
|> assign_counter()
54-
|> assign_presence_counter()
55-
|> assign(:private?, !!params["config"]["private"])
56-
|> assign(:policies, nil)
57-
58-
case Join.validate(params) do
59-
{:ok, _join} -> nil
60-
{:error, :invalid_join_payload, errors} -> log_error(socket, "InvalidJoinPayload", errors)
61-
end
53+
# We always need to assign the access token so we can get the logs metadata working as expected
54+
socket = assign_access_token(socket, params)
6255

63-
with :ok <- SignalHandler.shutdown_in_progress?(),
56+
with {:ok, %Socket{} = socket, %Join{} = configuration} <- configure_socket(socket, params),
57+
:ok <- SignalHandler.shutdown_in_progress?(),
6458
:ok <- only_private?(tenant_id, socket),
6559
:ok <- limit_joins(socket),
6660
:ok <- limit_channels(socket),
@@ -70,7 +64,6 @@ defmodule RealtimeWeb.RealtimeChannel do
7064
{:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id),
7165
{:ok, socket} <- maybe_assign_policies(sub_topic, db_conn, socket) do
7266
tenant_topic = Tenants.tenant_topic(tenant_id, sub_topic, !socket.assigns.private?)
73-
7467
# fastlane subscription
7568
metadata =
7669
MessageDispatcher.fastlane_metadata(transport_pid, serializer, topic, socket.assigns.log_level, tenant_id)
@@ -79,15 +72,11 @@ defmodule RealtimeWeb.RealtimeChannel do
7972

8073
Phoenix.PubSub.subscribe(Realtime.PubSub, "realtime:operations:" <> tenant_id)
8174

82-
is_new_api = new_api?(params)
83-
# TODO: Default will be moved to false in the future
84-
presence_enabled? =
85-
case get_in(params, ["config", "presence", "enabled"]) do
86-
enabled when is_boolean(enabled) -> enabled
87-
_ -> true
88-
end
75+
is_new_api = new_api?(configuration)
8976

90-
pg_change_params = pg_change_params(is_new_api, params, channel_pid, claims, sub_topic)
77+
presence_enabled? = Join.presence_enabled?(configuration)
78+
79+
pg_change_params = pg_change_params(is_new_api, configuration, channel_pid, claims, sub_topic)
9180

9281
opts = %{
9382
is_new_api: is_new_api,
@@ -104,13 +93,13 @@ defmodule RealtimeWeb.RealtimeChannel do
10493
state = %{postgres_changes: add_id_to_postgres_changes(pg_change_params)}
10594

10695
assigns = %{
107-
ack_broadcast: !!params["config"]["broadcast"]["ack"],
96+
ack_broadcast: Join.ack_broadcast?(configuration),
10897
confirm_token_ref: confirm_token_ref,
10998
is_new_api: is_new_api,
11099
pg_sub_ref: nil,
111100
pg_change_params: pg_change_params,
112-
presence_key: presence_key(params),
113-
self_broadcast: !!params["config"]["broadcast"]["self"],
101+
presence_key: Join.presence_key(configuration),
102+
self_broadcast: Join.self_broadcast?(configuration),
114103
tenant_topic: tenant_topic,
115104
channel_name: sub_topic,
116105
presence_enabled?: presence_enabled?
@@ -124,6 +113,9 @@ defmodule RealtimeWeb.RealtimeChannel do
124113

125114
{:ok, state, assign(socket, assigns)}
126115
else
116+
{:error, :invalid_join_payload, errors, socket} ->
117+
log_error(socket, "InvalidJoinPayload", errors)
118+
127119
{:error, :expired_token, msg} ->
128120
maybe_log_warning(socket, "InvalidJWTToken", msg)
129121

@@ -200,6 +192,23 @@ defmodule RealtimeWeb.RealtimeChannel do
200192
end
201193
end
202194

195+
defp configure_socket(socket, params) do
196+
case Join.validate(params) do
197+
{:ok, configuration} ->
198+
socket =
199+
socket
200+
|> assign_counter()
201+
|> assign_presence_counter()
202+
|> assign(:private?, Join.private?(configuration))
203+
|> assign(:policies, nil)
204+
205+
{:ok, socket, configuration}
206+
207+
{:error, :invalid_join_payload, errors} ->
208+
{:error, :invalid_join_payload, errors, socket}
209+
end
210+
end
211+
203212
@impl true
204213
def handle_info(:update_rate_counter, %{assigns: %{limits: %{max_events_per_second: max}}} = socket) do
205214
count(socket)
@@ -537,40 +546,24 @@ defmodule RealtimeWeb.RealtimeChannel do
537546

538547
defp count(%{assigns: %{rate_counter: counter}}), do: GenCounter.add(counter.id)
539548

540-
defp presence_key(params) do
541-
case params["config"]["presence"]["key"] do
542-
key when is_binary(key) and key != "" -> key
543-
_ -> UUID.uuid1()
544-
end
545-
end
546-
547-
defp assign_access_token(%{assigns: %{headers: headers}} = socket, params) do
548-
access_token = Map.get(params, "access_token") || Map.get(params, "user_token")
549+
defp assign_access_token(socket, params) do
550+
%{assigns: %{tenant_token: tenant_token, headers: headers}} = socket
549551
{_, header} = Enum.find(headers, {nil, nil}, fn {k, _} -> k == "x-api-key" end)
550552

551-
case access_token do
552-
nil -> assign(socket, :access_token, header)
553-
"sb_" <> _ -> assign(socket, :access_token, header)
554-
_ -> handle_access_token(socket, params)
555-
end
556-
end
557-
558-
defp assign_access_token(socket, params), do: handle_access_token(socket, params)
553+
access_token = Map.get(params, "access_token")
554+
user_token = Map.get(params, "user_token")
559555

560-
defp handle_access_token(%{assigns: %{tenant_token: _tenant_token}} = socket, %{"user_token" => user_token})
561-
when is_binary(user_token) do
562-
assign(socket, :access_token, user_token)
563-
end
556+
access_token =
557+
cond do
558+
is_binary(access_token) and !String.starts_with?(access_token, "sb_") -> access_token
559+
is_binary(user_token) and !String.starts_with?(user_token, "sb_") -> user_token
560+
is_binary(tenant_token) and !String.starts_with?(tenant_token, "sb_") -> tenant_token
561+
true -> header
562+
end
564563

565-
defp handle_access_token(%{assigns: %{tenant_token: _tenant_token}} = socket, %{"access_token" => access_token})
566-
when is_binary(access_token) do
567564
assign(socket, :access_token, access_token)
568565
end
569566

570-
defp handle_access_token(%{assigns: %{tenant_token: tenant_token}} = socket, _params) when is_binary(tenant_token) do
571-
assign(socket, :access_token, tenant_token)
572-
end
573-
574567
defp confirm_token(%{assigns: assigns}) do
575568
%{jwt_secret: jwt_secret, access_token: access_token} = assigns
576569

@@ -637,28 +630,30 @@ defmodule RealtimeWeb.RealtimeChannel do
637630
})
638631
end
639632

640-
defp new_api?(%{"config" => _}), do: true
633+
defp new_api?(%Join{config: config}) when not is_nil(config), do: true
641634
defp new_api?(_), do: false
642635

643-
defp pg_change_params(true, params, channel_pid, claims, _) do
644-
case get_in(params, ["config", "postgres_changes"]) do
645-
[_ | _] = params_list ->
646-
params_list
647-
|> Enum.reject(&is_nil/1)
648-
|> Enum.map(fn params ->
649-
%{
650-
id: UUID.uuid1(),
651-
channel_pid: channel_pid,
652-
claims: claims,
653-
params: params
654-
}
655-
end)
656-
657-
_ ->
658-
[]
659-
end
636+
defp pg_change_params(true, %Join{config: %Config{postgres_changes: postgres_changes}}, channel_pid, claims, _)
637+
when not is_nil(postgres_changes) do
638+
postgres_changes
639+
|> Enum.reject(&is_nil/1)
640+
|> Enum.map(fn %PostgresChange{table: table, event: event, schema: schema, filter: filter} ->
641+
params =
642+
%{"table" => table, "filter" => filter, "schema" => schema, "event" => event}
643+
|> Enum.reject(fn {_, v} -> is_nil(v) end)
644+
|> Map.new()
645+
646+
%{
647+
id: UUID.uuid1(),
648+
channel_pid: channel_pid,
649+
claims: claims,
650+
params: params
651+
}
652+
end)
660653
end
661654

655+
defp pg_change_params(true, _, _, _, _), do: []
656+
662657
defp pg_change_params(false, _, channel_pid, claims, sub_topic) do
663658
params =
664659
case String.split(sub_topic, ":", parts: 3) do

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ defmodule Realtime.MixProject do
44
def project do
55
[
66
app: :realtime,
7-
version: "2.43.2",
7+
version: "2.43.3",
88
elixir: "~> 1.17.3",
99
elixirc_paths: elixirc_paths(Mix.env()),
1010
start_permanent: Mix.env() == :prod,

test/integration/rt_channel_test.exs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -420,28 +420,18 @@ defmodule Realtime.Integration.RtChannelTest do
420420
500
421421
end
422422

423-
test "handle nil postgres changes params as empty param changes", %{tenant: tenant} do
423+
test "nil postgres changes params identified as error", %{tenant: tenant} do
424424
{socket, _} = get_connection(tenant)
425425
topic = "realtime:any"
426426
config = %{postgres_changes: [nil]}
427427

428-
WebsocketClient.join(socket, topic, %{config: config})
429-
430-
assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^topic}, 200
431-
assert_receive %Phoenix.Socket.Message{event: "presence_state", payload: %{}, topic: ^topic}, 500
428+
log =
429+
capture_log(fn ->
430+
WebsocketClient.join(socket, topic, %{config: config})
431+
Process.sleep(500)
432+
end)
432433

433-
refute_receive %Message{
434-
event: "system",
435-
payload: %{
436-
"channel" => "any",
437-
"extension" => "postgres_changes",
438-
"message" => "Subscribed to PostgreSQL",
439-
"status" => "ok"
440-
},
441-
ref: nil,
442-
topic: ^topic
443-
},
444-
1000
434+
assert log =~ "InvalidJoinPayload"
445435
end
446436
end
447437

test/realtime_web/channels/realtime_channel_test.exs

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,9 @@ defmodule RealtimeWeb.RealtimeChannelTest do
503503

504504
test "expired jwt returns a error with sub data if available log_level=warning", %{tenant: tenant} do
505505
sub = random_string()
506-
507506
api_key = Generators.generate_jwt_token(tenant)
508-
509-
jwt =
510-
Generators.generate_jwt_token(tenant, %{role: "authenticated", exp: System.system_time(:second) - 1, sub: sub})
511-
507+
claims = %{role: "authenticated", exp: System.system_time(:second) - 1, sub: sub}
508+
jwt = Generators.generate_jwt_token(tenant, claims)
512509
assert {:ok, socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, api_key))
513510

514511
log =
@@ -675,6 +672,44 @@ defmodule RealtimeWeb.RealtimeChannelTest do
675672
end
676673
end
677674

675+
describe "join payload validations" do
676+
test "valid payload allows join", %{tenant: tenant} do
677+
jwt = Generators.generate_jwt_token(tenant)
678+
{:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt))
679+
680+
config = %{
681+
"config" => %{
682+
"private" => false,
683+
"broadcast" => %{"ack" => false, "self" => false},
684+
"presence" => %{"enabled" => true, "key" => "potato"},
685+
"postgres_changes" => [
686+
%{"event" => "INSERT", "schema" => "public", "table" => "users", "filter" => "id=eq.1"},
687+
%{"event" => "DELETE", "schema" => "public", "table" => "users", "filter" => "id=eq.2"},
688+
%{"event" => "UPDATE", "schema" => "public", "table" => "users", "filter" => "id=eq.3"}
689+
]
690+
},
691+
"access_token" => jwt
692+
}
693+
694+
assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", config)
695+
end
696+
697+
test "invalid payload returns error", %{tenant: tenant} do
698+
jwt = Generators.generate_jwt_token(tenant)
699+
{:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt))
700+
701+
log =
702+
capture_log(fn ->
703+
assert {:error, %{reason: reason}} =
704+
subscribe_and_join(socket, "realtime:test", %{"config" => "potato"})
705+
706+
assert reason =~ "unable to parse, expected a map"
707+
end)
708+
709+
assert log =~ "InvalidJoinPayload"
710+
end
711+
end
712+
678713
test "registers transport pid and channel pid per tenant", %{tenant: tenant} do
679714
jwt = Generators.generate_jwt_token(tenant)
680715
{:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt))

0 commit comments

Comments
 (0)