Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: acl compatibility #3147

Merged
merged 4 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/facade/command_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <string_view>

#include "facade/facade_types.h"
#include "server/acl/acl_commands_def.h"

namespace facade {

Expand Down
1 change: 0 additions & 1 deletion src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
if (self->cntx()) {
if (msg.username == self->cntx()->authed_username) {
self->cntx()->acl_categories = msg.categories;
self->cntx()->acl_commands = msg.commands;
self->cntx()->keys = msg.keys;
}
Expand Down
1 change: 0 additions & 1 deletion src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class Connection : public util::Connection {
// ACL Update message, contains ACL updates to be applied to the connection.
struct AclUpdateMessage {
std::string username;
uint32_t categories;
std::vector<uint64_t> commands;
dfly::acl::AclKeys keys;
};
Expand Down
3 changes: 1 addition & 2 deletions src/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ cxx_test(hll_family_test dfly_test_lib LABELS DFLY)
cxx_test(bloom_family_test dfly_test_lib LABELS DFLY)
cxx_test(cluster/cluster_config_test dfly_test_lib LABELS DFLY)
cxx_test(cluster/cluster_family_test dfly_test_lib LABELS DFLY)
cxx_test(acl/user_registry_test dfly_test_lib LABELS DFLY)
cxx_test(acl/acl_family_test dfly_test_lib LABELS DFLY)
cxx_test(engine_shard_set_test dfly_test_lib LABELS DFLY)
cxx_test(search/search_family_test dfly_test_lib LABELS DFLY)
Expand All @@ -135,4 +134,4 @@ add_dependencies(check_dfly dragonfly_test json_family_test list_family_test
generic_family_test memcache_parser_test rdb_test journal_test
redis_parser_test stream_family_test string_family_test
bitops_family_test set_family_test zset_family_test hll_family_test
cluster_config_test cluster_family_test user_registry_test acl_family_test)
cluster_config_test cluster_family_test acl_family_test)
44 changes: 43 additions & 1 deletion src/server/acl/acl_commands_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#pragma once

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "base/logging.h"
#include "facade/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"

namespace dfly::acl {

Expand Down Expand Up @@ -84,6 +88,14 @@ inline const std::vector<std::string> REVERSE_CATEGORY_INDEX_TABLE{
"_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED",
"BLOOM", "FT_SEARCH", "THROTTLE", "JSON"};

// bit index to index in the REVERSE_CATEGORY_INDEX_TABLE
using CategoryToIdxStore = absl::flat_hash_map<uint32_t, uint32_t>;

inline const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) {
static CategoryToIdxStore cat_idx = std::move(store);
return cat_idx;
}

using RevCommandField = std::vector<std::string>;
using RevCommandsIndexStore = std::vector<RevCommandField>;

Expand All @@ -104,9 +116,39 @@ inline const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore sto
return rev_index_store;
}

inline void BuildIndexers(std::vector<std::vector<std::string>> families) {
using CategoryToCommandsIndexStore = absl::flat_hash_map<std::string, std::vector<uint64_t>>;

inline const CategoryToCommandsIndexStore& CategoryToCommandsIndex(
CategoryToCommandsIndexStore store = {}) {
static CategoryToCommandsIndexStore index = std::move(store);
return index;
}

inline void BuildIndexers(RevCommandsIndexStore families, CommandRegistry* cmd_registry) {
acl::NumberOfFamilies(families.size());
acl::CommandsRevIndexer(std::move(families));
CategoryToCommandsIndexStore index;
cmd_registry->Traverse([&](std::string_view name, auto& cid) {
auto cat = cid.acl_categories();
for (size_t i = 0; i < 32; ++i) {
if (cat & (1 << i)) {
std::string_view cat_name = REVERSE_CATEGORY_INDEX_TABLE[i];
if (index[cat_name].empty()) {
index[cat_name].resize(CommandsRevIndexer().size());
}
auto family = cid.GetFamily();
auto bit_index = cid.GetBitIndex();
index[cat_name][family] |= bit_index;
}
}
});

CategoryToCommandsIndex(std::move(index));
CategoryToIdxStore idx_store;
for (size_t i = 0; i < 32; ++i) {
idx_store[1 << i] = i;
}
Comment on lines +147 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could there be any reason there's not a 1-to-1 mapping? If not, we could elminiate this store allotogether

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category has large values depending on which bit is set, so for example SCRIPTING is 1<<20 which makes the 1-1 mapping impossible. The flow is that the user passes let's a +SCRIPTING, it's parsed to 1<<20. Now if we want to find the name of the category, we normalize the 1<<n into the index and then use that index in the REVERSE_CATEGORY_TABLE. Open to ideas if it can be improved 🤷

CategoryToIdx(std::move(idx_store));
}

} // namespace dfly::acl
49 changes: 26 additions & 23 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,23 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
std::string buffer = "user ";
const std::string_view pass = user.Password();
const std::string password = pass == "nopass" ? "nopass" : PrettyPrintSha(pass);
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";

const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());

const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";
const std::string maybe_space_com = acl_keys.empty() ? "" : " ";

using namespace std::string_view_literals;

absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, " ",
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys);
acl_cat_and_commands, maybe_space_com, acl_keys);

cntx->SendSimpleString(buffer);
}
}

void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys) {
auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) {
Expand All @@ -90,7 +91,7 @@ void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, u
if (connection->protocol() == facade::Protocol::REDIS && !connection->IsHttp() &&
connection->cntx()) {
connection->SendAclUpdateAsync(
facade::Connection::AclUpdateMessage{user, update_cat, update_commands, update_keys});
facade::Connection::AclUpdateMessage{user, update_commands, update_keys});
}
};

Expand All @@ -113,10 +114,14 @@ void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {

auto update_case = [username, &reg, cntx, this, exists](User::UpdateRequest&& req) {
auto& user = reg.registry[username];
if (!exists) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
user.Update(std::move(default_req));
}
user.Update(std::move(req));
if (exists) {
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCategory(),
user.AclCommands(), user.Keys());
StreamUpdatesToAllProactorConnections(std::string(username), user.AclCommands(), user.Keys());
}
cntx->SendOk();
};
Expand Down Expand Up @@ -194,20 +199,15 @@ std::string AclFamily::RegistryToString() const {
const std::string_view pass = user.Password();
const std::string password =
pass == "nopass" ? "nopass " : absl::StrCat("#", PrettyPrintSha(pass, true), " ");
const std::string acl_cat = AclCatToString(user.AclCategory());
const std::string acl_commands = AclCommandToString(user.AclCommandsRef());
const std::string maybe_space_com = acl_commands.empty() ? "" : " ";
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());
const std::string acl_keys = AclKeysToString(user.Keys());
const std::string maybe_space = acl_keys.empty() ? "" : " ";

using namespace std::string_view_literals;

absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password,
acl_cat, maybe_space_com, acl_commands, maybe_space, acl_keys, "\n");
}

if (!result.empty()) {
result.pop_back();
acl_cat_and_commands, maybe_space, acl_keys, "\n");
}

return result;
Expand Down Expand Up @@ -298,7 +298,10 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path,
}

for (size_t i = 0; i < usernames.size(); ++i) {
User::UpdateRequest default_req;
default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}};
auto& user = registry[usernames[i]];
user.Update(std::move(default_req));
user.Update(std::move(requests[i]));
}

Expand Down Expand Up @@ -446,6 +449,7 @@ void AclFamily::Cat(CmdArgList args, ConnectionContext* cntx) {

const uint32_t cid_mask = CATEGORY_INDEX_TABLE.find(category)->second;
std::vector<std::string_view> results;
// TODO replace this with indexer
auto cb = [cid_mask, &results](auto name, auto& cid) {
if (cid_mask & cid.acl_categories()) {
results.push_back(name);
Expand Down Expand Up @@ -510,10 +514,10 @@ void AclFamily::GetUser(CmdArgList args, ConnectionContext* cntx) {
}
rb->SendSimpleString("commands");

std::string acl = absl::StrCat(AclCatToString(user.AclCategory()), " ",
AclCommandToString(user.AclCommandsRef()));
const std::string acl_cat_and_commands =
AclCatAndCommandToString(user.CatChanges(), user.CmdChanges());

rb->SendSimpleString(acl);
rb->SendSimpleString(acl_cat_and_commands);

rb->SendSimpleString("keys");
std::string keys = AclKeysToString(user.Keys());
Expand Down Expand Up @@ -572,9 +576,8 @@ void AclFamily::DryRun(CmdArgList args, ConnectionContext* cntx) {
}

const auto& user = registry.find(username)->second;
const bool is_allowed = IsUserAllowedToInvokeCommandGeneric(
user.AclCategory(), user.AclCommandsRef(), {{}, true}, {}, *cid)
.first;
const bool is_allowed =
IsUserAllowedToInvokeCommandGeneric(user.AclCommandsRef(), {{}, true}, {}, *cid).first;
if (is_allowed) {
cntx->SendOk();
return;
Expand Down
2 changes: 1 addition & 1 deletion src/server/acl/acl_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AclFamily final {
// Helper function that updates all open connections and their
// respective ACL fields on all the available proactor threads
using Commands = std::vector<uint64_t>;
void StreamUpdatesToAllProactorConnections(const std::string& user, uint32_t update_cat,
void StreamUpdatesToAllProactorConnections(const std::string& user,
const Commands& update_commands,
const AclKeys& update_keys);

Expand Down
38 changes: 19 additions & 19 deletions src/server/acl/acl_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ TEST_F(AclFamilyTest, AclSetUser) {
EXPECT_THAT(resp, "OK");
resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user vlad off nopass +@NONE"));
EXPECT_THAT(
vec, UnorderedElementsAre("user default on nopass +@ALL ~*", "user vlad off nopass -@ALL"));

resp = Run({"ACL", "SETUSER", "vlad", "+ACL"});
EXPECT_THAT(resp, "OK");

resp = Run({"ACL", "LIST"});
vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user vlad off nopass +@NONE +ACL"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*",
"user vlad off nopass -@ALL +ACL"));
}

TEST_F(AclFamilyTest, AclDelUser) {
Expand All @@ -82,7 +82,7 @@ TEST_F(AclFamilyTest, AclDelUser) {
EXPECT_THAT(resp, IntArg(0));

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL +ALL ~*");
EXPECT_THAT(resp.GetString(), "user default on nopass +@ALL ~*");

Run({"ACL", "SETUSER", "michael", "ON"});
Run({"ACL", "SETUSER", "kobe", "ON"});
Expand All @@ -103,9 +103,9 @@ TEST_F(AclFamilyTest, AclList) {

resp = Run({"ACL", "LIST"});
auto vec = resp.GetVec();
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
"user kostas off d74ff0ee8da3b98 +@ADMIN",
"user adi off d74ff0ee8da3b98 +@FAST"));
EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass +@ALL ~*",
"user kostas off d74ff0ee8da3b98 -@ALL +@ADMIN",
"user adi off d74ff0ee8da3b98 -@ALL +@FAST"));
}

TEST_F(AclFamilyTest, AclAuth) {
Expand Down Expand Up @@ -154,16 +154,16 @@ TEST_F(AclFamilyTest, TestAllCategories) {

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@", cat)));
UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ", "+@", cat)));

resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)});
EXPECT_THAT(resp, "OK");

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@NONE")));
UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ", "-@", cat)));

resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
Expand Down Expand Up @@ -201,16 +201,16 @@ TEST_F(AclFamilyTest, TestAllCommands) {
EXPECT_THAT(resp, "OK");

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass +@NONE ",
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass -@ALL ",
"+", command_name)));

resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)});

resp = Run({"ACL", "LIST"});
EXPECT_THAT(resp.GetVec(),
UnorderedElementsAre("user default on nopass +@ALL +ALL ~*",
absl::StrCat("user kostas off nopass ", "+@NONE")));
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL ~*",
absl::StrCat("user kostas off nopass ",
"-@ALL ", "-", command_name)));

resp = Run({"ACL", "DELUSER", "kostas"});
EXPECT_THAT(resp, IntArg(1));
Expand Down Expand Up @@ -259,7 +259,7 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(vec[2], "passwords");
EXPECT_TRUE(vec[3].GetVec().empty());
EXPECT_THAT(vec[4], "commands");
EXPECT_THAT(vec[5], "+@ALL +ALL");
EXPECT_THAT(vec[5], "+@ALL");
EXPECT_THAT(vec[6], "keys");
EXPECT_THAT(vec[7], "~*");

Expand All @@ -271,7 +271,7 @@ TEST_F(AclFamilyTest, TestGetUser) {
EXPECT_THAT(kvec[2], "passwords");
EXPECT_TRUE(kvec[3].GetVec().empty());
EXPECT_THAT(kvec[4], "commands");
EXPECT_THAT(kvec[5], "+@STRING +HSET");
EXPECT_THAT(kvec[5], "-@ALL +@STRING +HSET");
}

TEST_F(AclFamilyTest, TestDryRun) {
Expand Down
Loading
Loading