Skip to content

Commit

Permalink
Re-evaluate stream permissions after secret update
Browse files Browse the repository at this point in the history
Re-evaluate permissions, cancel publishers and
subscriptions, send metadata update accordingly.

Move record definitions from the stream reader
to a dedicated header file to be able to write
unit tests.

Fixes #10292
  • Loading branch information
acogoluegnes committed Jan 17, 2024
1 parent 251497f commit add0eaa
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 114 deletions.
2 changes: 1 addition & 1 deletion deps/rabbit/src/rabbit_access_control.erl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ update_state(User = #user{authz_backends = Backends0}, NewState) ->
permission_cache_can_expire(#user{authz_backends = Backends}) ->
lists:any(fun ({Module, _State}) -> Module:state_can_expire() end, Backends).

-spec expiry_timestamp(User :: rabbit_types:user()) -> integer | never.
-spec expiry_timestamp(User :: rabbit_types:user()) -> integer() | never.
expiry_timestamp(User = #user{authz_backends = Modules}) ->
lists:foldl(fun({Module, Impl}, Ts0) ->
case Module:expiry_timestamp(auth_user(User, Impl)) of
Expand Down
7 changes: 7 additions & 0 deletions deps/rabbitmq_stream/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ rabbitmq_integration_suite(
name = "rabbit_stream_manager_SUITE",
)

rabbitmq_integration_suite(
name = "rabbit_stream_reader_SUITE",
deps = [
"//deps/rabbitmq_stream_common:erlang_app",
],
)

rabbitmq_integration_suite(
name = "rabbit_stream_SUITE",
shard_count = 3,
Expand Down
11 changes: 11 additions & 0 deletions deps/rabbitmq_stream/app.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def all_srcs(name = "all_srcs"):
)
filegroup(
name = "private_hdrs",
srcs = ["src/rabbit_stream_reader.hrl"],
)
filegroup(
name = "srcs",
Expand Down Expand Up @@ -175,3 +176,13 @@ def test_suite_beam_files(name = "test_suite_beam_files"):
erlc_opts = "//:test_erlc_opts",
deps = ["//deps/rabbit_common:erlang_app"],
)
erlang_bytecode(
name = "rabbit_stream_reader_SUITE_beam_files",
testonly = True,
srcs = ["test/rabbit_stream_reader_SUITE.erl"],
outs = ["test/rabbit_stream_reader_SUITE.beam"],
hdrs = ["src/rabbit_stream_reader.hrl"],
app_name = "rabbitmq_stream",
erlc_opts = "//:test_erlc_opts",
deps = ["//deps/rabbit_common:erlang_app", "//deps/rabbitmq_stream_common:erlang_app"],
)
227 changes: 118 additions & 109 deletions deps/rabbitmq_stream/src/rabbit_stream_reader.erl
Original file line number Diff line number Diff line change
Expand Up @@ -9,106 +9,21 @@
%% The Original Code is RabbitMQ.
%%
%% The Initial Developer of the Original Code is Pivotal Software, Inc.
%% Copyright (c) 2020-2023 VMware, Inc. or its affiliates. All rights reserved.
%% Copyright (c) 2020-2024 Broadcom. All Rights Reserved.
%% The term “Broadcom” refers to Broadcom Inc. and/or its subsidiaries. All rights reserved.
%%

-module(rabbit_stream_reader).

-behaviour(gen_statem).
-feature(maybe_expr, enable).

-include_lib("rabbit_common/include/rabbit.hrl").
-include_lib("rabbitmq_stream_common/include/rabbit_stream.hrl").
-behaviour(gen_statem).

-include("rabbit_stream_reader.hrl").
-include("rabbit_stream_metrics.hrl").

-type stream() :: binary().
-type publisher_id() :: byte().
-type publisher_reference() :: binary().
-type subscription_id() :: byte().
-type internal_id() :: integer().

-record(publisher,
{publisher_id :: publisher_id(),
stream :: stream(),
reference :: undefined | publisher_reference(),
leader :: pid(),
message_counters :: atomics:atomics_ref(),
%% use to distinguish a stale publisher from a live publisher with the same ID
%% used only for publishers without a reference (dedup off)
internal_id :: internal_id()}).
-record(consumer_configuration,
{socket :: rabbit_net:socket(), %% ranch_transport:socket(),
member_pid :: pid(),
subscription_id :: subscription_id(),
stream :: stream(),
offset :: osiris:offset(),
counters :: atomics:atomics_ref(),
properties :: map(),
active :: boolean()}).
-record(consumer,
{configuration :: #consumer_configuration{},
credit :: non_neg_integer(),
send_limit :: non_neg_integer(),
log :: undefined | osiris_log:state(),
last_listener_offset = undefined :: undefined | osiris:offset()}).
-record(request,
{start :: integer(),
content :: term()}).
-record(stream_connection_state,
{data :: rabbit_stream_core:state(), blocked :: boolean(),
consumers :: #{subscription_id() => #consumer{}}}).
-record(stream_connection,
{name :: binary(),
%% server host
host,
%% client host
peer_host,
%% server port
port,
%% client port
peer_port,
auth_mechanism,
authentication_state :: any(),
connected_at :: integer(),
helper_sup :: pid(),
socket :: rabbit_net:socket(),
publishers ::
#{publisher_id() =>
#publisher{}}, %% FIXME replace with a list (0-255 lookup faster?)
publisher_to_ids ::
#{{stream(), publisher_reference()} => publisher_id()},
stream_leaders :: #{stream() => pid()},
stream_subscriptions :: #{stream() => [subscription_id()]},
credits :: atomics:atomics_ref(),
user :: undefined | #user{},
virtual_host :: undefined | binary(),
connection_step ::
atom(), % tcp_connected, peer_properties_exchanged, authenticating, authenticated, tuning, tuned, opened, failure, closing, closing_done
frame_max :: integer(),
heartbeat :: undefined | integer(),
heartbeater :: any(),
client_properties = #{} :: #{binary() => binary()},
monitors = #{} :: #{reference() => {pid(), stream()}},
stats_timer :: undefined | rabbit_event:state(),
resource_alarm :: boolean(),
send_file_oct ::
atomics:atomics_ref(), % number of bytes sent with send_file (for metrics)
transport :: tcp | ssl,
proxy_socket :: undefined | ranch_transport:socket(),
correlation_id_sequence :: integer(),
outstanding_requests :: #{integer() => #request{}},
deliver_version :: rabbit_stream_core:command_version(),
request_timeout :: pos_integer(),
outstanding_requests_timer :: undefined | erlang:reference(),
filtering_supported :: boolean(),
%% internal sequence used for publishers
internal_sequence = 0 :: integer()}).
-record(configuration,
{initial_credits :: integer(),
credits_required_for_unblocking :: integer(),
frame_max :: integer(),
heartbeat :: integer(),
connection_negotiation_step_timeout :: integer()}).
-include_lib("rabbitmq_stream_common/include/rabbit_stream.hrl").

-record(statem_data,
{transport :: module(),
connection :: #stream_connection{},
Expand Down Expand Up @@ -184,6 +99,9 @@
tuned/3,
open/3,
close_sent/3]).
%% for unit test
-export([ensure_token_expiry_timer/2,
evaluate_state_after_secret_update/4]).

callback_mode() ->
[state_functions, state_enter].
Expand Down Expand Up @@ -999,6 +917,11 @@ open(info, check_outstanding_requests,
),
{keep_state, StatemData#statem_data{connection = Connection1}}
end;
open(info, token_expired, #statem_data{connection = Connection}) ->
_ = demonitor_all_streams(Connection),
rabbit_log_connection:info("Forcing stream connection ~tp closing because token expired",
[self()]),
{stop, {shutdown, <<"Token expired">>}};
open(info, {shutdown, Explanation} = Reason,
#statem_data{connection = Connection}) ->
%% rabbitmq_management or rabbitmq_stream_management plugin
Expand Down Expand Up @@ -1573,8 +1496,11 @@ handle_frame_pre_auth(Transport,

send(Transport, S, Frame),
%% FIXME check if vhost is alive (see rabbit_reader:is_vhost_alive/2)
Connection#stream_connection{connection_step = opened,
virtual_host = VirtualHost}

{_, Conn} = ensure_token_expiry_timer(User,
Connection#stream_connection{connection_step = opened,
virtual_host = VirtualHost}),
Conn
catch
exit:_ ->
F = rabbit_stream_core:frame({response, CorrelationId,
Expand Down Expand Up @@ -1648,18 +1574,17 @@ handle_frame_post_auth(Transport,

handle_frame_post_auth(Transport,
#stream_connection{user = #user{username = Username} = _User,
socket = S,
socket = Socket,
host = Host,
auth_mechanism = Auth_Mechanism,
authentication_state = AuthState,
resource_alarm = false} =
C1,
State,
resource_alarm = false} = C1,
S1,
{request, CorrelationId,
{sasl_authenticate, NewMechanism, NewSaslBin}}) ->
rabbit_log:debug("Open frame received sasl_authenticate for username '~ts'", [Username]),

Connection1 =
{Connection1, State1} =
case Auth_Mechanism of
{NewMechanism, AuthMechanism} -> %% Mechanism is the same used during the pre-auth phase
{C2, CmdBody} =
Expand All @@ -1668,9 +1593,10 @@ handle_frame_post_auth(Transport,
rabbit_core_metrics:auth_attempt_failed(Host,
NewUsername,
stream),
auth_fail(NewUsername, Msg, Args, C1, State),
auth_fail(NewUsername, Msg, Args, C1, S1),
rabbit_log_connection:warning(Msg, Args),
{C1#stream_connection{connection_step = failure},
S1,
{sasl_authenticate,
?RESPONSE_AUTHENTICATION_FAILURE, <<>>}};
{protocol_error, Msg, Args} ->
Expand All @@ -1683,7 +1609,7 @@ handle_frame_post_auth(Transport,
rabbit_misc:format(Msg,
Args)}],
C1,
State),
S1),
rabbit_log_connection:warning(Msg, Args),
{C1#stream_connection{connection_step = failure},
{sasl_authenticate, ?RESPONSE_SASL_ERROR, <<>>}};
Expand All @@ -1702,7 +1628,7 @@ handle_frame_post_auth(Transport,
user_authentication_success,
[],
C1,
State),
S1),
rabbit_log:debug("Successfully updated secret for username '~ts'", [Username]),
{C1#stream_connection{user = NewUser,
authentication_state = done,
Expand All @@ -1725,20 +1651,24 @@ handle_frame_post_auth(Transport,
Frame =
rabbit_stream_core:frame({response, CorrelationId,
CmdBody}),
send(Transport, S, Frame),
C2;
send(Transport, Socket, Frame),
case CmdBody of
{sasl_authenticate, ?RESPONSE_CODE_OK, _} ->
#stream_connection{user = NewUsr} = C2,
evaluate_state_after_secret_update(Transport, NewUsr, C2, S1);
_ ->
{C2, S1}
end;
{OtherMechanism, _} ->
rabbit_log_connection:warning("User '~ts' cannot change initial auth mechanism '~ts' for '~ts'",
[Username, NewMechanism, OtherMechanism]),
CmdBody =
{sasl_authenticate, ?RESPONSE_SASL_CANNOT_CHANGE_MECHANISM, <<>>},
Frame = rabbit_stream_core:frame({response, CorrelationId, CmdBody}),
send(Transport, S, Frame),
C1#stream_connection{connection_step = failure}
send(Transport, Socket, Frame),
{C1#stream_connection{connection_step = failure}, S1}
end,

{Connection1, State};

{Connection1, State1};
handle_frame_post_auth(Transport,
#stream_connection{user = User,
publishers = Publishers0,
Expand Down Expand Up @@ -3244,6 +3174,57 @@ request(Content) ->
#request{start = erlang:monotonic_time(millisecond),
content = Content}.

evaluate_state_after_secret_update(Transport,
User,
#stream_connection{socket = Socket,
publishers = Publishers,
stream_subscriptions = Subscriptions} = Conn0,
State0) ->
{_, Conn1} = ensure_token_expiry_timer(User, Conn0),
rabbit_stream_utils:clear_permission_cache(),
PublisherStreams =
lists:foldl(fun(#publisher{stream = Str}, Acc) ->
case rabbit_stream_utils:check_write_permitted(stream_r(Str, Conn0), User) of
ok ->
Acc;
_ ->
Acc#{Str => ok}
end
end, #{}, maps:values(Publishers)),
{SubscriptionStreams, Conn2, State1} =
maps:fold(fun(Str, Subs, {Acc, C0, S0}) ->
case rabbit_stream_utils:check_read_permitted(stream_r(Str, Conn0), User, #{}) of
ok ->
{Acc, C0, S0};
_ ->
{C1, S1} =
lists:foldl(fun(SubId, {Conn, St}) ->
remove_subscription(SubId, Conn, St)
end, {C0, S0}, Subs),
{Acc#{Str => ok}, C1, S1}
end
end, {#{}, Conn1, State0}, Subscriptions),
Streams = maps:merge(PublisherStreams, SubscriptionStreams),
{Conn3, State2} =
case maps:size(Streams) of
0 ->
{Conn2, State1};
_ ->
maps:fold(fun(Str, _, {C0, S0}) ->
{_, C1, S1} = clean_state_after_stream_deletion_or_failure(
undefined, Str, C0, S0),
Command = {metadata_update, Str,
?RESPONSE_CODE_STREAM_NOT_AVAILABLE},
Frame = rabbit_stream_core:frame(Command),
send(Transport, Socket, Frame),
rabbit_global_counters:increase_protocol_counter(stream,
?STREAM_NOT_AVAILABLE,
1),
{C1, S1}
end, {Conn2, State1}, Streams)
end,
{Conn3, State2}.

ensure_outstanding_requests_timer(#stream_connection{
outstanding_requests = Requests,
outstanding_requests_timer = undefined
Expand All @@ -3265,6 +3246,34 @@ ensure_outstanding_requests_timer(#stream_connection{
ensure_outstanding_requests_timer(C) ->
C.

ensure_token_expiry_timer(User, #stream_connection{token_expiry_timer = Timer} = Conn) ->
TimerRef =
maybe
rabbit_log:debug("Checking token expiry"),
true ?= rabbit_access_control:permission_cache_can_expire(User),
rabbit_log:debug("Token can expire"),
Ts = rabbit_access_control:expiry_timestamp(User),
rabbit_log:debug("Token expiry timestamp: ~tp", [Ts]),
true ?= is_integer(Ts),
Time = (Ts - os:system_time(second)) * 1000,
rabbit_log:debug("Token expires in ~tp ms, setting timer to close connection", [Time]),
true ?= Time > 0,
{ok, TRef} = timer:send_after(Time, self(), token_expired),
TRef
else
false ->
undefined;
{error, _} ->
undefined
end,
Cancel = case Timer of
undefined ->
ok;
_ ->
timer:cancel(Timer)
end,
{Cancel, Conn#stream_connection{token_expiry_timer = TimerRef}}.

maybe_unregister_consumer(_, _, false = _Sac, Requests) ->
Requests;
maybe_unregister_consumer(VirtualHost,
Expand Down
Loading

0 comments on commit add0eaa

Please # to comment.